Skip to content

Image generation with distilabel

Create synthetic images using distilabel.

This example shows how distilabel can be used to generate image data, either using InferenceEndpointsImageGeneration or OpenAIImageGeneration, thanks to the ImageGeneration task.

from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.steps.tasks import ImageGeneration

from datasets import load_dataset

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
    ilm = InferenceEndpointsImageGeneration(
        model_id="black-forest-labs/FLUX.1-schnell"
    )

    img_generation = ImageGeneration(
        name="flux_schnell",
        llm=ilm,
        input_mappings={"prompt": "persona"}
    )

    keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

    img_generation >> keep_columns

Sample image for the prompt:

A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati.

image_ie

from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.models.image_generation import OpenAIImageGeneration
from distilabel.steps.tasks import ImageGeneration

from datasets import load_dataset

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
    ilm = OpenAIImageGeneration(
        model="dall-e-3",
        generation_kwargs={
            "size": "1024x1024",
            "quality": "standard",
            "style": "natural"
        }
    )

    img_generation = ImageGeneration(
        name="dalle-3"
        llm=ilm,
        input_mappings={"prompt": "persona"}
    )

    keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

    img_generation >> keep_columns

Sample image for the prompt:

A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati.

image_oai

Save the Distiset as an Image Dataset

Note the call to Distiset.transform_columns_to_image, to have the images uploaded directly as an Image dataset:

if __name__ == "__main__":
    distiset = pipeline.run(use_cache=False, dataset=ds)
    # Save the images as `PIL.Image.Image`
    distiset = distiset.transform_columns_to_image("image")
    distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")

The full pipeline can be run at the following example. Keep in mind, you need to install pillow first: pip install distilabel[vision].

Run
python examples/image_generation.py
image_generation.py
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset

from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.steps.tasks import ImageGeneration

ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))

with Pipeline(name="image_generation_pipeline") as pipeline:
    igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")

    img_generation = ImageGeneration(
        name="flux_schnell",
        image_generation_model=igm,
        input_mappings={"prompt": "persona"},
    )

    keep_columns = KeepColumns(columns=["persona", "model_name", "image"])

    img_generation >> keep_columns


if __name__ == "__main__":
    distiset = pipeline.run(use_cache=False, dataset=ds)
    # Save the images as `PIL.Image.Image`
    distiset = distiset.transform_columns_to_image("image")
    distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")