๐ฆ Improving Text Embeddings with LLMs¶
In this tutorial, we will be replicating the process described in "Improving Text Embeddings with Large Language Models" by Liang Wang et al. for synthetically generating multilingual data to be used for training a sentence similarity model.
Installation¶
We will start off by installing distilabel
with the openai
extra, as the authors used both GPT-4 and GPT-3.5 to generate the synthetic data, so no other LLMs were used. Install Argilla for a better visualization and curation of the results
Note that we upgrade typing_extensions
first, since openai
may have conflicts with the default installed version of typing_extensions
if it's outdated.
Introduction¶
In "Improving Text Embeddings with Large Language Models" the authors leverage OpenAI proprietary LLMs such as GPT-4 and GPT-3.5 to generate synthetic data for a wide range of diverse text embedding tasks, achieving competitive performance without using any labeled data. When fine-tuning with a mixture of synthetic data and data from MS-Marco, their model sets SOTA results on BEIR and MTEB benchmarks.
So on, the authors divide the generation process into two steps/phases: * Synthetic generation of the task definition/name, following certain criteria. * Synthetic generation of the data for the task (to be used for fine-tuning) using the task definition and some other sampling parameters.
So that after those phases, the authors end up with a dataset that is suitable for model fine-tuning.
Phase 1: Generating task definitions¶
We will start off with the first phase, which implies generating synthetic task definitions for asymmetric tasks, in this case, we will focus only on the text classification task pool, which follows the following format:
Brainstorm a list of potentially useful text classification tasks.
Please adhere to the following guidelines:
- Tasks should cover a diverse range of domains and task types.
Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
Initially, we will need to define a custom Task
that removes the default system_prompt
and that parses the output using eval
, as the prompt is asking the LLM to generate it using Python list formatting.
Once the default Task
is created (TaskGenerationTask
) we can already initialize the LLM
, in this case, OpenAILLM
using GPT-4, and provide the recently defined task as an argument to it. Additionally, we will also include some generation kwargs, temperature
and top_p
, defined within the Appendix C of the paper, to encourage more diversity within the generation.
Before calling the Pipeline
, we'll need to prepare the input data, which in this case is only a prompt with no formatting required, as it's a simple TextGenerationTask
we want to call multiple times.
from datasets import Dataset
prompt = """Brainstorm a list of potentially useful text classification tasks.
Please adhere to the following guidelines:
- Tasks should cover a diverse range of domains and task types.
Your output must always be a Python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
"""
dataset = Dataset.from_dict({"input": [prompt]})
Then, we're ready to call the Pipeline.generate
method so that the prompt is sent to GPT-4 and N task definitions are generated synthetically.
In this case, N should be equal or close to num_generations
x 20, since within the prompt we ask the LLM to generate a Python list of about 20 elements.
Finally, once the generation has been completed, we will apply some post-processing before proceeding to the next phase. The post-processing is to remove the columns that are not required and to explore the columns with the tasks, so that we unwrap those lists and end up with a dataset with N rows, where N is the total number of tasks, initially contained within one row in nested lists.
Phase 2: Generating data for each task¶
Once all the task definitions have been generated, we can proceed to the next phase, which consists of generating the data for a given task.
In this case, the LLM will need to generate data that suits each task along with a label for that entry, and a misleading label for it too. Besides that, we will also sample some of the arguments within each prompt, so that the generation is ensured to be diverse.
The prompt to be used is the following:
You have been assigned a text classification task: {task}
Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
- "input_text": a string, the input text specified by the classification task.
- "label": a string, the correct label of the input text.
- "misleading_label": a string, an incorrect label that is related to the task.
Please adhere to the following guidelines:
- The "input_text" should be {num_words} words and diverse in expression.
- The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the
"input_text".
- The values for all fields should be in {language}.
- Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make
the task too easy.
- The "input_text" is {clarity} and requires {difficulty} level education to comprehend.
Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
And the possible values for each arg:
task
is the task definition generated in the previous phase.language
is any language name within XLM-R.num_words
is the number of words that theinput_text
to generate should contain at most.difficulty
is how difficult or which is the level required to comprehend theinput_text
to generate.clarity
is how easy or hard is theinput_text
to understand.
prompt = """You have been assigned a text classification task: {task}
Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
- "input_text": a string, the input text specified by the classification task.
- "label": a string, the correct label of the input text.
- "misleading_label": a string, an incorrect label that is related to the task.
Please adhere to the following guidelines:
- The "input_text" should be {num_words} words and diverse in expression.
- The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the
"input_text".
- The values for all fields should be in {language}.
- Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make
the task too easy.
- The "input_text" is {clarity} and requires {difficulty} level education to comprehend.
Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
"""
Just as before, we now need to create a custom Task
to not only parse the output via parse_output
, which in this case we are conducting the LLM to generate a valid JSON already but also to generate the prompt via generate_prompt
, since we need to introduce the sampling mentioned in the paper there.
import json
from random import choice
from typing import Any
from distilabel.tasks.prompt import Prompt
@dataclass
class ExampleGenerationTask(TextGenerationTask):
system_prompt: str = ""
@property
def input_args_names(self) -> List[str]:
return ["task"]
def generate_prompt(self, task: str) -> Prompt:
return Prompt(
system_prompt=self.system_prompt,
formatted_prompt=prompt.format(task=task, language="english", num_words=choice(num_words), difficulty=choice(difficulty), clarity=choice(clarity)),
)
@property
def output_args_names(self) -> List[str]:
return ["input_text", "label", "misleading_label"]
def parse_output(self, output: str) -> Dict[str, Any]:
return json.loads(output)
Other than that, we are all set to instantiate the OpenAILLM
with the recently created task, and call the Pipeline.generate
method with the previously generated datasets.Dataset
.
Human Feedback with Argilla¶
You can use the AI Feedback created by distilabel directly but we have seen that enhancing it with human feedback will improve the quality of your LLM. The datasets.Dataset
generated by Pipeline.generate
contains the to_argilla
method which creates a dataset for Argilla along with out-of-the-box tailored metadata filters and semantic search to allow you to provide human feedback as quickly and engaging as possible. You can check the Argilla docs to get it up and running.
If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:
Now we can convert our dataset to a formatted Argilla dataset and push it.
Conclusion¶
With distilabel
generating synthetic datasets is easier than ever, and also customizable to a wide variety of use cases. In this tutorial, we showcased how to replicate "Improving Text Embeddings with Large Language Models", but could be adapted to your own needs.
The authors mention that for future work, they aim to further improve the multilingual performance of our model and explore the possibility of using open-source LLMs to generate synthetic data, instead of OpenAI proprietary ones.
Additionally, they also intend to investigate ways to improve the inference efficiency and lower the storage cost for LLM-based text embeddings.