Generate synthetic text classification data¶
- Goal: Generate synthetic text classification data to augment an imbalanced and limited dataset for training a topic classifier. In addition, generate new data for training a fact-based versus opinion-based classifier to add a new label.
- Libraries: argilla, hf-inference-endpoints, SetFit
- Components: LoadDataFromDicts, EmbeddingTaskGenerator, GenerateTextClassificationData
Getting started¶
Install the dependencies¶
To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using the free but rate-limited Hugging Face serverless Inference API for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:
Let's make the required imports:
import random
from collections import Counter
from datasets import load_dataset, Dataset
from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import (
GenerateTextClassificationData,
)
from setfit import SetFitModel, Trainer, sample_dataset
You'll need an HF_TOKEN
to use the HF Inference Endpoints. Log in to use it directly within this notebook.
(optional) Deploy Argilla¶
You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.
Along with that, you will need to install Argilla as a distilabel extra.
The dataset¶
We will use the fancyzhx/ag_news
dataset from the Hugging Face Hub as our original data source. To simulate a real-world scenario with imbalanced and limited data, we will load only 20 samples from this dataset.
Now, we can retrieve the available labels in the dataset and examine the current data distribution.
As observed, the dataset is imbalanced, with most samples falling under the World
category, while the Sci/Tech
category is entirely missing. Moreover, there are insufficient samples to effectively train a topic classification model.
We will also define the labels for the new classification task.
Define the text classification task¶
To generate the data we will use the GenerateTextClassificationData
task. This task will use as input classification tasks and we can define the language, difficulty and clarity required for the generated data.
task = GenerateTextClassificationData(
language="English",
difficulty="college",
clarity="clear",
num_generations=1,
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.4},
),
input_batch_size=5,
)
task.load()
result = next(
task.process([{"task": "Classify the news article as fact-based or opinion-based"}])
)
print(result[0]["distilabel_metadata"]["raw_input_generate_text_classification_data_0"])
For our use case, we only need to generate data for two tasks: a topic classification task and a fact versus opinion classification task. Therefore, we will define the tasks accordingly. As we will be using an smaller model for generation, we will select 2 random labels for each topic classification task and change the order for the fact versus opinion classification task ensuring more diversity in the generated data.
task_templates = [
"Determine the news article as {}",
"Classify news article as {}",
"Identify the news article as {}",
"Categorize the news article as {}",
"Label the news article using {}",
"Annotate the news article based on {}",
"Determine the theme of a news article from {}",
"Recognize the topic of the news article as {}",
]
classification_tasks = [
{"task": action.format(" or ".join(random.sample(labels_topic, 2)))}
for action in task_templates for _ in range(4)
] + [
{"task": action.format(" or ".join(random.sample(labels_fact_opinion, 2)))}
for action in task_templates
]
Run the pipeline¶
Now, it's time to define and run the pipeline. As mentioned, we will load the written tasks and feed them into the GenerateTextClassificationData
task. For our use case, we will be using Meta-Llama-3.1-8B-Instruct
via the InferenceEndpointsLLM
, with different degrees of difficulty and clarity.
difficulties = ["college", "high school", "PhD"]
clarity = ["clear", "understandable with some effort", "ambiguous"]
with Pipeline("texcat-generation-pipeline") as pipeline:
tasks_generator = LoadDataFromDicts(data=classification_tasks)
generate_data = []
for difficulty in difficulties:
for clarity_level in clarity:
task = GenerateTextClassificationData(
language="English",
difficulty=difficulty,
clarity=clarity_level,
num_generations=2,
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
),
input_batch_size=5,
)
generate_data.append(task)
for task in generate_data:
tasks_generator.connect(task)
Let's now run the pipeline and generate the synthetic data.
You can push the dataset to the Hub for sharing with the community and embed it to explore the data.
By examining the distiset distribution, we can confirm that it includes at least the 8 required samples for each label to train our classification models with SetFit.
We will create two datasets with the required labels and data for our use cases.
def extract_rows(distiset, labels):
return [
{
"text": entry["input_text"],
"label": entry["label"],
"id": i
}
for dataset_name in distiset
for i, entry in enumerate(distiset[dataset_name]["train"])
if entry["label"] in labels
]
data_topic = extract_rows(distiset, labels_topic)
data_fact_opinion = extract_rows(distiset, labels_fact_opinion)
(Optional) Evaluate with Argilla¶
Get started in Argilla
If you are not familiar with Argilla, we recommend taking a look at the Argilla quickstart docs. Alternatively, you can use your Hugging Face account to login to the Argilla demo Space.
To get the most out of our data, we will use Argilla. First, we need to connect to the Argilla instance.
import argilla as rg
# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
We will create a Dataset
for each task, with an input TextField
for the text classification text and a LabelQuestion
to ensure the generated labels are correct.
def create_texcat_dataset(dataset_name, labels):
settings = rg.Settings(
fields=[rg.TextField("text")],
questions=[
rg.LabelQuestion(
name="label",
title="Classify the texts according to the following labels",
labels=labels,
),
],
)
return rg.Dataset(name=dataset_name, settings=settings).create()
rg_dataset_topic = create_texcat_dataset("topic-classification", labels_topic)
rg_dataset_fact_opinion = create_texcat_dataset(
"fact-opinion-classification", labels_fact_opinion
)
Now, we can upload the generated data to Argilla and evaluate it. We will use the generated labels as suggestions.
Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on Submit
. Otherwise, you can select the correct label.
Note
Check this how-to guide to know more about annotating in the UI.
Once, you get the annotations, let's continue by retrieving the data from Argilla and format it as a dataset with the required data.
Train your models¶
In our case, we will fine-tune using SetFit. However, you can select the one that best fits your requirements.
Formatting the data¶
The next step will be to format the data to be compatible with SetFit. In the case of the topic classification, we will need to combine the synthetic data with the original data.
If we check the data distribution now, we can see that we have enough samples for each label to train our models.
Now, let's create our training and validation datasets. The training dataset will gather 8 samples by label. In this case, the validation datasets will contain the remaining samples not included in the training datasets.
def sample_and_split(dataset, label_column, num_samples):
train_dataset = sample_dataset(
dataset, label_column=label_column, num_samples=num_samples
)
eval_dataset = dataset.filter(lambda x: x["id"] not in set(train_dataset["id"]))
return train_dataset, eval_dataset
dataset_topic_full = Dataset.from_list(data_topic)
dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)
train_dataset_topic, eval_dataset_topic = sample_and_split(
dataset_topic_full, "label", 8
)
train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(
dataset_fact_opinion_full, "label", 8
)
The actual training¶
Let's train our models for each task! We will use TaylorAI/bge-micro-v2, available in the Hugging Face Hub. You can check the MTEB leaderboard to select the best model for your use case.
model_fact_opinion = train_model(
model_name="TaylorAI/bge-micro-v2",
dataset=train_dataset_fact_opinion,
eval_dataset=eval_dataset_fact_opinion,
)
model_fact_opinion.save_pretrained("fact_opinion_classification_model")
model_fact_opinion = SetFitModel.from_pretrained("fact_opinion_classification_model")
Voilà! The models are now trained and ready to be used. You can start making predictions to check the model's performance and add the new label. Optionally, you can continue using distilabel to generate additional data or Argilla to verify the quality of the predictions.
Conclusions¶
In this tutorial, we showcased the detailed steps to build a pipeline for generating text classification data using distilabel. You can customize this pipeline for your own use cases and share your datasets with the community through the Hugging Face Hub.
We defined two text classification tasks—a topic classification task and a fact versus opinion classification task—and generated new data using various models via the serverless Hugging Face Inference API. Then, we curated the generated data with Argilla. Finally, we trained the models with SetFit using both the original and synthetic data.