Skip to content

Generate synthetic text classification data

Getting started

Install the dependencies

To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using the free but rate-limited Hugging Face serverless Inference API for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:

!pip install "distilabel[hf-inference-endpoints]"
!pip install "transformers~=4.40" "torch~=2.0" "setfit~=1.0"

Let's make the required imports:

import random
from collections import Counter

from datasets import load_dataset, Dataset
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import (
    GenerateTextClassificationData,
)
from setfit import SetFitModel, Trainer, sample_dataset

You'll need an HF_TOKEN to use the HF Inference Endpoints. Log in to use it directly within this notebook.

import os
from huggingface_hub import login

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True)

(optional) Deploy Argilla

You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.

Along with that, you will need to install Argilla as a distilabel extra.

!pip install "distilabel[argilla, hf-inference-endpoints]"

The dataset

We will use the fancyzhx/ag_news dataset from the Hugging Face Hub as our original data source. To simulate a real-world scenario with imbalanced and limited data, we will load only 20 samples from this dataset.

hf_dataset = load_dataset("fancyzhx/ag_news", split="train[-20:]")

Now, we can retrieve the available labels in the dataset and examine the current data distribution.

labels_topic = hf_dataset.features["label"].names
id2str = {i: labels_topic[i] for i in range(len(labels_topic))}
print(id2str)
print(Counter(hf_dataset["label"]))
{0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}
Counter({0: 12, 1: 6, 2: 2})

As observed, the dataset is imbalanced, with most samples falling under the World category, while the Sci/Tech category is entirely missing. Moreover, there are insufficient samples to effectively train a topic classification model.

We will also define the labels for the new classification task.

labels_fact_opinion = ["Fact-based", "Opinion-based"]

Define the text classification task

To generate the data we will use the GenerateTextClassificationData task. This task will use as input classification tasks and we can define the language, difficulty and clarity required for the generated data.

task = GenerateTextClassificationData(
    language="English",
    difficulty="college",
    clarity="clear",
    num_generations=1,
    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
        tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
        generation_kwargs={"max_new_tokens": 512, "temperature": 0.4},
    ),
    input_batch_size=5,
)
task.load()
result = next(
    task.process([{"task": "Classify the news article as fact-based or opinion-based"}])
)
print(result[0]["distilabel_metadata"]["raw_input_generate_text_classification_data_0"])
[{'role': 'user', 'content': 'You have been assigned a text classification task: Classify the news article as fact-based or opinion-based\n\nYour mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:\n - "input_text": a string, the input text specified by the classification task.\n - "label": a string, the correct label of the input text.\n - "misleading_label": a string, an incorrect label that is related to the task.\n\nPlease adhere to the following guidelines:\n - The "input_text" should be diverse in expression.\n - The "misleading_label" must be a valid label for the given task, but not as appropriate as the "label" for the "input_text".\n - The values for all fields should be in English.\n - Avoid including the values of the "label" and "misleading_label" fields in the "input_text", that would make the task too easy.\n - The "input_text" is clear and requires college level education to comprehend.\n\nYour output must always be a JSON object only, do not explain yourself or output anything else. Be creative!'}]

For our use case, we only need to generate data for two tasks: a topic classification task and a fact versus opinion classification task. Therefore, we will define the tasks accordingly. As we will be using an smaller model for generation, we will select 2 random labels for each topic classification task and change the order for the fact versus opinion classification task ensuring more diversity in the generated data.

task_templates = [
    "Determine the news article as {}",
    "Classify news article as {}",
    "Identify the news article as {}",
    "Categorize the news article as {}",
    "Label the news article using {}",
    "Annotate the news article based on {}",
    "Determine the theme of a news article from {}",
    "Recognize the topic of the news article as {}",
]

classification_tasks = [
    {"task": action.format(" or ".join(random.sample(labels_topic, 2)))}
    for action in task_templates for _ in range(4)
] + [
    {"task": action.format(" or ".join(random.sample(labels_fact_opinion, 2)))}
    for action in task_templates
]

Run the pipeline

Now, it's time to define and run the pipeline. As mentioned, we will load the written tasks and feed them into the GenerateTextClassificationData task. For our use case, we will be using Meta-Llama-3.1-8B-Instruct via the InferenceEndpointsLLM, with different degrees of difficulty and clarity.

difficulties = ["college", "high school", "PhD"]
clarity = ["clear", "understandable with some effort", "ambiguous"]

with Pipeline("texcat-generation-pipeline") as pipeline:

    tasks_generator = LoadDataFromDicts(data=classification_tasks)

    generate_data = []
    for difficulty in difficulties:
        for clarity_level in clarity:
            task = GenerateTextClassificationData(
                language="English",
                difficulty=difficulty,
                clarity=clarity_level,
                num_generations=2,
                llm=InferenceEndpointsLLM(
                    model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
                    generation_kwargs={"max_new_tokens": 512, "temperature": 0.7},
                ),
                input_batch_size=5,
            )
            generate_data.append(task)

    for task in generate_data:
        tasks_generator.connect(task)

Let's now run the pipeline and generate the synthetic data.

distiset = pipeline.run()
distiset["generate_text_classification_data_0"]["train"][0]
{'task': 'Determine the news article as Business or World',
 'input_text': "The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone's economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.",
 'label': 'Business',
 'misleading_label': 'World',
 'distilabel_metadata': {'raw_output_generate_text_classification_data_0': '{\n  "input_text": "The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone\'s economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.",\n  "label": "Business",\n  "misleading_label": "World"\n}'},
 'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct'}

You can push the dataset to the Hub for sharing with the community and embed it to explore the data.

distiset.push_to_hub("[your-owner-name]/example-texcat-generation-dataset")

By examining the distiset distribution, we can confirm that it includes at least the 8 required samples for each label to train our classification models with SetFit.

all_labels = [
    entry["label"]
    for dataset_name in distiset
    for entry in distiset[dataset_name]["train"]
]

Counter(all_labels)
Counter({'Sci/Tech': 275,
         'Business': 130,
         'World': 86,
         'Fact-based': 86,
         'Sports': 64,
         'Opinion-based': 54,
         None: 20,
         'Opinion Based': 1,
         'News/Opinion': 1,
         'Science': 1,
         'Environment': 1,
         'Opinion': 1})

We will create two datasets with the required labels and data for our use cases.

def extract_rows(distiset, labels):
    return [
        {
            "text": entry["input_text"],
            "label": entry["label"],
            "id": i
        }
        for dataset_name in distiset
        for i, entry in enumerate(distiset[dataset_name]["train"])
        if entry["label"] in labels
    ]

data_topic = extract_rows(distiset, labels_topic)
data_fact_opinion = extract_rows(distiset, labels_fact_opinion)

(Optional) Evaluate with Argilla

Get started in Argilla

If you are not familiar with Argilla, we recommend taking a look at the Argilla quickstart docs. Alternatively, you can use your Hugging Face account to login to the Argilla demo Space.

To get the most out of our data, we will use Argilla. First, we need to connect to the Argilla instance.

import argilla as rg

# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
    api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="[your-api-key]",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)

We will create a Dataset for each task, with an input TextField for the text classification text and a LabelQuestion to ensure the generated labels are correct.

def create_texcat_dataset(dataset_name, labels):
    settings = rg.Settings(
        fields=[rg.TextField("text")],
        questions=[
            rg.LabelQuestion(
                name="label",
                title="Classify the texts according to the following labels",
                labels=labels,
            ),
        ],
    )
    return rg.Dataset(name=dataset_name, settings=settings).create()


rg_dataset_topic = create_texcat_dataset("topic-classification", labels_topic)
rg_dataset_fact_opinion = create_texcat_dataset(
    "fact-opinion-classification", labels_fact_opinion
)

Now, we can upload the generated data to Argilla and evaluate it. We will use the generated labels as suggestions.

rg_dataset_topic.records.log(data_topic)
rg_dataset_fact_opinion.records.log(data_fact_opinion)

Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on Submit. Otherwise, you can select the correct label.

Note

Check this how-to guide to know more about annotating in the UI.

Once, you get the annotations, let's continue by retrieving the data from Argilla and format it as a dataset with the required data.

rg_dataset_topic = client.datasets("topic-classification")
rg_dataset_fact_opinion = client.datasets("fact-opinion-classification")
status_filter = rg.Query(filter=rg.Filter(("response.status", "==", "submitted")))

submitted_topic = rg_dataset_topic.records(status_filter).to_list(flatten=True)
submitted_fact_opinion = rg_dataset_fact_opinion.records(status_filter).to_list(
    flatten=True
)
def format_submitted(submitted):
    return [
        {
            "text": r["text"],
            "label": r["label.responses"][0],
            "id": i,
        }
        for i, r in enumerate(submitted)
    ]

data_topic = format_submitted(submitted_topic)
data_fact_opinion = format_submitted(submitted_fact_opinion)

Train your models

In our case, we will fine-tune using SetFit. However, you can select the one that best fits your requirements.

Formatting the data

The next step will be to format the data to be compatible with SetFit. In the case of the topic classification, we will need to combine the synthetic data with the original data.

hf_topic = hf_dataset.to_list()
num = len(data_topic)

data_topic.extend(
    [
        {
            "text": r["text"],
            "label": id2str[r["label"]],
            "id": num + i,
        }
        for i, r in enumerate(hf_topic)
    ]
)

If we check the data distribution now, we can see that we have enough samples for each label to train our models.

labels = [record["label"] for record in data_topic]
Counter(labels)
Counter({'Sci/Tech': 275, 'Business': 132, 'World': 98, 'Sports': 70})
labels = [record["label"] for record in data_fact_opinion]
Counter(labels)
Counter({'Fact-based': 86, 'Opinion-based': 54})

Now, let's create our training and validation datasets. The training dataset will gather 8 samples by label. In this case, the validation datasets will contain the remaining samples not included in the training datasets.

def sample_and_split(dataset, label_column, num_samples):
    train_dataset = sample_dataset(
        dataset, label_column=label_column, num_samples=num_samples
    )
    eval_dataset = dataset.filter(lambda x: x["id"] not in set(train_dataset["id"]))
    return train_dataset, eval_dataset


dataset_topic_full = Dataset.from_list(data_topic)
dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)

train_dataset_topic, eval_dataset_topic = sample_and_split(
    dataset_topic_full, "label", 8
)
train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(
    dataset_fact_opinion_full, "label", 8
)

The actual training

Let's train our models for each task! We will use TaylorAI/bge-micro-v2, available in the Hugging Face Hub. You can check the MTEB leaderboard to select the best model for your use case.

def train_model(model_name, dataset, eval_dataset):
    model = SetFitModel.from_pretrained(model_name)

    trainer = Trainer(
        model=model,
        train_dataset=dataset,
    )
    trainer.train()
    metrics = trainer.evaluate(eval_dataset)
    print(metrics)

    return model
model_topic = train_model(
    model_name="TaylorAI/bge-micro-v2",
    dataset=train_dataset_topic,
    eval_dataset=eval_dataset_topic,
)
model_topic.save_pretrained("topic_classification_model")
model_topic = SetFitModel.from_pretrained("topic_classification_model")
***** Running training *****
  Num unique pairs = 768
  Batch size = 16
  Num epochs = 1
  Total optimization steps = 48

{'embedding_loss': 0.1873, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.02}

***** Running evaluation *****

{'train_runtime': 4.9767, 'train_samples_per_second': 154.318, 'train_steps_per_second': 9.645, 'epoch': 1.0}
{'accuracy': 0.8333333333333334}

model_fact_opinion = train_model(
    model_name="TaylorAI/bge-micro-v2",
    dataset=train_dataset_fact_opinion,
    eval_dataset=eval_dataset_fact_opinion,
)
model_fact_opinion.save_pretrained("fact_opinion_classification_model")
model_fact_opinion = SetFitModel.from_pretrained("fact_opinion_classification_model")
***** Running training *****
  Num unique pairs = 144
  Batch size = 16
  Num epochs = 1
  Total optimization steps = 9

{'embedding_loss': 0.2985, 'learning_rate': 2e-05, 'epoch': 0.11}

***** Running evaluation *****

{'train_runtime': 0.8327, 'train_samples_per_second': 172.931, 'train_steps_per_second': 10.808, 'epoch': 1.0}
{'accuracy': 0.9090909090909091}

Voilà! The models are now trained and ready to be used. You can start making predictions to check the model's performance and add the new label. Optionally, you can continue using distilabel to generate additional data or Argilla to verify the quality of the predictions.

def predict(model, input, labels):
    model.labels = labels
    prediction = model.predict([input])
    return prediction[0]
predict(
    model_topic, "The new iPhone is expected to be released next month.", labels_topic
)
'Sci/Tech'
predict(
    model_fact_opinion,
    "The new iPhone is expected to be released next month.",
    labels_fact_opinion,
)
'Opinion-based'

Conclusions

In this tutorial, we showcased the detailed steps to build a pipeline for generating text classification data using distilabel. You can customize this pipeline for your own use cases and share your datasets with the community through the Hugging Face Hub.

We defined two text classification tasks—a topic classification task and a fact versus opinion classification task—and generated new data using various models via the serverless Hugging Face Inference API. Then, we curated the generated data with Argilla. Finally, we trained the models with SetFit using both the original and synthetic data.