Skip to content

๐Ÿฆ’ Improving Text Embeddings with LLMs

Open In Colab Open Source in Github

In this tutorial, we will be replicating the process described in "Improving Text Embeddings with Large Language Models" by Liang Wang et al. for synthetically generating multilingual data to be used for training a sentence similarity model.

Installation

We will start off by installing distilabel with the openai extra, as the authors used both GPT-4 and GPT-3.5 to generate the synthetic data, so no other LLMs were used. Install Argilla for a better visualization and curation of the results

Note that we upgrade typing_extensions first, since openai may have conflicts with the default installed version of typing_extensions if it's outdated.

%pip install --upgrade typing_extensions --quiet
%pip install "distilabel[openai,argilla]" --quiet

Introduction

In "Improving Text Embeddings with Large Language Models" the authors leverage OpenAI proprietary LLMs such as GPT-4 and GPT-3.5 to generate synthetic data for a wide range of diverse text embedding tasks, achieving competitive performance without using any labeled data. When fine-tuning with a mixture of synthetic data and data from MS-Marco, their model sets SOTA results on BEIR and MTEB benchmarks.

So on, the authors divide the generation process into two steps/phases: * Synthetic generation of the task definition/name, following certain criteria. * Synthetic generation of the data for the task (to be used for fine-tuning) using the task definition and some other sampling parameters.

So that after those phases, the authors end up with a dataset that is suitable for model fine-tuning.

Phase 1: Generating task definitions

We will start off with the first phase, which implies generating synthetic task definitions for asymmetric tasks, in this case, we will focus only on the text classification task pool, which follows the following format:

Brainstorm a list of potentially useful text classification tasks.

Please adhere to the following guidelines:
- Tasks should cover a diverse range of domains and task types.

Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
from distilabel.llm import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.tasks import TextGenerationTask

Initially, we will need to define a custom Task that removes the default system_prompt and that parses the output using eval, as the prompt is asking the LLM to generate it using Python list formatting.

from typing import Dict, List
from dataclasses import dataclass

@dataclass
class TaskGenerationTask(TextGenerationTask):
    system_prompt: str = ""

    def parse_output(self, output: str) -> Dict[str, List[str]]:
        return {"generations": eval(output)}

Once the default Task is created (TaskGenerationTask) we can already initialize the LLM, in this case, OpenAILLM using GPT-4, and provide the recently defined task as an argument to it. Additionally, we will also include some generation kwargs, temperature and top_p, defined within the Appendix C of the paper, to encourage more diversity within the generation.

llm = OpenAILLM(
    model="gpt-4",
    api_key="sk-***",
    task=TaskGenerationTask(),
    prompt_format="openai",
    max_new_tokens=1024,
    # Using the following kwargs as stated in Appendix C of the paper
    temperature=1.0,
    top_p=1.0,
)
pipeline = Pipeline(generator=llm)

Before calling the Pipeline, we'll need to prepare the input data, which in this case is only a prompt with no formatting required, as it's a simple TextGenerationTask we want to call multiple times.

from datasets import Dataset

prompt = """Brainstorm a list of potentially useful text classification tasks.
Please adhere to the following guidelines:
- Tasks should cover a diverse range of domains and task types.
Your output must always be a Python list of strings only, with about 20 elements, and each element corresponds to a distinct text classification task in one sentence. Do not explain yourself or output anything else. Be creative!
"""

dataset = Dataset.from_dict({"input": [prompt]})

Then, we're ready to call the Pipeline.generate method so that the prompt is sent to GPT-4 and N task definitions are generated synthetically.

In this case, N should be equal or close to num_generations x 20, since within the prompt we ask the LLM to generate a Python list of about 20 elements.

new_dataset = pipeline.generate(dataset, num_generations=5, skip_dry_run=True)

Finally, once the generation has been completed, we will apply some post-processing before proceeding to the next phase. The post-processing is to remove the columns that are not required and to explore the columns with the tasks, so that we unwrap those lists and end up with a dataset with N rows, where N is the total number of tasks, initially contained within one row in nested lists.

df_dataset = new_dataset.to_pandas()
df_dataset = df_dataset.drop(["generation_prompt", "raw_generation_responses"], axis=1)
df_dataset = df_dataset.explode(["generation_model", "generations"])
df_dataset = df_dataset.explode(["generations"])
df_dataset = df_dataset.reset_index(drop=True)
new_dataset = Dataset.from_pandas(df_dataset)
new_dataset = new_dataset.rename_columns({"generation_model": "model", "generations": "task"})
new_dataset.push_to_hub("alvarobartt/improving-text-embeddings-with-llms", config_name="task-generation")

Phase 2: Generating data for each task

Once all the task definitions have been generated, we can proceed to the next phase, which consists of generating the data for a given task.

In this case, the LLM will need to generate data that suits each task along with a label for that entry, and a misleading label for it too. Besides that, we will also sample some of the arguments within each prompt, so that the generation is ensured to be diverse.

The prompt to be used is the following:

You have been assigned a text classification task: {task}

Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
- "input_text": a string, the input text specified by the classification task.
- "label": a string, the correct label of the input text.
- "misleading_label": a string, an incorrect label that is related to the task.

Please adhere to the following guidelines:
- The "input_text" should be {num_words} words and diverse in expression.
- The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the
"input_text".
- The values for all fields should be in {language}.
- Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make
the task too easy.
- The "input_text" is {clarity} and requires {difficulty} level education to comprehend.

Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!

And the possible values for each arg:

  • task is the task definition generated in the previous phase.
  • language is any language name within XLM-R.
  • num_words is the number of words that the input_text to generate should contain at most.
  • difficulty is how difficult or which is the level required to comprehend the input_text to generate.
  • clarity is how easy or hard is the input_text to understand.
num_words = ["less than 10", "at least 10", "at least 50", "at least 100", "at least 200"]
difficulty = ["high", "school", "college", "PhD"]
clarity = ["clear", "understandable with some effort", "ambiguous"]
prompt = """You have been assigned a text classification task: {task}
Your mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:
- "input_text": a string, the input text specified by the classification task.
- "label": a string, the correct label of the input text.
- "misleading_label": a string, an incorrect label that is related to the task.
Please adhere to the following guidelines:
- The "input_text" should be {num_words} words and diverse in expression.
- The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the
"input_text".
- The values for all fields should be in {language}.
- Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make
the task too easy.
- The "input_text" is {clarity} and requires {difficulty} level education to comprehend.
Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!
"""

Just as before, we now need to create a custom Task to not only parse the output via parse_output, which in this case we are conducting the LLM to generate a valid JSON already but also to generate the prompt via generate_prompt, since we need to introduce the sampling mentioned in the paper there.

import json
from random import choice
from typing import Any

from distilabel.tasks.prompt import Prompt

@dataclass
class ExampleGenerationTask(TextGenerationTask):
    system_prompt: str = ""

    @property
    def input_args_names(self) -> List[str]:
        return ["task"]

    def generate_prompt(self, task: str) -> Prompt:
        return Prompt(
            system_prompt=self.system_prompt,
            formatted_prompt=prompt.format(task=task, language="english", num_words=choice(num_words), difficulty=choice(difficulty), clarity=choice(clarity)),
        )

    @property
    def output_args_names(self) -> List[str]:
        return ["input_text", "label", "misleading_label"]

    def parse_output(self, output: str) -> Dict[str, Any]:
        return json.loads(output)

Other than that, we are all set to instantiate the OpenAILLM with the recently created task, and call the Pipeline.generate method with the previously generated datasets.Dataset.

llm = OpenAILLM(
    model="gpt-4",
    api_key="sk-***",
    task=ExampleGenerationTask(),
    prompt_format="openai",
    max_new_tokens=1024,
    # Using the following kwargs as stated in Appendix C of the paper
    temperature=1.0,
    top_p=1.0,
)
pipeline = Pipeline(generator=llm)
final_dataset = pipeline.generate(new_dataset, num_generations=1, skip_dry_run=True)
final_dataset.push_to_hub("alvarobartt/improving-text-embeddings-with-llms", config_name="task-completion")

Human Feedback with Argilla

You can use the AI Feedback created by distilabel directly but we have seen that enhancing it with human feedback will improve the quality of your LLM. The datasets.Dataset generated by Pipeline.generate contains the to_argilla method which creates a dataset for Argilla along with out-of-the-box tailored metadata filters and semantic search to allow you to provide human feedback as quickly and engaging as possible. You can check the Argilla docs to get it up and running.

If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:

import argilla as rg

# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
    api_url="http://localhost:6900",
    api_key="owner.apikey",
    workspace="admin"
)

Now we can convert our dataset to a formatted Argilla dataset and push it.

# Convert the dataset to Argilla format
rg_dataset = final_dataset.to_argilla()

# Push the dataset to Argilla
rg_dataset.push_to_argilla(name="my-dataset", workspace="admin")

Conclusion

With distilabel generating synthetic datasets is easier than ever, and also customizable to a wide variety of use cases. In this tutorial, we showcased how to replicate "Improving Text Embeddings with Large Language Models", but could be adapted to your own needs.

The authors mention that for future work, they aim to further improve the multilingual performance of our model and explore the possibility of using open-source LLMs to generate synthetic data, instead of OpenAI proprietary ones.

Additionally, they also intend to investigate ways to improve the inference efficiency and lower the storage cost for LLM-based text embeddings.