Skip to content

Saving step generated artifacts

Some Steps might need to produce an auxiliary artifact that is not a result of the computation, but is needed for the computation. For example, the FaissNearestNeighbour needs to create a Faiss index to compute the output of the step which are the top k nearest neighbours for each input. Generating the Faiss index takes time and it could potentially be reused outside of the distilabel pipeline, so it would be a shame not saving it.

For this reason, Steps have a method called save_artifact that allows saving artifacts that will be included along the outputs of the pipeline in the generated Distiset. The generated artifacts will be uploaded and saved when using Distiset.push_to_hub or Distiset.save_to_disk respectively. Let's see how to use it with a simple example.

from typing import List, TYPE_CHECKING
from distilabel.steps import GlobalStep, StepInput, StepOutput
import matplotlib.pyplot as plt

if TYPE_CHECKING:
    from distilabel.steps import StepOutput


class CountTextCharacters(GlobalStep):
    @property
    def inputs(self) -> List[str]:
        return ["text"]

    @property
    def outputs(self) -> List[str]:
        return ["text_character_count"]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        character_counts = []

        for input in inputs:
            text_character_count = len(input["text"])
            input["text_character_count"] = text_character_count
            character_counts.append(text_character_count)

        # Generate plot with the distribution of text character counts
        plt.figure(figsize=(10, 6))
        plt.hist(character_counts, bins=30, edgecolor="black")
        plt.title("Distribution of Text Character Counts")
        plt.xlabel("Character Count")
        plt.ylabel("Frequency")

        # Save the plot as an artifact of the step
        self.save_artifact(
            name="text_character_count_distribution",
            write_function=lambda path: plt.savefig(path / "figure.png"),
            metadata={"type": "image", "library": "matplotlib"},
        )

        plt.close()

        yield inputs

As it can be seen in the example above, we have created a simple step that counts the number of characters in each input text and generates a histogram with the distribution of the character counts. We save the histogram as an artifact of the step using the save_artifact method. The method takes three arguments:

  • name: The name we want to give to the artifact.
  • write_function: A function that writes the artifact to the desired path. The function will receive a path argument which is a pathlib.Path object pointing to the directory where the artifact should be saved.
  • metadata: A dictionary with metadata about the artifact. This metadata will be saved along with the artifact.

Let's execute the step with a simple pipeline and push the resulting Distiset to the Hugging Face Hub:

Example full code
from typing import TYPE_CHECKING, List

import matplotlib.pyplot as plt
from datasets import load_dataset
from distilabel.pipeline import Pipeline
from distilabel.steps import GlobalStep, StepInput, StepOutput

if TYPE_CHECKING:
    from distilabel.steps import StepOutput


class CountTextCharacters(GlobalStep):
    @property
    def inputs(self) -> List[str]:
        return ["text"]

    @property
    def outputs(self) -> List[str]:
        return ["text_character_count"]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        character_counts = []

        for input in inputs:
            text_character_count = len(input["text"])
            input["text_character_count"] = text_character_count
            character_counts.append(text_character_count)

        # Generate plot with the distribution of text character counts
        plt.figure(figsize=(10, 6))
        plt.hist(character_counts, bins=30, edgecolor="black")
        plt.title("Distribution of Text Character Counts")
        plt.xlabel("Character Count")
        plt.ylabel("Frequency")

        # Save the plot as an artifact of the step
        self.save_artifact(
            name="text_character_count_distribution",
            write_function=lambda path: plt.savefig(path / "figure.png"),
            metadata={"type": "image", "library": "matplotlib"},
        )

        plt.close()

        yield inputs


with Pipeline() as pipeline:
    count_text_characters = CountTextCharacters()

if __name__ == "__main__":
    distiset = pipeline.run(
        dataset=load_dataset(
            "HuggingFaceH4/instruction-dataset", split="test"
        ).rename_column("prompt", "text"),
    )

    distiset.push_to_hub("distilabel-internal-testing/distilabel-artifacts-example")

The generated distilabel-internal-testing/distilabel-artifacts-example dataset repository has a section in its card describing the artifacts generated by the pipeline and the generated plot can be seen here.