Skip to content

Index

GeneratorStepOutput = Iterator[Tuple[List[Dict[str, Any]], bool]] module-attribute

GeneratorStepOutput is an alias of the typing Iterator[Tuple[List[Dict[str, Any]], bool]]

StepInput = Annotated[List[Dict[str, Any]], _STEP_INPUT_ANNOTATION] module-attribute

StepInput is just an Annotated alias of the typing List[Dict[str, Any]] with extra metadata that allows distilabel to perform validations over the process step method defined in each Step

StepOutput = Iterator[List[Dict[str, Any]]] module-attribute

StepOutput is an alias of the typing Iterator[List[Dict[str, Any]]]

CombineColumns

Bases: Step

Combines columns from a list of StepInput.

CombineColumns is a Step that implements the process method that calls the combine_dicts function to handle and combine a list of StepInput. Also CombineColumns provides two attributes columns and output_columns to specify the columns to merge and the output columns which will override the default value for the properties inputs and outputs, respectively.

Attributes:

Name Type Description
columns List[str]

List of strings with the names of the columns to merge.

output_columns Optional[List[str]]

Optional list of strings with the names of the output columns.

Input columns
  • dynamic (determined by columns attribute): The columns to merge.
Output columns
  • dynamic (determined by columns and output_columns attributes): The columns that were merged.
Source code in src/distilabel/steps/combine.py
class CombineColumns(Step):
    """Combines columns from a list of `StepInput`.

    `CombineColumns` is a `Step` that implements the `process` method that calls the `combine_dicts`
    function to handle and combine a list of `StepInput`. Also `CombineColumns` provides two attributes
    `columns` and `output_columns` to specify the columns to merge and the output columns
    which will override the default value for the properties `inputs` and `outputs`, respectively.

    Attributes:
        columns: List of strings with the names of the columns to merge.
        output_columns: Optional list of strings with the names of the output columns.

    Input columns:
        - dynamic (determined by `columns` attribute): The columns to merge.

    Output columns:
        - dynamic (determined by `columns` and `output_columns` attributes): The columns
            that were merged.
    """

    columns: List[str]
    output_columns: Optional[List[str]] = None

    @property
    def inputs(self) -> List[str]:
        """The inputs for the task are the column names in `columns`."""
        return self.columns

    @property
    def outputs(self) -> List[str]:
        """The outputs for the task are the column names in `output_columns` or
        `merged_{column}` for each column in `columns`."""
        return (
            self.output_columns
            if self.output_columns is not None
            else [f"merged_{column}" for column in self.columns]
        )

    @override
    def process(self, *inputs: StepInput) -> "StepOutput":
        """The `process` method calls the `combine_dicts` function to handle and combine a list of `StepInput`.

        Args:
            *inputs: A list of `StepInput` to be combined.

        Yields:
            A `StepOutput` with the combined `StepInput` using the `combine_dicts` function.
        """
        yield combine_dicts(
            *inputs,
            merge_keys=self.inputs,
            output_merge_keys=self.outputs,
        )

inputs: List[str] property

The inputs for the task are the column names in columns.

outputs: List[str] property

The outputs for the task are the column names in output_columns or merged_{column} for each column in columns.

process(*inputs)

The process method calls the combine_dicts function to handle and combine a list of StepInput.

Parameters:

Name Type Description Default
*inputs StepInput

A list of StepInput to be combined.

()

Yields:

Type Description
StepOutput

A StepOutput with the combined StepInput using the combine_dicts function.

Source code in src/distilabel/steps/combine.py
@override
def process(self, *inputs: StepInput) -> "StepOutput":
    """The `process` method calls the `combine_dicts` function to handle and combine a list of `StepInput`.

    Args:
        *inputs: A list of `StepInput` to be combined.

    Yields:
        A `StepOutput` with the combined `StepInput` using the `combine_dicts` function.
    """
    yield combine_dicts(
        *inputs,
        merge_keys=self.inputs,
        output_merge_keys=self.outputs,
    )

ConversationTemplate

Bases: Step

Generate a conversation template from an instruction and a response.

Input columns
  • instruction (str): The instruction to be used in the conversation.
  • response (str): The response to be used in the conversation.
Output columns
  • conversation (ChatType): The conversation template.
Categories
  • format
  • chat
  • template
Source code in src/distilabel/steps/formatting/conversation.py
class ConversationTemplate(Step):
    """Generate a conversation template from an instruction and a response.

    Input columns:
        - instruction (`str`): The instruction to be used in the conversation.
        - response (`str`): The response to be used in the conversation.

    Output columns:
        - conversation (`ChatType`): The conversation template.

    Categories:
        - format
        - chat
        - template
    """

    @property
    def inputs(self) -> List[str]:
        """The instruction and response."""
        return ["instruction", "response"]

    @property
    def outputs(self) -> List[str]:
        """The conversation template."""
        return ["conversation"]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Generate a conversation template from an instruction and a response.

        Args:
            inputs: The input data.

        Yields:
            The input data with the conversation template.
        """
        for input in inputs:
            input["conversation"] = [
                {"role": "user", "content": input["instruction"]},
                {"role": "assistant", "content": input["response"]},
            ]
        yield inputs

inputs: List[str] property

The instruction and response.

outputs: List[str] property

The conversation template.

process(inputs)

Generate a conversation template from an instruction and a response.

Parameters:

Name Type Description Default
inputs StepInput

The input data.

required

Yields:

Type Description
StepOutput

The input data with the conversation template.

Source code in src/distilabel/steps/formatting/conversation.py
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Generate a conversation template from an instruction and a response.

    Args:
        inputs: The input data.

    Yields:
        The input data with the conversation template.
    """
    for input in inputs:
        input["conversation"] = [
            {"role": "user", "content": input["instruction"]},
            {"role": "assistant", "content": input["response"]},
        ]
    yield inputs

DeitaFiltering

Bases: GlobalStep

Filter dataset rows using DEITA filtering strategy.

Filter the dataset based on the DEITA score and the cosine distance between the embeddings. It's an implementation of the filtering step from the paper 'What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.

Attributes:

Name Type Description
data_budget RuntimeParameter[int]

The desired size of the dataset after filtering.

diversity_threshold RuntimeParameter[float]

If a row has a cosine distance with respect to it's nearest neighbor greater than this value, it will be included in the filtered dataset. Defaults to 0.9.

normalize_embeddings RuntimeParameter[bool]

Whether to normalize the embeddings before computing the cosine distance. Defaults to True.

Runtime parameters
  • data_budget: The desired size of the dataset after filtering.
  • diversity_threshold: If a row has a cosine distance with respect to it's nearest neighbor greater than this value, it will be included in the filtered dataset.
Input columns
  • evol_instruction_score (float): The score of the instruction generated by ComplexityScorer step.
  • evol_response_score (float): The score of the response generated by QualityScorer step.
  • embedding (List[float]): The embedding generated for the conversation of the instruction-response pair using GenerateEmbeddings step.
Output columns
  • deita_score (float): The DEITA score for the instruction-response pair.
  • deita_score_computed_with (List[str]): The scores used to compute the DEITA score.
  • nearest_neighbor_distance (float): The cosine distance between the embeddings of the instruction-response pair.
Categories
  • filtering
References
Source code in src/distilabel/steps/deita.py
class DeitaFiltering(GlobalStep):
    """Filter dataset rows using DEITA filtering strategy.

    Filter the dataset based on the DEITA score and the cosine distance between the embeddings.
    It's an implementation of the filtering step from the paper 'What Makes Good Data
    for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.

    Attributes:
        data_budget: The desired size of the dataset after filtering.
        diversity_threshold: If a row has a cosine distance with respect to it's nearest
            neighbor greater than this value, it will be included in the filtered dataset.
            Defaults to `0.9`.
        normalize_embeddings: Whether to normalize the embeddings before computing the cosine
            distance. Defaults to `True`.

    Runtime parameters:
        - `data_budget`: The desired size of the dataset after filtering.
        - `diversity_threshold`: If a row has a cosine distance with respect to it's nearest
            neighbor greater than this value, it will be included in the filtered dataset.

    Input columns:
        - evol_instruction_score (`float`): The score of the instruction generated by
            `ComplexityScorer` step.
        - evol_response_score (`float`): The score of the response generated by
            `QualityScorer` step.
        - embedding (`List[float]`): The embedding generated for the conversation of the
            instruction-response pair using `GenerateEmbeddings` step.

    Output columns:
        - deita_score (`float`): The DEITA score for the instruction-response pair.
        - deita_score_computed_with (`List[str]`): The scores used to compute the DEITA
            score.
        - nearest_neighbor_distance (`float`): The cosine distance between the embeddings
            of the instruction-response pair.

    Categories:
        - filtering

    References:
        - [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
    """

    data_budget: RuntimeParameter[int] = Field(
        default=None, description="The desired size of the dataset after filtering."
    )
    diversity_threshold: RuntimeParameter[float] = Field(
        default=0.9,
        description="If a row has a cosine distance with respect to it's nearest neighbor"
        " greater than this value, it will be included in the filtered dataset.",
    )
    normalize_embeddings: RuntimeParameter[bool] = Field(
        default=True,
        description="Whether to normalize the embeddings before computing the cosine distance.",
    )
    distance_metric: RuntimeParameter[Literal["cosine", "manhattan"]] = Field(
        default="cosine",
        description="The distance metric to use. Currently only 'cosine' is supported.",
    )

    @property
    def inputs(self) -> List[str]:
        return ["evol_instruction_score", "evol_response_score", "embedding"]

    @property
    def outputs(self) -> List[str]:
        return ["deita_score", "nearest_neighbor_distance", "deita_score_computed_with"]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Filter the dataset based on the DEITA score and the cosine distance between the
        embeddings.

        Args:
            inputs: The input data.

        Returns:
            The filtered dataset.
        """
        inputs = self._compute_deita_score(inputs)
        inputs = self._compute_nearest_neighbor(inputs)
        inputs.sort(key=lambda x: x["deita_score"], reverse=True)

        selected_rows = []
        for input in inputs:
            if len(selected_rows) >= self.data_budget:  # type: ignore
                break
            if input["nearest_neighbor_distance"] >= self.diversity_threshold:
                selected_rows.append(input)
        yield selected_rows

    def _compute_deita_score(self, inputs: StepInput) -> StepInput:
        """Computes the DEITA score for each instruction-response pair. The DEITA score is
        the product of the instruction score and the response score.

        Args:
            inputs: The input data.

        Returns:
            The input data with the DEITA score computed.
        """
        for input_ in inputs:
            evol_instruction_score = input_.get("evol_instruction_score")
            evol_response_score = input_.get("evol_response_score")

            if evol_instruction_score and evol_response_score:
                deita_score = evol_instruction_score * evol_response_score
                score_computed_with = ["evol_instruction_score", "evol_response_score"]
            elif evol_instruction_score:
                self._logger.warning(
                    "Response score is missing for the instruction-response pair. Using"
                    " instruction score as DEITA score."
                )
                deita_score = evol_instruction_score
                score_computed_with = ["evol_instruction_score"]
            elif evol_response_score:
                self._logger.warning(
                    "Instruction score is missing for the instruction-response pair. Using"
                    " response score as DEITA score."
                )
                deita_score = evol_response_score
                score_computed_with = ["evol_response_score"]
            else:
                self._logger.warning(
                    "Instruction and response scores are missing for the instruction-response"
                    " pair. Setting DEITA score to 0."
                )
                deita_score = 0
                score_computed_with = []

            input_.update(
                {
                    "deita_score": deita_score,
                    "deita_score_computed_with": score_computed_with,
                }
            )
        return inputs

    def _compute_nearest_neighbor(self, inputs: StepInput) -> StepInput:
        """Computes the cosine distance between the embeddings of the instruction-response
        pairs and the nearest neighbor.

        Args:
            inputs: The input data.

        Returns:
            The input data with the cosine distance computed.
        """
        embeddings = np.array([input["embedding"] for input in inputs])
        if self.normalize_embeddings:
            embeddings = self._normalize_embeddings(embeddings)
        self._logger.info("📏 Computing nearest neighbor distance...")

        if self.distance_metric == "cosine":
            self._logger.info("📏 Using cosine distance.")
            distances = self._cosine_distance(embeddings)
        else:
            self._logger.info("📏 Using manhattan distance.")
            distances = self._manhattan_distance(embeddings)

        for distance, input in zip(distances, inputs):
            input["nearest_neighbor_distance"] = distance
        return inputs

    def _normalize_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
        """Normalize the embeddings.

        Args:
            embeddings: The embeddings to normalize.

        Returns:
            The normalized embeddings.
        """
        self._logger.info("⚖️ Normalizing embeddings...")
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        return embeddings / norms

    def _cosine_distance(self, embeddings: np.array) -> np.array:  # type: ignore
        """Computes the cosine distance between the embeddings.

        Args:
            embeddings: The embeddings.

        Returns:
            The cosine distance between the embeddings.
        """
        cosine_similarity = np.dot(embeddings, embeddings.T)
        cosine_distance = 1 - cosine_similarity
        # Ignore self-distance
        np.fill_diagonal(cosine_distance, np.inf)
        return np.min(cosine_distance, axis=1)

    def _manhattan_distance(self, embeddings: np.array) -> np.array:  # type: ignore
        """Computes the manhattan distance between the embeddings.

        Args:
            embeddings: The embeddings.

        Returns:
            The manhattan distance between the embeddings.
        """
        manhattan_distance = np.abs(embeddings[:, None] - embeddings).sum(-1)
        # Ignore self-distance
        np.fill_diagonal(manhattan_distance, np.inf)
        return np.min(manhattan_distance, axis=1)

process(inputs)

Filter the dataset based on the DEITA score and the cosine distance between the embeddings.

Parameters:

Name Type Description Default
inputs StepInput

The input data.

required

Returns:

Type Description
StepOutput

The filtered dataset.

Source code in src/distilabel/steps/deita.py
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Filter the dataset based on the DEITA score and the cosine distance between the
    embeddings.

    Args:
        inputs: The input data.

    Returns:
        The filtered dataset.
    """
    inputs = self._compute_deita_score(inputs)
    inputs = self._compute_nearest_neighbor(inputs)
    inputs.sort(key=lambda x: x["deita_score"], reverse=True)

    selected_rows = []
    for input in inputs:
        if len(selected_rows) >= self.data_budget:  # type: ignore
            break
        if input["nearest_neighbor_distance"] >= self.diversity_threshold:
            selected_rows.append(input)
    yield selected_rows

ExpandColumns

Bases: Step

Expand columns that contain lists into multiple rows.

ExpandColumns is a Step that takes a list of columns and expands them into multiple rows. The new rows will have the same data as the original row, except for the expanded column, which will contain a single item from the original list.

Attributes:

Name Type Description
columns Union[Dict[str, str], List[str]]

A dictionary that maps the column to be expanded to the new column name or a list of columns to be expanded. If a list is provided, the new column name will be the same as the column name.

Input columns
  • dynamic (determined by columns attribute): The columns to be expanded into multiple rows.
Output columns
  • dynamic (determined by columns attribute): The expanded columns.
Source code in src/distilabel/steps/expand.py
class ExpandColumns(Step):
    """Expand columns that contain lists into multiple rows.

    `ExpandColumns` is a `Step` that takes a list of columns and expands them into multiple
    rows. The new rows will have the same data as the original row, except for the expanded
    column, which will contain a single item from the original list.

    Attributes:
        columns: A dictionary that maps the column to be expanded to the new column name
            or a list of columns to be expanded. If a list is provided, the new column name
            will be the same as the column name.

    Input columns:
        - dynamic (determined by `columns` attribute): The columns to be expanded into
            multiple rows.

    Output columns:
        - dynamic (determined by `columns` attribute):  The expanded columns.
    """

    columns: Union[Dict[str, str], List[str]]

    @field_validator("columns")
    @classmethod
    def always_dict(cls, value: Union[Dict[str, str], List[str]]) -> Dict[str, str]:
        """Ensure that the columns are always a dictionary.

        Args:
            value: The columns to be expanded.

        Returns:
            The columns to be expanded as a dictionary.
        """
        if isinstance(value, list):
            return {col: col for col in value}

        return value

    @property
    def inputs(self) -> List[str]:
        """The columns to be expanded."""
        return list(self.columns.keys())

    @property
    def outputs(self) -> List[str]:
        """The expanded columns."""
        return [
            new_column if new_column else expand_column
            for expand_column, new_column in self.columns.items()
        ]

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Expand the columns in the input data.

        Args:
            inputs: The input data.

        Yields:
            The expanded rows.
        """
        yield [row for input in inputs for row in self._expand_columns(input)]

    def _expand_columns(self, input: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Expand the columns in the input data.

        Args:
            input: The input data.

        Returns:
            The expanded rows.
        """
        expanded_rows = []
        for expand_column, new_column in self.columns.items():  # type: ignore
            data = input.get(expand_column)
            rows = []
            for item, expanded in zip_longest(*[data, expanded_rows], fillvalue=input):
                rows.append({**expanded, new_column: item})
            expanded_rows = rows
        return expanded_rows

inputs: List[str] property

The columns to be expanded.

outputs: List[str] property

The expanded columns.

always_dict(value) classmethod

Ensure that the columns are always a dictionary.

Parameters:

Name Type Description Default
value Union[Dict[str, str], List[str]]

The columns to be expanded.

required

Returns:

Type Description
Dict[str, str]

The columns to be expanded as a dictionary.

Source code in src/distilabel/steps/expand.py
@field_validator("columns")
@classmethod
def always_dict(cls, value: Union[Dict[str, str], List[str]]) -> Dict[str, str]:
    """Ensure that the columns are always a dictionary.

    Args:
        value: The columns to be expanded.

    Returns:
        The columns to be expanded as a dictionary.
    """
    if isinstance(value, list):
        return {col: col for col in value}

    return value

process(inputs)

Expand the columns in the input data.

Parameters:

Name Type Description Default
inputs StepInput

The input data.

required

Yields:

Type Description
StepOutput

The expanded rows.

Source code in src/distilabel/steps/expand.py
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Expand the columns in the input data.

    Args:
        inputs: The input data.

    Yields:
        The expanded rows.
    """
    yield [row for input in inputs for row in self._expand_columns(input)]

FormatChatGenerationDPO

Bases: Step

Format the output of a combination of a ChatGeneration + a preference task such as UltraFeedback, for Direct Preference Optimization (DPO) following the standard formatting from frameworks such as axolotl or alignment-handbook.

FormatChatGenerationDPO is a Step that formats the output of the combination of a ChatGeneration task with a preference Task i.e. a task generating ratings, so that those are used to rank the existing generations and provide the chosen and rejected generations based on the ratings.

Note

The messages column should contain at least one message from the user, the generations column should contain at least two generations, the ratings column should contain the same number of ratings as generations.

Input columns
  • messages (List[Dict[str, str]]): The conversation messages.
  • generations (List[str]): The generations produced by the LLM.
  • generation_models (List[str], optional): The model names used to generate the generations, only available if the model_name from the ChatGeneration task/s is combined into a single column named this way, otherwise, it will be ignored.
  • ratings (List[float]): The ratings for each of the generations, produced by a preference task such as UltraFeedback.
Output columns
  • prompt (str): The user message used to generate the generations with the LLM.
  • prompt_id (str): The SHA256 hash of the prompt.
  • chosen (List[Dict[str, str]]): The chosen generation based on the ratings.
  • chosen_model (str, optional): The model name used to generate the chosen generation, if the generation_models are available.
  • chosen_rating (float): The rating of the chosen generation.
  • rejected (List[Dict[str, str]]): The rejected generation based on the ratings.
  • rejected_model (str, optional): The model name used to generate the rejected generation, if the generation_models are available.
  • rejected_rating (float): The rating of the rejected generation.
Categories
  • format
  • chat-generation
  • preference
  • messages
  • generations
Source code in src/distilabel/steps/formatting/dpo.py
class FormatChatGenerationDPO(Step):
    """Format the output of a combination of a `ChatGeneration` + a preference task such as
    `UltraFeedback`, for Direct Preference Optimization (DPO) following the standard formatting
    from frameworks such as `axolotl` or `alignment-handbook`.

    `FormatChatGenerationDPO` is a `Step` that formats the output of the combination of a `ChatGeneration`
    task with a preference `Task` i.e. a task generating `ratings`, so that those are used to rank the
    existing generations and provide the `chosen` and `rejected` generations based on the `ratings`.

    Note:
        The `messages` column should contain at least one message from the user, the `generations`
        column should contain at least two generations, the `ratings` column should contain the same
        number of ratings as generations.

    Input columns:
        - messages (`List[Dict[str, str]]`): The conversation messages.
        - generations (`List[str]`): The generations produced by the `LLM`.
        - generation_models (`List[str]`, optional): The model names used to generate the `generations`,
            only available if the `model_name` from the `ChatGeneration` task/s is combined into a single
            column named this way, otherwise, it will be ignored.
        - ratings (`List[float]`): The ratings for each of the `generations`, produced by a preference
            task such as `UltraFeedback`.

    Output columns:
        - prompt (`str`): The user message used to generate the `generations` with the `LLM`.
        - prompt_id (`str`): The `SHA256` hash of the `prompt`.
        - chosen (`List[Dict[str, str]]`): The `chosen` generation based on the `ratings`.
        - chosen_model (`str`, optional): The model name used to generate the `chosen` generation,
            if the `generation_models` are available.
        - chosen_rating (`float`): The rating of the `chosen` generation.
        - rejected (`List[Dict[str, str]]`): The `rejected` generation based on the `ratings`.
        - rejected_model (`str`, optional): The model name used to generate the `rejected` generation,
            if the `generation_models` are available.
        - rejected_rating (`float`): The rating of the `rejected` generation.

    Categories:
        - format
        - chat-generation
        - preference
        - messages
        - generations
    """

    @property
    def inputs(self) -> List[str]:
        """List of inputs required by the `Step`, which in this case are: `messages`, `generations`,
        and `ratings`."""
        return ["messages", "generations", "ratings"]

    @property
    def optional_inputs(self) -> List[str]:
        """List of optional inputs, which are not required by the `Step` but used if available,
        which in this case is: `generation_models`."""
        return ["generation_models"]

    @property
    def outputs(self) -> List[str]:
        """List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `chosen`,
        `chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`. Both
        the `chosen_model` and `rejected_model` being optional and only used if `generation_models`
        is available.

        Reference:
            - Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
        """
        return [
            "prompt",
            "prompt_id",
            "chosen",
            "chosen_model",
            "chosen_rating",
            "rejected",
            "rejected_model",
            "rejected_rating",
        ]

    def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
        """The `process` method formats the received `StepInput` or list of `StepInput`
        according to the DPO formatting standard.

        Args:
            *inputs: A list of `StepInput` to be combined.

        Yields:
            A `StepOutput` with batches of formatted `StepInput` following the DPO standard.
        """
        for input in inputs:
            for item in input:
                item["prompt"] = next(
                    (
                        turn["content"]
                        for turn in item["messages"]
                        if turn["role"] == "user"
                    ),
                    None,
                )
                item["prompt_id"] = hashlib.sha256(
                    item["prompt"].encode("utf-8")  # type: ignore
                ).hexdigest()

                chosen_idx = max(enumerate(item["ratings"]), key=lambda x: x[1])[0]
                item["chosen"] = item["messages"] + [
                    {
                        "role": "assistant",
                        "content": item["generations"][chosen_idx],
                    }
                ]
                if "generation_models" in item:
                    item["chosen_model"] = item["generation_models"][chosen_idx]
                item["chosen_rating"] = item["ratings"][chosen_idx]

                rejected_idx = min(enumerate(item["ratings"]), key=lambda x: x[1])[0]
                item["rejected"] = item["messages"] + [
                    {
                        "role": "assistant",
                        "content": item["generations"][rejected_idx],
                    }
                ]
                if "generation_models" in item:
                    item["rejected_model"] = item["generation_models"][rejected_idx]
                item["rejected_rating"] = item["ratings"][rejected_idx]

            yield input

inputs: List[str] property

List of inputs required by the Step, which in this case are: messages, generations, and ratings.

optional_inputs: List[str] property

List of optional inputs, which are not required by the Step but used if available, which in this case is: generation_models.

outputs: List[str] property

List of outputs generated by the Step, which are: prompt, prompt_id, chosen, chosen_model, chosen_rating, rejected, rejected_model, rejected_rating. Both the chosen_model and rejected_model being optional and only used if generation_models is available.

Reference
  • Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

process(*inputs)

The process method formats the received StepInput or list of StepInput according to the DPO formatting standard.

Parameters:

Name Type Description Default
*inputs StepInput

A list of StepInput to be combined.

()

Yields:

Type Description
StepOutput

A StepOutput with batches of formatted StepInput following the DPO standard.

Source code in src/distilabel/steps/formatting/dpo.py
def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
    """The `process` method formats the received `StepInput` or list of `StepInput`
    according to the DPO formatting standard.

    Args:
        *inputs: A list of `StepInput` to be combined.

    Yields:
        A `StepOutput` with batches of formatted `StepInput` following the DPO standard.
    """
    for input in inputs:
        for item in input:
            item["prompt"] = next(
                (
                    turn["content"]
                    for turn in item["messages"]
                    if turn["role"] == "user"
                ),
                None,
            )
            item["prompt_id"] = hashlib.sha256(
                item["prompt"].encode("utf-8")  # type: ignore
            ).hexdigest()

            chosen_idx = max(enumerate(item["ratings"]), key=lambda x: x[1])[0]
            item["chosen"] = item["messages"] + [
                {
                    "role": "assistant",
                    "content": item["generations"][chosen_idx],
                }
            ]
            if "generation_models" in item:
                item["chosen_model"] = item["generation_models"][chosen_idx]
            item["chosen_rating"] = item["ratings"][chosen_idx]

            rejected_idx = min(enumerate(item["ratings"]), key=lambda x: x[1])[0]
            item["rejected"] = item["messages"] + [
                {
                    "role": "assistant",
                    "content": item["generations"][rejected_idx],
                }
            ]
            if "generation_models" in item:
                item["rejected_model"] = item["generation_models"][rejected_idx]
            item["rejected_rating"] = item["ratings"][rejected_idx]

        yield input

FormatChatGenerationSFT

Bases: Step

Format the output of a ChatGeneration task for Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as axolotl or alignment-handbook.

FormatChatGenerationSFT is a Step that formats the output of a ChatGeneration task for Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as axolotl or alignment-handbook. The output of the ChatGeneration task is formatted into a chat-like conversation with the instruction as the user message and the generation as the assistant message. Optionally, if the system_prompt is available, it is included as the first message in the conversation.

Input columns
  • system_prompt (str, optional): The system prompt used within the LLM to generate the generation, if available.
  • instruction (str): The instruction used to generate the generation with the LLM.
  • generation (str): The generation produced by the LLM.
Output columns
  • prompt (str): The instruction used to generate the generation with the LLM.
  • prompt_id (str): The SHA256 hash of the prompt.
  • messages (List[Dict[str, str]]): The chat-like conversation with the instruction as the user message and the generation as the assistant message.
Categories
  • format
  • chat-generation
  • instruction
  • generation
Source code in src/distilabel/steps/formatting/sft.py
class FormatChatGenerationSFT(Step):
    """Format the output of a `ChatGeneration` task for Supervised Fine-Tuning (SFT) following the
    standard formatting from frameworks such as `axolotl` or `alignment-handbook`.

    `FormatChatGenerationSFT` is a `Step` that formats the output of a `ChatGeneration` task for
    Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as `axolotl`
    or `alignment-handbook`. The output of the `ChatGeneration` task is formatted into a chat-like
    conversation with the `instruction` as the user message and the `generation` as the assistant
    message. Optionally, if the `system_prompt` is available, it is included as the first message
    in the conversation.

    Input columns:
        - system_prompt (`str`, optional): The system prompt used within the `LLM` to generate the
            `generation`, if available.
        - instruction (`str`): The instruction used to generate the `generation` with the `LLM`.
        - generation (`str`): The generation produced by the `LLM`.

    Output columns:
        - prompt (`str`): The instruction used to generate the `generation` with the `LLM`.
        - prompt_id (`str`): The `SHA256` hash of the `prompt`.
        - messages (`List[Dict[str, str]]`): The chat-like conversation with the `instruction` as
            the user message and the `generation` as the assistant message.

    Categories:
        - format
        - chat-generation
        - instruction
        - generation
    """

    @property
    def inputs(self) -> List[str]:
        """List of inputs required by the `Step`, which in this case are: `instruction`, and `generation`."""
        return ["messages", "generation"]

    @property
    def outputs(self) -> List[str]:
        """List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `messages`.

        Reference:
            - Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
        """
        return ["prompt", "prompt_id", "messages"]

    def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
        """The `process` method formats the received `StepInput` or list of `StepInput`
        according to the SFT formatting standard.

        Args:
            *inputs: A list of `StepInput` to be combined.

        Yields:
            A `StepOutput` with batches of formatted `StepInput` following the SFT standard.
        """
        for input in inputs:
            for item in input:
                item["prompt"] = next(
                    (
                        turn["content"]
                        for turn in item["messages"]
                        if turn["role"] == "user"
                    ),
                    None,
                )

                item["prompt_id"] = hashlib.sha256(
                    item["prompt"].encode("utf-8")  # type: ignore
                ).hexdigest()

                item["messages"] = item["messages"] + [
                    {"role": "assistant", "content": item["generation"]},  # type: ignore
                ]
            yield input

inputs: List[str] property

List of inputs required by the Step, which in this case are: instruction, and generation.

outputs: List[str] property

List of outputs generated by the Step, which are: prompt, prompt_id, messages.

Reference
  • Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

process(*inputs)

The process method formats the received StepInput or list of StepInput according to the SFT formatting standard.

Parameters:

Name Type Description Default
*inputs StepInput

A list of StepInput to be combined.

()

Yields:

Type Description
StepOutput

A StepOutput with batches of formatted StepInput following the SFT standard.

Source code in src/distilabel/steps/formatting/sft.py
def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
    """The `process` method formats the received `StepInput` or list of `StepInput`
    according to the SFT formatting standard.

    Args:
        *inputs: A list of `StepInput` to be combined.

    Yields:
        A `StepOutput` with batches of formatted `StepInput` following the SFT standard.
    """
    for input in inputs:
        for item in input:
            item["prompt"] = next(
                (
                    turn["content"]
                    for turn in item["messages"]
                    if turn["role"] == "user"
                ),
                None,
            )

            item["prompt_id"] = hashlib.sha256(
                item["prompt"].encode("utf-8")  # type: ignore
            ).hexdigest()

            item["messages"] = item["messages"] + [
                {"role": "assistant", "content": item["generation"]},  # type: ignore
            ]
        yield input

FormatTextGenerationDPO

Bases: Step

Format the output of your LLMs for Direct Preference Optimization (DPO).

FormatTextGenerationDPO is a Step that formats the output of the combination of a TextGeneration task with a preference Task i.e. a task generating ratings, so that those are used to rank the existing generations and provide the chosen and rejected generations based on the ratings. Use this step to transform the output of a combination of a TextGeneration + a preference task such as UltraFeedback following the standard formatting from frameworks such as axolotl or alignment-handbook.

Note

The generations column should contain at least two generations, the ratings column should contain the same number of ratings as generations.

Input columns
  • system_prompt (str, optional): The system prompt used within the LLM to generate the generations, if available.
  • instruction (str): The instruction used to generate the generations with the LLM.
  • generations (List[str]): The generations produced by the LLM.
  • generation_models (List[str], optional): The model names used to generate the generations, only available if the model_name from the TextGeneration task/s is combined into a single column named this way, otherwise, it will be ignored.
  • ratings (List[float]): The ratings for each of the generations, produced by a preference task such as UltraFeedback.
Output columns
  • prompt (str): The instruction used to generate the generations with the LLM.
  • prompt_id (str): The SHA256 hash of the prompt.
  • chosen (List[Dict[str, str]]): The chosen generation based on the ratings.
  • chosen_model (str, optional): The model name used to generate the chosen generation, if the generation_models are available.
  • chosen_rating (float): The rating of the chosen generation.
  • rejected (List[Dict[str, str]]): The rejected generation based on the ratings.
  • rejected_model (str, optional): The model name used to generate the rejected generation, if the generation_models are available.
  • rejected_rating (float): The rating of the rejected generation.
Categories
  • format
  • text-generation
  • preference
  • instruction
  • generations
Source code in src/distilabel/steps/formatting/dpo.py
class FormatTextGenerationDPO(Step):
    """Format the output of your LLMs for Direct Preference Optimization (DPO).

    `FormatTextGenerationDPO` is a `Step` that formats the output of the combination of a `TextGeneration`
    task with a preference `Task` i.e. a task generating `ratings`, so that those are used to rank the
    existing generations and provide the `chosen` and `rejected` generations based on the `ratings`.
    Use this step to transform the output of a combination of a `TextGeneration` + a preference task such as
    `UltraFeedback` following the standard formatting from frameworks such as `axolotl` or `alignment-handbook`.

    Note:
        The `generations` column should contain at least two generations, the `ratings` column should
        contain the same number of ratings as generations.

    Input columns:
        - system_prompt (`str`, optional): The system prompt used within the `LLM` to generate the
            `generations`, if available.
        - instruction (`str`): The instruction used to generate the `generations` with the `LLM`.
        - generations (`List[str]`): The generations produced by the `LLM`.
        - generation_models (`List[str]`, optional): The model names used to generate the `generations`,
            only available if the `model_name` from the `TextGeneration` task/s is combined into a single
            column named this way, otherwise, it will be ignored.
        - ratings (`List[float]`): The ratings for each of the `generations`, produced by a preference
            task such as `UltraFeedback`.

    Output columns:
        - prompt (`str`): The instruction used to generate the `generations` with the `LLM`.
        - prompt_id (`str`): The `SHA256` hash of the `prompt`.
        - chosen (`List[Dict[str, str]]`): The `chosen` generation based on the `ratings`.
        - chosen_model (`str`, optional): The model name used to generate the `chosen` generation,
            if the `generation_models` are available.
        - chosen_rating (`float`): The rating of the `chosen` generation.
        - rejected (`List[Dict[str, str]]`): The `rejected` generation based on the `ratings`.
        - rejected_model (`str`, optional): The model name used to generate the `rejected` generation,
            if the `generation_models` are available.
        - rejected_rating (`float`): The rating of the `rejected` generation.

    Categories:
        - format
        - text-generation
        - preference
        - instruction
        - generations
    """

    @property
    def inputs(self) -> List[str]:
        """List of inputs required by the `Step`, which in this case are: `instruction`, `generations`,
        and `ratings`."""
        return ["instruction", "generations", "ratings"]

    @property
    def optional_inputs(self) -> List[str]:
        """List of optional inputs, which are not required by the `Step` but used if available,
        which in this case are: `system_prompt`, and `generation_models`."""
        return ["system_prompt", "generation_models"]

    @property
    def outputs(self) -> List[str]:
        """List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `chosen`,
        `chosen_model`, `chosen_rating`, `rejected`, `rejected_model`, `rejected_rating`. Both
        the `chosen_model` and `rejected_model` being optional and only used if `generation_models`
        is available.

        Reference:
            - Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
        """
        return [
            "prompt",
            "prompt_id",
            "chosen",
            "chosen_model",
            "chosen_rating",
            "rejected",
            "rejected_model",
            "rejected_rating",
        ]

    def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
        """The `process` method formats the received `StepInput` or list of `StepInput`
        according to the DPO formatting standard.

        Args:
            *inputs: A list of `StepInput` to be combined.

        Yields:
            A `StepOutput` with batches of formatted `StepInput` following the DPO standard.
        """
        for input in inputs:
            for item in input:
                messages = [
                    {"role": "user", "content": item["instruction"]},  # type: ignore
                ]
                if (
                    "system_prompt" in item
                    and isinstance(item["system_prompt"], str)  # type: ignore
                    and len(item["system_prompt"]) > 0  # type: ignore
                ):
                    messages.insert(
                        0,
                        {"role": "system", "content": item["system_prompt"]},  # type: ignore
                    )

                item["prompt"] = item["instruction"]
                item["prompt_id"] = hashlib.sha256(
                    item["prompt"].encode("utf-8")  # type: ignore
                ).hexdigest()

                chosen_idx = max(enumerate(item["ratings"]), key=lambda x: x[1])[0]
                item["chosen"] = messages + [
                    {
                        "role": "assistant",
                        "content": item["generations"][chosen_idx],
                    }
                ]
                if "generation_models" in item:
                    item["chosen_model"] = item["generation_models"][chosen_idx]
                item["chosen_rating"] = item["ratings"][chosen_idx]

                rejected_idx = min(enumerate(item["ratings"]), key=lambda x: x[1])[0]
                item["rejected"] = messages + [
                    {
                        "role": "assistant",
                        "content": item["generations"][rejected_idx],
                    }
                ]
                if "generation_models" in item:
                    item["rejected_model"] = item["generation_models"][rejected_idx]
                item["rejected_rating"] = item["ratings"][rejected_idx]

            yield input

inputs: List[str] property

List of inputs required by the Step, which in this case are: instruction, generations, and ratings.

optional_inputs: List[str] property

List of optional inputs, which are not required by the Step but used if available, which in this case are: system_prompt, and generation_models.

outputs: List[str] property

List of outputs generated by the Step, which are: prompt, prompt_id, chosen, chosen_model, chosen_rating, rejected, rejected_model, rejected_rating. Both the chosen_model and rejected_model being optional and only used if generation_models is available.

Reference
  • Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

process(*inputs)

The process method formats the received StepInput or list of StepInput according to the DPO formatting standard.

Parameters:

Name Type Description Default
*inputs StepInput

A list of StepInput to be combined.

()

Yields:

Type Description
StepOutput

A StepOutput with batches of formatted StepInput following the DPO standard.

Source code in src/distilabel/steps/formatting/dpo.py
def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
    """The `process` method formats the received `StepInput` or list of `StepInput`
    according to the DPO formatting standard.

    Args:
        *inputs: A list of `StepInput` to be combined.

    Yields:
        A `StepOutput` with batches of formatted `StepInput` following the DPO standard.
    """
    for input in inputs:
        for item in input:
            messages = [
                {"role": "user", "content": item["instruction"]},  # type: ignore
            ]
            if (
                "system_prompt" in item
                and isinstance(item["system_prompt"], str)  # type: ignore
                and len(item["system_prompt"]) > 0  # type: ignore
            ):
                messages.insert(
                    0,
                    {"role": "system", "content": item["system_prompt"]},  # type: ignore
                )

            item["prompt"] = item["instruction"]
            item["prompt_id"] = hashlib.sha256(
                item["prompt"].encode("utf-8")  # type: ignore
            ).hexdigest()

            chosen_idx = max(enumerate(item["ratings"]), key=lambda x: x[1])[0]
            item["chosen"] = messages + [
                {
                    "role": "assistant",
                    "content": item["generations"][chosen_idx],
                }
            ]
            if "generation_models" in item:
                item["chosen_model"] = item["generation_models"][chosen_idx]
            item["chosen_rating"] = item["ratings"][chosen_idx]

            rejected_idx = min(enumerate(item["ratings"]), key=lambda x: x[1])[0]
            item["rejected"] = messages + [
                {
                    "role": "assistant",
                    "content": item["generations"][rejected_idx],
                }
            ]
            if "generation_models" in item:
                item["rejected_model"] = item["generation_models"][rejected_idx]
            item["rejected_rating"] = item["ratings"][rejected_idx]

        yield input

FormatTextGenerationSFT

Bases: Step

Format the output of a TextGeneration task for Supervised Fine-Tuning (SFT).

FormatTextGenerationSFT is a Step that formats the output of a TextGeneration task for Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as axolotl or alignment-handbook. The output of the TextGeneration task is formatted into a chat-like conversation with the instruction as the user message and the generation as the assistant message. Optionally, if the system_prompt is available, it is included as the first message in the conversation.

Input columns
  • system_prompt (str, optional): The system prompt used within the LLM to generate the generation, if available.
  • instruction (str): The instruction used to generate the generation with the LLM.
  • generation (str): The generation produced by the LLM.
Output columns
  • prompt (str): The instruction used to generate the generation with the LLM.
  • prompt_id (str): The SHA256 hash of the prompt.
  • messages (List[Dict[str, str]]): The chat-like conversation with the instruction as the user message and the generation as the assistant message.
Categories
  • format
  • text-generation
  • instruction
  • generation
Source code in src/distilabel/steps/formatting/sft.py
class FormatTextGenerationSFT(Step):
    """Format the output of a `TextGeneration` task for Supervised Fine-Tuning (SFT).

    `FormatTextGenerationSFT` is a `Step` that formats the output of a `TextGeneration` task for
    Supervised Fine-Tuning (SFT) following the standard formatting from frameworks such as `axolotl`
    or `alignment-handbook`. The output of the `TextGeneration` task is formatted into a chat-like
    conversation with the `instruction` as the user message and the `generation` as the assistant
    message. Optionally, if the `system_prompt` is available, it is included as the first message
    in the conversation.

    Input columns:
        - system_prompt (`str`, optional): The system prompt used within the `LLM` to generate the
            `generation`, if available.
        - instruction (`str`): The instruction used to generate the `generation` with the `LLM`.
        - generation (`str`): The generation produced by the `LLM`.

    Output columns:
        - prompt (`str`): The instruction used to generate the `generation` with the `LLM`.
        - prompt_id (`str`): The `SHA256` hash of the `prompt`.
        - messages (`List[Dict[str, str]]`): The chat-like conversation with the `instruction` as
            the user message and the `generation` as the assistant message.

    Categories:
        - format
        - text-generation
        - instruction
        - generation
    """

    @property
    def inputs(self) -> List[str]:
        """List of inputs required by the `Step`, which in this case are: `instruction`, and `generation`."""
        return ["instruction", "generation"]

    @property
    def optional_inputs(self) -> List[str]:
        """List of optional inputs, which are not required by the `Step` but used if available,
        which in this case is: `system_prompt`."""
        return ["system_prompt"]

    @property
    def outputs(self) -> List[str]:
        """List of outputs generated by the `Step`, which are: `prompt`, `prompt_id`, `messages`.

        Reference:
            - Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k
        """
        return ["prompt", "prompt_id", "messages"]

    def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
        """The `process` method formats the received `StepInput` or list of `StepInput`
        according to the SFT formatting standard.

        Args:
            *inputs: A list of `StepInput` to be combined.

        Yields:
            A `StepOutput` with batches of formatted `StepInput` following the SFT standard.
        """
        for input in inputs:
            for item in input:
                item["prompt"] = item["instruction"]

                item["prompt_id"] = hashlib.sha256(
                    item["prompt"].encode("utf-8")  # type: ignore
                ).hexdigest()

                item["messages"] = [
                    {"role": "user", "content": item["instruction"]},  # type: ignore
                    {"role": "assistant", "content": item["generation"]},  # type: ignore
                ]
                if (
                    "system_prompt" in item
                    and isinstance(item["system_prompt"], str)  # type: ignore
                    and len(item["system_prompt"]) > 0  # type: ignore
                ):
                    item["messages"].insert(
                        0,
                        {"role": "system", "content": item["system_prompt"]},  # type: ignore
                    )

            yield input

inputs: List[str] property

List of inputs required by the Step, which in this case are: instruction, and generation.

optional_inputs: List[str] property

List of optional inputs, which are not required by the Step but used if available, which in this case is: system_prompt.

outputs: List[str] property

List of outputs generated by the Step, which are: prompt, prompt_id, messages.

Reference
  • Format inspired in https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

process(*inputs)

The process method formats the received StepInput or list of StepInput according to the SFT formatting standard.

Parameters:

Name Type Description Default
*inputs StepInput

A list of StepInput to be combined.

()

Yields:

Type Description
StepOutput

A StepOutput with batches of formatted StepInput following the SFT standard.

Source code in src/distilabel/steps/formatting/sft.py
def process(self, *inputs: StepInput) -> "StepOutput":  # type: ignore
    """The `process` method formats the received `StepInput` or list of `StepInput`
    according to the SFT formatting standard.

    Args:
        *inputs: A list of `StepInput` to be combined.

    Yields:
        A `StepOutput` with batches of formatted `StepInput` following the SFT standard.
    """
    for input in inputs:
        for item in input:
            item["prompt"] = item["instruction"]

            item["prompt_id"] = hashlib.sha256(
                item["prompt"].encode("utf-8")  # type: ignore
            ).hexdigest()

            item["messages"] = [
                {"role": "user", "content": item["instruction"]},  # type: ignore
                {"role": "assistant", "content": item["generation"]},  # type: ignore
            ]
            if (
                "system_prompt" in item
                and isinstance(item["system_prompt"], str)  # type: ignore
                and len(item["system_prompt"]) > 0  # type: ignore
            ):
                item["messages"].insert(
                    0,
                    {"role": "system", "content": item["system_prompt"]},  # type: ignore
                )

        yield input

GeneratorStep

Bases: _Step, ABC

A special kind of Step that is able to generate data i.e. it doesn't receive any input from the previous steps.

Attributes:

Name Type Description
batch_size RuntimeParameter[int]

The number of rows that will contain the batches generated by the step. Defaults to 50.

Runtime parameters
  • batch_size: The number of rows that will contain the batches generated by the step. Defaults to 50.
Source code in src/distilabel/steps/base.py
class GeneratorStep(_Step, ABC):
    """A special kind of `Step` that is able to generate data i.e. it doesn't receive
    any input from the previous steps.

    Attributes:
        batch_size: The number of rows that will contain the batches generated by the
            step. Defaults to `50`.

    Runtime parameters:
        - `batch_size`: The number of rows that will contain the batches generated by
            the step. Defaults to `50`.
    """

    batch_size: RuntimeParameter[int] = Field(
        default=50,
        description="The number of rows that will contain the batches generated by the"
        " step.",
    )

    @abstractmethod
    def process(self, offset: int = 0) -> "GeneratorStepOutput":
        """Method that defines the generation logic of the step. It should yield the
        output rows and a boolean indicating if it's the last batch or not.

        Args:
            offset: The offset to start the generation from. Defaults to 0.

        Yields:
            The output rows and a boolean indicating if it's the last batch or not.
        """
        pass

    def process_applying_mappings(self, offset: int = 0) -> "GeneratorStepOutput":
        """Runs the `process` method of the step applying the `outputs_mappings` to the
        output rows. This is the function that should be used to run the generation logic
        of the step.

        Args:
            offset: The offset to start the generation from. Defaults to 0.

        Yields:
            The output rows and a boolean indicating if it's the last batch or not.
        """

        # If the `Step` was built using the `@step` decorator, then we need to pass
        # the runtime parameters as `kwargs`, so they can be used within the processing
        # function
        generator = (
            self.process(offset=offset)
            if not self._built_from_decorator
            else self.process(offset=offset, **self._runtime_parameters)
        )

        for output_rows, last_batch in generator:
            yield (
                [
                    {self.output_mappings.get(k, k): v for k, v in row.items()}
                    for row in output_rows
                ],
                last_batch,
            )

process(offset=0) abstractmethod

Method that defines the generation logic of the step. It should yield the output rows and a boolean indicating if it's the last batch or not.

Parameters:

Name Type Description Default
offset int

The offset to start the generation from. Defaults to 0.

0

Yields:

Type Description
GeneratorStepOutput

The output rows and a boolean indicating if it's the last batch or not.

Source code in src/distilabel/steps/base.py
@abstractmethod
def process(self, offset: int = 0) -> "GeneratorStepOutput":
    """Method that defines the generation logic of the step. It should yield the
    output rows and a boolean indicating if it's the last batch or not.

    Args:
        offset: The offset to start the generation from. Defaults to 0.

    Yields:
        The output rows and a boolean indicating if it's the last batch or not.
    """
    pass

process_applying_mappings(offset=0)

Runs the process method of the step applying the outputs_mappings to the output rows. This is the function that should be used to run the generation logic of the step.

Parameters:

Name Type Description Default
offset int

The offset to start the generation from. Defaults to 0.

0

Yields:

Type Description
GeneratorStepOutput

The output rows and a boolean indicating if it's the last batch or not.

Source code in src/distilabel/steps/base.py
def process_applying_mappings(self, offset: int = 0) -> "GeneratorStepOutput":
    """Runs the `process` method of the step applying the `outputs_mappings` to the
    output rows. This is the function that should be used to run the generation logic
    of the step.

    Args:
        offset: The offset to start the generation from. Defaults to 0.

    Yields:
        The output rows and a boolean indicating if it's the last batch or not.
    """

    # If the `Step` was built using the `@step` decorator, then we need to pass
    # the runtime parameters as `kwargs`, so they can be used within the processing
    # function
    generator = (
        self.process(offset=offset)
        if not self._built_from_decorator
        else self.process(offset=offset, **self._runtime_parameters)
    )

    for output_rows, last_batch in generator:
        yield (
            [
                {self.output_mappings.get(k, k): v for k, v in row.items()}
                for row in output_rows
            ],
            last_batch,
        )

GlobalStep

Bases: Step, ABC

A special kind of Step which it's process method receives all the data processed by their previous steps at once, instead of receiving it in batches. This kind of steps are useful when the processing logic requires to have all the data at once, for example to train a model, to perform a global aggregation, etc.

Source code in src/distilabel/steps/base.py
class GlobalStep(Step, ABC):
    """A special kind of `Step` which it's `process` method receives all the data processed
    by their previous steps at once, instead of receiving it in batches. This kind of steps
    are useful when the processing logic requires to have all the data at once, for example
    to train a model, to perform a global aggregation, etc.
    """

    @property
    def inputs(self) -> List[str]:
        return []

    @property
    def outputs(self) -> List[str]:
        return []

KeepColumns

Bases: Step

Keeps selected columns in the dataset.

KeepColumns is a Step that implements the process method that keeps only the columns specified in the columns attribute. Also KeepColumns provides an attribute columns to specify the columns to keep which will override the default value for the properties inputs and outputs.

Note

The order in which the columns are provided is important, as the output will be sorted using the provided order, which is useful before pushing either a dataset.Dataset via the PushToHub step or a distilabel.Distiset via the Pipeline.run output variable.

Attributes:

Name Type Description
columns List[str]

List of strings with the names of the columns to keep.

Input columns
  • dynamic (determined by columns attribute): The columns to keep.
Output columns
  • dynamic (determined by columns attribute): The columns that were kept.
Source code in src/distilabel/steps/keep.py
class KeepColumns(Step):
    """Keeps selected columns in the dataset.

    `KeepColumns` is a `Step` that implements the `process` method that keeps only the columns
    specified in the `columns` attribute. Also `KeepColumns` provides an attribute `columns` to
    specify the columns to keep which will override the default value for the properties `inputs`
    and `outputs`.

    Note:
        The order in which the columns are provided is important, as the output will be sorted
        using the provided order, which is useful before pushing either a `dataset.Dataset` via
        the `PushToHub` step or a `distilabel.Distiset` via the `Pipeline.run` output variable.

    Attributes:
        columns: List of strings with the names of the columns to keep.

    Input columns:
        - dynamic (determined by `columns` attribute): The columns to keep.

    Output columns:
        - dynamic (determined by `columns` attribute): The columns that were kept.
    """

    columns: List[str]

    @property
    def inputs(self) -> List[str]:
        """The inputs for the task are the column names in `columns`."""
        return self.columns

    @property
    def outputs(self) -> List[str]:
        """The outputs for the task are the column names in `columns`."""
        return self.columns

    @override
    def process(self, *inputs: StepInput) -> "StepOutput":
        """The `process` method keeps only the columns specified in the `columns` attribute.

        Args:
            *inputs: A list of dictionaries with the input data.

        Yields:
            A list of dictionaries with the output data.
        """
        for input in inputs:
            outputs = []
            for item in input:
                outputs.append({col: item[col] for col in self.columns})
            yield outputs

inputs: List[str] property

The inputs for the task are the column names in columns.

outputs: List[str] property

The outputs for the task are the column names in columns.

process(*inputs)

The process method keeps only the columns specified in the columns attribute.

Parameters:

Name Type Description Default
*inputs StepInput

A list of dictionaries with the input data.

()

Yields:

Type Description
StepOutput

A list of dictionaries with the output data.

Source code in src/distilabel/steps/keep.py
@override
def process(self, *inputs: StepInput) -> "StepOutput":
    """The `process` method keeps only the columns specified in the `columns` attribute.

    Args:
        *inputs: A list of dictionaries with the input data.

    Yields:
        A list of dictionaries with the output data.
    """
    for input in inputs:
        outputs = []
        for item in input:
            outputs.append({col: item[col] for col in self.columns})
        yield outputs

LoadDataFromDicts

Bases: GeneratorStep

Loads a dataset from a list of dictionaries.

GeneratorStep that loads a dataset from a list of dictionaries and yields it in batches.

Attributes:

Name Type Description
data List[Dict[str, Any]]

The list of dictionaries to load the data from.

Runtime parameters
  • batch_size: The batch size to use when processing the data.
Output columns
  • dynamic (based on the keys found on the first dictionary of the list): The columns of the dataset.
Categories
  • load
Source code in src/distilabel/steps/generators/data.py
class LoadDataFromDicts(GeneratorStep):
    """Loads a dataset from a list of dictionaries.

    `GeneratorStep` that loads a dataset from a list of dictionaries and yields it in
    batches.

    Attributes:
        data: The list of dictionaries to load the data from.

    Runtime parameters:
        - `batch_size`: The batch size to use when processing the data.

    Output columns:
        - dynamic (based on the keys found on the first dictionary of the list): The columns
            of the dataset.

    Categories:
        - load
    """

    data: List[Dict[str, Any]]

    @override
    def process(self, offset: int = 0) -> "GeneratorStepOutput":  # type: ignore
        """Yields batches from a list of dictionaries.

        Args:
            offset: The offset to start the generation from. Defaults to `0`.

        Yields:
            A list of Python dictionaries as read from the inputs (propagated in batches)
            and a flag indicating whether the yield batch is the last one.
        """
        if offset:
            self.data = self.data[offset:]

        while self.data:
            batch = self.data[: self.batch_size]
            self.data = self.data[self.batch_size :]
            yield (
                batch,
                True if len(self.data) == 0 else False,
            )

    @property
    def outputs(self) -> List[str]:
        """Returns a list of strings with the names of the columns that the step will generate."""
        return list(self.data[0].keys())

outputs: List[str] property

Returns a list of strings with the names of the columns that the step will generate.

process(offset=0)

Yields batches from a list of dictionaries.

Parameters:

Name Type Description Default
offset int

The offset to start the generation from. Defaults to 0.

0

Yields:

Type Description
GeneratorStepOutput

A list of Python dictionaries as read from the inputs (propagated in batches)

GeneratorStepOutput

and a flag indicating whether the yield batch is the last one.

Source code in src/distilabel/steps/generators/data.py
@override
def process(self, offset: int = 0) -> "GeneratorStepOutput":  # type: ignore
    """Yields batches from a list of dictionaries.

    Args:
        offset: The offset to start the generation from. Defaults to `0`.

    Yields:
        A list of Python dictionaries as read from the inputs (propagated in batches)
        and a flag indicating whether the yield batch is the last one.
    """
    if offset:
        self.data = self.data[offset:]

    while self.data:
        batch = self.data[: self.batch_size]
        self.data = self.data[self.batch_size :]
        yield (
            batch,
            True if len(self.data) == 0 else False,
        )

LoadHubDataset

Bases: GeneratorStep

Loads a dataset from the Hugging Face Hub.

GeneratorStep that loads a dataset from the Hugging Face Hub using the datasets library.

Attributes:

Name Type Description
repo_id RuntimeParameter[str]

The Hugging Face Hub repository ID of the dataset to load.

split RuntimeParameter[str]

The split of the dataset to load.

config Optional[RuntimeParameter[str]]

The configuration of the dataset to load. This is optional and only needed if the dataset has multiple configurations.

Runtime parameters
  • batch_size: The batch size to use when processing the data.
  • repo_id: The Hugging Face Hub repository ID of the dataset to load.
  • split: The split of the dataset to load. Defaults to 'train'.
  • config: The configuration of the dataset to load. This is optional and only needed if the dataset has multiple configurations.
  • streaming: Whether to load the dataset in streaming mode or not. Defaults to False.
  • num_examples: The number of examples to load from the dataset. By default will load all examples.
Output columns
  • dynamic (all): The columns that will be generated by this step, based on the datasets loaded from the Hugging Face Hub.
Categories
  • load
Source code in src/distilabel/steps/generators/huggingface.py
class LoadHubDataset(GeneratorStep):
    """Loads a dataset from the Hugging Face Hub.

    `GeneratorStep` that loads a dataset from the Hugging Face Hub using the `datasets`
    library.

    Attributes:
        repo_id: The Hugging Face Hub repository ID of the dataset to load.
        split: The split of the dataset to load.
        config: The configuration of the dataset to load. This is optional and only needed
            if the dataset has multiple configurations.

    Runtime parameters:
        - `batch_size`: The batch size to use when processing the data.
        - `repo_id`: The Hugging Face Hub repository ID of the dataset to load.
        - `split`: The split of the dataset to load. Defaults to 'train'.
        - `config`: The configuration of the dataset to load. This is optional and only
            needed if the dataset has multiple configurations.
        - `streaming`: Whether to load the dataset in streaming mode or not. Defaults to
            `False`.
        - `num_examples`: The number of examples to load from the dataset.
            By default will load all examples.

    Output columns:
        - dynamic (`all`): The columns that will be generated by this step, based on the
            datasets loaded from the Hugging Face Hub.

    Categories:
        - load
    """

    repo_id: RuntimeParameter[str] = Field(
        default=None,
        description="The Hugging Face Hub repository ID of the dataset to load.",
    )
    split: RuntimeParameter[str] = Field(
        default="train",
        description="The split of the dataset to load. Defaults to 'train'.",
    )
    config: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The configuration of the dataset to load. This is optional and only"
        " needed if the dataset has multiple configurations.",
    )
    streaming: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether to load the dataset in streaming mode or not. Defaults to False.",
    )
    num_examples: Optional[RuntimeParameter[int]] = Field(
        default=None,
        description="The number of examples to load from the dataset. By default will load all examples.",
    )

    _dataset: Union[IterableDataset, None] = PrivateAttr(...)

    def load(self) -> None:
        """Load the dataset from the Hugging Face Hub"""
        super().load()

        self._dataset = load_dataset(
            self.repo_id,  # type: ignore
            self.config,
            split=self.split,
            streaming=self.streaming,
        )
        num_examples = self._get_dataset_num_examples()
        self.num_examples = (
            min(self.num_examples, num_examples) if self.num_examples else num_examples
        )

        if not self.streaming:
            self._dataset = self._dataset.select(range(self.num_examples))

    def process(self, offset: int = 0) -> "GeneratorStepOutput":
        """Yields batches from the loaded dataset from the Hugging Face Hub.

        Args:
            offset: The offset to start yielding the data from. Will be used during the caching
                process to help skipping already processed data.

        Yields:
            A tuple containing a batch of rows and a boolean indicating if the batch is
            the last one.
        """
        num_returned_rows = 0
        for batch_num, batch in enumerate(
            self._dataset.iter(batch_size=self.batch_size)  # type: ignore
        ):
            if batch_num * self.batch_size < offset:
                continue
            transformed_batch = self._transform_batch(batch)
            batch_size = len(transformed_batch)
            num_returned_rows += batch_size
            yield transformed_batch, num_returned_rows >= self.num_examples

    @property
    def outputs(self) -> List[str]:
        """The columns that will be generated by this step, based on the datasets loaded
        from the Hugging Face Hub.

        Returns:
            The columns that will be generated by this step.
        """
        return self._get_dataset_columns()

    def _transform_batch(self, batch: Dict[str, Any]) -> List[Dict[str, Any]]:
        """Transform a batch of data from the Hugging Face Hub into a list of rows.

        Args:
            batch: The batch of data from the Hugging Face Hub.

        Returns:
            A list of rows, where each row is a dictionary of column names and values.
        """
        length = len(next(iter(batch.values())))
        rows = []
        for i in range(length):
            rows.append({col: values[i] for col, values in batch.items()})
        return rows

    def _get_dataset_num_examples(self) -> int:
        """Get the number of examples in the dataset, based on the `split` and `config`
        runtime parameters provided.

        Returns:
            The number of examples in the dataset.
        """
        dataset_info = self._get_dataset_info()
        split = self.split
        if self.config:
            return dataset_info["splits"][split]["num_examples"]
        return dataset_info["default"]["splits"][split]["num_examples"]

    def _get_dataset_columns(self) -> List[str]:
        """Get the columns of the dataset, based on the `config` runtime parameter provided.

        Returns:
            The columns of the dataset.
        """
        dataset_info = self._get_dataset_info()

        if isinstance(dataset_info, DatasetInfo):
            if self.config:
                return list(self._dataset[self.config].info.features.keys())
            return list(self._dataset.info.features.keys())

        if self.config:
            return list(dataset_info["features"].keys())
        return list(dataset_info["default"]["features"].keys())

    def _get_dataset_info(self) -> Dict[str, Any]:
        """Calls the Datasets Server API from Hugging Face to obtain the dataset information.

        Returns:
            The dataset information.
        """
        repo_id = self.repo_id
        config = self.config

        try:
            return _get_hf_dataset_info(repo_id, config)
        except ConnectionError:
            # The previous could fail in case of a internet connection issues.
            # Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway.
            self.load()
            if config:
                return self._dataset[config].info
            return self._dataset.info

outputs: List[str] property

The columns that will be generated by this step, based on the datasets loaded from the Hugging Face Hub.

Returns:

Type Description
List[str]

The columns that will be generated by this step.

load()

Load the dataset from the Hugging Face Hub

Source code in src/distilabel/steps/generators/huggingface.py
def load(self) -> None:
    """Load the dataset from the Hugging Face Hub"""
    super().load()

    self._dataset = load_dataset(
        self.repo_id,  # type: ignore
        self.config,
        split=self.split,
        streaming=self.streaming,
    )
    num_examples = self._get_dataset_num_examples()
    self.num_examples = (
        min(self.num_examples, num_examples) if self.num_examples else num_examples
    )

    if not self.streaming:
        self._dataset = self._dataset.select(range(self.num_examples))

process(offset=0)

Yields batches from the loaded dataset from the Hugging Face Hub.

Parameters:

Name Type Description Default
offset int

The offset to start yielding the data from. Will be used during the caching process to help skipping already processed data.

0

Yields:

Type Description
GeneratorStepOutput

A tuple containing a batch of rows and a boolean indicating if the batch is

GeneratorStepOutput

the last one.

Source code in src/distilabel/steps/generators/huggingface.py
def process(self, offset: int = 0) -> "GeneratorStepOutput":
    """Yields batches from the loaded dataset from the Hugging Face Hub.

    Args:
        offset: The offset to start yielding the data from. Will be used during the caching
            process to help skipping already processed data.

    Yields:
        A tuple containing a batch of rows and a boolean indicating if the batch is
        the last one.
    """
    num_returned_rows = 0
    for batch_num, batch in enumerate(
        self._dataset.iter(batch_size=self.batch_size)  # type: ignore
    ):
        if batch_num * self.batch_size < offset:
            continue
        transformed_batch = self._transform_batch(batch)
        batch_size = len(transformed_batch)
        num_returned_rows += batch_size
        yield transformed_batch, num_returned_rows >= self.num_examples

PreferenceToArgilla

Bases: Argilla

Creates a preference dataset in Argilla.

Step that creates a dataset in Argilla during the load phase, and then pushes the input batches into it as records. This dataset is a preference dataset, where there's one field for the instruction and one extra field per each generation within the same record, and then a rating question per each of the generation fields. The rating question asks the annotator to set a rating from 1 to 5 for each of the provided generations.

Note

This step is meant to be used in conjunction with the UltraFeedback step, or any other step generating both ratings and responses for a given set of instruction and generations for the given instruction. But alternatively, it can also be used with any other task or step generating only the instruction and generations, as the ratings and rationales are optional.

Attributes:

Name Type Description
num_generations int

The number of generations to include in the dataset.

dataset_name int

The name of the dataset in Argilla.

dataset_workspace int

The workspace where the dataset will be created in Argilla. Defaults to None, which means it will be created in the default workspace.

api_url int

The URL of the Argilla API. Defaults to None, which means it will be read from the ARGILLA_API_URL environment variable.

api_key int

The API key to authenticate with Argilla. Defaults to None, which means it will be read from the ARGILLA_API_KEY environment variable.

Runtime parameters
  • api_url: The base URL to use for the Argilla API requests.
  • api_key: The API key to authenticate the requests to the Argilla API.
Input columns
  • instruction (str): The instruction that was used to generate the completion.
  • generations (List[str]): The completion that was generated based on the input instruction.
  • ratings (List[str], optional): The ratings for the generations. If not provided, the generated ratings won't be pushed to Argilla.
  • rationales (List[str], optional): The rationales for the ratings. If not provided, the generated rationales won't be pushed to Argilla.
Source code in src/distilabel/steps/argilla/preference.py
class PreferenceToArgilla(Argilla):
    """Creates a preference dataset in Argilla.

    Step that creates a dataset in Argilla during the load phase, and then pushes the input
    batches into it as records. This dataset is a preference dataset, where there's one field
    for the instruction and one extra field per each generation within the same record, and then
    a rating question per each of the generation fields. The rating question asks the annotator to
    set a rating from 1 to 5 for each of the provided generations.

    Note:
        This step is meant to be used in conjunction with the `UltraFeedback` step, or any other step
        generating both ratings and responses for a given set of instruction and generations for the
        given instruction. But alternatively, it can also be used with any other task or step generating
        only the `instruction` and `generations`, as the `ratings` and `rationales` are optional.

    Attributes:
        num_generations: The number of generations to include in the dataset.
        dataset_name: The name of the dataset in Argilla.
        dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
            `None`, which means it will be created in the default workspace.
        api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
            the `ARGILLA_API_URL` environment variable.
        api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
            be read from the `ARGILLA_API_KEY` environment variable.

    Runtime parameters:
        - `api_url`: The base URL to use for the Argilla API requests.
        - `api_key`: The API key to authenticate the requests to the Argilla API.

    Input columns:
        - instruction (`str`): The instruction that was used to generate the completion.
        - generations (`List[str]`): The completion that was generated based on the input instruction.
        - ratings (`List[str]`, optional): The ratings for the generations. If not provided, the
            generated ratings won't be pushed to Argilla.
        - rationales (`List[str]`, optional): The rationales for the ratings. If not provided, the
            generated rationales won't be pushed to Argilla.
    """

    num_generations: int

    _id: str = PrivateAttr(default="id")
    _instruction: str = PrivateAttr(...)
    _generations: str = PrivateAttr(...)
    _ratings: str = PrivateAttr(...)
    _rationales: str = PrivateAttr(...)

    def load(self) -> None:
        """Sets the `_instruction` and `_generations` attributes based on the `inputs_mapping`, otherwise
        uses the default values; and then uses those values to create a `FeedbackDataset` suited for
        the text-generation scenario. And then it pushes it to Argilla.
        """
        super().load()

        # Both `instruction` and `generations` will be used as the fields of the dataset
        self._instruction = self.input_mappings.get("instruction", "instruction")
        self._generations = self.input_mappings.get("generations", "generations")
        # Both `ratings` and `rationales` will be used as suggestions to the default questions of the dataset
        self._ratings = self.input_mappings.get("ratings", "ratings")
        self._rationales = self.input_mappings.get("rationales", "rationales")

        if self._rg_dataset_exists():
            _rg_dataset = rg.FeedbackDataset.from_argilla(  # type: ignore
                name=self.dataset_name,
                workspace=self.dataset_workspace,
            )

            for field in _rg_dataset.fields:
                if (
                    field.name
                    not in [self._id, self._instruction]
                    + [
                        f"{self._generations}-{idx}"
                        for idx in range(self.num_generations)
                    ]
                    and field.required
                ):
                    raise ValueError(
                        f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
                        f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`,"
                        f" nor `{self._generations}`."
                    )

            self._rg_dataset = _rg_dataset
        else:
            _rg_dataset = rg.FeedbackDataset(  # type: ignore
                fields=[
                    rg.TextField(name=self._id, title=self._id),  # type: ignore
                    rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                    *self._generation_fields(),  # type: ignore
                ],
                questions=self._rating_rationale_pairs(),  # type: ignore
            )
            self._rg_dataset = _rg_dataset.push_to_argilla(
                name=self.dataset_name,  # type: ignore
                workspace=self.dataset_workspace,
            )

    def _generation_fields(self) -> List["TextField"]:
        """Method to generate the fields for each of the generations."""
        return [
            rg.TextField(  # type: ignore
                name=f"{self._generations}-{idx}",
                title=f"{self._generations}-{idx}",
                required=True if idx == 0 else False,
            )
            for idx in range(self.num_generations)
        ]

    def _rating_rationale_pairs(
        self,
    ) -> List[Union["RatingQuestion", "TextQuestion"]]:
        """Method to generate the rating and rationale questions for each of the generations."""
        questions = []
        for idx in range(self.num_generations):
            questions.extend(
                [
                    rg.RatingQuestion(  # type: ignore
                        name=f"{self._generations}-{idx}-rating",
                        title=f"Rate {self._generations}-{idx} given {self._instruction}.",
                        description=f"Ignore this question if the corresponding `{self._generations}-{idx}` field is not available."
                        if idx != 0
                        else None,
                        values=[1, 2, 3, 4, 5],
                        required=True if idx == 0 else False,
                    ),
                    rg.TextQuestion(  # type: ignore
                        name=f"{self._generations}-{idx}-rationale",
                        title=f"Specify the rationale for {self._generations}-{idx}'s rating.",
                        description=f"Ignore this question if the corresponding `{self._generations}-{idx}` field is not available."
                        if idx != 0
                        else None,
                        required=False,
                    ),
                ]
            )
        return questions

    @property
    def inputs(self) -> List[str]:
        """The inputs for the step are the `instruction` and the `generations`. Optionally, one could also
        provide the `ratings` and the `rationales` for the generations."""
        return ["instruction", "generations"]

    def _add_suggestions_if_any(
        self, input: Dict[str, Any]
    ) -> List["SuggestionSchema"]:
        """Method to generate the suggestions for the `FeedbackRecord` based on the input."""
        # Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
        suggestions = []
        # If `ratings` is in `input`, then add those as suggestions
        if self._ratings in input:
            suggestions.extend(
                [
                    {
                        "question_name": f"{self._generations}-{idx}-rating",
                        "value": rating,
                    }
                    for idx, rating in enumerate(input[self._ratings])
                    if rating is not None
                    and isinstance(rating, int)
                    and rating in [1, 2, 3, 4, 5]
                ],
            )
        # If `rationales` is in `input`, then add those as suggestions
        if self._rationales in input:
            suggestions.extend(
                [
                    {
                        "question_name": f"{self._generations}-{idx}-rationale",
                        "value": rationale,
                    }
                    for idx, rationale in enumerate(input[self._rationales])
                    if rationale is not None and isinstance(rationale, str)
                ],
            )
        return suggestions

    @override
    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Creates and pushes the records as FeedbackRecords to the Argilla dataset.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Returns:
            A list of Python dictionaries with the outputs of the task.
        """
        records = []
        for input in inputs:
            # Generate the SHA-256 hash of the instruction to use it as the metadata
            instruction_id = hashlib.sha256(
                input["instruction"].encode("utf-8")  # type: ignore
            ).hexdigest()

            generations = {
                f"{self._generations}-{idx}": generation
                for idx, generation in enumerate(input["generations"])  # type: ignore
            }

            records.append(  # type: ignore
                rg.FeedbackRecord(  # type: ignore
                    fields={
                        "id": instruction_id,
                        "instruction": input["instruction"],  # type: ignore
                        **generations,
                    },
                    suggestions=self._add_suggestions_if_any(input),  # type: ignore
                )
            )
        self._rg_dataset.add_records(records)  # type: ignore
        yield inputs

inputs: List[str] property

The inputs for the step are the instruction and the generations. Optionally, one could also provide the ratings and the rationales for the generations.

load()

Sets the _instruction and _generations attributes based on the inputs_mapping, otherwise uses the default values; and then uses those values to create a FeedbackDataset suited for the text-generation scenario. And then it pushes it to Argilla.

Source code in src/distilabel/steps/argilla/preference.py
def load(self) -> None:
    """Sets the `_instruction` and `_generations` attributes based on the `inputs_mapping`, otherwise
    uses the default values; and then uses those values to create a `FeedbackDataset` suited for
    the text-generation scenario. And then it pushes it to Argilla.
    """
    super().load()

    # Both `instruction` and `generations` will be used as the fields of the dataset
    self._instruction = self.input_mappings.get("instruction", "instruction")
    self._generations = self.input_mappings.get("generations", "generations")
    # Both `ratings` and `rationales` will be used as suggestions to the default questions of the dataset
    self._ratings = self.input_mappings.get("ratings", "ratings")
    self._rationales = self.input_mappings.get("rationales", "rationales")

    if self._rg_dataset_exists():
        _rg_dataset = rg.FeedbackDataset.from_argilla(  # type: ignore
            name=self.dataset_name,
            workspace=self.dataset_workspace,
        )

        for field in _rg_dataset.fields:
            if (
                field.name
                not in [self._id, self._instruction]
                + [
                    f"{self._generations}-{idx}"
                    for idx in range(self.num_generations)
                ]
                and field.required
            ):
                raise ValueError(
                    f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
                    f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`,"
                    f" nor `{self._generations}`."
                )

        self._rg_dataset = _rg_dataset
    else:
        _rg_dataset = rg.FeedbackDataset(  # type: ignore
            fields=[
                rg.TextField(name=self._id, title=self._id),  # type: ignore
                rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                *self._generation_fields(),  # type: ignore
            ],
            questions=self._rating_rationale_pairs(),  # type: ignore
        )
        self._rg_dataset = _rg_dataset.push_to_argilla(
            name=self.dataset_name,  # type: ignore
            workspace=self.dataset_workspace,
        )

process(inputs)

Creates and pushes the records as FeedbackRecords to the Argilla dataset.

Parameters:

Name Type Description Default
inputs StepInput

A list of Python dictionaries with the inputs of the task.

required

Returns:

Type Description
StepOutput

A list of Python dictionaries with the outputs of the task.

Source code in src/distilabel/steps/argilla/preference.py
@override
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Creates and pushes the records as FeedbackRecords to the Argilla dataset.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Returns:
        A list of Python dictionaries with the outputs of the task.
    """
    records = []
    for input in inputs:
        # Generate the SHA-256 hash of the instruction to use it as the metadata
        instruction_id = hashlib.sha256(
            input["instruction"].encode("utf-8")  # type: ignore
        ).hexdigest()

        generations = {
            f"{self._generations}-{idx}": generation
            for idx, generation in enumerate(input["generations"])  # type: ignore
        }

        records.append(  # type: ignore
            rg.FeedbackRecord(  # type: ignore
                fields={
                    "id": instruction_id,
                    "instruction": input["instruction"],  # type: ignore
                    **generations,
                },
                suggestions=self._add_suggestions_if_any(input),  # type: ignore
            )
        )
    self._rg_dataset.add_records(records)  # type: ignore
    yield inputs

PushToHub

Bases: GlobalStep

Push data to a Hugging Face Hub dataset.

A GlobalStep which creates a datasets.Dataset with the input data and pushes it to the Hugging Face Hub.

Attributes:

Name Type Description
repo_id RuntimeParameter[str]

The Hugging Face Hub repository ID where the dataset will be uploaded.

split RuntimeParameter[str]

The split of the dataset that will be pushed. Defaults to "train".

private RuntimeParameter[bool]

Whether the dataset to be pushed should be private or not. Defaults to False.

token Optional[RuntimeParameter[str]]

The token that will be used to authenticate in the Hub. If not provided, the token will be tried to be obtained from the environment variable HF_TOKEN. If not provided using one of the previous methods, then huggingface_hub library will try to use the token from the local Hugging Face CLI configuration. Defaults to None.

Runtime parameters
  • repo_id: The Hugging Face Hub repository ID where the dataset will be uploaded.
  • split: The split of the dataset that will be pushed.
  • private: Whether the dataset to be pushed should be private or not.
  • token: The token that will be used to authenticate in the Hub.
Input columns
  • dynamic (all): all columns from the input will be used to create the dataset.
Categories
  • save
  • dataset
  • huggingface
Source code in src/distilabel/steps/globals/huggingface.py
class PushToHub(GlobalStep):
    """Push data to a Hugging Face Hub dataset.

    A `GlobalStep` which creates a `datasets.Dataset` with the input data and pushes
    it to the Hugging Face Hub.

    Attributes:
        repo_id: The Hugging Face Hub repository ID where the dataset will be uploaded.
        split: The split of the dataset that will be pushed. Defaults to `"train"`.
        private: Whether the dataset to be pushed should be private or not. Defaults to
            `False`.
        token: The token that will be used to authenticate in the Hub. If not provided, the
            token will be tried to be obtained from the environment variable `HF_TOKEN`.
            If not provided using one of the previous methods, then `huggingface_hub` library
            will try to use the token from the local Hugging Face CLI configuration. Defaults
            to `None`.

    Runtime parameters:
        - `repo_id`: The Hugging Face Hub repository ID where the dataset will be uploaded.
        - `split`: The split of the dataset that will be pushed.
        - `private`: Whether the dataset to be pushed should be private or not.
        - `token`: The token that will be used to authenticate in the Hub.

    Input columns:
        - dynamic (`all`): all columns from the input will be used to create the dataset.

    Categories:
        - save
        - dataset
        - huggingface
    """

    repo_id: RuntimeParameter[str] = Field(
        default=None,
        description="The Hugging Face Hub repository ID where the dataset will be uploaded.",
    )
    split: RuntimeParameter[str] = Field(
        default="train",
        description="The split of the dataset that will be pushed. Defaults to 'train'.",
    )
    private: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether the dataset to be pushed should be private or not. Defaults"
        " to `False`.",
    )
    token: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The token that will be used to authenticate in the Hub. If not provided,"
        " the token will be tried to be obtained from the environment variable `HF_TOKEN`."
        " If not provided using one of the previous methods, then `huggingface_hub` library"
        " will try to use the token from the local Hugging Face CLI configuration. Defaults"
        " to `None`",
    )

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Method that processes the input data, respecting the `datasets.Dataset` formatting,
        and pushes it to the Hugging Face Hub based on the `RuntimeParameter`s attributes.

        Args:
            inputs: that input data within a single object (as it's a GlobalStep) that
                will be transformed into a `datasets.Dataset`.

        Yields:
            Propagates the received inputs so that the `Distiset` can be generated if this is
            the last step of the `Pipeline`, or if this is not a leaf step and has follow up
            steps.
        """
        dataset_dict = defaultdict(list)
        for input in inputs:
            for key, value in input.items():
                dataset_dict[key].append(value)
        dataset_dict = dict(dataset_dict)
        dataset = Dataset.from_dict(dataset_dict)
        dataset.push_to_hub(
            self.repo_id,  # type: ignore
            split=self.split,
            private=self.private,
            token=self.token or os.getenv("HF_TOKEN"),
        )
        yield inputs

process(inputs)

Method that processes the input data, respecting the datasets.Dataset formatting, and pushes it to the Hugging Face Hub based on the RuntimeParameters attributes.

Parameters:

Name Type Description Default
inputs StepInput

that input data within a single object (as it's a GlobalStep) that will be transformed into a datasets.Dataset.

required

Yields:

Type Description
StepOutput

Propagates the received inputs so that the Distiset can be generated if this is

StepOutput

the last step of the Pipeline, or if this is not a leaf step and has follow up

StepOutput

steps.

Source code in src/distilabel/steps/globals/huggingface.py
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Method that processes the input data, respecting the `datasets.Dataset` formatting,
    and pushes it to the Hugging Face Hub based on the `RuntimeParameter`s attributes.

    Args:
        inputs: that input data within a single object (as it's a GlobalStep) that
            will be transformed into a `datasets.Dataset`.

    Yields:
        Propagates the received inputs so that the `Distiset` can be generated if this is
        the last step of the `Pipeline`, or if this is not a leaf step and has follow up
        steps.
    """
    dataset_dict = defaultdict(list)
    for input in inputs:
        for key, value in input.items():
            dataset_dict[key].append(value)
    dataset_dict = dict(dataset_dict)
    dataset = Dataset.from_dict(dataset_dict)
    dataset.push_to_hub(
        self.repo_id,  # type: ignore
        split=self.split,
        private=self.private,
        token=self.token or os.getenv("HF_TOKEN"),
    )
    yield inputs

Step

Bases: _Step, ABC

Base class for the steps that can be included in a Pipeline.

Attributes:

Name Type Description
input_batch_size RuntimeParameter[PositiveInt]

The number of rows that will contain the batches processed by the step. Defaults to 50.

Runtime parameters
  • input_batch_size: The number of rows that will contain the batches processed by the step. Defaults to 50.
Source code in src/distilabel/steps/base.py
class Step(_Step, ABC):
    """Base class for the steps that can be included in a `Pipeline`.

    Attributes:
        input_batch_size: The number of rows that will contain the batches processed by
            the step. Defaults to `50`.

    Runtime parameters:
        - `input_batch_size`: The number of rows that will contain the batches processed
            by the step. Defaults to `50`.
    """

    input_batch_size: RuntimeParameter[PositiveInt] = Field(
        default=DEFAULT_INPUT_BATCH_SIZE,
        description="The number of rows that will contain the batches processed by the"
        " step.",
    )

    @abstractmethod
    def process(self, *inputs: StepInput) -> "StepOutput":
        """Method that defines the processing logic of the step. It should yield the
        output rows.

        Args:
            *inputs: An argument used to receive the outputs of the previous steps. The
                number of arguments depends on the number of previous steps. It doesn't
                need to be an `*args` argument, it can be a regular argument annotated
                with `StepInput` if the step has only one previous step.
        """
        pass

    def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput":
        """Runs the `process` method of the step applying the `input_mappings` to the input
        rows and the `outputs_mappings` to the output rows. This is the function that
        should be used to run the processing logic of the step.

        Yields:
            The output rows.
        """

        inputs = self._apply_input_mappings(args) if self.input_mappings else args

        # If the `Step` was built using the `@step` decorator, then we need to pass
        # the runtime parameters as kwargs, so they can be used within the processing
        # function
        generator = (
            self.process(*inputs)
            if not self._built_from_decorator
            else self.process(*inputs, **self._runtime_parameters)
        )

        for output_rows in generator:
            yield [
                {
                    # Apply output mapping and revert input mapping
                    self.output_mappings.get(k, None)
                    or self.input_mappings.get(k, None)
                    or k: v
                    for k, v in row.items()
                }
                for row in output_rows
            ]

    def _revert_input_mappings(self, input: Dict[str, Any]) -> Dict[str, Any]:
        """Reverts the `input_mappings` of the step to the input row.

        Args:
            input: The input row.

        Returns:
            The input row with the `input_mappings` reverted.
        """
        return {self.input_mappings.get(k, k): v for k, v in input.items()}

    def _apply_input_mappings(
        self, inputs: Tuple[List[Dict[str, Any]], ...]
    ) -> List[List[Dict[str, Any]]]:
        """Applies the `input_mappings` to the input rows.

        Args:
            inputs: The input rows.

        Returns:
            The input rows with the `input_mappings` applied.
        """
        reverted_input_mappings = {v: k for k, v in self.input_mappings.items()}

        return [
            [
                {reverted_input_mappings.get(k, k): v for k, v in row.items()}
                for row in row_inputs
            ]
            for row_inputs in inputs
        ]

process(*inputs) abstractmethod

Method that defines the processing logic of the step. It should yield the output rows.

Parameters:

Name Type Description Default
*inputs StepInput

An argument used to receive the outputs of the previous steps. The number of arguments depends on the number of previous steps. It doesn't need to be an *args argument, it can be a regular argument annotated with StepInput if the step has only one previous step.

()
Source code in src/distilabel/steps/base.py
@abstractmethod
def process(self, *inputs: StepInput) -> "StepOutput":
    """Method that defines the processing logic of the step. It should yield the
    output rows.

    Args:
        *inputs: An argument used to receive the outputs of the previous steps. The
            number of arguments depends on the number of previous steps. It doesn't
            need to be an `*args` argument, it can be a regular argument annotated
            with `StepInput` if the step has only one previous step.
    """
    pass

process_applying_mappings(*args)

Runs the process method of the step applying the input_mappings to the input rows and the outputs_mappings to the output rows. This is the function that should be used to run the processing logic of the step.

Yields:

Type Description
StepOutput

The output rows.

Source code in src/distilabel/steps/base.py
def process_applying_mappings(self, *args: List[Dict[str, Any]]) -> "StepOutput":
    """Runs the `process` method of the step applying the `input_mappings` to the input
    rows and the `outputs_mappings` to the output rows. This is the function that
    should be used to run the processing logic of the step.

    Yields:
        The output rows.
    """

    inputs = self._apply_input_mappings(args) if self.input_mappings else args

    # If the `Step` was built using the `@step` decorator, then we need to pass
    # the runtime parameters as kwargs, so they can be used within the processing
    # function
    generator = (
        self.process(*inputs)
        if not self._built_from_decorator
        else self.process(*inputs, **self._runtime_parameters)
    )

    for output_rows in generator:
        yield [
            {
                # Apply output mapping and revert input mapping
                self.output_mappings.get(k, None)
                or self.input_mappings.get(k, None)
                or k: v
                for k, v in row.items()
            }
            for row in output_rows
        ]

TextGenerationToArgilla

Bases: Argilla

Creates a text generation dataset in Argilla.

Step that creates a dataset in Argilla during the load phase, and then pushes the input batches into it as records. This dataset is a text-generation dataset, where there's one field per each input, and then a label question to rate the quality of the completion in either bad (represented with 👎) or good (represented with 👍).

Note

This step is meant to be used in conjunction with a TextGeneration step and no column mapping is needed, as it will use the default values for the instruction and generation columns.

Attributes:

Name Type Description
dataset_name

The name of the dataset in Argilla.

dataset_workspace

The workspace where the dataset will be created in Argilla. Defaults to None, which means it will be created in the default workspace.

api_url

The URL of the Argilla API. Defaults to None, which means it will be read from the ARGILLA_API_URL environment variable.

api_key

The API key to authenticate with Argilla. Defaults to None, which means it will be read from the ARGILLA_API_KEY environment variable.

Runtime parameters
  • api_url: The base URL to use for the Argilla API requests.
  • api_key: The API key to authenticate the requests to the Argilla API.
Input columns
  • instruction (str): The instruction that was used to generate the completion.
  • generation (str or List[str]): The completions that were generated based on the input instruction.
Source code in src/distilabel/steps/argilla/text_generation.py
class TextGenerationToArgilla(Argilla):
    """Creates a text generation dataset in Argilla.

    `Step` that creates a dataset in Argilla during the load phase, and then pushes the input
    batches into it as records. This dataset is a text-generation dataset, where there's one field
    per each input, and then a label question to rate the quality of the completion in either bad
    (represented with 👎) or good (represented with 👍).

    Note:
        This step is meant to be used in conjunction with a `TextGeneration` step and no column mapping
        is needed, as it will use the default values for the `instruction` and `generation` columns.

    Attributes:
        dataset_name: The name of the dataset in Argilla.
        dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to
            `None`, which means it will be created in the default workspace.
        api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from
            the `ARGILLA_API_URL` environment variable.
        api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will
            be read from the `ARGILLA_API_KEY` environment variable.

    Runtime parameters:
        - `api_url`: The base URL to use for the Argilla API requests.
        - `api_key`: The API key to authenticate the requests to the Argilla API.

    Input columns:
        - instruction (`str`): The instruction that was used to generate the completion.
        - generation (`str` or `List[str]`): The completions that were generated based on the input instruction.
    """

    _id: str = PrivateAttr(default="id")
    _instruction: str = PrivateAttr(...)
    _generation: str = PrivateAttr(...)

    def load(self) -> None:
        """Sets the `_instruction` and `_generation` attributes based on the `inputs_mapping`, otherwise
        uses the default values; and then uses those values to create a `FeedbackDataset` suited for
        the text-generation scenario. And then it pushes it to Argilla.
        """
        super().load()

        self._instruction = self.input_mappings.get("instruction", "instruction")
        self._generation = self.input_mappings.get("generation", "generation")

        if self._rg_dataset_exists():
            _rg_dataset = rg.FeedbackDataset.from_argilla(  # type: ignore
                name=self.dataset_name,
                workspace=self.dataset_workspace,
            )

            for field in _rg_dataset.fields:
                if (
                    field.name not in [self._id, self._instruction, self._generation]
                    and field.required
                ):
                    raise ValueError(
                        f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
                        f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`"
                        f", nor `{self._generation}`."
                    )

            self._rg_dataset = _rg_dataset
        else:
            _rg_dataset = rg.FeedbackDataset(  # type: ignore
                fields=[
                    rg.TextField(name=self._id, title=self._id),  # type: ignore
                    rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                    rg.TextField(name=self._generation, title=self._generation),  # type: ignore
                ],
                questions=[
                    rg.LabelQuestion(  # type: ignore
                        name="quality",
                        title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
                        labels={"bad": "👎", "good": "👍"},
                    )
                ],
            )
            self._rg_dataset = _rg_dataset.push_to_argilla(
                name=self.dataset_name,  # type: ignore
                workspace=self.dataset_workspace,
            )

    @property
    def inputs(self) -> List[str]:
        """The inputs for the step are the `instruction` and the `generation`."""
        return ["instruction", "generation"]

    @override
    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Creates and pushes the records as FeedbackRecords to the Argilla dataset.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Returns:
            A list of Python dictionaries with the outputs of the task.
        """
        records = []
        for input in inputs:
            # Generate the SHA-256 hash of the instruction to use it as the metadata
            instruction_id = hashlib.sha256(
                input["instruction"].encode("utf-8")
            ).hexdigest()

            generations = input["generation"]

            # If the `generation` is not a list, then convert it into a list
            if not isinstance(generations, list):
                generations = [generations]

            # Create a `generations_set` to avoid adding duplicates
            generations_set = set()

            for generation in generations:
                # If the generation is already in the set, then skip it
                if generation in generations_set:
                    continue
                # Otherwise, add it to the set
                generations_set.add(generation)

                records.append(
                    rg.FeedbackRecord(  # type: ignore
                        fields={
                            self._id: instruction_id,
                            self._instruction: input["instruction"],
                            self._generation: generation,
                        },
                    )
                )
        self._rg_dataset.add_records(records)  # type: ignore
        yield inputs

inputs: List[str] property

The inputs for the step are the instruction and the generation.

load()

Sets the _instruction and _generation attributes based on the inputs_mapping, otherwise uses the default values; and then uses those values to create a FeedbackDataset suited for the text-generation scenario. And then it pushes it to Argilla.

Source code in src/distilabel/steps/argilla/text_generation.py
def load(self) -> None:
    """Sets the `_instruction` and `_generation` attributes based on the `inputs_mapping`, otherwise
    uses the default values; and then uses those values to create a `FeedbackDataset` suited for
    the text-generation scenario. And then it pushes it to Argilla.
    """
    super().load()

    self._instruction = self.input_mappings.get("instruction", "instruction")
    self._generation = self.input_mappings.get("generation", "generation")

    if self._rg_dataset_exists():
        _rg_dataset = rg.FeedbackDataset.from_argilla(  # type: ignore
            name=self.dataset_name,
            workspace=self.dataset_workspace,
        )

        for field in _rg_dataset.fields:
            if (
                field.name not in [self._id, self._instruction, self._generation]
                and field.required
            ):
                raise ValueError(
                    f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
                    f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`"
                    f", nor `{self._generation}`."
                )

        self._rg_dataset = _rg_dataset
    else:
        _rg_dataset = rg.FeedbackDataset(  # type: ignore
            fields=[
                rg.TextField(name=self._id, title=self._id),  # type: ignore
                rg.TextField(name=self._instruction, title=self._instruction),  # type: ignore
                rg.TextField(name=self._generation, title=self._generation),  # type: ignore
            ],
            questions=[
                rg.LabelQuestion(  # type: ignore
                    name="quality",
                    title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
                    labels={"bad": "👎", "good": "👍"},
                )
            ],
        )
        self._rg_dataset = _rg_dataset.push_to_argilla(
            name=self.dataset_name,  # type: ignore
            workspace=self.dataset_workspace,
        )

process(inputs)

Creates and pushes the records as FeedbackRecords to the Argilla dataset.

Parameters:

Name Type Description Default
inputs StepInput

A list of Python dictionaries with the inputs of the task.

required

Returns:

Type Description
StepOutput

A list of Python dictionaries with the outputs of the task.

Source code in src/distilabel/steps/argilla/text_generation.py
@override
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Creates and pushes the records as FeedbackRecords to the Argilla dataset.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Returns:
        A list of Python dictionaries with the outputs of the task.
    """
    records = []
    for input in inputs:
        # Generate the SHA-256 hash of the instruction to use it as the metadata
        instruction_id = hashlib.sha256(
            input["instruction"].encode("utf-8")
        ).hexdigest()

        generations = input["generation"]

        # If the `generation` is not a list, then convert it into a list
        if not isinstance(generations, list):
            generations = [generations]

        # Create a `generations_set` to avoid adding duplicates
        generations_set = set()

        for generation in generations:
            # If the generation is already in the set, then skip it
            if generation in generations_set:
                continue
            # Otherwise, add it to the set
            generations_set.add(generation)

            records.append(
                rg.FeedbackRecord(  # type: ignore
                    fields={
                        self._id: instruction_id,
                        self._instruction: input["instruction"],
                        self._generation: generation,
                    },
                )
            )
    self._rg_dataset.add_records(records)  # type: ignore
    yield inputs

step(inputs=None, outputs=None, step_type='normal')

Creates an Step from a processing function.

Parameters:

Name Type Description Default
inputs Union[List[str], None]

a list containing the name of the inputs columns/keys expected by this step. If not provided the default will be an empty list [] and it will be assumed that the step doesn't need any specific columns. Defaults to None.

None
outputs Union[List[str], None]

a list containing the name of the outputs columns/keys that the step will generate. If not provided the default will be an empty list [] and it will be assumed that the step doesn't need any specific columns. Defaults to None.

None
step_type Literal['normal', 'global', 'generator']

the kind of step to create. Valid choices are: "normal" (Step), "global" (GlobalStep) or "generator" (GeneratorStep). Defaults to "normal".

'normal'

Returns:

Type Description
Callable[..., Type[_Step]]

A callable that will generate the type given the processing function.

Example:

# Normal step
@step(inputs=["instruction"], outputs=["generation"])
def GenerationStep(inputs: StepInput, dummy_generation: RuntimeParameter[str]) -> StepOutput:
    for input in inputs:
        input["generation"] = dummy_generation
    yield inputs

# Global step
@step(inputs=["instruction"], step_type="global")
def FilteringStep(inputs: StepInput, max_length: RuntimeParameter[int] = 256) -> StepOutput:
    yield [
        input
        for input in inputs
        if len(input["instruction"]) <= max_length
    ]

# Generator step
@step(outputs=["num"], step_type="generator")
def RowGenerator(num_rows: RuntimeParameter[int] = 500) -> GeneratorStepOutput:
    data = list(range(num_rows))
    for i in range(0, len(data), 100):
        last_batch = i + 100 >= len(data)
        yield [{"num": num} for num in data[i : i + 100]], last_batch
Source code in src/distilabel/steps/decorator.py
def step(
    inputs: Union[List[str], None] = None,
    outputs: Union[List[str], None] = None,
    step_type: Literal["normal", "global", "generator"] = "normal",
) -> Callable[..., Type["_Step"]]:
    """Creates an `Step` from a processing function.

    Args:
        inputs: a list containing the name of the inputs columns/keys expected by this step.
            If not provided the default will be an empty list `[]` and it will be assumed
            that the step doesn't need any specific columns. Defaults to `None`.
        outputs: a list containing the name of the outputs columns/keys that the step
            will generate. If not provided the default will be an empty list `[]` and it
            will be assumed that the step doesn't need any specific columns. Defaults to
            `None`.
        step_type: the kind of step to create. Valid choices are: "normal" (`Step`),
            "global" (`GlobalStep`) or "generator" (`GeneratorStep`). Defaults to
            `"normal"`.

    Returns:
        A callable that will generate the type given the processing function.

    Example:

    ```python
    # Normal step
    @step(inputs=["instruction"], outputs=["generation"])
    def GenerationStep(inputs: StepInput, dummy_generation: RuntimeParameter[str]) -> StepOutput:
        for input in inputs:
            input["generation"] = dummy_generation
        yield inputs

    # Global step
    @step(inputs=["instruction"], step_type="global")
    def FilteringStep(inputs: StepInput, max_length: RuntimeParameter[int] = 256) -> StepOutput:
        yield [
            input
            for input in inputs
            if len(input["instruction"]) <= max_length
        ]

    # Generator step
    @step(outputs=["num"], step_type="generator")
    def RowGenerator(num_rows: RuntimeParameter[int] = 500) -> GeneratorStepOutput:
        data = list(range(num_rows))
        for i in range(0, len(data), 100):
            last_batch = i + 100 >= len(data)
            yield [{"num": num} for num in data[i : i + 100]], last_batch
    ```
    """

    inputs = inputs or []
    outputs = outputs or []

    def decorator(func: ProcessingFunc) -> Type["_Step"]:
        if step_type not in _STEP_MAPPING:
            raise ValueError(
                f"Invalid step type '{step_type}'. Please, review the '{func.__name__}'"
                " function decorated with the `@step` decorator and provide a valid"
                " `step_type`. Valid choices are: 'normal', 'global' or 'generator'."
            )

        BaseClass = _STEP_MAPPING[step_type]

        signature = inspect.signature(func)

        runtime_parameters = {
            name: (
                param.annotation,
                param.default if param.default != param.empty else None,
            )
            for name, param in signature.parameters.items()
        }

        runtime_parameters = {}
        step_input_parameter = None
        for name, param in signature.parameters.items():
            if is_parameter_annotated_with(param, _RUNTIME_PARAMETER_ANNOTATION):
                runtime_parameters[name] = (
                    param.annotation,
                    param.default if param.default != param.empty else None,
                )

            if not step_type == "generator" and is_parameter_annotated_with(
                param, _STEP_INPUT_ANNOTATION
            ):
                if step_input_parameter is not None:
                    raise ValueError(
                        f"Function '{func.__name__}' has more than one parameter annotated"
                        f" with `StepInput`. Please, review the '{func.__name__}' function"
                        " decorated with the `@step` decorator and provide only one"
                        " argument annotated with `StepInput`."
                    )
                step_input_parameter = param

        RuntimeParametersModel = create_model(  # type: ignore
            "RuntimeParametersModel",
            **runtime_parameters,  # type: ignore
        )

        def inputs_property(self) -> List[str]:
            return inputs

        def outputs_property(self) -> List[str]:
            return outputs

        def process(
            self, *args: Any, **kwargs: Any
        ) -> Union["StepOutput", "GeneratorStepOutput"]:
            return func(*args, **kwargs)

        return type(  # type: ignore
            func.__name__,
            (
                BaseClass,
                RuntimeParametersModel,
            ),
            {
                "process": process,
                "inputs": property(inputs_property),
                "outputs": property(outputs_property),
                "__module__": func.__module__,
                "__doc__": func.__doc__,
                "_built_from_decorator": True,
                # Override the `get_process_step_input` method to return the parameter
                # of the original function annotated with `StepInput`.
                "get_process_step_input": lambda self: step_input_parameter,
            },
        )

    return decorator