Skip to content

GeneratorTask that produces output

Working with GeneratorTasks

The GeneratorTask is a custom implementation of a Task based on the GeneratorStep. As with a Task, it is normally used within a Pipeline but can also be used standalone.

Warning

This task is still experimental and may be subject to changes in the future.

from typing import Any, Dict, List, Union
from typing_extensions import override

from distilabel.steps.tasks.base import GeneratorTask
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.typing import GeneratorOutput


class MyCustomTask(GeneratorTask):
    instruction: str

    @override
    def process(self, offset: int = 0) -> GeneratorOutput:
        output = self.llm.generate(
            inputs=[
                [
                    {"role": "user", "content": self.instruction},
                ],
            ],
        )
        output = {"model_name": self.llm.model_name}
        output.update(
            self.format_output(output=output, input=None)
        )
        yield output

    @property
    def outputs(self) -> List[str]:
        return ["output_field", "model_name"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        return {"output_field": output}

We can then use it as follows:

task = MyCustomTask(
    name="custom-generation",
    instruction="Tell me a joke.",
    llm=OpenAILLM(model="gpt-4"),
)
task.load()

next(task.process())
# [{'output_field": "Why did the scarecrow win an award? Because he was outstanding!", "model_name": "gpt-4"}]

Note

Most of the times you would need to override the default process method, as it's suited for the standard Task and not for the GeneratorTask. But within the context of the process function you can freely use the llm to generate data in any way.

Note

The Step.load() always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution.

Defining custom GeneratorTasks

We can define a custom generator task by creating a new subclass of the GeneratorTask and defining the following:

  • process: is a method that generates the data based on the LLM and the instruction provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in outputs. Note that the inputs argument is not allowed in this function since this is a GeneratorTask. The signature only expects the offset argument, which is used to keep track of the current iteration in the generator.

  • outputs: is a property that returns a list of strings with the names of the output fields, this property should always include model_name as one of the outputs since that's automatically injected from the LLM.

  • format_output: is a method that receives the output from the LLM and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in outputs. Note that there's no need to include the model_name in the output.

from typing import Any, Dict, List, Union

from distilabel.steps.tasks.base import GeneratorTask
from distilabel.steps.tasks.typing import ChatType


class MyCustomTask(GeneratorTask):
    @override
    def process(self, offset: int = 0) -> GeneratorOutput:
        output = self.llm.generate(
            inputs=[
                [{"role": "user", "content": "Tell me a joke."}],
            ],
        )
        output = {"model_name": self.llm.model_name}
        output.update(
            self.format_output(output=output, input=None)
        )
        yield output

    @property
    def outputs(self) -> List[str]:
        return ["output_field", "model_name"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        return {"output_field": output}