Skip to content

ImageGenerationModel Gallery

This section contains the existing ImageGenerationModel subclasses implemented in distilabel.

image_generation

AsyncImageGenerationModel

Bases: ImageGenerationModel

Abstract class for asynchronous ImageGenerationModels, to benefit from the async capabilities of each LLM implementation. This class is meant to be subclassed by each ImageGenerationModel, and the method agenerate needs to be implemented to provide the asynchronous generation of responses.

Attributes:

Name Type Description
_event_loop AbstractEventLoop

the event loop to be used for the asynchronous generation of responses.

Source code in src/distilabel/models/image_generation/base.py
class AsyncImageGenerationModel(ImageGenerationModel):
    """Abstract class for asynchronous `ImageGenerationModels`, to benefit from the async capabilities
    of each LLM implementation. This class is meant to be subclassed by each `ImageGenerationModel`, and the
    method `agenerate` needs to be implemented to provide the asynchronous generation of
    responses.

    Attributes:
        _event_loop: the event loop to be used for the asynchronous generation of responses.
    """

    _num_generations_param_supported = True
    _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)
    _new_event_loop: bool = PrivateAttr(default=False)

    @property
    def generate_parameters(self) -> list[inspect.Parameter]:
        """Returns the parameters of the `agenerate` method.

        Returns:
            A list containing the parameters of the `agenerate` method.
        """
        return list(inspect.signature(self.agenerate).parameters.values())

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `agenerate` method.

        Returns:
            The parsed docstring of the `agenerate` method.
        """
        return parse_google_docstring(self.agenerate)

    @property
    def event_loop(self) -> "asyncio.AbstractEventLoop":
        if self._event_loop is None:
            try:
                self._event_loop = asyncio.get_running_loop()
                if self._event_loop.is_closed():
                    self._event_loop = asyncio.new_event_loop()  # type: ignore
                    self._new_event_loop = True
            except RuntimeError:
                self._event_loop = asyncio.new_event_loop()
                self._new_event_loop = True
        asyncio.set_event_loop(self._event_loop)
        return self._event_loop

    @abstractmethod
    async def agenerate(
        self, input: str, num_generations: int = 1, **kwargs: Any
    ) -> list[dict[str, Any]]:
        """Generates images from the provided input.

        Args:
            input: the input text to generate the image from.
            num_generations: the number of images to generate. Defaults to `1`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        pass

    async def _agenerate(
        self, inputs: list[str], num_generations: int = 1, **kwargs: Any
    ) -> list[list[dict[str, Any]]]:
        """Internal function to concurrently generate images for a list of inputs.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        if self._num_generations_param_supported:
            tasks = [
                asyncio.create_task(
                    self.agenerate(
                        input=input, num_generations=num_generations, **kwargs
                    )
                )
                for input in inputs
            ]
            return await asyncio.gather(*tasks)

        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
        return [
            list(group)
            for group in grouper(outputs, n=num_generations, incomplete="ignore")
        ]

    def generate(
        self,
        inputs: list[str],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> list[list[dict[str, Any]]]:
        """Method to generate a list of images asynchronously, returning the output
        synchronously awaiting for the image of each input sent to `agenerate`.

        Args:
            inputs: the list of inputs to generate images for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the images for each input.
        """
        return self.event_loop.run_until_complete(
            self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
        )

    def __del__(self) -> None:
        """Closes the event loop when the object is deleted."""
        if sys.meta_path is None:
            return

        if self._new_event_loop:
            if self._event_loop.is_running():
                self._event_loop.stop()
            self._event_loop.close()
generate_parameters property

Returns the parameters of the agenerate method.

Returns:

Type Description
list[Parameter]

A list containing the parameters of the agenerate method.

generate_parsed_docstring cached property

Returns the parsed docstring of the agenerate method.

Returns:

Type Description
Docstring

The parsed docstring of the agenerate method.

agenerate(input, num_generations=1, **kwargs) abstractmethod async

Generates images from the provided input.

Parameters:

Name Type Description Default
input str

the input text to generate the image from.

required
num_generations int

the number of images to generate. Defaults to 1.

1

Returns:

Type Description
list[dict[str, Any]]

A list with a dictionary with the list of images generated.

Source code in src/distilabel/models/image_generation/base.py
@abstractmethod
async def agenerate(
    self, input: str, num_generations: int = 1, **kwargs: Any
) -> list[dict[str, Any]]:
    """Generates images from the provided input.

    Args:
        input: the input text to generate the image from.
        num_generations: the number of images to generate. Defaults to `1`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    pass
_agenerate(inputs, num_generations=1, **kwargs) async

Internal function to concurrently generate images for a list of inputs.

Parameters:

Name Type Description Default
inputs list[str]

the list of inputs to generate images for.

required
num_generations int

the number of generations to generate per input.

1
**kwargs Any

the additional kwargs to be used for the generation.

{}

Returns:

Type Description
list[list[dict[str, Any]]]

A list containing the generations for each input.

Source code in src/distilabel/models/image_generation/base.py
async def _agenerate(
    self, inputs: list[str], num_generations: int = 1, **kwargs: Any
) -> list[list[dict[str, Any]]]:
    """Internal function to concurrently generate images for a list of inputs.

    Args:
        inputs: the list of inputs to generate images for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the generations for each input.
    """
    if self._num_generations_param_supported:
        tasks = [
            asyncio.create_task(
                self.agenerate(
                    input=input, num_generations=num_generations, **kwargs
                )
            )
            for input in inputs
        ]
        return await asyncio.gather(*tasks)

    tasks = [
        asyncio.create_task(self.agenerate(input=input, **kwargs))
        for input in inputs
        for _ in range(num_generations)
    ]
    outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
    return [
        list(group)
        for group in grouper(outputs, n=num_generations, incomplete="ignore")
    ]
generate(inputs, num_generations=1, **kwargs)

Method to generate a list of images asynchronously, returning the output synchronously awaiting for the image of each input sent to agenerate.

Parameters:

Name Type Description Default
inputs list[str]

the list of inputs to generate images for.

required
num_generations int

the number of generations to generate per input.

1
**kwargs Any

the additional kwargs to be used for the generation.

{}

Returns:

Type Description
list[list[dict[str, Any]]]

A list containing the images for each input.

Source code in src/distilabel/models/image_generation/base.py
def generate(
    self,
    inputs: list[str],
    num_generations: int = 1,
    **kwargs: Any,
) -> list[list[dict[str, Any]]]:
    """Method to generate a list of images asynchronously, returning the output
    synchronously awaiting for the image of each input sent to `agenerate`.

    Args:
        inputs: the list of inputs to generate images for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the images for each input.
    """
    return self.event_loop.run_until_complete(
        self._agenerate(inputs=inputs, num_generations=num_generations, **kwargs)
    )
__del__()

Closes the event loop when the object is deleted.

Source code in src/distilabel/models/image_generation/base.py
def __del__(self) -> None:
    """Closes the event loop when the object is deleted."""
    if sys.meta_path is None:
        return

    if self._new_event_loop:
        if self._event_loop.is_running():
            self._event_loop.stop()
        self._event_loop.close()

InferenceEndpointsImageGeneration

Bases: InferenceEndpointsBaseClient, AsyncImageGenerationModel

Inference Endpoint image generation implementation running the async API client.

Attributes:

Name Type Description
model_id Optional[str]

the model ID to use for the ImageGenerationModel as available in the Hugging Face Hub, which will be used to resolve the base URL for the serverless Inference Endpoints API requests. Defaults to None.

endpoint_name Optional[RuntimeParameter[str]]

the name of the Inference Endpoint to use for the LLM. Defaults to None.

endpoint_namespace Optional[RuntimeParameter[str]]

the namespace of the Inference Endpoint to use for the LLM. Defaults to None.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Inference Endpoints API requests.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Inference Endpoints API.

Icon

:hugging:

Examples:

Generate images from text prompts:

from distilabel.models.image_generation import InferenceEndpointsImageGeneration

igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key")
igm.load()

output = igm.generate_outputs(
    inputs=["a white siamese cat"],
)
# [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
Source code in src/distilabel/models/image_generation/huggingface/inference_endpoints.py
class InferenceEndpointsImageGeneration(  # type: ignore
    InferenceEndpointsBaseClient, AsyncImageGenerationModel
):
    """Inference Endpoint image generation implementation running the async API client.

    Attributes:
        model_id: the model ID to use for the ImageGenerationModel as available in the Hugging Face Hub, which
            will be used to resolve the base URL for the serverless Inference Endpoints API requests.
            Defaults to `None`.
        endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
        endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
        base_url: the base URL to use for the Inference Endpoints API requests.
        api_key: the API key to authenticate the requests to the Inference Endpoints API.

    Icon:
        `:hugging:`

    Examples:
        Generate images from text prompts:

        ```python
        from distilabel.models.image_generation import InferenceEndpointsImageGeneration

        igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell", api_key="api.key")
        igm.load()

        output = igm.generate_outputs(
            inputs=["a white siamese cat"],
        )
        # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
        ```
    """

    def load(self) -> None:
        from distilabel.models.image_generation.utils import image_to_str

        # Sets the logger and calls the load method of the BaseClient
        AsyncImageGenerationModel.load(self)
        InferenceEndpointsBaseClient.load(self)

        self._image_to_str = image_to_str

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: str,
        negative_prompt: Optional[str] = None,
        height: Optional[float] = None,
        width: Optional[float] = None,
        num_inference_steps: Optional[float] = None,
        guidance_scale: Optional[float] = None,
        num_generations: int = 1,
    ) -> list[dict[str, Any]]:
        """Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`.

        Args:
            input: Prompt to generate an image from.
            negative_prompt: An optional negative prompt for the image generation. Defaults to None.
            height: The height in pixels of the image to generate.
            width: The width in pixels of the image to generate.
            num_inference_steps: The number of denoising steps. More denoising steps usually lead
                to a higher quality image at the expense of slower inference.
            guidance_scale: Higher guidance scale encourages to generate images that are closely
                linked to the text `prompt`, usually at the expense of lower image quality.
            num_generations: The number of images to generate. Defaults to `1`.
                It's here to ensure the validation succeeds, but it won't have effect.

        Returns:
            A list with a dictionary containing a list with the image as a base64 string.
        """

        image: "Image" = await self._aclient.text_to_image(  # type: ignore
            input,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )
        img_str = self._image_to_str(image, image_format="JPEG")

        return [{"images": [img_str]}]
agenerate(input, negative_prompt=None, height=None, width=None, num_inference_steps=None, guidance_scale=None, num_generations=1) async

Generates images from text prompts using huggingface_hub.AsyncInferenceClient.text_to_image.

Parameters:

Name Type Description Default
input str

Prompt to generate an image from.

required
negative_prompt Optional[str]

An optional negative prompt for the image generation. Defaults to None.

None
height Optional[float]

The height in pixels of the image to generate.

None
width Optional[float]

The width in pixels of the image to generate.

None
num_inference_steps Optional[float]

The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.

None
guidance_scale Optional[float]

Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.

None
num_generations int

The number of images to generate. Defaults to 1. It's here to ensure the validation succeeds, but it won't have effect.

1

Returns:

Type Description
list[dict[str, Any]]

A list with a dictionary containing a list with the image as a base64 string.

Source code in src/distilabel/models/image_generation/huggingface/inference_endpoints.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: str,
    negative_prompt: Optional[str] = None,
    height: Optional[float] = None,
    width: Optional[float] = None,
    num_inference_steps: Optional[float] = None,
    guidance_scale: Optional[float] = None,
    num_generations: int = 1,
) -> list[dict[str, Any]]:
    """Generates images from text prompts using `huggingface_hub.AsyncInferenceClient.text_to_image`.

    Args:
        input: Prompt to generate an image from.
        negative_prompt: An optional negative prompt for the image generation. Defaults to None.
        height: The height in pixels of the image to generate.
        width: The width in pixels of the image to generate.
        num_inference_steps: The number of denoising steps. More denoising steps usually lead
            to a higher quality image at the expense of slower inference.
        guidance_scale: Higher guidance scale encourages to generate images that are closely
            linked to the text `prompt`, usually at the expense of lower image quality.
        num_generations: The number of images to generate. Defaults to `1`.
            It's here to ensure the validation succeeds, but it won't have effect.

    Returns:
        A list with a dictionary containing a list with the image as a base64 string.
    """

    image: "Image" = await self._aclient.text_to_image(  # type: ignore
        input,
        negative_prompt=negative_prompt,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
    )
    img_str = self._image_to_str(image, image_format="JPEG")

    return [{"images": [img_str]}]

OpenAIImageGeneration

Bases: OpenAIBaseClient, AsyncImageGenerationModel

OpenAI image generation implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the ImageGenerationModel e.g. "dall-e-3", etc. Supported models can be found here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the OpenAI API requests. Defaults to None, which means that the value set for the environment variable OPENAI_BASE_URL will be used, or "https://api.openai.com/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the OpenAI API. Defaults to None which means that the value set for the environment variable OPENAI_API_KEY will be used, or None if not set.

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 6.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

Icon

:simple-openai:

Examples:

Generate images from text prompts:

from distilabel.models.image_generation import OpenAIImageGeneration

igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key")

igm.load()

output = igm.generate_outputs(
    inputs=["a white siamese cat"],
    size="1024x1024",
    quality="standard",
    style="natural",
)
# [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
Source code in src/distilabel/models/image_generation/openai.py
class OpenAIImageGeneration(OpenAIBaseClient, AsyncImageGenerationModel):
    """OpenAI image generation implementation running the async API client.

    Attributes:
        model: the model name to use for the ImageGenerationModel e.g. "dall-e-3", etc.
            Supported models can be found [here](https://platform.openai.com/docs/guides/images).
        base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
            means that the value set for the environment variable `OPENAI_BASE_URL` will
            be used, or "https://api.openai.com/v1" if not set.
        api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
            `None` which means that the value set for the environment variable `OPENAI_API_KEY`
            will be used, or `None` if not set.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Icon:
        `:simple-openai:`

    Examples:
        Generate images from text prompts:

        ```python
        from distilabel.models.image_generation import OpenAIImageGeneration

        igm = OpenAIImageGeneration(model="dall-e-3", api_key="api.key")

        igm.load()

        output = igm.generate_outputs(
            inputs=["a white siamese cat"],
            size="1024x1024",
            quality="standard",
            style="natural",
        )
        # [{"images": ["iVBORw0KGgoAAAANSUhEUgA..."]}]
        ```
    """

    def load(self) -> None:
        # Sets the logger and calls the load method of the BaseClient
        AsyncImageGenerationModel.load(self)
        OpenAIBaseClient.load(self)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: str,
        num_generations: int = 1,
        quality: Optional[Literal["standard", "hd"]] = "standard",
        response_format: Optional[Literal["url", "b64_json"]] = "url",
        size: Optional[
            Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
        ] = None,
        style: Optional[Literal["vivid", "natural"]] = None,
    ) -> list[dict[str, Any]]:
        """Generates `num_generations` images for the given input using the OpenAI async
        client. The images are base64 string representations.

        Args:
            input: A text description of the desired image(s). The maximum length is 1000
                characters for `dall-e-2` and 4000 characters for `dall-e-3`.
            num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
                `n=1` is supported.
            quality: The quality of the image that will be generated. `hd` creates images with finer
                details and greater consistency across the image. This param is only supported
                for `dall-e-3`.
            response_format: The format in which the generated images are returned. Must be one of `url` or
                `b64_json`. URLs are only valid for 60 minutes after the image has been
                generated.
            size: The size of the generated images. Must be one of `256x256`, `512x512`, or
                `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
                `1024x1792` for `dall-e-3` models.
            style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid
                causes the model to lean towards generating hyper-real and dramatic images.
                Natural causes the model to produce more natural, less hyper-real looking
                images. This param is only supported for `dall-e-3`.

        Returns:
            A list with a dictionary with the list of images generated.
        """
        images_response: "ImagesResponse" = await self._aclient.images.generate(
            model=self.model_name,
            prompt=input,
            n=num_generations,
            quality=quality,
            response_format=response_format,
            size=size,
            style=style,
        )
        images = []
        for image in images_response.data:
            if response_format == "url":
                image_data = requests.get(
                    image.url
                ).content  # TODO: Keep a requests/httpx session instead
                image_str = base64.b64encode(image_data).decode()
                images.append(image_str)
            elif response_format == "b64_json":
                images.append(image.b64_json)
        return [{"images": images}]
agenerate(input, num_generations=1, quality='standard', response_format='url', size=None, style=None) async

Generates num_generations images for the given input using the OpenAI async client. The images are base64 string representations.

Parameters:

Name Type Description Default
input str

A text description of the desired image(s). The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3.

required
num_generations int

The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1 is supported.

1
quality Optional[Literal['standard', 'hd']]

The quality of the image that will be generated. hd creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3.

'standard'
response_format Optional[Literal['url', 'b64_json']]

The format in which the generated images are returned. Must be one of url or b64_json. URLs are only valid for 60 minutes after the image has been generated.

'url'
size Optional[Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']]

The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.

None
style Optional[Literal['vivid', 'natural']]

The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for dall-e-3.

None

Returns:

Type Description
list[dict[str, Any]]

A list with a dictionary with the list of images generated.

Source code in src/distilabel/models/image_generation/openai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: str,
    num_generations: int = 1,
    quality: Optional[Literal["standard", "hd"]] = "standard",
    response_format: Optional[Literal["url", "b64_json"]] = "url",
    size: Optional[
        Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
    ] = None,
    style: Optional[Literal["vivid", "natural"]] = None,
) -> list[dict[str, Any]]:
    """Generates `num_generations` images for the given input using the OpenAI async
    client. The images are base64 string representations.

    Args:
        input: A text description of the desired image(s). The maximum length is 1000
            characters for `dall-e-2` and 4000 characters for `dall-e-3`.
        num_generations: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
            `n=1` is supported.
        quality: The quality of the image that will be generated. `hd` creates images with finer
            details and greater consistency across the image. This param is only supported
            for `dall-e-3`.
        response_format: The format in which the generated images are returned. Must be one of `url` or
            `b64_json`. URLs are only valid for 60 minutes after the image has been
            generated.
        size: The size of the generated images. Must be one of `256x256`, `512x512`, or
            `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
            `1024x1792` for `dall-e-3` models.
        style: The style of the generated images. Must be one of `vivid` or `natural`. Vivid
            causes the model to lean towards generating hyper-real and dramatic images.
            Natural causes the model to produce more natural, less hyper-real looking
            images. This param is only supported for `dall-e-3`.

    Returns:
        A list with a dictionary with the list of images generated.
    """
    images_response: "ImagesResponse" = await self._aclient.images.generate(
        model=self.model_name,
        prompt=input,
        n=num_generations,
        quality=quality,
        response_format=response_format,
        size=size,
        style=style,
    )
    images = []
    for image in images_response.data:
        if response_format == "url":
            image_data = requests.get(
                image.url
            ).content  # TODO: Keep a requests/httpx session instead
            image_str = base64.b64encode(image_data).decode()
            images.append(image_str)
        elif response_format == "b64_json":
            images.append(image.b64_json)
    return [{"images": images}]