Synthetic data generation for fine-tuning custom retrieval and reranking models¶
- Goal: Bootstrap, optimize and maintain your embedding models and rerankers through synthetic data generation and human feedback.
- Libraries: argilla, hf-inference-endpoints, sentence-transformers
- Components: LoadDataFromHub, GenerateSentencePair, InferenceEndpointsLLM
Getting started¶
Install the dependencies¶
To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using the free but rate-limited Hugging Face serverless Inference API for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:
Let's make the needed imports:
You'll need an HF_TOKEN
to use the HF Inference Endpoints. Login to use it directly within this notebook.
(optional) Deploy Argilla¶
You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.
Along with that, you will need to install Argilla as a distilabel extra.
Let's make the extra needed imports:
The dataset¶
Before starting any project, it is always important to look at your data. Our data is publicly available on the Hugging Face Hub so we can have a quick look through their dataset viewer within an embedded iFrame.
As we can see, our dataset contains a column called chunks
, which was obtained from the Argilla docs. Normally, you would need to download and chunk the data but we will not cover that in this tutorial. To read a full explanation for how this dataset was generated, please refer to How we leveraged distilabel to create an Argilla 2.0 Chatbot.
Alternatively, we can load the entire dataset to disk with datasets.load_dataset
.
Synthetic data generation¶
The GenerateSentencePair
component from distilabel
can be used to generate training datasets for embeddings models.
It is a pre-defined Task
that given an anchor
sentence generate data for a specific action
. Supported actions are: "paraphrase", "semantically-similar", "query", "answer"
. In our case the chunks
column corresponds to the anchor
. This means we will use query
to generate potential queries for a fine-tuning a retrieval model and that we will use semantically-similar
to generate texts that are similar to the intial anchor for fine-tuning a reranking model.
We will triplet=True
in order to generate both positive and negative examples, which should help the model generalize better during fine-tuning and we will set hard_negative=True
to generate more challenging examples that are closer to the anchor and discussed topics.
Lastly, we can seed the LLM with context
to generate more relevant examples.
Retrieval¶
For retrieval, we will thus generate queries that are similar to the chunks
column. We will use the query
action to generate potential queries for a fine-tuning a retrieval model.
Reranking¶
For reranking, we will generate texts that are similar to the intial anchor. We will use the semantically-similar
action to generate texts that are similar to the intial anchor for fine-tuning a reranking model. In this case, we set hard_negative=False
to generate more diverse and potentially wrong examples, which can be used as negative examples for similarity fine-tuning because rerankers cannot be fine-tuned using triplets.
Combined pipeline¶
We will now use the GenerateSentencePair
task to generate synthetic data for both retrieval and reranking models in a single pipeline. Note that, we map the chunks
column to the anchor
argument.
llm = InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
tokenizer_id="mistralai/Mistral-7B-Instruct-v0.2",
)
with Pipeline(name="generate") as pipeline:
load_dataset = LoadDataFromHub(
num_examples=15,
output_mappings={"chunks": "anchor"},
)
generate_retrieval_pairs = GenerateSentencePair(
name="generate_retrieval_pairs",
triplet=True,
hard_negative=True,
action="query",
llm=llm,
input_batch_size=10,
context=context,
)
generate_reranking_pairs = GenerateSentencePair(
name="generate_reranking_pairs",
triplet=True,
hard_negative=False, # to potentially generate non-relevant pairs
action="semantically-similar",
llm=llm,
input_batch_size=10,
context=context,
)
load_dataset.connect(generate_retrieval_pairs, generate_reranking_pairs)
Next, we can execute this using pipeline.run
. We will provide some parameters
to specific components within our pipeline.
generation_kwargs = {
"llm": {
"generation_kwargs": {
"temperature": 0.7,
"max_new_tokens": 512,
}
}
}
distiset = pipeline.run(
parameters={
load_dataset.name: {
"repo_id": "plaguss/argilla_sdk_docs_raw_unstructured",
"split": "train",
},
generate_retrieval_pairs.name: generation_kwargs,
generate_reranking_pairs.name: generation_kwargs,
},
use_cache=False, # False for demo
)
Data generation can be a expensive, so it is recommended to store the data somewhere. For now, we will store it on the Hugging Face Hub, using our push_to_hub
method.
We have got 2 different leaf/end nodes, therefore we've got a distil configurations we can access, one for the retrieval data, and one for the reranking data.
Looking at these initial examples, we can see they nicely capture the essence of the chunks
column but we will need to evaluate the quality of the data a bit more before we can use it for fine-tuning.
Data quality evaluation¶
Data is never as clean as it can be and this also holds for synthetically generated data too, therefore, it is always good to spent some time and look at your data.
Feature engineering¶
In order to evaluate the quality of our data we will use features of the models that we intent to fine-tune as proxy for data quality. We can then use these features to filter out the best examples.
In order to choose a good default model, we will use the Massive Text Embedding Benchmark (MTEB) Leaderboard. We want to optimize for size and speed, so we will set model size <100M
and then filter for Retrieval
and Reranking
based on the highest average score, resulting in Snowflake/snowflake-arctic-embed-s and sentence-transformers/all-MiniLM-L12-v2 respectively.
Retrieval¶
For retrieval, we will compute similarities for the current embeddings of anchor-positive
, positive-negative
and anchor-negative
pairs. We assume that an overlap of these similarities will cause the model to have difficulties generalizing and therefore we can use these features to evaluate the quality of our data.
Next, we will encode the generated text pairs and compute the similarities.
from sklearn.metrics.pairwise import cosine_similarity
def get_embeddings(texts):
vectors = model_retrieval.encode(texts)
return [vector.tolist() for vector in vectors]
def get_similarities(vector_batch_a, vector_batch_b):
similarities = []
for vector_a, vector_b in zip(vector_batch_a, vector_batch_b):
similarity = cosine_similarity([vector_a], [vector_b])[0][0]
similarities.append(similarity)
return similarities
def format_data_retriever(batch):# -> Any:
batch["anchor-vector"] = get_embeddings(batch["anchor"])
batch["positive-vector"] = get_embeddings(batch["positive"])
batch["negative-vector"] = get_embeddings(batch["negative"])
batch["similarity-positive-negative"] = get_similarities(batch["positive-vector"], batch["negative-vector"])
batch["similarity-anchor-positive"] = get_similarities(batch["anchor-vector"], batch["positive-vector"])
batch["similarity-anchor-negative"] = get_similarities(batch["anchor-vector"], batch["negative-vector"])
return batch
dataset_generate_retrieval_pairs = distiset["generate_retrieval_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)
Reranking¶
For reranking, we will compute the compute the relevance scores from an existing reranker model for anchor-positive
, positive-negative
and anchor-negative
pais and make a similar assumption as for the retrieval model.
Next, we will compute the similarity for the generated text pairs using the reranker. On top of that, we will compute an anchor-vector
to allow for doing semantic search.
def format_data_retriever(batch):# -> Any:
batch["anchor-vector"] = get_embeddings(batch["anchor"])
batch["similarity-positive-negative"] = model.predict(zip(batch["positive-vector"], batch["negative-vector"]))
batch["similarity-anchor-positive"] = model.predict(zip(batch["anchor-vector"], batch["positive-vector"]))
batch["similarity-anchor-negative"] = model.predict(zip(batch["anchor-vector"], batch["negative-vector"]))
return batch
dataset_generate_reranking_pairs = distiset["generate_reranking_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)
And voila, we have our proxies for quality evaluation which we can use to filter out the best and worst examples.
(Optional) Argilla¶
To get the most out of you data and actually look at our data, we will use Argilla. If you are not familiar with Argilla, we recommend taking a look at the Argilla quickstart docs. Alternatively, you can use your Hugging Face account to login to the Argilla demo Space.
To start exploring data, we first need to define an argilla.Dataset
. We will create a basic datset with some input TextFields
for the anchor
and output TextQuestions
for the positive
and negative
pairs. Additionally, we will use the file_name
as MetaDataProperty
. Lastly, we will be re-using the vectors obtained from our previous step to allow for semantic search and we will add te similarity scores for some basic filtering and sorting.
First, we need to define the setting for our Argilla dataset. We will create two different datasets, one for the retrieval data and one for the reranking data to ensure our annotators can focus on the task at hand.
import argilla as rg
from argilla._exceptions import ConflictError
api_key = "ohh so secret"
api_url = "https://[your-owner-name]-[your-space-name].hf.space"
client = rg.Argilla(api_url=api_url, api_key=api_key)
settings = rg.Settings(
fields=[
rg.TextField("anchor")
],
questions=[
rg.TextQuestion("positive"),
rg.TextQuestion("negative"),
rg.LabelQuestion(
name="is_positive_relevant",
title="Is the positive query relevant?",
labels=["yes", "no"],
),
rg.LabelQuestion(
name="is_negative_irrelevant",
title="Is the negative query irrelevant?",
labels=["yes", "no"],
)
],
metadata=[
rg.TermsMetadataProperty("filename"),
rg.FloatMetadataProperty("similarity-positive-negative"),
rg.FloatMetadataProperty("similarity-anchor-positive"),
rg.FloatMetadataProperty("similarity-anchor-negative"),
],
vectors=[
rg.VectorField("anchor-vector", dimensions=model.get_sentence_embedding_dimension())
]
)
rg_datasets = []
for dataset_name in ["generate_retrieval_pairs", "generate_reranking_pairs"]:
ds = rg.Dataset(
name=dataset_name,
settings=settings
)
try:
ds.create()
except ConflictError:
ds = client.datasets(dataset_name)
rg_datasets.append(ds)
Now, we've got our dataset definitions setup in Argilla, we can upload our data to Argilla.
ds_datasets = [dataset_generate_retrieval_pairs, dataset_generate_reranking_pairs]
records = []
for rg_dataset, ds_dataset in zip(rg_datasets, ds_datasets):
for idx, entry in enumerate(ds_dataset):
records.append(
rg.Record(
id=idx,
fields={"anchor": entry["anchor"]},
suggestions=[
rg.Suggestion("positive", value=entry["positive"], agent="gpt-4o", type="model"),
rg.Suggestion("negative", value=entry["negative"], agent="gpt-4o", type="model"),
],
metadata={
"filename": entry["filename"],
"similarity-positive-negative": entry["similarity-positive-negative"],
"similarity-anchor-positive": entry["similarity-anchor-positive"],
"similarity-anchor-negative": entry["similarity-anchor-negative"]
},
vectors={"anchor-vector": entry["anchor-vector"]}
)
)
rg_dataset.records.log(records)
Now, we can explore the UI and add a final human touch to get he most out of our dataset.
Fine-tuning¶
At last, we can fine-tune our models. We will use the sentence-transformers
library to fine-tune our models.
Retrieval¶
For retrieval, we have created a script that fine-tunes a model on our generated data the generated data based https://github.com/argilla-io/argilla-sdk-chatbot/blob/main/train_embedding.ipynb.You can also open it in Google Colab directly.
Reranking¶
For reranking, sentence-transformers
provides a script that shows how to fine-tune a CrossEncoder models. Ad of now, there is some uncertainty over fine-tuning CrossEncoder models with triplets but you can still use the positive
and anchor
Conclusions¶
In this tutorial, we present an end-to-end example of fine-tuning retrievers and rerankers for RAG. This serves as a good starting point for optimizing and maintaining your data and model but need to be adapted to your specific use case.
We started with some seed data from the Argilla docs, generated synthetic data for retrieval and reranking models, evaluated the quality of the data, and showed how to fine-tune the models. We also used Argilla to get a human touch on the data.