ImageTask to work with Image Generation Models¶
Working with ImageTasks¶
The ImageTask
is a custom implementation of a Task
special to deal images. These tasks behave exactly as any other Task
, but instead of relying on an LLM
, they work with a ImageGenerationModel
.
New in version 1.5.0
This task is new and is expected to work with Image Generation Models.
These tasks take as attribute an image_generation_model
instead of llm
as we would have with the standard Task
, but everything else remains the same. Let's see an example with ImageGeneration
:
from distilabel.steps.tasks import ImageGeneration
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
task = ImageGeneration(
name="image-generation",
image_generation_model=InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell"),
)
task.load()
next(task.process([{"prompt": "a white siamese cat"}]))
# [{'image": "iVBORw0KGgoAAAANSUhEUgA...", "model_name": "black-forest-labs/FLUX.1-schnell"}]
Visualize the image in a notebook
If you are testing the ImageGeneration
task in a notebook, you can do the following
to see the rendered image:
Running ImageGeneration in a Pipeline
This transformation between image as string and as PIL object can be done for the whole dataset if running a pipeline, by calling the method transform_columns_to_image
on the final distiset and passing the name (or list of names) of the column image.
Defining custom ImageTasks¶
We can define a custom generator task by creating a new subclass of the ImageTask
and defining the following:
-
process
: is a method that generates the data based on theImageGenerationModel
and theprompt
provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns inoutputs
. -
inputs
: is a property that returns a list of strings with the names of the required input fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. -
outputs
: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always includemodel_name
as one of the outputs since that's automatically injected from the LLM. -
format_input
: is a method that receives a dictionary with the input data and returns a prompt to be passed to the model. -
format_output
: is a method that receives the output from theImageGenerationModel
and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns inoutputs
.
from typing import TYPE_CHECKING
from distilabel.models.image_generation.utils import image_from_str, image_to_str
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import ImageTask
if TYPE_CHECKING:
from distilabel.typing import StepColumns, StepOutput
class MyCustomImageTask(ImageTask):
@override
def process(self, offset: int = 0) -> GeneratorOutput:
formatted_inputs = self._format_inputs(inputs)
outputs = self.llm.generate_outputs(
inputs=formatted_inputs,
num_generations=self.num_generations,
**self.llm.get_generation_kwargs(),
)
task_outputs = []
for input, input_outputs in zip(inputs, outputs):
formatted_outputs = self._format_outputs(input_outputs, input)
for formatted_output in formatted_outputs:
task_outputs.append(
{**input, **formatted_output, "model_name": self.llm.model_name}
)
yield task_outputs
@property
def inputs(self) -> "StepColumns":
return ["prompt"]
@property
def outputs(self) -> "StepColumns":
return ["image", "model_name"]
def format_input(self, input: dict[str, any]) -> str:
return input["prompt"]
def format_output(
self, output: Union[str, None], input: dict[str, any]
) -> Dict[str, Any]:
# Extract/generate/modify the image from the output
return {"image": ..., "model_name": self.llm.model_name}
Warning
Note the fact that in the process
method we are not dealing with the image_generation
attribute but with the llm
. This is not a bug, but intended, as internally we rename the image_generation
to llm
to reuse the code.