Skip to content

Synthetic data generation for fine-tuning custom retrieval and reranking models

GenerateSentencePair pipeline overview

Note

For a comprehensive overview on optimizing the retrieval performance in a RAG pipeline, check this guide in collaboration with ZenML, an open-source MLOps framework designed for building portable and production-ready machine learning pipelines.

Getting started

Install the dependencies

To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using the free but rate-limited Hugging Face serverless Inference API for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:

!pip install "distilabel[hf-inference-endpoints]"
!pip install "sentence-transformers~=3.0"

Let's make the needed imports:

from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.steps import LoadDataFromHub

from sentence_transformers import SentenceTransformer, CrossEncoder
import torch

You'll need an HF_TOKEN to use the HF Inference Endpoints. Login to use it directly within this notebook.

import os
from huggingface_hub import login

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)

(optional) Deploy Argilla

You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.

Along with that, you will need to install Argilla as a distilabel extra.

!pip install "distilabel[argilla, hf-inference-endpoints]"

Let's make the extra needed imports:

import argilla as rg

The dataset

Before starting any project, it is always important to look at your data. Our data is publicly available on the Hugging Face Hub so we can have a quick look through their dataset viewer within an embedded iFrame.

As we can see, our dataset contains a column called chunks, which was obtained from the Argilla docs. Normally, you would need to download and chunk the data but we will not cover that in this tutorial. To read a full explanation for how this dataset was generated, please refer to How we leveraged distilabel to create an Argilla 2.0 Chatbot.

Alternatively, we can load the entire dataset to disk with datasets.load_dataset.

Synthetic data generation

The GenerateSentencePair component from distilabel can be used to generate training datasets for embeddings models.

It is a pre-defined Task that given an anchor sentence generate data for a specific action. Supported actions are: "paraphrase", "semantically-similar", "query", "answer". In our case the chunks column corresponds to the anchor. This means we will use query to generate potential queries for a fine-tuning a retrieval model and that we will use semantically-similar to generate texts that are similar to the intial anchor for fine-tuning a reranking model.

We will triplet=True in order to generate both positive and negative examples, which should help the model generalize better during fine-tuning and we will set hard_negative=True to generate more challenging examples that are closer to the anchor and discussed topics.

Lastly, we can seed the LLM with context to generate more relevant examples.

context = (
"""
The text is a chunk from technical Python SDK documentation of Argilla.
Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets.
Along with prose explanations, the text chunk may include code snippets and Python references.
"""
)

Retrieval

For retrieval, we will thus generate queries that are similar to the chunks column. We will use the query action to generate potential queries for a fine-tuning a retrieval model.

generate_sentence_pair = GenerateSentencePair(
    triplet=True,  
    hard_negative=True,
    action="query",
    llm=llm,
    input_batch_size=10,
    context=context,
)

Reranking

For reranking, we will generate texts that are similar to the intial anchor. We will use the semantically-similar action to generate texts that are similar to the intial anchor for fine-tuning a reranking model. In this case, we set hard_negative=False to generate more diverse and potentially wrong examples, which can be used as negative examples for similarity fine-tuning because rerankers cannot be fine-tuned using triplets.

generate_sentence_pair = GenerateSentencePair(
    triplet=True,
    hard_negative=False,
    action="semantically-similar",
    llm=llm,
    input_batch_size=10,
    context=context,
)

Combined pipeline

We will now use the GenerateSentencePair task to generate synthetic data for both retrieval and reranking models in a single pipeline. Note that, we map the chunks column to the anchor argument.

llm = InferenceEndpointsLLM(
    model_id="mistralai/Mistral-7B-Instruct-v0.2",
    tokenizer_id="mistralai/Mistral-7B-Instruct-v0.2",
)

with Pipeline(name="generate") as pipeline:
    load_dataset = LoadDataFromHub(
        num_examples=15,
        output_mappings={"chunks": "anchor"},
    )
    generate_retrieval_pairs = GenerateSentencePair(
        name="generate_retrieval_pairs",
        triplet=True,
        hard_negative=True,
        action="query",
        llm=llm,
        input_batch_size=10,
        context=context,
    )
    generate_reranking_pairs = GenerateSentencePair(
        name="generate_reranking_pairs",
        triplet=True,
        hard_negative=False,  # to potentially generate non-relevant pairs
        action="semantically-similar",
        llm=llm,
        input_batch_size=10,
        context=context,
    )

    load_dataset.connect(generate_retrieval_pairs, generate_reranking_pairs)

Next, we can execute this using pipeline.run. We will provide some parameters to specific components within our pipeline.

generation_kwargs = {
    "llm": {
        "generation_kwargs": {
            "temperature": 0.7,
            "max_new_tokens": 512,
        }
    }
}

distiset = pipeline.run(  
    parameters={
        load_dataset.name: {
            "repo_id": "plaguss/argilla_sdk_docs_raw_unstructured",
            "split": "train",
        },
        generate_retrieval_pairs.name: generation_kwargs,
        generate_reranking_pairs.name: generation_kwargs,
    },
    use_cache=False,  # False for demo
)

Data generation can be a expensive, so it is recommended to store the data somewhere. For now, we will store it on the Hugging Face Hub, using our push_to_hub method.

distiset.push_to_hub("[your-owner-name]/example-retrieval-reranking-dataset")

We have got 2 different leaf/end nodes, therefore we've got a distil configurations we can access, one for the retrieval data, and one for the reranking data.

Looking at these initial examples, we can see they nicely capture the essence of the chunks column but we will need to evaluate the quality of the data a bit more before we can use it for fine-tuning.

Data quality evaluation

Data is never as clean as it can be and this also holds for synthetically generated data too, therefore, it is always good to spent some time and look at your data.

Feature engineering

In order to evaluate the quality of our data we will use features of the models that we intent to fine-tune as proxy for data quality. We can then use these features to filter out the best examples.

In order to choose a good default model, we will use the Massive Text Embedding Benchmark (MTEB) Leaderboard. We want to optimize for size and speed, so we will set model size <100M and then filter for Retrieval and Reranking based on the highest average score, resulting in Snowflake/snowflake-arctic-embed-s and sentence-transformers/all-MiniLM-L12-v2 respectively.

Retrieval

For retrieval, we will compute similarities for the current embeddings of anchor-positive, positive-negative and anchor-negative pairs. We assume that an overlap of these similarities will cause the model to have difficulties generalizing and therefore we can use these features to evaluate the quality of our data.

model_id = "Snowflake/snowflake-arctic-embed-m"  # Hugging Face model ID

model_retrieval = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

Next, we will encode the generated text pairs and compute the similarities.

from sklearn.metrics.pairwise import cosine_similarity

def get_embeddings(texts):
    vectors = model_retrieval.encode(texts)
    return [vector.tolist() for vector in vectors]


def get_similarities(vector_batch_a, vector_batch_b):
    similarities = []
    for vector_a, vector_b in zip(vector_batch_a, vector_batch_b):
        similarity = cosine_similarity([vector_a], [vector_b])[0][0]
        similarities.append(similarity)
    return similarities

def format_data_retriever(batch):# -&gt; Any:
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["positive-vector"] = get_embeddings(batch["positive"])
    batch["negative-vector"] = get_embeddings(batch["negative"])    
    batch["similarity-positive-negative"] = get_similarities(batch["positive-vector"], batch["negative-vector"])
    batch["similarity-anchor-positive"] = get_similarities(batch["anchor-vector"], batch["positive-vector"])
    batch["similarity-anchor-negative"] = get_similarities(batch["anchor-vector"], batch["negative-vector"])
    return batch

dataset_generate_retrieval_pairs = distiset["generate_retrieval_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)

Reranking

For reranking, we will compute the compute the relevance scores from an existing reranker model for anchor-positive, positive-negative and anchor-negative pais and make a similar assumption as for the retrieval model.

model_id = "sentence-transformers/all-MiniLM-L12-v2"

model = CrossEncoder(model_id)

Next, we will compute the similarity for the generated text pairs using the reranker. On top of that, we will compute an anchor-vector to allow for doing semantic search.

def format_data_retriever(batch):# -&gt; Any:
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["similarity-positive-negative"] = model.predict(zip(batch["positive-vector"], batch["negative-vector"]))
    batch["similarity-anchor-positive"] = model.predict(zip(batch["anchor-vector"], batch["positive-vector"]))
    batch["similarity-anchor-negative"] = model.predict(zip(batch["anchor-vector"], batch["negative-vector"]))
    return batch

dataset_generate_reranking_pairs = distiset["generate_reranking_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)

And voila, we have our proxies for quality evaluation which we can use to filter out the best and worst examples.

(Optional) Argilla

To get the most out of you data and actually look at our data, we will use Argilla. If you are not familiar with Argilla, we recommend taking a look at the Argilla quickstart docs. Alternatively, you can use your Hugging Face account to login to the Argilla demo Space.

To start exploring data, we first need to define an argilla.Dataset. We will create a basic datset with some input TextFields for the anchor and output TextQuestions for the positive and negative pairs. Additionally, we will use the file_name as MetaDataProperty. Lastly, we will be re-using the vectors obtained from our previous step to allow for semantic search and we will add te similarity scores for some basic filtering and sorting.

First, we need to define the setting for our Argilla dataset. We will create two different datasets, one for the retrieval data and one for the reranking data to ensure our annotators can focus on the task at hand.

import argilla as rg
from argilla._exceptions import ConflictError

api_key = "ohh so secret"
api_url = "https://[your-owner-name]-[your-space-name].hf.space"

client = rg.Argilla(api_url=api_url, api_key=api_key)

settings = rg.Settings(
    fields=[
        rg.TextField("anchor")
    ],
    questions=[
        rg.TextQuestion("positive"),
        rg.TextQuestion("negative"),
        rg.LabelQuestion(
            name="is_positive_relevant",
            title="Is the positive query relevant?",
            labels=["yes", "no"],
        ),
        rg.LabelQuestion(
            name="is_negative_irrelevant",
            title="Is the negative query irrelevant?",
            labels=["yes", "no"],
        )
    ],
    metadata=[
        rg.TermsMetadataProperty("filename"),
        rg.FloatMetadataProperty("similarity-positive-negative"),
        rg.FloatMetadataProperty("similarity-anchor-positive"),
        rg.FloatMetadataProperty("similarity-anchor-negative"),
    ],
    vectors=[
        rg.VectorField("anchor-vector", dimensions=model.get_sentence_embedding_dimension())
    ]
)
rg_datasets = []
for dataset_name in ["generate_retrieval_pairs", "generate_reranking_pairs"]:
    ds = rg.Dataset(
        name=dataset_name,
        settings=settings
    )
    try:
        ds.create()
    except ConflictError:
        ds = client.datasets(dataset_name)
    rg_datasets.append(ds)

Now, we've got our dataset definitions setup in Argilla, we can upload our data to Argilla.

ds_datasets = [dataset_generate_retrieval_pairs, dataset_generate_reranking_pairs]

records = []

for rg_dataset, ds_dataset in zip(rg_datasets, ds_datasets):
    for idx, entry in enumerate(ds_dataset):
        records.append(
            rg.Record(
                id=idx,
                fields={"anchor": entry["anchor"]},
                suggestions=[
                    rg.Suggestion("positive", value=entry["positive"], agent="gpt-4o", type="model"),
                    rg.Suggestion("negative", value=entry["negative"], agent="gpt-4o", type="model"),
                ],
                metadata={
                    "filename": entry["filename"],
                    "similarity-positive-negative": entry["similarity-positive-negative"],
                    "similarity-anchor-positive": entry["similarity-anchor-positive"],
                    "similarity-anchor-negative": entry["similarity-anchor-negative"]
                },
                vectors={"anchor-vector": entry["anchor-vector"]}
            )
        )
    rg_dataset.records.log(records)

Now, we can explore the UI and add a final human touch to get he most out of our dataset.

Fine-tuning

At last, we can fine-tune our models. We will use the sentence-transformers library to fine-tune our models.

Retrieval

For retrieval, we have created a script that fine-tunes a model on our generated data the generated data based https://github.com/argilla-io/argilla-sdk-chatbot/blob/main/train_embedding.ipynb.You can also open it in Google Colab directly.

Reranking

For reranking, sentence-transformers provides a script that shows how to fine-tune a CrossEncoder models. Ad of now, there is some uncertainty over fine-tuning CrossEncoder models with triplets but you can still use the positive and anchor

Conclusions

In this tutorial, we present an end-to-end example of fine-tuning retrievers and rerankers for RAG. This serves as a good starting point for optimizing and maintaining your data and model but need to be adapted to your specific use case.

We started with some seed data from the Argilla docs, generated synthetic data for retrieval and reranking models, evaluated the quality of the data, and showed how to fine-tune the models. We also used Argilla to get a human touch on the data.