Image generation with distilabel¶
Create synthetic images using distilabel.
This example shows how distilabel can be used to generate image data, either using InferenceEndpointsImageGeneration or OpenAIImageGeneration, thanks to the ImageGeneration task.
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.steps.tasks import ImageGeneration
from datasets import load_dataset
ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))
with Pipeline(name="image_generation_pipeline") as pipeline:
ilm = InferenceEndpointsImageGeneration(
model_id="black-forest-labs/FLUX.1-schnell"
)
img_generation = ImageGeneration(
name="flux_schnell",
llm=ilm,
input_mappings={"prompt": "persona"}
)
keep_columns = KeepColumns(columns=["persona", "model_name", "image"])
img_generation >> keep_columns
Sample image for the prompt:
A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati.
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.models.image_generation import OpenAIImageGeneration
from distilabel.steps.tasks import ImageGeneration
from datasets import load_dataset
ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))
with Pipeline(name="image_generation_pipeline") as pipeline:
ilm = OpenAIImageGeneration(
model="dall-e-3",
generation_kwargs={
"size": "1024x1024",
"quality": "standard",
"style": "natural"
}
)
img_generation = ImageGeneration(
name="dalle-3"
llm=ilm,
input_mappings={"prompt": "persona"}
)
keep_columns = KeepColumns(columns=["persona", "model_name", "image"])
img_generation >> keep_columns
Sample image for the prompt:
A local art historian and museum professional interested in 19th-century American art and the local cultural heritage of Cincinnati.
Save the Distiset as an Image Dataset
Note the call to Distiset.transform_columns_to_image, to have the images uploaded directly as an Image dataset:
The full pipeline can be run at the following example. Keep in mind, you need to install pillow first: pip install distilabel[vision].
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from distilabel.models.image_generation import InferenceEndpointsImageGeneration
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns
from distilabel.steps.tasks import ImageGeneration
ds = load_dataset("dvilasuero/finepersonas-v0.1-tiny", split="train").select(range(3))
with Pipeline(name="image_generation_pipeline") as pipeline:
igm = InferenceEndpointsImageGeneration(model_id="black-forest-labs/FLUX.1-schnell")
img_generation = ImageGeneration(
name="flux_schnell",
image_generation_model=igm,
input_mappings={"prompt": "persona"},
)
keep_columns = KeepColumns(columns=["persona", "model_name", "image"])
img_generation >> keep_columns
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False, dataset=ds)
# Save the images as `PIL.Image.Image`
distiset = distiset.transform_columns_to_image("image")
distiset.push_to_hub("plaguss/test-finepersonas-v0.1-tiny-flux-schnell")

