Skip to content

ImageGenerationModel

This section contains the API reference for the distilabel image generation models, both for the ImageGenerationModel synchronous implementation, and for the AsyncImageGenerationModel asynchronous one.

For more information and examples on how to use existing LLMs or create custom ones, please refer to Tutorial - ImageGenerationModel.

base

ImageGenerationModel

Bases: RuntimeParametersModelMixin, BaseModel, _Serializable, ABC

Base class for ImageGeneration models.

To implement an ImageGeneration subclass, you need to subclass this class and implement: - load method to load the ImageGeneration model if needed. Don't forget to call super().load(), so the _logger attribute is initialized. - model_name property to return the model name used for the LLM. - generate method to generate num_generations per input in inputs.

Attributes:

Name Type Description
generation_kwargs Optional[RuntimeParameter[dict[str, Any]]]

the kwargs to be propagated to either generate or agenerate methods within each ImageGenerationModel.

_logger Logger

the logger to be used for the ImageGenerationModel. It will be initialized when the load method is called.

Source code in src/distilabel/models/image_generation/base.py
class ImageGenerationModel(RuntimeParametersModelMixin, BaseModel, _Serializable, ABC):
    """Base class for `ImageGeneration` models.

    To implement an `ImageGeneration` subclass, you need to subclass this class and implement:
        - `load` method to load the `ImageGeneration` model if needed. Don't forget to call `super().load()`,
            so the `_logger` attribute is initialized.
        - `model_name` property to return the model name used for the LLM.
        - `generate` method to generate `num_generations` per input in `inputs`.

    Attributes:
        generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
            methods within each `ImageGenerationModel`.
        _logger: the logger to be used for the `ImageGenerationModel`. It will be initialized
            when the `load` method is called.
    """

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        protected_namespaces=(),
        validate_default=True,
        validate_assignment=True,
        extra="forbid",
    )

    generation_kwargs: Optional[RuntimeParameter[dict[str, Any]]] = Field(
        default_factory=dict,
        description="The kwargs to be propagated to either `generate` or `agenerate`"
        " methods within each `ImageGenerationModel`.",
    )
    _logger: "Logger" = PrivateAttr(None)

    def load(self) -> None:
        """Method to be called to initialize the `ImageGenerationModel`, and its logger."""
        self._logger = logging.getLogger(
            f"distilabel.models.image_generation.{self.model_name}"
        )

    def unload(self) -> None:
        """Method to be called to unload the `ImageGenerationModel` and release any resources."""
        pass

    @property
    @abstractmethod
    def model_name(self) -> str:
        """Returns the model name used for the `ImageGenerationModel`."""
        pass

    def get_generation_kwargs(self) -> dict[str, Any]:
        """Returns the generation kwargs to be used for the generation. This method can
        be overridden to provide a more complex logic for the generation kwargs.

        Returns:
            The kwargs to be used for the generation.
        """
        return self.generation_kwargs  # type: ignore

    @abstractmethod
    def generate(
        self, inputs: list[str], num_generations: int = 1, **kwargs: Any
    ) -> list[list[dict[str, Any]]]:
        """Generates images from the provided input.

        Args:
            inputs: the prompt text to generate the image from.
            num_generations: the number of images to generate. Defaults to `1`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        pass

    def generate_outputs(
        self,
        inputs: list[str],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> list[list[dict[str, Any]]]:
        """This method is defined for compatibility with the `LLMs`. It calls the `generate`
        method.
        """
        return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)
model_name abstractmethod property

Returns the model name used for the ImageGenerationModel.

load()

Method to be called to initialize the ImageGenerationModel, and its logger.

Source code in src/distilabel/models/image_generation/base.py
def load(self) -> None:
    """Method to be called to initialize the `ImageGenerationModel`, and its logger."""
    self._logger = logging.getLogger(
        f"distilabel.models.image_generation.{self.model_name}"
    )
unload()

Method to be called to unload the ImageGenerationModel and release any resources.

Source code in src/distilabel/models/image_generation/base.py
def unload(self) -> None:
    """Method to be called to unload the `ImageGenerationModel` and release any resources."""
    pass
get_generation_kwargs()

Returns the generation kwargs to be used for the generation. This method can be overridden to provide a more complex logic for the generation kwargs.

Returns:

Type Description
dict[str, Any]

The kwargs to be used for the generation.

Source code in src/distilabel/models/image_generation/base.py
def get_generation_kwargs(self) -> dict[str, Any]:
    """Returns the generation kwargs to be used for the generation. This method can
    be overridden to provide a more complex logic for the generation kwargs.

    Returns:
        The kwargs to be used for the generation.
    """
    return self.generation_kwargs  # type: ignore
generate(inputs, num_generations=1, **kwargs) abstractmethod

Generates images from the provided input.

Parameters:

Name Type Description Default
inputs list[str]

the prompt text to generate the image from.

required
num_generations int

the number of images to generate. Defaults to 1.

1

Returns:

Type Description
list[list[dict[str, Any]]]

A list with a dictionary with the list of images generated.

Source code in src/distilabel/models/image_generation/base.py
@abstractmethod
def generate(
    self, inputs: list[str], num_generations: int = 1, **kwargs: Any
) -> list[list[dict[str, Any]]]:
    """Generates images from the provided input.

    Args:
        inputs: the prompt text to generate the image from.
        num_generations: the number of images to generate. Defaults to `1`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    pass
generate_outputs(inputs, num_generations=1, **kwargs)

This method is defined for compatibility with the LLMs. It calls the generate method.

Source code in src/distilabel/models/image_generation/base.py
def generate_outputs(
    self,
    inputs: list[str],
    num_generations: int = 1,
    **kwargs: Any,
) -> list[list[dict[str, Any]]]:
    """This method is defined for compatibility with the `LLMs`. It calls the `generate`
    method.
    """
    return self.generate(inputs=inputs, num_generations=num_generations, **kwargs)

AsyncImageGenerationModel

Bases: ImageGenerationModel

Abstract class for asynchronous ImageGenerationModels, to benefit from the async capabilities of each LLM implementation. This class is meant to be subclassed by each ImageGenerationModel, and the method agenerate needs to be implemented to provide the asynchronous generation of responses.

Attributes:

Name Type Description
_event_loop AbstractEventLoop

the event loop to be used for the asynchronous generation of responses.

Source code in src/distilabel/models/image_generation/base.py
class AsyncImageGenerationModel(ImageGenerationModel):
    """Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities
    of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the
    method `agenerate` needs to be implemented to provide the asynchronous generation of
    responses.

    Attributes:
        _event_loop: the event loop to be used for the asynchronous generation of responses.
    """

    _num_generations_param_supported = True
    _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
    _new_event_loop: bool = PrivateAttr(default=False)

    @property
    def generate_parameters(self) -> list[inspect.Parameter]:
        """Returns the parameters of the `agenerate` method.

        Returns:
            A list containing the parameters of the `agenerate` method.
        """
        return list(inspect.signature(self.agenerate).parameters.values())

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `agenerate` method.

        Returns:
            The parsed docstring of the `agenerate` method.
        """
        return parse_google_docstring(self.agenerate)

    @property
    def event_loop(self) -> "asyncio.AbstractEventLoop":
        if self._event_loop is None:
            try:
                self._event_loop = asyncio.get_running_loop()
                if self._event_loop.is_closed():
                    self._event_loop = asyncio.new_event_loop()  # type: ignore
                    self._new_event_loop = True
            except RuntimeError:
                self._event_loop = asyncio.new_event_loop()
                self._new_event_loop = True
        asyncio.set_event_loop(self._event_loop)
        return self._event_loop

    @abstractmethod
    async def agenerate(
        self, input: str, num_generations: int = 1, **kwargs: Any
    ) -> list[dict[str, Any]]:
        """Generates images from the provided input.

        Args:
            input: the input text to generate the image from.
            num_generations: the number of images to generate. Defaults to `1`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        pass

    async def _agenerate(
        self, inputs: list[str], num_generations: int = 1, **kwargs: Any
    ) -> list[list[dict[str, Any]]]:
        """Internal function to concurrently generate images for a list of inputs.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        if self._num_generations_param_supported:
            tasks = [
                asyncio.create_task(
                    self.agenerate(
                        input=input, num_generations=num_generations, **kwargs
                    )
                )
                for input in inputs
            ]
            return await asyncio.gather(*tasks)

        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
        return [
            list(group)
            for group in grouper(outputs, n=num_generations, incomplete="ignore")
        ]

    def generate(
        self,
        inputs: list[str],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> list[list[dict[str, Any]]]:
        """Method to generate a list of images asynchronously, returning the output
        synchronously awaiting for the image of each input sent to `agenerate`.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the images for each input.
        """
        return self.event_loop.run_until_complete(
            self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
        )

    def __del__(self) -> None:
        """Closes the event loop when the object is deleted."""
        if sys.meta_path is None:
            return

        if self._new_event_loop:
            if self._event_loop.is_running():
                self._event_loop.stop()
            self._event_loop.close()
generate_parameters property

Returns the parameters of the agenerate method.

Returns:

Type Description
list[Parameter]

A list containing the parameters of the agenerate method.

generate_parsed_docstring cached property

Returns the parsed docstring of the agenerate method.

Returns:

Type Description
Docstring

The parsed docstring of the agenerate method.

agenerate(input, num_generations=1, **kwargs) abstractmethod async

Generates images from the provided input.

Parameters:

Name Type Description Default
input str

the input text to generate the image from.

required
num_generations int

the number of images to generate. Defaults to 1.

1

Returns:

Type Description
list[dict[str, Any]]

A list with a dictionary with the list of images generated.

Source code in src/distilabel/models/image_generation/base.py
@abstractmethod
async def agenerate(
    self, input: str, num_generations: int = 1, **kwargs: Any
) -> list[dict[str, Any]]:
    """Generates images from the provided input.

    Args:
        input: the input text to generate the image from.
        num_generations: the number of images to generate. Defaults to `1`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    pass
generate(inputs, num_generations=1, **kwargs)

Method to generate a list of images asynchronously, returning the output synchronously awaiting for the image of each input sent to agenerate.

Parameters:

Name Type Description Default
inputs list[str]

the list of inputs to generate images for.

required
num_generations int

the number of generations to generate per input.

1
**kwargs Any

the additional kwargs to be used for the generation.

{}

Returns:

Type Description
list[list[dict[str, Any]]]

A list containing the images for each input.

Source code in src/distilabel/models/image_generation/base.py
def generate(
    self,
    inputs: list[str],
    num_generations: int = 1,
    **kwargs: Any,
) -> list[list[dict[str, Any]]]:
    """Method to generate a list of images asynchronously, returning the output
    synchronously awaiting for the image of each input sent to `agenerate`.

    Args:
        inputs: the list of inputs to generate images for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the images for each input.
    """
    return self.event_loop.run_until_complete(
        self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
    )
__del__()

Closes the event loop when the object is deleted.

Source code in src/distilabel/models/image_generation/base.py
def __del__(self) -> None:
    """Closes the event loop when the object is deleted."""
    if sys.meta_path is None:
        return

    if self._new_event_loop:
        if self._event_loop.is_running():
            self._event_loop.stop()
        self._event_loop.close()