Skip to content

ImageTask

This section contains the API reference for the distilabel image generation tasks.

For more information on how the ImageTask works and see some examples, check the Tutorial - Task - ImageTask page.

ImageTask

Bases: _Task, Step

ImageTask is a class that implements the _Task abstract class and adds the Step interface to be used as a step in the pipeline. It differs from the Task in that it's expected to work with ImageGenerationModels instead of LLMs.

Attributes:

Name Type Description
image_generation_model ImageGenerationModel

the ImageGenerationModel to be used to generate the outputs.

llm Union[LLM, ImageGenerationModel, None]

This attribute is here to respect the _Task interface, but it's used internally only.

group_generations bool

whether to group the num_generations generated per input in a list or create a row per generation. Defaults to False.

num_generations RuntimeParameter[int]

The number of generations to be produced per input.

Source code in src/distilabel/steps/tasks/base.py
class ImageTask(_Task, Step):
    """`ImageTask` is a class that implements the `_Task` abstract class and adds the `Step`
    interface to be used as a step in the pipeline. It differs from the `Task` in that it's
    expected to work with `ImageGenerationModel`s instead of `LLM`s.

    Attributes:
        image_generation_model: the `ImageGenerationModel` to be used to generate the outputs.
        llm: This attribute is here to respect the `_Task` interface, but it's used internally only.
        group_generations: whether to group the `num_generations` generated per input in
            a list or create a row per generation. Defaults to `False`.
        num_generations: The number of generations to be produced per input.
    """

    llm: Union[LLM, ImageGenerationModel, None] = None
    image_generation_model: ImageGenerationModel

    def model_post_init(self, __context: Any) -> None:
        assert self.llm is None, (
            "`ImageTask` cannot use an `LLM` attribute given by the user, pass "
            "the `image_generation_model` attribute instead."
        )
        self.llm = self.image_generation_model
        # Call the post init from the Step, as we don't want to call specific behaviour
        # from the task, that may need to deal with specific attributes from the LLM
        # not in the ImageGenerationModel
        super(Step, self).model_post_init(__context)

    @abstractmethod
    def format_input(self, input: dict[str, any]) -> str:
        """Abstract method to format the inputs of the task. It needs to receive an input
        as a Python dictionary, and generates a string to be used as the prompt for the model."""
        pass

    def _format_inputs(self, inputs: list[dict[str, any]]) -> List["FormattedInput"]:
        """Formats the inputs of the task using the `format_input` method.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Returns:
            A list containing the formatted inputs, which are `ChatType`-like following
            the OpenAI formatting.
        """
        return [self.format_input(input) for input in inputs]

    def _format_outputs(
        self,
        outputs: list[Union[str, None]],
        input: Union[Dict[str, Any], None] = None,
    ) -> List[Dict[str, Any]]:
        """Formats the outputs of the task using the `format_output` method. If the output
        is `None` (i.e. the LLM failed to generate a response), then the outputs will be
        set to `None` as well.

        Args:
            outputs: The outputs (`n` generations) for the provided `input`.
            input: The input used to generate the output.

        Returns:
            A list containing a dictionary with the outputs of the task for each input.
        """
        inputs = [None] if input is None else [input]
        formatted_outputs = []

        for output, input in zip(outputs, inputs):  # type: ignore
            try:
                formatted_output = self.format_output(output, input)
                formatted_output = self._create_metadata(
                    formatted_output,
                    output,
                    input,
                    add_raw_output=self.add_raw_output,  # type: ignore
                    add_raw_input=self.add_raw_input,  # type: ignore
                    statistics=None,
                )
                formatted_outputs.append(formatted_output)
            except Exception as e:
                self._logger.warning(  # type: ignore
                    f"Task '{self.name}' failed to format output: {e}. Saving raw response."  # type: ignore
                )
                formatted_outputs.append(self._output_on_failure(output, input))
        return formatted_outputs

    @abstractmethod
    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Yields:
            A list of Python dictionaries with the outputs of the task.
        """
        pass

format_input(input) abstractmethod

Abstract method to format the inputs of the task. It needs to receive an input as a Python dictionary, and generates a string to be used as the prompt for the model.

Source code in src/distilabel/steps/tasks/base.py
@abstractmethod
def format_input(self, input: dict[str, any]) -> str:
    """Abstract method to format the inputs of the task. It needs to receive an input
    as a Python dictionary, and generates a string to be used as the prompt for the model."""
    pass

process(inputs) abstractmethod

Processes the inputs of the task and generates the outputs using the ImageGenerationModel.

Parameters:

Name Type Description Default
inputs StepInput

A list of Python dictionaries with the inputs of the task.

required

Yields:

Type Description
StepOutput

A list of Python dictionaries with the outputs of the task.

Source code in src/distilabel/steps/tasks/base.py
@abstractmethod
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Processes the inputs of the task and generates the outputs using the `ImageGenerationModel`.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Yields:
        A list of Python dictionaries with the outputs of the task.
    """
    pass