Skip to content

๐Ÿฆ’ Improving Text Embeddings with LLMs

Open In Colab Open Source in Github

In this tutorial, we will be replicating the process described in "Improving Text Embeddings with Large Language Models" by Liang Wang et al. for synthetically generating multilingual data to be used for training a sentence similarity model.

Installation

We will start off by installing distilabel with the openai extra, as the authors used both GPT-4 and GPT-3.5 to generate the synthetic data, so no other LLMs were used.

Note that we upgrade typing_extensions first, since openai may have conflicts with the default installed version of typing_extensions if outdated.

%pip install --upgrade typing_extensions --quiet
%pip install "distilabel[openai]" --quiet

Introduction

In "Improving Text Embeddings with Large Language Models" the authors leverage OpenAI proprietary LLMs as GPT-4 and GPT-3.5 to generate synthetic data for a wide range of diverse text embedding tasks, achieving competitive performance without using any labeled data. While when fine-tuning with a mixture of sythetic data and data from MS-Marco, their model sets SOTA results on BEIR and MTEB benchmarks.

So on, the authors divide the generation process in two steps/phases: * Synthetic generation of the task definition/name, following a certain criteria. * Synthetic generation of the data for the task (to be used for fine-tuning) using the task definition and some other sampling params.

So that after those phases, the authors end up with a dataset that is suitable for model fine-tuning.

Phase 1: Generating task definitions

We will start off with the first phase, which implies generating synthetic task definitions for asymmetric tasks, in this case, we will focus only on the text classification task pool, which follows the following format:

Brainstorm a list of potentially useful text classification tasks.

Please adhere to the following guidelines:
- Tasks should cover a diverse range of domains and task types.

Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
from distilabel.llm import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.tasks import TextGenerationTask

Initially, we will need to define a custom Task that removes the default system_prompt and that parses the output using eval, as the prompt is asking the LLM to generate it using Python list formatting.

from typing import Dict, List
from dataclasses import dataclass

@dataclass
class TaskGenerationTask(TextGenerationTask):
    system_prompt: str = ""

    def parse_output(self, output: str) -> Dict[str, List[str]]:
        return {"generations": eval(output)}

Once the default Task is created (TaskGenerationTask) we can already initialize the LLM, in this case OpenAILLM using GPT-4, and provide the recently defined task as an argument to it. Additionally, we will also include some generation kwargs, temperature and top_p, defined within the Appendix C of the paper, to encourage more diversity within the generation.

llm = OpenAILLM(
    model="gpt-4",
    api_key="sk-***",
    task=TaskGenerationTask(),
    prompt_format="openai",
    max_new_tokens=1024,
    # Using the following kwargs as stated in Appendix C of the paper
    temperature=1.0,
    top_p=1.0,
)
pipeline = Pipeline(generator=llm)

Before calling the Pipeline, we'll need to prepare the input data, which in this case is only a prompt with no formatting required, as it's a simple TextGenerationTask we want to call multiple times.

from datasets import Dataset

prompt = """Brainstorm a list of potentially useful text classification tasks.
Please adhere to the following guidelines:
- Tasks should cover a diverse range of domains and task types.
Your output must always be a Python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
"""

dataset = Dataset.from_dict({"input": [prompt]})

Then, we're ready to call the Pipeline.generate method so that the prompt is sent to GPT-4 and N task definitions are generated synthetically.

In this case, N should be equal or close to num_generations x 20, since within the prompt we ask the LLM to generate a Python list of about 20 elements.

new_dataset = pipeline.generate(dataset, num_generations=5, skip_dry_run=True)
Output()
INFO:distilabel:Processing batch 1 of 1...
INFO:distilabel:Calling generator for batch 1...



Flattening the indices:   0%|          | 0/1 [00:00<?, ? examples/s]
Saving the dataset (0/1 shards):   0%|          | 0/1 [00:00<?, ? examples/s]
INFO:distilabel:Final dataset saved at /content/ckpt


Finally, once the generation has been completed, we will apply some post processing before proceeding to the next phase. The post-processing is to remove the columns that are not required, and to explore the columns with the tasks, so that we unwrap those lists and end up with a dataset with N rows, where N is the total number of tasks, initially contained within one row in nested lists.

df_dataset = new_dataset.to_pandas()
df_dataset = df_dataset.drop(["generation_prompt", "raw_generation_responses"], axis=1)
df_dataset = df_dataset.explode(["generation_model", "generations"])
df_dataset = df_dataset.explode(["generations"])
df_dataset = df_dataset.reset_index(drop=True)
new_dataset = Dataset.from_pandas(df_dataset)
new_dataset = new_dataset.rename_columns({"generation_model": "model", "generations": "task"})
new_dataset.push_to_hub("alvarobartt/improving-text-embeddings-with-llms", config_name="task-generation")
Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]
README.md:   0%|          | 0.00/1.21k [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/datasets/alvarobartt/improving-text-embeddings-with-llms/commit/e762103a4eaa250749900030463907cd36b773d3', commit_message='Upload dataset', commit_description='', oid='e762103a4eaa250749900030463907cd36b773d3', pr_url=None, pr_revision=None, pr_num=None)

Phase 2: Generating data for each task

Once all the task definitions have been generated, we can proceed to the next phase, which consists on generating the data for a given task.

In this case, the LLM will need to generate data that suits each task along with a label for that entry, and a misleading label for it too. Besides that, we will also sample some of the arguments within each prompt, so that the generation is ensured to be diverse.

The prompt to be used is the following:

You have been assigned a text classification task: {task}

Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
- "input_text": a string, the input text specified by the classification task.
- "label": a string, the correct label of the input text.
- "misleading_label": a string, an incorrect label that is related to the task.

Please adhere to the following guidelines:
- The "input_text" should be {num_words} words and diverse in expression.
- The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the
"input_text".
- The values for all fields should be in {language}.
- Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make
the task too easy.
- The "input_text" is {clarity} and requires {difficulty} level education to comprehend.

Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!

And the possible values for each arg:

  • task is the task definition generated in the previous phase.
  • language is any language name within XLM-R.
  • num_words is the number of words that the input_text to generate should contain at most.
  • difficulty is how difficult or which is the level required to comprehend the input_text to generate.
  • clarity is how easy or hard is the input_text to understand.
num_words = ["less than 10", "at least 10", "at least 50", "at least 100", "at least 200"]
difficulty = ["high", "school", "college", "PhD"]
clarity = ["clear", "understandable with some effort", "ambiguous"]
prompt = """You have been assigned a text classification task: {task}
Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
- "input_text": a string, the input text specified by the classification task.
- "label": a string, the correct label of the input text.
- "misleading_label": a string, an incorrect label that is related to the task.
Please adhere to the following guidelines:
- The "input_text" should be {num_words} words and diverse in expression.
- The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the
"input_text".
- The values for all fields should be in {language}.
- Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make
the task too easy.
- The "input_text" is {clarity} and requires {difficulty} level education to comprehend.
Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
"""

Just as before, we now need to create a custom Task to not only parse the output via parse_output, which in this case we are conducting the LLM to generate a valid JSON already, but also to generate the prompt via generate_prompt, since we need to introduce the sampling mentioned in the paper there.

import json
from random import choice
from typing import Any

from distilabel.tasks.prompt import Prompt

@dataclass
class ExampleGenerationTask(TextGenerationTask):
    system_prompt: str = ""

    @property
    def input_args_names(self) -&gt; List[str]:
        return ["task"]

    def generate_prompt(self, task: str) -&gt; Prompt:
        return Prompt(
            system_prompt=self.system_prompt,
            formatted_prompt=prompt.format(task=task, language="english", num_words=choice(num_words), difficulty=choice(difficulty), clarity=choice(clarity)),
        )

    @property
    def output_args_names(self) -&gt; List[str]:
        return ["input_text", "label", "misleading_label"]

    def parse_output(self, output: str) -&gt; Dict[str, Any]:
        return json.loads(output)

Other than that, we are all set to instantiate the OpenAILLM with the recently created task, and call the Pipeline.generate method with the previously generated datasets.Dataset.

llm = OpenAILLM(
    model="gpt-4",
    api_key="sk-***",
    task=ExampleGenerationTask(),
    prompt_format="openai",
    max_new_tokens=1024,
    # Using the following kwargs as stated in Appendix C of the paper
    temperature=1.0,
    top_p=1.0,
)
pipeline = Pipeline(generator=llm)
final_dataset = pipeline.generate(new_dataset, num_generations=1, skip_dry_run=True)
INFO:distilabel:Processing batch 1 of 100...
INFO:distilabel:Calling generator for batch 1...
INFO:distilabel:Processing batch 2 of 100...
INFO:distilabel:Calling generator for batch 2...
INFO:distilabel:Processing batch 3 of 100...
INFO:distilabel:Calling generator for batch 3...
INFO:distilabel:Processing batch 4 of 100...
INFO:distilabel:Calling generator for batch 4...
INFO:distilabel:Processing batch 5 of 100...
INFO:distilabel:Calling generator for batch 5...
INFO:distilabel:Processing batch 6 of 100...
INFO:distilabel:Calling generator for batch 6...
INFO:distilabel:Processing batch 7 of 100...
INFO:distilabel:Calling generator for batch 7...
INFO:distilabel:Processing batch 8 of 100...
INFO:distilabel:Calling generator for batch 8...
INFO:distilabel:Processing batch 9 of 100...
INFO:distilabel:Calling generator for batch 9...
INFO:distilabel:Processing batch 10 of 100...
INFO:distilabel:Calling generator for batch 10...
INFO:distilabel:Processing batch 11 of 100...
INFO:distilabel:Calling generator for batch 11...
INFO:distilabel:Processing batch 12 of 100...
INFO:distilabel:Calling generator for batch 12...
INFO:distilabel:Processing batch 13 of 100...
INFO:distilabel:Calling generator for batch 13...
INFO:distilabel:Processing batch 14 of 100...
INFO:distilabel:Calling generator for batch 14...
INFO:distilabel:Processing batch 15 of 100...
INFO:distilabel:Calling generator for batch 15...
INFO:distilabel:Processing batch 16 of 100...
INFO:distilabel:Calling generator for batch 16...
INFO:distilabel:Processing batch 17 of 100...
INFO:distilabel:Calling generator for batch 17...
INFO:distilabel:Processing batch 18 of 100...
INFO:distilabel:Calling generator for batch 18...
INFO:distilabel:Processing batch 19 of 100...
INFO:distilabel:Calling generator for batch 19...
INFO:distilabel:Processing batch 20 of 100...
INFO:distilabel:Calling generator for batch 20...
INFO:distilabel:Processing batch 22 of 100...
INFO:distilabel:Calling generator for batch 22...
INFO:distilabel:Processing batch 23 of 100...
INFO:distilabel:Calling generator for batch 23...
INFO:distilabel:Processing batch 24 of 100...
INFO:distilabel:Calling generator for batch 24...
INFO:distilabel:Processing batch 25 of 100...
INFO:distilabel:Calling generator for batch 25...
INFO:distilabel:Processing batch 26 of 100...
INFO:distilabel:Calling generator for batch 26...
INFO:distilabel:Processing batch 27 of 100...
INFO:distilabel:Calling generator for batch 27...
INFO:distilabel:Processing batch 28 of 100...
INFO:distilabel:Calling generator for batch 28...
INFO:distilabel:Processing batch 29 of 100...
INFO:distilabel:Calling generator for batch 29...
INFO:distilabel:Processing batch 30 of 100...
INFO:distilabel:Calling generator for batch 30...
INFO:distilabel:Processing batch 31 of 100...
INFO:distilabel:Calling generator for batch 31...
INFO:distilabel:Processing batch 32 of 100...
INFO:distilabel:Calling generator for batch 32...
INFO:distilabel:Processing batch 33 of 100...
INFO:distilabel:Calling generator for batch 33...
INFO:distilabel:Processing batch 34 of 100...
INFO:distilabel:Calling generator for batch 34...
INFO:distilabel:Processing batch 35 of 100...
INFO:distilabel:Calling generator for batch 35...
INFO:distilabel:Processing batch 36 of 100...
INFO:distilabel:Calling generator for batch 36...
INFO:distilabel:Processing batch 37 of 100...
INFO:distilabel:Calling generator for batch 37...
INFO:distilabel:Processing batch 38 of 100...
INFO:distilabel:Calling generator for batch 38...
INFO:distilabel:Processing batch 39 of 100...
INFO:distilabel:Calling generator for batch 39...
INFO:distilabel:Processing batch 40 of 100...
INFO:distilabel:Calling generator for batch 40...
INFO:distilabel:Processing batch 41 of 100...
INFO:distilabel:Calling generator for batch 41...
INFO:distilabel:Processing batch 42 of 100...
INFO:distilabel:Calling generator for batch 42...
INFO:distilabel:Processing batch 43 of 100...
INFO:distilabel:Calling generator for batch 43...
INFO:distilabel:Processing batch 44 of 100...
INFO:distilabel:Calling generator for batch 44...
INFO:distilabel:Processing batch 45 of 100...
INFO:distilabel:Calling generator for batch 45...
INFO:distilabel:Processing batch 46 of 100...
INFO:distilabel:Calling generator for batch 46...
INFO:distilabel:Processing batch 47 of 100...
INFO:distilabel:Calling generator for batch 47...
INFO:distilabel:Processing batch 48 of 100...
INFO:distilabel:Calling generator for batch 48...
INFO:distilabel:Processing batch 49 of 100...
INFO:distilabel:Calling generator for batch 49...
INFO:distilabel:Processing batch 50 of 100...
INFO:distilabel:Calling generator for batch 50...
INFO:distilabel:Processing batch 51 of 100...
INFO:distilabel:Calling generator for batch 51...
INFO:distilabel:Processing batch 52 of 100...
INFO:distilabel:Calling generator for batch 52...
INFO:distilabel:Processing batch 53 of 100...
INFO:distilabel:Calling generator for batch 53...
INFO:distilabel:Processing batch 54 of 100...
INFO:distilabel:Calling generator for batch 54...
INFO:distilabel:Processing batch 55 of 100...
INFO:distilabel:Calling generator for batch 55...
INFO:distilabel:Processing batch 56 of 100...
INFO:distilabel:Calling generator for batch 56...
INFO:distilabel:Processing batch 57 of 100...
INFO:distilabel:Calling generator for batch 57...
INFO:distilabel:Processing batch 58 of 100...
INFO:distilabel:Calling generator for batch 58...
INFO:distilabel:Processing batch 59 of 100...
INFO:distilabel:Calling generator for batch 59...
INFO:distilabel:Processing batch 60 of 100...
INFO:distilabel:Calling generator for batch 60...
INFO:distilabel:Processing batch 61 of 100...
INFO:distilabel:Calling generator for batch 61...
INFO:distilabel:Processing batch 62 of 100...
INFO:distilabel:Calling generator for batch 62...
INFO:distilabel:Processing batch 63 of 100...
INFO:distilabel:Calling generator for batch 63...
INFO:distilabel:Processing batch 64 of 100...
INFO:distilabel:Calling generator for batch 64...
INFO:distilabel:Processing batch 65 of 100...
INFO:distilabel:Calling generator for batch 65...
INFO:distilabel:Processing batch 66 of 100...
INFO:distilabel:Calling generator for batch 66...
INFO:distilabel:Processing batch 67 of 100...
INFO:distilabel:Calling generator for batch 67...
INFO:distilabel:Processing batch 68 of 100...
INFO:distilabel:Calling generator for batch 68...
INFO:distilabel:Processing batch 69 of 100...
INFO:distilabel:Calling generator for batch 69...
INFO:distilabel:Processing batch 70 of 100...
INFO:distilabel:Calling generator for batch 70...
INFO:distilabel:Processing batch 71 of 100...
INFO:distilabel:Calling generator for batch 71...
INFO:distilabel:Processing batch 72 of 100...
INFO:distilabel:Calling generator for batch 72...
INFO:distilabel:Processing batch 73 of 100...
INFO:distilabel:Calling generator for batch 73...
INFO:distilabel:Processing batch 74 of 100...
INFO:distilabel:Calling generator for batch 74...
INFO:distilabel:Processing batch 75 of 100...
INFO:distilabel:Calling generator for batch 75...
INFO:distilabel:Processing batch 76 of 100...
INFO:distilabel:Calling generator for batch 76...
INFO:distilabel:Processing batch 77 of 100...
INFO:distilabel:Calling generator for batch 77...
INFO:distilabel:Processing batch 78 of 100...
INFO:distilabel:Calling generator for batch 78...
INFO:distilabel:Processing batch 79 of 100...
INFO:distilabel:Calling generator for batch 79...
INFO:distilabel:Processing batch 80 of 100...
INFO:distilabel:Calling generator for batch 80...
INFO:distilabel:Processing batch 81 of 100...
INFO:distilabel:Calling generator for batch 81...
INFO:distilabel:Processing batch 82 of 100...
INFO:distilabel:Calling generator for batch 82...
INFO:distilabel:Processing batch 83 of 100...
INFO:distilabel:Calling generator for batch 83...
INFO:distilabel:Processing batch 84 of 100...
INFO:distilabel:Calling generator for batch 84...
INFO:distilabel:Processing batch 85 of 100...
INFO:distilabel:Calling generator for batch 85...
INFO:distilabel:Processing batch 86 of 100...
INFO:distilabel:Calling generator for batch 86...
INFO:distilabel:Processing batch 87 of 100...
INFO:distilabel:Calling generator for batch 87...
INFO:distilabel:Processing batch 88 of 100...
INFO:distilabel:Calling generator for batch 88...
INFO:distilabel:Processing batch 89 of 100...
INFO:distilabel:Calling generator for batch 89...
INFO:distilabel:Processing batch 90 of 100...
INFO:distilabel:Calling generator for batch 90...
INFO:distilabel:Processing batch 91 of 100...
INFO:distilabel:Calling generator for batch 91...
INFO:distilabel:Processing batch 92 of 100...
INFO:distilabel:Calling generator for batch 92...
INFO:distilabel:Processing batch 93 of 100...
INFO:distilabel:Calling generator for batch 93...
INFO:distilabel:Processing batch 94 of 100...
INFO:distilabel:Calling generator for batch 94...
INFO:distilabel:Processing batch 95 of 100...
INFO:distilabel:Calling generator for batch 95...
INFO:distilabel:Processing batch 96 of 100...
INFO:distilabel:Calling generator for batch 96...
INFO:distilabel:Calling generator for batch 97...
INFO:distilabel:Processing batch 98 of 100...
INFO:distilabel:Calling generator for batch 98...
INFO:distilabel:Processing batch 99 of 100...
INFO:distilabel:Calling generator for batch 99...
INFO:distilabel:Processing batch 100 of 100...
INFO:distilabel:Calling generator for batch 100...



Flattening the indices:   0%|          | 0/100 [00:00<?, ? examples/s]
Saving the dataset (0/1 shards):   0%|          | 0/100 [00:00<?, ? examples/s]
INFO:distilabel:Final dataset saved at /content/ckpt


final_dataset.push_to_hub("alvarobartt/improving-text-embeddings-with-llms", config_name="task-completion")
Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]
README.md:   0%|          | 0.00/1.21k [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/datasets/alvarobartt/improving-text-embeddings-with-llms/commit/1d5ce88974b56cbc1f410ed625e15ac96fe1497d', commit_message='Upload dataset', commit_description='', oid='1d5ce88974b56cbc1f410ed625e15ac96fe1497d', pr_url=None, pr_revision=None, pr_num=None)

Annotate with Argilla

The datasets.Dataset generated by Pipeline.generate contains a pre-implemented method to easily export it to a FeedbackDataset in Argilla, so as to allow any use to easily incorporate feedback to the previously generated synthetic dataset.

So on, adding Argilla as a curation tool with humans in the loop, would even push the generated synthetic data further in quality.

Before calling the to_argilla method over the generated dataset, one should first install argilla. It can be installed either from the extra within distilabel as pip install "distilabel[argilla]", which is the recommended way, or just as pip install argilla --upgrade.

rg_dataset = final_dataset.to_argilla()

Besides converting it into a FeedbackDataset, it can also be pushed to Argilla, so as to use the Argilla UI to annotate the records recently generated with distilabel. To do so, once should first have an Argilla instance running (see Argilla Documentation - Installation) and then you would be free to push the recently converted dataset to annotate it.

import argilla as rg

rg.init(api_url="...", api_key="...")
rg_dataset.push_to_argilla("my-dataset", workspace="admin")

Conclusion

With distilabel generating synthetic dataset is easier than ever, and also customizable to a wide variety of use cases. In this tutorial, we showcased how to replicate "Improving Text Embeddings with Large Language Models", but could be adapted to your own needs.

The authors mention that for future work, they aim to further improve the multilingual performance of our model and explore the possibility of using open-source LLMs to generate synthetic data, instead of OpenAI proprietary ones.

Additionally, they also intend to investigate ways to improve the inference efficiency and lower the storage cost for LLM based text embeddings.