Skip to content

LLM Gallery

This section contains the existing LLM subclasses implemented in distilabel.

llms

AnthropicLLM

Bases: AsyncLLM

Anthropic LLM implementation running the Async API client.

Attributes:

Name Type Description
model str

the name of the model to use for the LLM e.g. "claude-3-opus-20240229", "claude-3-sonnet-20240229", etc. Available models can be checked here: Anthropic: Models overview.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Anthropic API. If not provided, it will be read from ANTHROPIC_API_KEY environment variable.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Anthropic API. Defaults to None which means that https://api.anthropic.com will be used internally.

timeout RuntimeParameter[float]

the maximum time in seconds to wait for a response. Defaults to 600.0.

max_retries RuntimeParameter[int]

The maximum number of times to retry the request before failing. Defaults to 6.

http_client Optional[AsyncClient]

if provided, an alternative HTTP client to use for calling Anthropic API. Defaults to None.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

_aclient Optional[AsyncAnthropic]

the AsyncAnthropic client to use for the Anthropic API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • api_key: the API key to authenticate the requests to the Anthropic API. If not provided, it will be read from ANTHROPIC_API_KEY environment variable.
  • base_url: the base URL to use for the Anthropic API. Defaults to "https://api.anthropic.com".
  • timeout: the maximum time in seconds to wait for a response. Defaults to 600.0.
  • max_retries: the maximum number of times to retry the request before failing. Defaults to 6.

Examples:

Generate text:

from distilabel.models.llms import AnthropicLLM

llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import AnthropicLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = AnthropicLLM(
    model="claude-3-opus-20240229",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/anthropic.py
class AnthropicLLM(AsyncLLM):
    """Anthropic LLM implementation running the Async API client.

    Attributes:
        model: the name of the model to use for the LLM e.g. "claude-3-opus-20240229",
            "claude-3-sonnet-20240229", etc. Available models can be checked here:
            [Anthropic: Models overview](https://docs.anthropic.com/claude/docs/models-overview).
        api_key: the API key to authenticate the requests to the Anthropic API. If not provided,
            it will be read from `ANTHROPIC_API_KEY` environment variable.
        base_url: the base URL to use for the Anthropic API. Defaults to `None` which means
            that `https://api.anthropic.com` will be used internally.
        timeout: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        max_retries: The maximum number of times to retry the request before failing. Defaults
            to `6`.
        http_client: if provided, an alternative HTTP client to use for calling Anthropic
            API. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.
        _aclient: the `AsyncAnthropic` client to use for the Anthropic API. It is meant
            to be used internally. Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Anthropic API. If not
            provided, it will be read from `ANTHROPIC_API_KEY` environment variable.
        - `base_url`: the base URL to use for the Anthropic API. Defaults to `"https://api.anthropic.com"`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        - `max_retries`: the maximum number of times to retry the request before failing.
            Defaults to `6`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnthropicLLM

        llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import AnthropicLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = AnthropicLLM(
            model="claude-3-opus-20240229",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANTHROPIC_BASE_URL", "https://api.anthropic.com"
        ),
        description="The base URL to use for the Anthropic API.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANTHROPIC_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anthropic API.",
    )
    timeout: RuntimeParameter[float] = Field(
        default=600.0,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    http_client: Optional[AsyncClient] = Field(default=None, exclude=True)
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)

    def _check_model_exists(self) -> None:
        """Checks if the specified model exists in the available models."""
        from anthropic import AsyncAnthropic

        annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
        models = [
            value
            for type_ in get_args(annotation)
            if get_origin(type_) is Literal
            for value in get_args(type_)
        ]

        if self.model not in models:
            raise ValueError(
                f"Model {self.model} does not exist among available models. "
                f"The available models are {', '.join(models)}"
            )

    def load(self) -> None:
        """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
        super().load()

        try:
            from anthropic import AsyncAnthropic
        except ImportError as ie:
            raise ImportError(
                "Anthropic Python client is not installed. Please install it using"
                " `pip install 'distilabel[anthropic]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._check_model_exists()

        self._aclient = AsyncAnthropic(
            api_key=self.api_key.get_secret_value(),
            base_url=self.base_url,
            timeout=self.timeout,
            http_client=self.http_client,
            max_retries=self.max_retries,
        )
        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="anthropic",
            )
            self._aclient = result.get("client")
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_tokens: int = 128,
        stop_sequences: Union[List[str], None] = None,
        temperature: float = 1.0,
        top_p: Union[float, None] = None,
        top_k: Union[int, None] = None,
    ) -> GenerateOutput:
        """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

        Args:
            input: a single input in chat format to generate responses for.
            max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
            stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
            temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
            top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from anthropic._types import NOT_GIVEN

        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="anthropic",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "system": (
                input.pop(0)["content"]
                if input and input[0]["role"] == "system"
                else NOT_GIVEN
            ),
            "max_tokens": max_tokens,
            "stream": False,
            "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences,
            "temperature": temperature,
            "top_p": NOT_GIVEN if top_p is None else top_p,
            "top_k": NOT_GIVEN if top_k is None else top_k,
        }

        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
            **kwargs
        )  # type: ignore
        if structured_output:
            # raw_response = completion._raw_response
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        if (content := completion.content[0].text) is None:
            self._logger.warning(
                f"Received no response using Anthropic client (model: '{self.model}')."
                f" Finish reason was: {completion.stop_reason}"
            )
        return prepare_output([content], **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "Message") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.input_tokens],
            "output_tokens": [completion.usage.output_tokens],
        }
model_name property

Returns the model name used for the LLM.

_check_model_exists()

Checks if the specified model exists in the available models.

Source code in src/distilabel/models/llms/anthropic.py
def _check_model_exists(self) -> None:
    """Checks if the specified model exists in the available models."""
    from anthropic import AsyncAnthropic

    annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
    models = [
        value
        for type_ in get_args(annotation)
        if get_origin(type_) is Literal
        for value in get_args(type_)
    ]

    if self.model not in models:
        raise ValueError(
            f"Model {self.model} does not exist among available models. "
            f"The available models are {', '.join(models)}"
        )
load()

Loads the AsyncAnthropic client to use the Anthropic async API.

Source code in src/distilabel/models/llms/anthropic.py
def load(self) -> None:
    """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
    super().load()

    try:
        from anthropic import AsyncAnthropic
    except ImportError as ie:
        raise ImportError(
            "Anthropic Python client is not installed. Please install it using"
            " `pip install 'distilabel[anthropic]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._check_model_exists()

    self._aclient = AsyncAnthropic(
        api_key=self.api_key.get_secret_value(),
        base_url=self.base_url,
        timeout=self.timeout,
        http_client=self.http_client,
        max_retries=self.max_retries,
    )
    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="anthropic",
        )
        self._aclient = result.get("client")
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, max_tokens=128, stop_sequences=None, temperature=1.0, top_p=None, top_k=None) async

Generates a response asynchronously, using the Anthropic Async API definition.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
max_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
stop_sequences Union[List[str], None]

custom text sequences that will cause the model to stop generating. Defaults to NOT_GIVEN.

None
temperature float

the temperature to use for the generation. Set only if top_p is None. Defaults to 1.0.

1.0
top_p Union[float, None]

the top-p value to use for the generation. Defaults to NOT_GIVEN.

None
top_k Union[int, None]

the top-k value to use for the generation. Defaults to NOT_GIVEN.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/anthropic.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_tokens: int = 128,
    stop_sequences: Union[List[str], None] = None,
    temperature: float = 1.0,
    top_p: Union[float, None] = None,
    top_k: Union[int, None] = None,
) -> GenerateOutput:
    """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

    Args:
        input: a single input in chat format to generate responses for.
        max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
        stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
        temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
        top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from anthropic._types import NOT_GIVEN

    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="anthropic",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "system": (
            input.pop(0)["content"]
            if input and input[0]["role"] == "system"
            else NOT_GIVEN
        ),
        "max_tokens": max_tokens,
        "stream": False,
        "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences,
        "temperature": temperature,
        "top_p": NOT_GIVEN if top_p is None else top_p,
        "top_k": NOT_GIVEN if top_k is None else top_k,
    }

    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
        **kwargs
    )  # type: ignore
    if structured_output:
        # raw_response = completion._raw_response
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    if (content := completion.content[0].text) is None:
        self._logger.warning(
            f"Received no response using Anthropic client (model: '{self.model}')."
            f" Finish reason was: {completion.stop_reason}"
        )
    return prepare_output([content], **self._get_llm_statistics(completion))

AnyscaleLLM

Bases: OpenAILLM

Anyscale LLM implementation running the async API client of OpenAI.

Attributes:

Name Type Description
model str

the model name to use for the LLM, e.g., google/gemma-7b-it. See the supported models under the "Text Generation -> Supported Models" section here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Anyscale API requests. Defaults to None, which means that the value set for the environment variable ANYSCALE_BASE_URL will be used, or "https://api.endpoints.anyscale.com/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Anyscale API. Defaults to None which means that the value set for the environment variable ANYSCALE_API_KEY will be used, or None if not set.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

Examples:

Generate text:

from distilabel.models.llms import AnyscaleLLM

llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/anyscale.py
class AnyscaleLLM(OpenAILLM):
    """Anyscale LLM implementation running the async API client of OpenAI.

    Attributes:
        model: the model name to use for the LLM, e.g., `google/gemma-7b-it`. See the
            supported models under the "Text Generation -> Supported Models" section
            [here](https://docs.endpoints.anyscale.com/).
        base_url: the base URL to use for the Anyscale API requests. Defaults to `None`, which
            means that the value set for the environment variable `ANYSCALE_BASE_URL` will be used, or
            "https://api.endpoints.anyscale.com/v1" if not set.
        api_key: the API key to authenticate the requests to the Anyscale API. Defaults to `None` which
            means that the value set for the environment variable `ANYSCALE_API_KEY` will be used, or
            `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key.
            It is meant to be used internally.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnyscaleLLM

        llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANYSCALE_BASE_URL", "https://api.endpoints.anyscale.com/v1"
        ),
        description="The base URL to use for the Anyscale API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANYSCALE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anyscale API.",
    )

    _api_key_env_var: str = PrivateAttr(_ANYSCALE_API_KEY_ENV_VAR_NAME)

AzureOpenAILLM

Bases: OpenAILLM

Azure OpenAI LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM i.e. the name of the Azure deployment.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Azure OpenAI API can be set with AZURE_OPENAI_ENDPOINT. Defaults to None which means that the value set for the environment variable AZURE_OPENAI_ENDPOINT will be used, or None if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Azure OpenAI API. Defaults to None which means that the value set for the environment variable AZURE_OPENAI_API_KEY will be used, or None if not set.

api_version Optional[RuntimeParameter[str]]

the API version to use for the Azure OpenAI API. Defaults to None which means that the value set for the environment variable OPENAI_API_VERSION will be used, or None if not set.

Icon

:material-microsoft-azure:

Examples:

Generate text:

from distilabel.models.llms import AzureOpenAILLM

llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate text from a custom endpoint following the OpenAI API:

from distilabel.models.llms import AzureOpenAILLM

llm = AzureOpenAILLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    base_url=r"http://localhost:8080/v1"
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import AzureOpenAILLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = AzureOpenAILLM(
    model="gpt-4-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/azure.py
class AzureOpenAILLM(OpenAILLM):
    """Azure OpenAI LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM i.e. the name of the Azure deployment.
        base_url: the base URL to use for the Azure OpenAI API can be set with `AZURE_OPENAI_ENDPOINT`.
            Defaults to `None` which means that the value set for the environment variable
            `AZURE_OPENAI_ENDPOINT` will be used, or `None` if not set.
        api_key: the API key to authenticate the requests to the Azure OpenAI API. Defaults to `None`
            which means that the value set for the environment variable `AZURE_OPENAI_API_KEY` will be
            used, or `None` if not set.
        api_version: the API version to use for the Azure OpenAI API. Defaults to `None` which means
            that the value set for the environment variable `OPENAI_API_VERSION` will be used, or
            `None` if not set.

    Icon:
        `:material-microsoft-azure:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AzureOpenAILLM

        llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate text from a custom endpoint following the OpenAI API:

        ```python
        from distilabel.models.llms import AzureOpenAILLM

        llm = AzureOpenAILLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            base_url=r"http://localhost:8080/v1"
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import AzureOpenAILLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = AzureOpenAILLM(
            model="gpt-4-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME),
        description="The base URL to use for the Azure OpenAI API requests i.e. the Azure OpenAI endpoint.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Azure OpenAI API.",
    )

    api_version: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv("OPENAI_API_VERSION"),
        description="The API version to use for the Azure OpenAI API.",
    )

    _base_url_env_var: str = PrivateAttr(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME)
    _api_key_env_var: str = PrivateAttr(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAzureOpenAI"] = PrivateAttr(...)  # type: ignore

    @override
    def load(self) -> None:
        """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
        # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
        # in the load method before we have the proper client.
        with patch(
            "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
        ):
            super().load()

        try:
            from openai import AsyncAzureOpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[openai]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        # TODO: May be worth adding the AD auth too? Also the `organization`?
        self._aclient = AsyncAzureOpenAI(  # type: ignore
            azure_endpoint=self.base_url,  # type: ignore
            azure_deployment=self.model,
            api_version=self.api_version,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        if self.structured_output:
            self._prepare_structured_output(self.structured_output)
load()

Loads the AsyncAzureOpenAI client to benefit from async requests.

Source code in src/distilabel/models/llms/azure.py
@override
def load(self) -> None:
    """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
    # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
    # in the load method before we have the proper client.
    with patch(
        "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
    ):
        super().load()

    try:
        from openai import AsyncAzureOpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[openai]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    # TODO: May be worth adding the AD auth too? Also the `organization`?
    self._aclient = AsyncAzureOpenAI(  # type: ignore
        azure_endpoint=self.base_url,  # type: ignore
        azure_deployment=self.model,
        api_version=self.api_version,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    if self.structured_output:
        self._prepare_structured_output(self.structured_output)

CohereLLM

Bases: AsyncLLM

Cohere API implementation using the async client for concurrent text generation.

Attributes:

Name Type Description
model str

the name of the model from the Cohere API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Cohere API requests. Defaults to "https://api.cohere.ai/v1".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

client_name RuntimeParameter[str]

the name of the client to use for the API requests. Defaults to "distilabel".

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_ChatMessage Type[ChatMessage]

the ChatMessage class from the cohere package.

_aclient AsyncClient

the AsyncClient client from the cohere package.

Runtime parameters
  • base_url: the base URL to use for the Cohere API requests. Defaults to "https://api.cohere.ai/v1".
  • api_key: the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
  • client_name: the name of the client to use for the API requests. Defaults to "distilabel".

Examples:

Generate text:

from distilabel.models.llms import CohereLLM

llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import CohereLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = CohereLLM(
    model="CohereForAI/c4ai-command-r-plus",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/cohere.py
class CohereLLM(AsyncLLM):
    """Cohere API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Cohere API to use for the generation.
        base_url: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        api_key: the API key to authenticate the requests to the Cohere API. Defaults to
            the value of the `COHERE_API_KEY` environment variable.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        client_name: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _ChatMessage: the `ChatMessage` class from the `cohere` package.
        _aclient: the `AsyncClient` client from the `cohere` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        - `api_key`: the API key to authenticate the requests to the Cohere API. Defaults
            to the value of the `COHERE_API_KEY` environment variable.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `client_name`: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import CohereLLM

        llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import CohereLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = CohereLLM(
            model="CohereForAI/c4ai-command-r-plus",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "COHERE_BASE_URL", "https://api.cohere.ai/v1"
        ),
        description="The base URL to use for the Cohere API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_COHERE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Cohere API.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    client_name: RuntimeParameter[str] = Field(
        default="distilabel",
        description="The name of the client to use for the API requests.",
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
    _aclient: "AsyncClient" = PrivateAttr(...)
    _tokenizer: "Tokenizer" = PrivateAttr(...)

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def load(self) -> None:
        """Loads the `AsyncClient` client from the `cohere` package."""

        super().load()

        try:
            from cohere import AsyncClient, ChatMessage
        except ImportError as ie:
            raise ImportError(
                "The `cohere` package is required to use the `CohereLLM` class."
            ) from ie

        self._ChatMessage = ChatMessage

        self._aclient = AsyncClient(
            api_key=self.api_key.get_secret_value(),  # type: ignore
            client_name=self.client_name,
            base_url=self.base_url,
            timeout=self.timeout,
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="cohere",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

        from cohere.manually_maintained.tokenizers import get_hf_tokenizer

        self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)

    def _format_chat_to_cohere(
        self, input: "FormattedInput"
    ) -> Tuple[Union[str, None], List["ChatMessage"], str]:
        """Formats the chat input to the Cohere Chat API conversational format.

        Args:
            input: The chat input to format.

        Returns:
            A tuple containing the system, chat history, and message.
        """
        system = None
        message = None
        chat_history = []
        for item in input:
            role = item["role"]
            content = item["content"]
            if role == "system":
                system = content
            elif role == "user":
                message = content
            elif role == "assistant":
                if message is None:
                    raise ValueError(
                        "An assistant message but be preceded by a user message."
                    )
                chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
                chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
                message = None

        if message is None:
            raise ValueError("The chat input must end with a user message.")

        return system, chat_history, message

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        k: Optional[int] = None,
        p: Optional[float] = None,
        seed: Optional[float] = None,
        stop_sequences: Optional[Sequence[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        raw_prompting: Optional[bool] = None,
    ) -> GenerateOutput:
        """Generates a response from the LLM given an input.

        Args:
            input: a single input in chat format to generate responses for.
            temperature: the temperature to use for the generation. Defaults to `None`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `None`.
            k: the number of highest probability vocabulary tokens to keep for the generation.
                Defaults to `None`.
            p: the nucleus sampling probability to use for the generation. Defaults to
                `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: a list of sequences to use as stopping criteria for the generation.
                Defaults to `None`.
            frequency_penalty: the frequency penalty to use for the generation. Defaults
                to `None`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `None`.
            raw_prompting: a flag to use raw prompting for the generation. Defaults to
                `None`.

        Returns:
            The generated response from the Cohere API model.
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,  # type: ignore
                client=self._aclient,
                framework="cohere",
            )
            self._aclient = result.get("client")  # type: ignore

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        system, chat_history, message = self._format_chat_to_cohere(input)

        kwargs = {
            "message": message,
            "model": self.model,
            "preamble": system,
            "chat_history": chat_history,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "k": k,
            "p": p,
            "seed": seed,
            "stop_sequences": stop_sequences,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "raw_prompting": raw_prompting,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

        response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs)  # type: ignore

        if structured_output:
            return prepare_output(
                [response.model_dump_json()],
                **self._get_llm_statistics(
                    input, orjson.dumps(response.model_dump_json()).decode("utf-8")
                ),  # type: ignore
            )

        if (text := response.text) == "":
            self._logger.warning(  # type: ignore
                f"Received no response using Cohere client (model: '{self.model}')."
                f" Finish reason was: {response.finish_reason}"
            )
            return prepare_output(
                [None],
                **self._get_llm_statistics(input, ""),
            )

        return prepare_output(
            [text],
            **self._get_llm_statistics(input, text),
        )

    def _get_llm_statistics(
        self, input: FormattedInput, output: str
    ) -> "LLMStatistics":
        return {
            "input_tokens": [compute_tokens(input, self._tokenizer.encode)],
            "output_tokens": [compute_tokens(output, self._tokenizer.encode)],
        }
model_name property

Returns the model name used for the LLM.

load()

Loads the AsyncClient client from the cohere package.

Source code in src/distilabel/models/llms/cohere.py
def load(self) -> None:
    """Loads the `AsyncClient` client from the `cohere` package."""

    super().load()

    try:
        from cohere import AsyncClient, ChatMessage
    except ImportError as ie:
        raise ImportError(
            "The `cohere` package is required to use the `CohereLLM` class."
        ) from ie

    self._ChatMessage = ChatMessage

    self._aclient = AsyncClient(
        api_key=self.api_key.get_secret_value(),  # type: ignore
        client_name=self.client_name,
        base_url=self.base_url,
        timeout=self.timeout,
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="cohere",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output

    from cohere.manually_maintained.tokenizers import get_hf_tokenizer

    self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)
_format_chat_to_cohere(input)

Formats the chat input to the Cohere Chat API conversational format.

Parameters:

Name Type Description Default
input FormattedInput

The chat input to format.

required

Returns:

Type Description
Tuple[Union[str, None], List[ChatMessage], str]

A tuple containing the system, chat history, and message.

Source code in src/distilabel/models/llms/cohere.py
def _format_chat_to_cohere(
    self, input: "FormattedInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
    """Formats the chat input to the Cohere Chat API conversational format.

    Args:
        input: The chat input to format.

    Returns:
        A tuple containing the system, chat history, and message.
    """
    system = None
    message = None
    chat_history = []
    for item in input:
        role = item["role"]
        content = item["content"]
        if role == "system":
            system = content
        elif role == "user":
            message = content
        elif role == "assistant":
            if message is None:
                raise ValueError(
                    "An assistant message but be preceded by a user message."
                )
            chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
            chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
            message = None

    if message is None:
        raise ValueError("The chat input must end with a user message.")

    return system, chat_history, message
agenerate(input, temperature=None, max_tokens=None, k=None, p=None, seed=None, stop_sequences=None, frequency_penalty=None, presence_penalty=None, raw_prompting=None) async

Generates a response from the LLM given an input.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
temperature Optional[float]

the temperature to use for the generation. Defaults to None.

None
max_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to None.

None
k Optional[int]

the number of highest probability vocabulary tokens to keep for the generation. Defaults to None.

None
p Optional[float]

the nucleus sampling probability to use for the generation. Defaults to None.

None
seed Optional[float]

the seed to use for the generation. Defaults to None.

None
stop_sequences Optional[Sequence[str]]

a list of sequences to use as stopping criteria for the generation. Defaults to None.

None
frequency_penalty Optional[float]

the frequency penalty to use for the generation. Defaults to None.

None
presence_penalty Optional[float]

the presence penalty to use for the generation. Defaults to None.

None
raw_prompting Optional[bool]

a flag to use raw prompting for the generation. Defaults to None.

None

Returns:

Type Description
GenerateOutput

The generated response from the Cohere API model.

Source code in src/distilabel/models/llms/cohere.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    k: Optional[int] = None,
    p: Optional[float] = None,
    seed: Optional[float] = None,
    stop_sequences: Optional[Sequence[str]] = None,
    frequency_penalty: Optional[float] = None,
    presence_penalty: Optional[float] = None,
    raw_prompting: Optional[bool] = None,
) -> GenerateOutput:
    """Generates a response from the LLM given an input.

    Args:
        input: a single input in chat format to generate responses for.
        temperature: the temperature to use for the generation. Defaults to `None`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `None`.
        k: the number of highest probability vocabulary tokens to keep for the generation.
            Defaults to `None`.
        p: the nucleus sampling probability to use for the generation. Defaults to
            `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: a list of sequences to use as stopping criteria for the generation.
            Defaults to `None`.
        frequency_penalty: the frequency penalty to use for the generation. Defaults
            to `None`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `None`.
        raw_prompting: a flag to use raw prompting for the generation. Defaults to
            `None`.

    Returns:
        The generated response from the Cohere API model.
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,  # type: ignore
            client=self._aclient,
            framework="cohere",
        )
        self._aclient = result.get("client")  # type: ignore

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    system, chat_history, message = self._format_chat_to_cohere(input)

    kwargs = {
        "message": message,
        "model": self.model,
        "preamble": system,
        "chat_history": chat_history,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "k": k,
        "p": p,
        "seed": seed,
        "stop_sequences": stop_sequences,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "raw_prompting": raw_prompting,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

    response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs)  # type: ignore

    if structured_output:
        return prepare_output(
            [response.model_dump_json()],
            **self._get_llm_statistics(
                input, orjson.dumps(response.model_dump_json()).decode("utf-8")
            ),  # type: ignore
        )

    if (text := response.text) == "":
        self._logger.warning(  # type: ignore
            f"Received no response using Cohere client (model: '{self.model}')."
            f" Finish reason was: {response.finish_reason}"
        )
        return prepare_output(
            [None],
            **self._get_llm_statistics(input, ""),
        )

    return prepare_output(
        [text],
        **self._get_llm_statistics(input, text),
    )

GroqLLM

Bases: AsyncLLM

Groq API implementation using the async client for concurrent text generation.

Attributes:

Name Type Description
model str

the name of the model from the Groq API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Groq API requests. Defaults to "https://api.groq.com".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Groq API. Defaults to the value of the GROQ_API_KEY environment variable.

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 2.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_api_key_env_var str

the name of the environment variable to use for the API key.

_aclient Optional[AsyncGroq]

the AsyncGroq client from the groq package.

Runtime parameters
  • base_url: the base URL to use for the Groq API requests. Defaults to "https://api.groq.com".
  • api_key: the API key to authenticate the requests to the Groq API. Defaults to the value of the GROQ_API_KEY environment variable.
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 2.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.

Examples:

Generate text:

from distilabel.models.llms import GroqLLM

llm = GroqLLM(model="llama3-70b-8192")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import GroqLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = GroqLLM(
    model="llama3-70b-8192",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/groq.py
class GroqLLM(AsyncLLM):
    """Groq API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Groq API to use for the generation.
        base_url: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        api_key: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key.
        _aclient: the `AsyncGroq` client from the `groq` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        - `api_key`: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import GroqLLM

        llm = GroqLLM(model="llama3-70b-8192")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import GroqLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = GroqLLM(
            model="llama3-70b-8192",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            _GROQ_API_BASE_URL_ENV_VAR_NAME, "https://api.groq.com"
        ),
        description="The base URL to use for the Groq API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_GROQ_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Groq API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=2,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(_GROQ_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncGroq"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `AsyncGroq` client to benefit from async requests."""
        super().load()

        try:
            from groq import AsyncGroq
        except ImportError as ie:
            raise ImportError(
                "Groq Python client is not installed. Please install it using"
                ' `pip install "distilabel[groq]"`.'
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = AsyncGroq(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="groq",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        seed: Optional[int] = None,
        max_new_tokens: int = 128,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[str] = None,
    ) -> "GenerateOutput":
        """Generates `num_generations` responses for the given input using the Groq async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            seed: the seed to use for the generation. Defaults to `None`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: the stop sequence to use for the generation. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.

        References:
            - https://console.groq.com/docs/text-chat
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="groq",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "seed": seed,
            "temperature": temperature,
            "max_tokens": max_new_tokens,
            "top_p": top_p,
            "stream": False,
            "stop": stop,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
        if structured_output:
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        generations = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using the Groq client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return prepare_output(generations, **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "ChatCompletion") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.prompt_tokens if completion else 0],
            "output_tokens": [completion.usage.completion_tokens if completion else 0],
        }
model_name property

Returns the model name used for the LLM.

load()

Loads the AsyncGroq client to benefit from async requests.

Source code in src/distilabel/models/llms/groq.py
def load(self) -> None:
    """Loads the `AsyncGroq` client to benefit from async requests."""
    super().load()

    try:
        from groq import AsyncGroq
    except ImportError as ie:
        raise ImportError(
            "Groq Python client is not installed. Please install it using"
            ' `pip install "distilabel[groq]"`.'
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = AsyncGroq(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="groq",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, seed=None, max_new_tokens=128, temperature=1.0, top_p=1.0, stop=None) async

Generates num_generations responses for the given input using the Groq async client.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
seed Optional[int]

the seed to use for the generation. Defaults to None.

None
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[str]

the stop sequence to use for the generation. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

References
  • https://console.groq.com/docs/text-chat
Source code in src/distilabel/models/llms/groq.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    seed: Optional[int] = None,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[str] = None,
) -> "GenerateOutput":
    """Generates `num_generations` responses for the given input using the Groq async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        seed: the seed to use for the generation. Defaults to `None`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: the stop sequence to use for the generation. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.

    References:
        - https://console.groq.com/docs/text-chat
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="groq",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "seed": seed,
        "temperature": temperature,
        "max_tokens": max_new_tokens,
        "top_p": top_p,
        "stream": False,
        "stop": stop,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
    if structured_output:
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    generations = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using the Groq client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return prepare_output(generations, **self._get_llm_statistics(completion))

InferenceEndpointsLLM

Bases: InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin

InferenceEndpoints LLM implementation running the async API client.

This LLM will internally use huggingface_hub.AsyncInferenceClient.

Attributes:

Name Type Description
model_id Optional[str]

the model ID to use for the LLM as available in the Hugging Face Hub, which will be used to resolve the base URL for the serverless Inference Endpoints API requests. Defaults to None.

endpoint_name Optional[RuntimeParameter[str]]

the name of the Inference Endpoint to use for the LLM. Defaults to None.

endpoint_namespace Optional[RuntimeParameter[str]]

the namespace of the Inference Endpoint to use for the LLM. Defaults to None.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Inference Endpoints API requests.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Inference Endpoints API.

tokenizer_id Optional[str]

the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to None, but defining one is recommended to properly format the prompt.

model_display_name Optional[str]

the model display name to use for the LLM. Defaults to None.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

structured_output Optional[RuntimeParameter[StructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

Icon

:hugging:

Examples:

Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Dedicated Inference Endpoints:

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    endpoint_name="<ENDPOINT_NAME>",
    api_key="<HF_API_KEY>",
    endpoint_namespace="<USER|ORG>",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Dedicated Inference Endpoints or TGI:

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    api_key="<HF_API_KEY>",
    base_url="<BASE_URL>",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import InferenceEndpointsLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
    api_key="api.key",
    structured_output={"format": "json", "schema": User.model_json_schema()}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
class InferenceEndpointsLLM(
    InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin
):
    """InferenceEndpoints LLM implementation running the async API client.

    This LLM will internally use `huggingface_hub.AsyncInferenceClient`.

    Attributes:
        model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which
            will be used to resolve the base URL for the serverless Inference Endpoints API requests.
            Defaults to `None`.
        endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
        endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
        base_url: the base URL to use for the Inference Endpoints API requests.
        api_key: the API key to authenticate the requests to the Inference Endpoints API.
        tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub.
            Defaults to `None`, but defining one is recommended to properly format the prompt.
        model_display_name: the model display name to use for the LLM. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        structured_output: a dictionary containing the structured output configuration or
            if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`.
            Defaults to None.

    Icon:
        `:hugging:`

    Examples:
        Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Dedicated Inference Endpoints:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            endpoint_name="<ENDPOINT_NAME>",
            api_key="<HF_API_KEY>",
            endpoint_namespace="<USER|ORG>",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Dedicated Inference Endpoints or TGI:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            api_key="<HF_API_KEY>",
            base_url="<BASE_URL>",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import InferenceEndpointsLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
            api_key="api.key",
            structured_output={"format": "json", "schema": User.model_json_schema()}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
        ```
    """

    def load(self) -> None:
        # Sets the logger and calls the load method of the BaseClient
        self._num_generations_param_supported = False
        AsyncLLM.load(self)
        InferenceEndpointsBaseClient.load(self)

    @model_validator(mode="after")  # type: ignore
    def only_one_of_model_id_endpoint_name_or_base_url_provided(
        self,
    ) -> "InferenceEndpointsLLM":
        """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
        provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
        favour of the dynamically calculated one.."""

        if self.base_url and (self.model_id or self.endpoint_name):
            self._logger.warning(  # type: ignore
                f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
                " or `endpoint_name` is also provided, the `base_url` will either be ignored"
                " or overwritten with the one generated from either of those args, for serverless"
                " or dedicated inference endpoints, respectively."
            )

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

        if (
            self.model_id
            and self.tokenizer_id is None
            and self.structured_output is not None
        ):
            self.tokenizer_id = self.model_id

        if self.base_url and not (self.model_id or self.endpoint_name):
            return self

        if self.model_id and not self.endpoint_name:
            return self

        if self.endpoint_name and not self.model_id:
            return self

        raise ValidationError(
            f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
            f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
            f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                conversation=input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _get_structured_output(
        self, input: FormattedInput
    ) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
        """Gets the structured output (if any) for the given input.

        Args:
            input: a single input in chat format to generate responses for.

        Returns:
            The input and the structured output that will be passed as `grammar` to the
            inference endpoint or `None` if not required.
        """
        structured_output = None

        # Specific structured output per input
        if isinstance(input, tuple):
            input, structured_output = input
            structured_output = {
                "type": structured_output["format"],  # type: ignore
                "value": structured_output["schema"],  # type: ignore
            }

        # Same structured output for all the inputs
        if structured_output is None and self.structured_output is not None:
            try:
                structured_output = {
                    "type": self.structured_output["format"],  # type: ignore
                    "value": self.structured_output["schema"],  # type: ignore
                }
            except KeyError as e:
                raise ValueError(
                    "To use the structured output you have to inform the `format` and `schema` in "
                    "the `structured_output` attribute."
                ) from e

        if structured_output:
            if isinstance(structured_output["value"], ModelMetaclass):
                structured_output["value"] = structured_output[
                    "value"
                ].model_json_schema()

        return input, structured_output

    async def _generate_with_text_generation(
        self,
        input: str,
        max_new_tokens: int = 128,
        repetition_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        temperature: float = 1.0,
        do_sample: bool = False,
        top_n_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        typical_p: Optional[float] = None,
        stop_sequences: Union[List[str], None] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        watermark: bool = False,
        structured_output: Union[Dict[str, Any], None] = None,
    ) -> GenerateOutput:
        generation: Union["TextGenerationOutput", None] = None
        try:
            generation = await self._aclient.text_generation(  # type: ignore
                prompt=input,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_n_tokens=top_n_tokens,
                top_p=top_p,
                top_k=top_k,
                stop_sequences=stop_sequences,
                return_full_text=return_full_text,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=seed or random.randint(0, sys.maxsize),
                watermark=watermark,
                grammar=structured_output,  # type: ignore
                details=True,
            )
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
        return prepare_output(
            generations=[generation.generated_text] if generation else [None],
            input_tokens=[
                compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1
            ],
            output_tokens=[
                generation.details.generated_tokens
                if generation and generation.details
                else 0
            ],
            logprobs=self._get_logprobs_from_text_generation(generation)
            if generation
            else None,  # type: ignore
        )

    def _get_logprobs_from_text_generation(
        self, generation: "TextGenerationOutput"
    ) -> Union[List[List[List["Logprob"]]], None]:
        if generation.details is None or generation.details.top_tokens is None:
            return None

        return [
            [
                [
                    {"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
                    for top_logprob in token_logprobs
                ]
                for token_logprobs in generation.details.top_tokens
            ]
        ]

    async def _generate_with_chat_completion(
        self,
        input: "StandardInput",
        max_new_tokens: int = 128,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[List[float]] = None,
        logprobs: bool = False,
        presence_penalty: Optional[float] = None,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: float = 1.0,
        tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
        tool_prompt: Optional[str] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
        top_logprobs: Optional[PositiveInt] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        message = None
        completion: Union["ChatCompletionOutput", None] = None
        output_logprobs = None
        try:
            completion = await self._aclient.chat_completion(  # type: ignore
                messages=input,  # type: ignore
                max_tokens=max_new_tokens,
                frequency_penalty=frequency_penalty,
                logit_bias=logit_bias,
                logprobs=logprobs,
                presence_penalty=presence_penalty,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=seed or random.randint(0, sys.maxsize),
                stop=stop_sequences,
                temperature=temperature,
                tool_choice=tool_choice,  # type: ignore
                tool_prompt=tool_prompt,
                tools=tools,  # type: ignore
                top_logprobs=top_logprobs,
                top_p=top_p,
            )
            choice = completion.choices[0]  # type: ignore
            if (message := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            if choice_logprobs := self._get_logprobs_from_choice(choice):
                output_logprobs = [choice_logprobs]
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
        return prepare_output(
            generations=[message],
            input_tokens=[completion.usage.prompt_tokens] if completion else None,
            output_tokens=[completion.usage.completion_tokens] if completion else None,
            logprobs=output_logprobs,
        )

    def _get_logprobs_from_choice(
        self, choice: "ChatCompletionOutputComplete"
    ) -> Union[List[List["Logprob"]], None]:
        if choice.logprobs is None:
            return None

        return [
            [
                {"token": top_logprob.token, "logprob": top_logprob.logprob}
                for top_logprob in token_logprobs.top_logprobs
            ]
            for token_logprobs in choice.logprobs.content
        ]

    def _check_stop_sequences(
        self,
        stop_sequences: Optional[Union[str, List[str]]] = None,
    ) -> Union[List[str], None]:
        """Checks that no more than 4 stop sequences are provided.

        Args:
            stop_sequences: the stop sequences to be checked.

        Returns:
            The stop sequences.
        """
        if stop_sequences is not None:
            if isinstance(stop_sequences, str):
                stop_sequences = [stop_sequences]
            if len(stop_sequences) > 4:
                warnings.warn(
                    "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
                    UserWarning,
                    stacklevel=2,
                )
                stop_sequences = stop_sequences[:4]
        return stop_sequences

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
        logit_bias: Optional[List[float]] = None,
        logprobs: bool = False,
        presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: float = 1.0,
        tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
        tool_prompt: Optional[str] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
        top_logprobs: Optional[PositiveInt] = None,
        top_n_tokens: Optional[PositiveInt] = None,
        top_p: Optional[float] = None,
        do_sample: bool = False,
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        top_k: Optional[int] = None,
        typical_p: Optional[float] = None,
        watermark: bool = False,
        num_generations: int = 1,
    ) -> GenerateOutput:
        """Generates completions for the given input using the async client. This method
        uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
        `chat_completion` method will be used only if no `tokenizer_id` has been specified.
        Some arguments of this function are specific to the `text_generation` method, while
        some others are specific to the `chat_completion` method.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
                new tokens based on their existing frequency in the text so far, decreasing
                model's likelihood to repeat the same line verbatim. Defauls to `None`.
            logit_bias: modify the likelihood of specified tokens appearing in the completion.
                This argument is exclusive to the `chat_completion` method and will be used
                only if `tokenizer_id` is `None`.
                Defaults to `None`.
            logprobs: whether to return the log probabilities or not. This argument is exclusive
                to the `chat_completion` method and will be used only if `tokenizer_id`
                is `None`. Defaults to `False`.
            presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
                new tokens based on whether they appear in the text so far, increasing the
                model likelihood to talk about new topics. This argument is exclusive to
                the `chat_completion` method and will be used only if `tokenizer_id` is
                `None`. Defauls to `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: either a single string or a list of strings containing the sequences
                to stop the generation at. Defaults to `None`, but will be set to the
                `tokenizer.eos_token` if available.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            tool_choice: the name of the tool the model should call. It can be a dictionary
                like `{"function_name": "my_tool"}` or "auto". If not provided, then the
                model won't use any tool. This argument is exclusive to the `chat_completion`
                method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
            tool_prompt: A prompt to be appended before the tools. This argument is exclusive
                to the `chat_completion` method and will be used only if `tokenizer_id`
                is `None`. Defauls to `None`.
            tools: a list of tools definitions that the LLM can use.
                This argument is exclusive to the `chat_completion` method and will be used
                only if `tokenizer_id` is `None`. Defaults to `None`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. This argument is exclusive to the `chat_completion` method and
                will be used only if `tokenizer_id` is `None`. Defaults to `None`.
            top_n_tokens: the number of top log probabilities to return per output token
                generated. This argument is exclusive of the `text_generation` method and
                will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            do_sample: whether to use sampling for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id` is not
                `None`. Defaults to `False`.
            repetition_penalty: the repetition penalty to use for the generation. This argument
                is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            return_full_text: whether to return the full text of the completion or just
                the generated text. Defaults to `False`, meaning that only the generated
                text will be returned. This argument is exclusive of the `text_generation`
                method and will be only used if `tokenizer_id` is not `None`.
            top_k: the top-k value to use for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
                values in TGI.
            typical_p: the typical-p value to use for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            watermark: whether to add the watermark to the generated text. This argument
                is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
                the validation succeds.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        stop_sequences = self._check_stop_sequences(stop_sequences)

        if isinstance(input, str) or self.tokenizer_id is not None:
            structured_output = None
            if not isinstance(input, str):
                input, structured_output = self._get_structured_output(input)
                input = self.prepare_input(input)

            return await self._generate_with_text_generation(
                input=input,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_n_tokens=top_n_tokens,
                top_p=top_p,
                top_k=top_k,
                stop_sequences=stop_sequences,
                return_full_text=return_full_text,
                seed=seed,
                watermark=watermark,
                structured_output=structured_output,
            )

        return await self._generate_with_chat_completion(
            input=input,  # type: ignore
            max_new_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            presence_penalty=presence_penalty,
            seed=seed,
            stop_sequences=stop_sequences,
            temperature=temperature,
            tool_choice=tool_choice,
            tool_prompt=tool_prompt,
            tools=tools,
            top_logprobs=top_logprobs,
            top_p=top_p,
        )
only_one_of_model_id_endpoint_name_or_base_url_provided()

Validates that only one of model_id or endpoint_name is provided; and if base_url is also provided, a warning will be shown informing the user that the provided base_url will be ignored in favour of the dynamically calculated one..

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
@model_validator(mode="after")  # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
    self,
) -> "InferenceEndpointsLLM":
    """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
    provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
    favour of the dynamically calculated one.."""

    if self.base_url and (self.model_id or self.endpoint_name):
        self._logger.warning(  # type: ignore
            f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
            " or `endpoint_name` is also provided, the `base_url` will either be ignored"
            " or overwritten with the one generated from either of those args, for serverless"
            " or dedicated inference endpoints, respectively."
        )

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )

    if (
        self.model_id
        and self.tokenizer_id is None
        and self.structured_output is not None
    ):
        self.tokenizer_id = self.model_id

    if self.base_url and not (self.model_id or self.endpoint_name):
        return self

    if self.model_id and not self.endpoint_name:
        return self

    if self.endpoint_name and not self.model_id:
        return self

    raise ValidationError(
        f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
        f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
        f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
    )
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            conversation=input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
_get_structured_output(input)

Gets the structured output (if any) for the given input.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required

Returns:

Type Description
StandardInput

The input and the structured output that will be passed as grammar to the

Union[Dict[str, Any], None]

inference endpoint or None if not required.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
def _get_structured_output(
    self, input: FormattedInput
) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
    """Gets the structured output (if any) for the given input.

    Args:
        input: a single input in chat format to generate responses for.

    Returns:
        The input and the structured output that will be passed as `grammar` to the
        inference endpoint or `None` if not required.
    """
    structured_output = None

    # Specific structured output per input
    if isinstance(input, tuple):
        input, structured_output = input
        structured_output = {
            "type": structured_output["format"],  # type: ignore
            "value": structured_output["schema"],  # type: ignore
        }

    # Same structured output for all the inputs
    if structured_output is None and self.structured_output is not None:
        try:
            structured_output = {
                "type": self.structured_output["format"],  # type: ignore
                "value": self.structured_output["schema"],  # type: ignore
            }
        except KeyError as e:
            raise ValueError(
                "To use the structured output you have to inform the `format` and `schema` in "
                "the `structured_output` attribute."
            ) from e

    if structured_output:
        if isinstance(structured_output["value"], ModelMetaclass):
            structured_output["value"] = structured_output[
                "value"
            ].model_json_schema()

    return input, structured_output
_check_stop_sequences(stop_sequences=None)

Checks that no more than 4 stop sequences are provided.

Parameters:

Name Type Description Default
stop_sequences Optional[Union[str, List[str]]]

the stop sequences to be checked.

None

Returns:

Type Description
Union[List[str], None]

The stop sequences.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
def _check_stop_sequences(
    self,
    stop_sequences: Optional[Union[str, List[str]]] = None,
) -> Union[List[str], None]:
    """Checks that no more than 4 stop sequences are provided.

    Args:
        stop_sequences: the stop sequences to be checked.

    Returns:
        The stop sequences.
    """
    if stop_sequences is not None:
        if isinstance(stop_sequences, str):
            stop_sequences = [stop_sequences]
        if len(stop_sequences) > 4:
            warnings.warn(
                "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
                UserWarning,
                stacklevel=2,
            )
            stop_sequences = stop_sequences[:4]
    return stop_sequences
agenerate(input, max_new_tokens=128, frequency_penalty=None, logit_bias=None, logprobs=False, presence_penalty=None, seed=None, stop_sequences=None, temperature=1.0, tool_choice=None, tool_prompt=None, tools=None, top_logprobs=None, top_n_tokens=None, top_p=None, do_sample=False, repetition_penalty=None, return_full_text=False, top_k=None, typical_p=None, watermark=False, num_generations=1) async

Generates completions for the given input using the async client. This method uses two methods of the huggingface_hub.AsyncClient: chat_completion and text_generation. chat_completion method will be used only if no tokenizer_id has been specified. Some arguments of this function are specific to the text_generation method, while some others are specific to the chat_completion method.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty Optional[Annotated[float, Field(ge=-2.0, le=2.0)]]

a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing model's likelihood to repeat the same line verbatim. Defauls to None.

None
logit_bias Optional[List[float]]

modify the likelihood of specified tokens appearing in the completion. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
logprobs bool

whether to return the log probabilities or not. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to False.

False
presence_penalty Optional[Annotated[float, Field(ge=-2.0, le=2.0)]]

a value between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model likelihood to talk about new topics. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defauls to None.

None
seed Optional[int]

the seed to use for the generation. Defaults to None.

None
stop_sequences Optional[List[str]]

either a single string or a list of strings containing the sequences to stop the generation at. Defaults to None, but will be set to the tokenizer.eos_token if available.

None
temperature float

the temperature to use for the generation. Defaults to 1.0.

1.0
tool_choice Optional[Union[Dict[str, str], Literal['auto']]]

the name of the tool the model should call. It can be a dictionary like {"function_name": "my_tool"} or "auto". If not provided, then the model won't use any tool. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
tool_prompt Optional[str]

A prompt to be appended before the tools. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defauls to None.

None
tools Optional[List[Dict[str, Any]]]

a list of tools definitions that the LLM can use. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
top_logprobs Optional[PositiveInt]

the number of top log probabilities to return per output token generated. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
top_n_tokens Optional[PositiveInt]

the number of top log probabilities to return per output token generated. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

None
do_sample bool

whether to use sampling for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to False.

False
repetition_penalty Optional[float]

the repetition penalty to use for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

None
return_full_text bool

whether to return the full text of the completion or just the generated text. Defaults to False, meaning that only the generated text will be returned. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None.

False
top_k Optional[int]

the top-k value to use for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to 0.8, since neither 0.0 nor 1.0 are valid values in TGI.

None
typical_p Optional[float]

the typical-p value to use for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

None
watermark bool

whether to add the watermark to the generated text. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

False
num_generations int

the number of generations to generate. Defaults to 1. It's here to ensure the validation succeds.

1

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_new_tokens: int = 128,
    frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
    logit_bias: Optional[List[float]] = None,
    logprobs: bool = False,
    presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
    seed: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    temperature: float = 1.0,
    tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
    tool_prompt: Optional[str] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
    top_logprobs: Optional[PositiveInt] = None,
    top_n_tokens: Optional[PositiveInt] = None,
    top_p: Optional[float] = None,
    do_sample: bool = False,
    repetition_penalty: Optional[float] = None,
    return_full_text: bool = False,
    top_k: Optional[int] = None,
    typical_p: Optional[float] = None,
    watermark: bool = False,
    num_generations: int = 1,
) -> GenerateOutput:
    """Generates completions for the given input using the async client. This method
    uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
    `chat_completion` method will be used only if no `tokenizer_id` has been specified.
    Some arguments of this function are specific to the `text_generation` method, while
    some others are specific to the `chat_completion` method.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
            new tokens based on their existing frequency in the text so far, decreasing
            model's likelihood to repeat the same line verbatim. Defauls to `None`.
        logit_bias: modify the likelihood of specified tokens appearing in the completion.
            This argument is exclusive to the `chat_completion` method and will be used
            only if `tokenizer_id` is `None`.
            Defaults to `None`.
        logprobs: whether to return the log probabilities or not. This argument is exclusive
            to the `chat_completion` method and will be used only if `tokenizer_id`
            is `None`. Defaults to `False`.
        presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
            new tokens based on whether they appear in the text so far, increasing the
            model likelihood to talk about new topics. This argument is exclusive to
            the `chat_completion` method and will be used only if `tokenizer_id` is
            `None`. Defauls to `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: either a single string or a list of strings containing the sequences
            to stop the generation at. Defaults to `None`, but will be set to the
            `tokenizer.eos_token` if available.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        tool_choice: the name of the tool the model should call. It can be a dictionary
            like `{"function_name": "my_tool"}` or "auto". If not provided, then the
            model won't use any tool. This argument is exclusive to the `chat_completion`
            method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
        tool_prompt: A prompt to be appended before the tools. This argument is exclusive
            to the `chat_completion` method and will be used only if `tokenizer_id`
            is `None`. Defauls to `None`.
        tools: a list of tools definitions that the LLM can use.
            This argument is exclusive to the `chat_completion` method and will be used
            only if `tokenizer_id` is `None`. Defaults to `None`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. This argument is exclusive to the `chat_completion` method and
            will be used only if `tokenizer_id` is `None`. Defaults to `None`.
        top_n_tokens: the number of top log probabilities to return per output token
            generated. This argument is exclusive of the `text_generation` method and
            will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        do_sample: whether to use sampling for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id` is not
            `None`. Defaults to `False`.
        repetition_penalty: the repetition penalty to use for the generation. This argument
            is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        return_full_text: whether to return the full text of the completion or just
            the generated text. Defaults to `False`, meaning that only the generated
            text will be returned. This argument is exclusive of the `text_generation`
            method and will be only used if `tokenizer_id` is not `None`.
        top_k: the top-k value to use for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
            values in TGI.
        typical_p: the typical-p value to use for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        watermark: whether to add the watermark to the generated text. This argument
            is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
            the validation succeds.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    stop_sequences = self._check_stop_sequences(stop_sequences)

    if isinstance(input, str) or self.tokenizer_id is not None:
        structured_output = None
        if not isinstance(input, str):
            input, structured_output = self._get_structured_output(input)
            input = self.prepare_input(input)

        return await self._generate_with_text_generation(
            input=input,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            typical_p=typical_p,
            repetition_penalty=repetition_penalty,
            frequency_penalty=frequency_penalty,
            temperature=temperature,
            top_n_tokens=top_n_tokens,
            top_p=top_p,
            top_k=top_k,
            stop_sequences=stop_sequences,
            return_full_text=return_full_text,
            seed=seed,
            watermark=watermark,
            structured_output=structured_output,
        )

    return await self._generate_with_chat_completion(
        input=input,  # type: ignore
        max_new_tokens=max_new_tokens,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        logprobs=logprobs,
        presence_penalty=presence_penalty,
        seed=seed,
        stop_sequences=stop_sequences,
        temperature=temperature,
        tool_choice=tool_choice,
        tool_prompt=tool_prompt,
        tools=tools,
        top_logprobs=top_logprobs,
        top_p=top_p,
    )

TransformersLLM

Bases: LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin

Hugging Face transformers library LLM implementation using the text generation pipeline.

Attributes:

Name Type Description
model str

the model Hugging Face Hub repo id or a path to a directory containing the model weights and configuration files.

revision str

if model refers to a Hugging Face Hub repository, then the revision (e.g. a branch name or a commit id) to use. Defaults to "main".

torch_dtype str

the torch dtype to use for the model e.g. "float16", "float32", etc. Defaults to "auto".

trust_remote_code bool

whether to allow fetching and executing remote code fetched from the repository in the Hub. Defaults to False.

model_kwargs Optional[Dict[str, Any]]

additional dictionary of keyword arguments that will be passed to the from_pretrained method of the model.

tokenizer Optional[str]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer config files. If not provided, the one associated to the model will be used. Defaults to None.

use_fast bool

whether to use a fast tokenizer or not. Defaults to True.

chat_template Optional[str]

a chat template that will be used to build the prompts before sending them to the model. If not provided, the chat template defined in the tokenizer config will be used. If not provided and the tokenizer doesn't have a chat template, then ChatML template will be used. Defaults to None.

device Optional[Union[str, int]]

the name or index of the device where the model will be loaded. Defaults to None.

device_map Optional[Union[str, Dict[str, Any]]]

a dictionary mapping each layer of the model to a device, or a mode like "sequential" or "auto". Defaults to None.

token Optional[SecretStr]

the Hugging Face Hub token that will be used to authenticate to the Hugging Face Hub. If not provided, the HF_TOKEN environment or huggingface_hub package local configuration will be used. Defaults to None.

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

Icon

:hugging:

Examples:

Generate text:

from distilabel.models.llms import TransformersLLM

llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/huggingface/transformers.py
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
    """Hugging Face `transformers` library LLM implementation using the text generation
    pipeline.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        revision: if `model` refers to a Hugging Face Hub repository, then the revision
            (e.g. a branch name or a commit id) to use. Defaults to `"main"`.
        torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc.
            Defaults to `"auto"`.
        trust_remote_code: whether to allow fetching and executing remote code fetched
            from the repository in the Hub. Defaults to `False`.
        model_kwargs: additional dictionary of keyword arguments that will be passed to
            the `from_pretrained` method of the model.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_fast: whether to use a fast tokenizer or not. Defaults to `True`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        device: the name or index of the device where the model will be loaded. Defaults
            to `None`.
        device_map: a dictionary mapping each layer of the model to a device, or a mode
            like `"sequential"` or `"auto"`. Defaults to `None`.
        token: the Hugging Face Hub token that will be used to authenticate to the Hugging
            Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
            local configuration will be used. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    Icon:
        `:hugging:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import TransformersLLM

        llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str
    revision: str = "main"
    torch_dtype: str = "auto"
    trust_remote_code: bool = False
    model_kwargs: Optional[Dict[str, Any]] = None
    tokenizer: Optional[str] = None
    use_fast: bool = True
    chat_template: Optional[str] = None
    device: Optional[Union[str, int]] = None
    device_map: Optional[Union[str, Dict[str, Any]]] = None
    token: Optional[SecretStr] = Field(
        default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR)
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )

    _pipeline: Optional["Pipeline"] = PrivateAttr(...)
    _prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
    _logits_processor: Union[Callable, None] = PrivateAttr(default=None)

    def load(self) -> None:
        """Loads the model and tokenizer and creates the text generation pipeline. In addition,
        it will configure the tokenizer chat template."""
        if self.device == "cuda":
            CudaDevicePlacementMixin.load(self)

        try:
            from transformers import pipeline
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie

        token = self.token.get_secret_value() if self.token is not None else self.token

        self._pipeline = pipeline(
            "text-generation",
            model=self.model,
            revision=self.revision,
            torch_dtype=self.torch_dtype,
            trust_remote_code=self.trust_remote_code,
            model_kwargs=self.model_kwargs or {},
            tokenizer=self.tokenizer or self.model,
            use_fast=self.use_fast,
            device=self.device,
            device_map=self.device_map,
            token=token,
            return_full_text=False,
        )

        if self.chat_template is not None:
            self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore

        if self._pipeline.tokenizer.pad_token is None:  # type: ignore
            self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token  # type: ignore

        if self.structured_output:
            processor = self._prepare_structured_output(self.structured_output)
            if _is_outlines_version_below_0_1_0():
                self._prefix_allowed_tokens_fn = processor
            else:
                self._logits_processor = [processor]

        super().load()

    def unload(self) -> None:
        """Unloads the `vLLM` model."""
        CudaDevicePlacementMixin.unload(self)
        super().unload()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if self._pipeline.tokenizer.chat_template is None:  # type: ignore
            return input[0]["content"]

        prompt: str = (
            self._pipeline.tokenizer.apply_chat_template(  # type: ignore
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[StandardInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        temperature: float = 0.1,
        repetition_penalty: float = 1.1,
        top_p: float = 1.0,
        top_k: int = 0,
        do_sample: bool = True,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            repetition_penalty: the repetition penalty to use for the generation. Defaults
                to `1.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            do_sample: whether to use sampling or not. Defaults to `True`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        prepared_inputs = [self.prepare_input(input=input) for input in inputs]

        outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
            prepared_inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            top_p=top_p,
            top_k=top_k,
            do_sample=do_sample,
            num_return_sequences=num_generations,
            prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
            pad_token_id=self._pipeline.tokenizer.eos_token_id,
            logits_processor=self._logits_processor,
        )
        llm_output = [
            [generation["generated_text"] for generation in output]
            for output in outputs
        ]

        result = []
        for input, output in zip(inputs, llm_output):
            result.append(
                prepare_output(
                    output,
                    input_tokens=[
                        compute_tokens(input, self._pipeline.tokenizer.encode)
                    ],
                    output_tokens=[
                        compute_tokens(row, self._pipeline.tokenizer.encode)
                        for row in output
                    ],
                )
            )

        return result

    def get_last_hidden_states(
        self, inputs: List["StandardInput"]
    ) -> List["HiddenState"]:
        """Gets the last `hidden_states` of the model for the given inputs. It doesn't
        execute the task head.

        Args:
            inputs: a list of inputs in chat format to generate the embeddings for.

        Returns:
            A list containing the last hidden state for each sequence using a NumPy array
            with shape [num_tokens, hidden_size].
        """
        model: "PreTrainedModel" = (
            self._pipeline.model.model  # type: ignore
            if hasattr(self._pipeline.model, "model")  # type: ignore
            else next(self._pipeline.model.children())  # type: ignore
        )
        tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
        input_ids = tokenizer(
            [self.prepare_input(input) for input in inputs],  # type: ignore
            return_tensors="pt",
            padding=True,
        ).to(model.device)
        last_hidden_states = model(**input_ids)["last_hidden_state"]

        return [
            seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
            for seq_last_hidden_state, attention_mask in zip(
                last_hidden_states,
                input_ids["attention_mask"],  # type: ignore
            )
        ]

    def _prepare_structured_output(
        self, structured_output: Optional[OutlinesStructuredOutputType] = None
    ) -> Union[Callable, List[Callable]]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        result = prepare_guided_output(
            structured_output, "transformers", self._pipeline
        )
        if schema := result.get("schema"):
            self.structured_output["schema"] = schema
        return result["processor"]
model_name property

Returns the model name used for the LLM.

load()

Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def load(self) -> None:
    """Loads the model and tokenizer and creates the text generation pipeline. In addition,
    it will configure the tokenizer chat template."""
    if self.device == "cuda":
        CudaDevicePlacementMixin.load(self)

    try:
        from transformers import pipeline
    except ImportError as ie:
        raise ImportError(
            "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
        ) from ie

    token = self.token.get_secret_value() if self.token is not None else self.token

    self._pipeline = pipeline(
        "text-generation",
        model=self.model,
        revision=self.revision,
        torch_dtype=self.torch_dtype,
        trust_remote_code=self.trust_remote_code,
        model_kwargs=self.model_kwargs or {},
        tokenizer=self.tokenizer or self.model,
        use_fast=self.use_fast,
        device=self.device,
        device_map=self.device_map,
        token=token,
        return_full_text=False,
    )

    if self.chat_template is not None:
        self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore

    if self._pipeline.tokenizer.pad_token is None:  # type: ignore
        self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token  # type: ignore

    if self.structured_output:
        processor = self._prepare_structured_output(self.structured_output)
        if _is_outlines_version_below_0_1_0():
            self._prefix_allowed_tokens_fn = processor
        else:
            self._logits_processor = [processor]

    super().load()
unload()

Unloads the vLLM model.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def unload(self) -> None:
    """Unloads the `vLLM` model."""
    CudaDevicePlacementMixin.unload(self)
    super().unload()
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if self._pipeline.tokenizer.chat_template is None:  # type: ignore
        return input[0]["content"]

    prompt: str = (
        self._pipeline.tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_new_tokens=128, temperature=0.1, repetition_penalty=1.1, top_p=1.0, top_k=0, do_sample=True)

Generates num_generations responses for each input using the text generation pipeline.

Parameters:

Name Type Description Default
inputs List[StandardInput]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
temperature float

the temperature to use for the generation. Defaults to 0.1.

0.1
repetition_penalty float

the repetition penalty to use for the generation. Defaults to 1.1.

1.1
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
top_k int

the top-k value to use for the generation. Defaults to 0.

0
do_sample bool

whether to use sampling or not. Defaults to True.

True

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/huggingface/transformers.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[StandardInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    temperature: float = 0.1,
    repetition_penalty: float = 1.1,
    top_p: float = 1.0,
    top_k: int = 0,
    do_sample: bool = True,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        repetition_penalty: the repetition penalty to use for the generation. Defaults
            to `1.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        do_sample: whether to use sampling or not. Defaults to `True`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    prepared_inputs = [self.prepare_input(input=input) for input in inputs]

    outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
        prepared_inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        top_p=top_p,
        top_k=top_k,
        do_sample=do_sample,
        num_return_sequences=num_generations,
        prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
        pad_token_id=self._pipeline.tokenizer.eos_token_id,
        logits_processor=self._logits_processor,
    )
    llm_output = [
        [generation["generated_text"] for generation in output]
        for output in outputs
    ]

    result = []
    for input, output in zip(inputs, llm_output):
        result.append(
            prepare_output(
                output,
                input_tokens=[
                    compute_tokens(input, self._pipeline.tokenizer.encode)
                ],
                output_tokens=[
                    compute_tokens(row, self._pipeline.tokenizer.encode)
                    for row in output
                ],
            )
        )

    return result
get_last_hidden_states(inputs)

Gets the last hidden_states of the model for the given inputs. It doesn't execute the task head.

Parameters:

Name Type Description Default
inputs List[StandardInput]

a list of inputs in chat format to generate the embeddings for.

required

Returns:

Type Description
List[HiddenState]

A list containing the last hidden state for each sequence using a NumPy array

List[HiddenState]

with shape [num_tokens, hidden_size].

Source code in src/distilabel/models/llms/huggingface/transformers.py
def get_last_hidden_states(
    self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
    """Gets the last `hidden_states` of the model for the given inputs. It doesn't
    execute the task head.

    Args:
        inputs: a list of inputs in chat format to generate the embeddings for.

    Returns:
        A list containing the last hidden state for each sequence using a NumPy array
        with shape [num_tokens, hidden_size].
    """
    model: "PreTrainedModel" = (
        self._pipeline.model.model  # type: ignore
        if hasattr(self._pipeline.model, "model")  # type: ignore
        else next(self._pipeline.model.children())  # type: ignore
    )
    tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
    input_ids = tokenizer(
        [self.prepare_input(input) for input in inputs],  # type: ignore
        return_tensors="pt",
        padding=True,
    ).to(model.device)
    last_hidden_states = model(**input_ids)["last_hidden_state"]

    return [
        seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
        for seq_last_hidden_state, attention_mask in zip(
            last_hidden_states,
            input_ids["attention_mask"],  # type: ignore
        )
    ]
_prepare_structured_output(structured_output=None)

Creates the appropriate function to filter tokens to generate structured outputs.

Parameters:

Name Type Description Default
structured_output Optional[OutlinesStructuredOutputType]

the configuration dict to prepare the structured output.

None

Returns:

Type Description
Union[Callable, List[Callable]]

The callable that will be used to guide the generation of the model.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def _prepare_structured_output(
    self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, List[Callable]]:
    """Creates the appropriate function to filter tokens to generate structured outputs.

    Args:
        structured_output: the configuration dict to prepare the structured output.

    Returns:
        The callable that will be used to guide the generation of the model.
    """
    from distilabel.steps.tasks.structured_outputs.outlines import (
        prepare_guided_output,
    )

    result = prepare_guided_output(
        structured_output, "transformers", self._pipeline
    )
    if schema := result.get("schema"):
        self.structured_output["schema"] = schema
    return result["processor"]

LiteLLM

Bases: AsyncLLM

LiteLLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large", etc.

verbose RuntimeParameter[bool]

whether to log the LiteLLM client's logs. Defaults to False.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

Runtime parameters
  • verbose: whether to log the LiteLLM client's logs. Defaults to False.

Examples:

Generate text:

from distilabel.models.llms import LiteLLM

llm = LiteLLM(model="gpt-3.5-turbo")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import LiteLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = LiteLLM(
    model="gpt-3.5-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/litellm.py
class LiteLLM(AsyncLLM):
    """LiteLLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large",
            etc.
        verbose: whether to log the LiteLLM client's logs. Defaults to `False`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.

    Runtime parameters:
        - `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import LiteLLM

        llm = LiteLLM(model="gpt-3.5-turbo")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import LiteLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = LiteLLM(
            model="gpt-3.5-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    verbose: RuntimeParameter[bool] = Field(
        default=False, description="Whether to log the LiteLLM client's logs."
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _aclient: Optional[Callable] = PrivateAttr(...)

    def load(self) -> None:
        """
        Loads the `acompletion` LiteLLM client to benefit from async requests.
        """
        super().load()

        try:
            import litellm

            litellm.telemetry = False
        except ImportError as e:
            raise ImportError(
                "LiteLLM Python client is not installed. Please install it using"
                " `pip install 'distilabel[litellm]'`."
            ) from e
        self._aclient = litellm.acompletion

        if not self.verbose:
            litellm.suppress_debug_info = True
            for key in logging.Logger.manager.loggerDict.keys():
                if "litellm" not in key.lower():
                    continue
                logging.getLogger(key).setLevel(logging.CRITICAL)

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="litellm",
            )
            self._aclient = result.get("client")
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore # noqa: C901
        self,
        input: FormattedInput,
        num_generations: int = 1,
        functions: Optional[List] = None,
        function_call: Optional[str] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
        stop: Optional[Union[str, list]] = None,
        max_tokens: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[dict] = None,
        user: Optional[str] = None,
        metadata: Optional[dict] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        api_key: Optional[str] = None,
        model_list: Optional[list] = None,
        mock_response: Optional[str] = None,
        force_timeout: Optional[int] = 600,
        custom_llm_provider: Optional[str] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            functions: a list of functions to apply to the conversation messages. Defaults to
                `None`.
            function_call: the name of the function to call within the conversation. Defaults
                to `None`.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: Up to 4 sequences where the LLM API will stop generating further tokens.
                Defaults to `None`.
            max_tokens: The maximum number of tokens in the generated completion. Defaults to
                `None`.
            presence_penalty: It is used to penalize new tokens based on their existence in the
                text so far. Defaults to `None`.
            frequency_penalty: It is used to penalize new tokens based on their frequency in the
                text so far. Defaults to `None`.
            logit_bias: Used to modify the probability of specific tokens appearing in the
                completion. Defaults to `None`.
            user: A unique identifier representing your end-user. This can help the LLM provider
                to monitor and detect abuse. Defaults to `None`.
            metadata: Pass in additional metadata to tag your completion calls - eg. prompt
                version, details, etc. Defaults to `None`.
            api_base: Base URL for the API. Defaults to `None`.
            api_version: API version. Defaults to `None`.
            api_key: API key. Defaults to `None`.
            model_list: List of api base, version, keys. Defaults to `None`.
            mock_response: If provided, return a mock completion response for testing or debugging
                purposes. Defaults to `None`.
            force_timeout: The maximum execution time in seconds for the completion request.
                Defaults to `600`.
            custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
                model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
                `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        import litellm
        from litellm import token_counter

        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="litellm",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "model": self.model,
            "messages": input,
            "n": num_generations,
            "functions": functions,
            "function_call": function_call,
            "temperature": temperature,
            "top_p": top_p,
            "stream": False,
            "stop": stop,
            "max_tokens": max_tokens,
            "presence_penalty": presence_penalty,
            "frequency_penalty": frequency_penalty,
            "logit_bias": logit_bias,
            "user": user,
            "metadata": metadata,
            "api_base": api_base,
            "api_version": api_version,
            "api_key": api_key,
            "model_list": model_list,
            "mock_response": mock_response,
            "force_timeout": force_timeout,
            "custom_llm_provider": custom_llm_provider,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        async def _call_aclient_until_n_choices() -> List["Choices"]:
            choices = []
            while len(choices) < num_generations:
                completion = await self._aclient(**kwargs)  # type: ignore
                if not self.structured_output:
                    completion = completion.choices
                choices.extend(completion)
            return choices

        # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
        try:
            litellm.drop_params = False
            choices = await _call_aclient_until_n_choices()
        except litellm.exceptions.APIError as e:
            if "does not support parameters" in str(e):
                litellm.drop_params = True
                choices = await _call_aclient_until_n_choices()
            else:
                raise e

        generations = []
        input_tokens = [
            token_counter(model=self.model, messages=input)
        ] * num_generations
        output_tokens = []

        if self.structured_output:
            for choice in choices:
                generations.append(choice.model_dump_json())
                output_tokens.append(
                    token_counter(
                        model=self.model,
                        text=orjson.dumps(choice.model_dump_json()).decode("utf-8"),
                    )
                )
            return prepare_output(
                generations,
                input_tokens=input_tokens,
                output_tokens=output_tokens,
            )

        for choice in choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using LiteLLM client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
            output_tokens.append(token_counter(model=self.model, text=content))

        return prepare_output(
            generations, input_tokens=input_tokens, output_tokens=output_tokens
        )
model_name property

Returns the model name used for the LLM.

load()

Loads the acompletion LiteLLM client to benefit from async requests.

Source code in src/distilabel/models/llms/litellm.py
def load(self) -> None:
    """
    Loads the `acompletion` LiteLLM client to benefit from async requests.
    """
    super().load()

    try:
        import litellm

        litellm.telemetry = False
    except ImportError as e:
        raise ImportError(
            "LiteLLM Python client is not installed. Please install it using"
            " `pip install 'distilabel[litellm]'`."
        ) from e
    self._aclient = litellm.acompletion

    if not self.verbose:
        litellm.suppress_debug_info = True
        for key in logging.Logger.manager.loggerDict.keys():
            if "litellm" not in key.lower():
                continue
            logging.getLogger(key).setLevel(logging.CRITICAL)

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="litellm",
        )
        self._aclient = result.get("client")
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, num_generations=1, functions=None, function_call=None, temperature=1.0, top_p=1.0, stop=None, max_tokens=None, presence_penalty=None, frequency_penalty=None, logit_bias=None, user=None, metadata=None, api_base=None, api_version=None, api_key=None, model_list=None, mock_response=None, force_timeout=600, custom_llm_provider=None) async

Generates num_generations responses for the given input using the LiteLLM async client.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
functions Optional[List]

a list of functions to apply to the conversation messages. Defaults to None.

None
function_call Optional[str]

the name of the function to call within the conversation. Defaults to None.

None
temperature Optional[float]

the temperature to use for the generation. Defaults to 1.0.

1.0
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[Union[str, list]]

Up to 4 sequences where the LLM API will stop generating further tokens. Defaults to None.

None
max_tokens Optional[int]

The maximum number of tokens in the generated completion. Defaults to None.

None
presence_penalty Optional[float]

It is used to penalize new tokens based on their existence in the text so far. Defaults to None.

None
frequency_penalty Optional[float]

It is used to penalize new tokens based on their frequency in the text so far. Defaults to None.

None
logit_bias Optional[dict]

Used to modify the probability of specific tokens appearing in the completion. Defaults to None.

None
user Optional[str]

A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. Defaults to None.

None
metadata Optional[dict]

Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. Defaults to None.

None
api_base Optional[str]

Base URL for the API. Defaults to None.

None
api_version Optional[str]

API version. Defaults to None.

None
api_key Optional[str]

API key. Defaults to None.

None
model_list Optional[list]

List of api base, version, keys. Defaults to None.

None
mock_response Optional[str]

If provided, return a mock completion response for testing or debugging purposes. Defaults to None.

None
force_timeout Optional[int]

The maximum execution time in seconds for the completion request. Defaults to 600.

600
custom_llm_provider Optional[str]

Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable) model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/litellm.py
@validate_call
async def agenerate(  # type: ignore # noqa: C901
    self,
    input: FormattedInput,
    num_generations: int = 1,
    functions: Optional[List] = None,
    function_call: Optional[str] = None,
    temperature: Optional[float] = 1.0,
    top_p: Optional[float] = 1.0,
    stop: Optional[Union[str, list]] = None,
    max_tokens: Optional[int] = None,
    presence_penalty: Optional[float] = None,
    frequency_penalty: Optional[float] = None,
    logit_bias: Optional[dict] = None,
    user: Optional[str] = None,
    metadata: Optional[dict] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    api_key: Optional[str] = None,
    model_list: Optional[list] = None,
    mock_response: Optional[str] = None,
    force_timeout: Optional[int] = 600,
    custom_llm_provider: Optional[str] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        functions: a list of functions to apply to the conversation messages. Defaults to
            `None`.
        function_call: the name of the function to call within the conversation. Defaults
            to `None`.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: Up to 4 sequences where the LLM API will stop generating further tokens.
            Defaults to `None`.
        max_tokens: The maximum number of tokens in the generated completion. Defaults to
            `None`.
        presence_penalty: It is used to penalize new tokens based on their existence in the
            text so far. Defaults to `None`.
        frequency_penalty: It is used to penalize new tokens based on their frequency in the
            text so far. Defaults to `None`.
        logit_bias: Used to modify the probability of specific tokens appearing in the
            completion. Defaults to `None`.
        user: A unique identifier representing your end-user. This can help the LLM provider
            to monitor and detect abuse. Defaults to `None`.
        metadata: Pass in additional metadata to tag your completion calls - eg. prompt
            version, details, etc. Defaults to `None`.
        api_base: Base URL for the API. Defaults to `None`.
        api_version: API version. Defaults to `None`.
        api_key: API key. Defaults to `None`.
        model_list: List of api base, version, keys. Defaults to `None`.
        mock_response: If provided, return a mock completion response for testing or debugging
            purposes. Defaults to `None`.
        force_timeout: The maximum execution time in seconds for the completion request.
            Defaults to `600`.
        custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
            model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
            `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    import litellm
    from litellm import token_counter

    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="litellm",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "model": self.model,
        "messages": input,
        "n": num_generations,
        "functions": functions,
        "function_call": function_call,
        "temperature": temperature,
        "top_p": top_p,
        "stream": False,
        "stop": stop,
        "max_tokens": max_tokens,
        "presence_penalty": presence_penalty,
        "frequency_penalty": frequency_penalty,
        "logit_bias": logit_bias,
        "user": user,
        "metadata": metadata,
        "api_base": api_base,
        "api_version": api_version,
        "api_key": api_key,
        "model_list": model_list,
        "mock_response": mock_response,
        "force_timeout": force_timeout,
        "custom_llm_provider": custom_llm_provider,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    async def _call_aclient_until_n_choices() -> List["Choices"]:
        choices = []
        while len(choices) < num_generations:
            completion = await self._aclient(**kwargs)  # type: ignore
            if not self.structured_output:
                completion = completion.choices
            choices.extend(completion)
        return choices

    # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
    try:
        litellm.drop_params = False
        choices = await _call_aclient_until_n_choices()
    except litellm.exceptions.APIError as e:
        if "does not support parameters" in str(e):
            litellm.drop_params = True
            choices = await _call_aclient_until_n_choices()
        else:
            raise e

    generations = []
    input_tokens = [
        token_counter(model=self.model, messages=input)
    ] * num_generations
    output_tokens = []

    if self.structured_output:
        for choice in choices:
            generations.append(choice.model_dump_json())
            output_tokens.append(
                token_counter(
                    model=self.model,
                    text=orjson.dumps(choice.model_dump_json()).decode("utf-8"),
                )
            )
        return prepare_output(
            generations,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
        )

    for choice in choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using LiteLLM client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
        output_tokens.append(token_counter(model=self.model, text=content))

    return prepare_output(
        generations, input_tokens=input_tokens, output_tokens=output_tokens
    )

LlamaCppLLM

Bases: LLM, MagpieChatTemplateMixin

llama.cpp LLM implementation running the Python bindings for the C++ code.

Attributes:

Name Type Description
model_path RuntimeParameter[FilePath]

contains the path to the GGUF quantized model, compatible with the installed version of the llama.cpp Python bindings.

n_gpu_layers RuntimeParameter[int]

the number of layers to use for the GPU. Defaults to -1, meaning that the available GPU device will be used.

chat_format Optional[RuntimeParameter[str]]

the chat format to use for the model. Defaults to None, which means the Llama format will be used.

n_ctx int

the context size to use for the model. Defaults to 512.

n_batch int

the prompt processing maximum batch size to use for the model. Defaults to 512.

seed int

random seed to use for the generation. Defaults to 4294967295.

verbose RuntimeParameter[bool]

whether to print verbose output. Defaults to False.

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

extra_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

additional dictionary of keyword arguments that will be passed to the Llama class of llama_cpp library. Defaults to {}.

tokenizer_id Optional[RuntimeParameter[str]]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer config files. If not provided, the one associated to the model will be used. Defaults to None.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

_model Optional[Llama]

the Llama model instance. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

Runtime parameters
  • model_path: the path to the GGUF quantized model.
  • n_gpu_layers: the number of layers to use for the GPU. Defaults to -1.
  • chat_format: the chat format to use for the model. Defaults to None.
  • verbose: whether to print verbose output. Defaults to False.
  • extra_kwargs: additional dictionary of keyword arguments that will be passed to the Llama class of llama_cpp library. Defaults to {}.
References

Examples:

Generate text:

from pathlib import Path
from distilabel.models.llms import LlamaCppLLM

# You can follow along this example downloading the following model running the following
# command in the terminal, that will download the model to the `Downloads` folder:
# curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf

model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

llm = LlamaCppLLM(
    model_path=str(Path.home() / model_path),
    n_gpu_layers=-1,  # To use the GPU if available
    n_ctx=1024,       # Set the context size
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pathlib import Path
from distilabel.models.llms import LlamaCppLLM

model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = LlamaCppLLM(
    model_path=str(Path.home() / model_path),  # type: ignore
    n_gpu_layers=-1,
    n_ctx=1024,
    structured_output={"format": "json", "schema": Character},
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/llamacpp.py
class LlamaCppLLM(LLM, MagpieChatTemplateMixin):
    """llama.cpp LLM implementation running the Python bindings for the C++ code.

    Attributes:
        model_path: contains the path to the GGUF quantized model, compatible with the
            installed version of the `llama.cpp` Python bindings.
        n_gpu_layers: the number of layers to use for the GPU. Defaults to `-1`, meaning that
            the available GPU device will be used.
        chat_format: the chat format to use for the model. Defaults to `None`, which means the
            Llama format will be used.
        n_ctx: the context size to use for the model. Defaults to `512`.
        n_batch: the prompt processing maximum batch size to use for the model. Defaults to `512`.
        seed: random seed to use for the generation. Defaults to `4294967295`.
        verbose: whether to print verbose output. Defaults to `False`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        extra_kwargs: additional dictionary of keyword arguments that will be passed to the
            `Llama` class of `llama_cpp` library. Defaults to `{}`.
        tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        _model: the Llama model instance. This attribute is meant to be used internally and
            should not be accessed directly. It will be set in the `load` method.

    Runtime parameters:
        - `model_path`: the path to the GGUF quantized model.
        - `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`.
        - `chat_format`: the chat format to use for the model. Defaults to `None`.
        - `verbose`: whether to print verbose output. Defaults to `False`.
        - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the
            `Llama` class of `llama_cpp` library. Defaults to `{}`.

    References:
        - [`llama.cpp`](https://github.com/ggerganov/llama.cpp)
        - [`llama-cpp-python`](https://github.com/abetlen/llama-cpp-python)

    Examples:
        Generate text:

        ```python
        from pathlib import Path
        from distilabel.models.llms import LlamaCppLLM

        # You can follow along this example downloading the following model running the following
        # command in the terminal, that will download the model to the `Downloads` folder:
        # curl -L -o ~/Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q4_K_M.gguf

        model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

        llm = LlamaCppLLM(
            model_path=str(Path.home() / model_path),
            n_gpu_layers=-1,  # To use the GPU if available
            n_ctx=1024,       # Set the context size
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pathlib import Path
        from distilabel.models.llms import LlamaCppLLM

        model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = LlamaCppLLM(
            model_path=str(Path.home() / model_path),  # type: ignore
            n_gpu_layers=-1,
            n_ctx=1024,
            structured_output={"format": "json", "schema": Character},
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model_path: RuntimeParameter[FilePath] = Field(
        default=None, description="The path to the GGUF quantized model.", exclude=True
    )
    n_gpu_layers: RuntimeParameter[int] = Field(
        default=-1,
        description="The number of layers that will be loaded in the GPU.",
    )
    chat_format: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The chat format to use for the model. Defaults to `None`, which means the Llama format will be used.",
    )

    n_ctx: int = 512
    n_batch: int = 512
    seed: int = 4294967295
    verbose: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether to print verbose output from llama.cpp library.",
    )
    extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="Additional dictionary of keyword arguments that will be passed to the"
        " `Llama` class of `llama_cpp` library. See all the supported arguments at: "
        "https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__",
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )
    tokenizer_id: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The Hugging Face Hub repo id or a path to a directory containing"
        " the tokenizer config files. If not provided, the one associated to the `model`"
        " will be used.",
    )
    _logits_processor: Optional["LogitsProcessorList"] = PrivateAttr(default=None)
    _model: Optional["Llama"] = PrivateAttr(...)

    @model_validator(mode="after")
    def validate_magpie_usage(
        self,
    ) -> "LlamaCppLLM":
        """Validates that magpie usage is valid."""

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

    def load(self) -> None:
        """Loads the `Llama` model from the `model_path`."""
        try:
            from llama_cpp import Llama
        except ImportError as ie:
            raise ImportError(
                "The `llama_cpp` package is required to use the `LlamaCppLLM` class."
            ) from ie

        self._model = Llama(
            model_path=self.model_path.as_posix(),
            seed=self.seed,
            n_ctx=self.n_ctx,
            n_batch=self.n_batch,
            chat_format=self.chat_format,
            n_gpu_layers=self.n_gpu_layers,
            verbose=self.verbose,
            **self.extra_kwargs,
        )

        if self.structured_output:
            self._logits_processor = self._prepare_structured_output(
                self.structured_output
            )

        if self.use_magpie_template or self.magpie_pre_query_template:
            if not self.tokenizer_id:
                raise ValueError(
                    "The Hugging Face Hub repo id or a path to a directory containing"
                    " the tokenizer config files is required when using the `use_magpie_template`"
                    " or `magpie_pre_query_template` runtime parameters."
                )

        if self.tokenizer_id:
            try:
                from transformers import AutoTokenizer
            except ImportError as ie:
                raise ImportError(
                    "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
                ) from ie
            self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
            if self._tokenizer.chat_template is None:
                raise ValueError(
                    "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
                )

        # NOTE: Here because of the custom `logging` interface used, since it will create the logging name
        # out of the model name, which won't be available until the `Llama` instance is created.
        super().load()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self._model.model_path  # type: ignore

    def _generate_chat_completion(
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_generation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> "CreateChatCompletionResponse":
        return self._model.create_chat_completion(  # type: ignore
            messages=input,  # type: ignore
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            logits_processor=self._logits_processor,
            **(extra_generation_kwargs or {}),
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                conversation=input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _generate_with_text_generation(
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_generation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> "CreateChatCompletionResponse":
        prompt = self.prepare_input(input)
        return self._model.create_completion(
            prompt=prompt,
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            logits_processor=self._logits_processor,
            **(extra_generation_kwargs or {}),
        )

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[FormattedInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_generation_kwargs: Optional[Dict[str, Any]] = None,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for the given input using the Llama model.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            extra_generation_kwargs: dictionary with additional arguments to be passed to
                the `create_chat_completion` method. Reference at
                https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        structured_output = None
        batch_outputs = []
        for input in inputs:
            if isinstance(input, tuple):
                input, structured_output = input
            elif self.structured_output:
                structured_output = self.structured_output

            outputs = []
            output_tokens = []
            for _ in range(num_generations):
                # NOTE(plaguss): There seems to be a bug in how the logits processor
                # is used. Basically it consumes the FSM internally, and it isn't reinitialized
                # after each generation, so subsequent calls yield nothing. This is a workaround
                # until is fixed in the `llama_cpp` or `outlines` libraries.
                if structured_output:
                    self._logits_processor = self._prepare_structured_output(
                        structured_output
                    )
                if self.tokenizer_id is None:
                    completion = self._generate_chat_completion(
                        input,
                        max_new_tokens,
                        frequency_penalty,
                        presence_penalty,
                        temperature,
                        top_p,
                        extra_generation_kwargs,
                    )
                    outputs.append(completion["choices"][0]["message"]["content"])
                    output_tokens.append(completion["usage"]["completion_tokens"])
                else:
                    completion: "CreateChatCompletionResponse" = (
                        self._generate_with_text_generation(  # type: ignore
                            input,
                            max_new_tokens,
                            frequency_penalty,
                            presence_penalty,
                            temperature,
                            top_p,
                            extra_generation_kwargs,
                        )
                    )
                    outputs.append(completion["choices"][0]["text"])
                    output_tokens.append(completion["usage"]["completion_tokens"])
            batch_outputs.append(
                prepare_output(
                    outputs,
                    input_tokens=[completion["usage"]["prompt_tokens"]]
                    * num_generations,
                    output_tokens=output_tokens,
                )
            )

        return batch_outputs

    def _prepare_structured_output(
        self, structured_output: Optional[OutlinesStructuredOutputType] = None
    ) -> Union["LogitsProcessorList", "LogitsProcessor"]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        result = prepare_guided_output(structured_output, "llamacpp", self._model)
        if (schema := result.get("schema")) and self.structured_output:
            self.structured_output["schema"] = schema
        return [result["processor"]]
model_name property

Returns the model name used for the LLM.

validate_magpie_usage()

Validates that magpie usage is valid.

Source code in src/distilabel/models/llms/llamacpp.py
@model_validator(mode="after")
def validate_magpie_usage(
    self,
) -> "LlamaCppLLM":
    """Validates that magpie usage is valid."""

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )
load()

Loads the Llama model from the model_path.

Source code in src/distilabel/models/llms/llamacpp.py
def load(self) -> None:
    """Loads the `Llama` model from the `model_path`."""
    try:
        from llama_cpp import Llama
    except ImportError as ie:
        raise ImportError(
            "The `llama_cpp` package is required to use the `LlamaCppLLM` class."
        ) from ie

    self._model = Llama(
        model_path=self.model_path.as_posix(),
        seed=self.seed,
        n_ctx=self.n_ctx,
        n_batch=self.n_batch,
        chat_format=self.chat_format,
        n_gpu_layers=self.n_gpu_layers,
        verbose=self.verbose,
        **self.extra_kwargs,
    )

    if self.structured_output:
        self._logits_processor = self._prepare_structured_output(
            self.structured_output
        )

    if self.use_magpie_template or self.magpie_pre_query_template:
        if not self.tokenizer_id:
            raise ValueError(
                "The Hugging Face Hub repo id or a path to a directory containing"
                " the tokenizer config files is required when using the `use_magpie_template`"
                " or `magpie_pre_query_template` runtime parameters."
            )

    if self.tokenizer_id:
        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie
        self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
        if self._tokenizer.chat_template is None:
            raise ValueError(
                "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
            )

    # NOTE: Here because of the custom `logging` interface used, since it will create the logging name
    # out of the model name, which won't be available until the `Llama` instance is created.
    super().load()
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/llamacpp.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            conversation=input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, extra_generation_kwargs=None)

Generates num_generations responses for the given input using the Llama model.

Parameters:

Name Type Description Default
inputs List[FormattedInput]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
extra_generation_kwargs Optional[Dict[str, Any]]

dictionary with additional arguments to be passed to the create_chat_completion method. Reference at https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion

None

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/llamacpp.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[FormattedInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    extra_generation_kwargs: Optional[Dict[str, Any]] = None,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for the given input using the Llama model.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        extra_generation_kwargs: dictionary with additional arguments to be passed to
            the `create_chat_completion` method. Reference at
            https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    structured_output = None
    batch_outputs = []
    for input in inputs:
        if isinstance(input, tuple):
            input, structured_output = input
        elif self.structured_output:
            structured_output = self.structured_output

        outputs = []
        output_tokens = []
        for _ in range(num_generations):
            # NOTE(plaguss): There seems to be a bug in how the logits processor
            # is used. Basically it consumes the FSM internally, and it isn't reinitialized
            # after each generation, so subsequent calls yield nothing. This is a workaround
            # until is fixed in the `llama_cpp` or `outlines` libraries.
            if structured_output:
                self._logits_processor = self._prepare_structured_output(
                    structured_output
                )
            if self.tokenizer_id is None:
                completion = self._generate_chat_completion(
                    input,
                    max_new_tokens,
                    frequency_penalty,
                    presence_penalty,
                    temperature,
                    top_p,
                    extra_generation_kwargs,
                )
                outputs.append(completion["choices"][0]["message"]["content"])
                output_tokens.append(completion["usage"]["completion_tokens"])
            else:
                completion: "CreateChatCompletionResponse" = (
                    self._generate_with_text_generation(  # type: ignore
                        input,
                        max_new_tokens,
                        frequency_penalty,
                        presence_penalty,
                        temperature,
                        top_p,
                        extra_generation_kwargs,
                    )
                )
                outputs.append(completion["choices"][0]["text"])
                output_tokens.append(completion["usage"]["completion_tokens"])
        batch_outputs.append(
            prepare_output(
                outputs,
                input_tokens=[completion["usage"]["prompt_tokens"]]
                * num_generations,
                output_tokens=output_tokens,
            )
        )

    return batch_outputs
_prepare_structured_output(structured_output=None)

Creates the appropriate function to filter tokens to generate structured outputs.

Parameters:

Name Type Description Default
structured_output Optional[OutlinesStructuredOutputType]

the configuration dict to prepare the structured output.

None

Returns:

Type Description
Union[LogitsProcessorList, LogitsProcessor]

The callable that will be used to guide the generation of the model.

Source code in src/distilabel/models/llms/llamacpp.py
def _prepare_structured_output(
    self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union["LogitsProcessorList", "LogitsProcessor"]:
    """Creates the appropriate function to filter tokens to generate structured outputs.

    Args:
        structured_output: the configuration dict to prepare the structured output.

    Returns:
        The callable that will be used to guide the generation of the model.
    """
    from distilabel.steps.tasks.structured_outputs.outlines import (
        prepare_guided_output,
    )

    result = prepare_guided_output(structured_output, "llamacpp", self._model)
    if (schema := result.get("schema")) and self.structured_output:
        self.structured_output["schema"] = schema
    return [result["processor"]]

MistralLLM

Bases: AsyncLLM

Mistral LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.

endpoint str

the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Mistral API. Defaults to None which means that the value set for the environment variable OPENAI_API_KEY will be used, or None if not set.

max_retries RuntimeParameter[int]

the maximum number of retries to attempt when a request fails. Defaults to 5.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response. Defaults to 120.

max_concurrent_requests RuntimeParameter[int]

the maximum number of concurrent requests to send. Defaults to 64.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

_aclient Optional[Mistral]

the Mistral to use for the Mistral API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • api_key: the API key to authenticate the requests to the Mistral API.
  • max_retries: the maximum number of retries to attempt when a request fails. Defaults to 5.
  • timeout: the maximum time in seconds to wait for a response. Defaults to 120.
  • max_concurrent_requests: the maximum number of concurrent requests to send. Defaults to 64.

Examples:

Generate text:

from distilabel.models.llms import MistralLLM

llm = MistralLLM(model="open-mixtral-8x22b")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import MistralLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = MistralLLM(
    model="open-mixtral-8x22b",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/mistral.py
class MistralLLM(AsyncLLM):
    """Mistral LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.
        endpoint: the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".
        api_key: the API key to authenticate the requests to the Mistral API. Defaults to `None` which
            means that the value set for the environment variable `OPENAI_API_KEY` will be used, or
            `None` if not set.
        max_retries: the maximum number of retries to attempt when a request fails. Defaults to `5`.
        timeout: the maximum time in seconds to wait for a response. Defaults to `120`.
        max_concurrent_requests: the maximum number of concurrent requests to send. Defaults
            to `64`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key. It is meant to
            be used internally.
        _aclient: the `Mistral` to use for the Mistral API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Mistral API.
        - `max_retries`: the maximum number of retries to attempt when a request fails.
            Defaults to `5`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `120`.
        - `max_concurrent_requests`: the maximum number of concurrent requests to send.
            Defaults to `64`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import MistralLLM

        llm = MistralLLM(model="open-mixtral-8x22b")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import MistralLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = MistralLLM(
            model="open-mixtral-8x22b",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    endpoint: str = "https://api.mistral.ai"
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_MISTRALAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Mistral API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_concurrent_requests: RuntimeParameter[int] = Field(
        default=64, description="The maximum number of concurrent requests to send."
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["Mistral"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `Mistral` client to benefit from async requests."""
        super().load()

        try:
            from mistralai import Mistral
        except ImportError as ie:
            raise ImportError(
                "MistralAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[mistralai]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = Mistral(
            api_key=self.api_key.get_secret_value(),
            endpoint=self.endpoint,
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,  # type: ignore
            max_concurrent_requests=self.max_concurrent_requests,  # type: ignore
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="mistral",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    # TODO: add `num_generations` parameter once Mistral client allows `n` parameter
    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_new_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the MistralAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="mistral",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        generations = []
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)
            # TODO: This should work just with the _aclient.chat method, but it's not working.
            # We need to check instructor and see if we can create a PR.
            completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
        else:
            # completion = await self._aclient.chat(**kwargs)  # type: ignore
            completion = await self._aclient.chat.complete_async(**kwargs)  # type: ignore

        if structured_output:
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using MistralAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)

        return prepare_output(generations, **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "ChatCompletionResponse") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.prompt_tokens],
            "output_tokens": [completion.usage.completion_tokens],
        }
model_name property

Returns the model name used for the LLM.

load()

Loads the Mistral client to benefit from async requests.

Source code in src/distilabel/models/llms/mistral.py
def load(self) -> None:
    """Loads the `Mistral` client to benefit from async requests."""
    super().load()

    try:
        from mistralai import Mistral
    except ImportError as ie:
        raise ImportError(
            "MistralAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[mistralai]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = Mistral(
        api_key=self.api_key.get_secret_value(),
        endpoint=self.endpoint,
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,  # type: ignore
        max_concurrent_requests=self.max_concurrent_requests,  # type: ignore
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="mistral",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, max_new_tokens=None, temperature=None, top_p=None) async

Generates num_generations responses for the given input using the MistralAI async client.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
max_new_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to 128.

None
temperature Optional[float]

the temperature to use for the generation. Defaults to 0.1.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/mistral.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_new_tokens: Optional[int] = None,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the MistralAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="mistral",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
    }
    generations = []
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)
        # TODO: This should work just with the _aclient.chat method, but it's not working.
        # We need to check instructor and see if we can create a PR.
        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
    else:
        # completion = await self._aclient.chat(**kwargs)  # type: ignore
        completion = await self._aclient.chat.complete_async(**kwargs)  # type: ignore

    if structured_output:
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using MistralAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)

    return prepare_output(generations, **self._get_llm_statistics(completion))

MlxLLM

Bases: LLM, MagpieChatTemplateMixin

Apple MLX LLM implementation.

Attributes:

Name Type Description
path_or_hf_repo str

the path to the model or the Hugging Face Hub repo id.

tokenizer_config Dict[str, Any]

the tokenizer configuration.

model_config Dict[str, Any]

the model configuration.

adapter_path Optional[str]

the path to the adapter.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

Icon

:apple:

Examples:

Generate text:

from distilabel.models.llms import MlxLLM

llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/mlx.py
class MlxLLM(LLM, MagpieChatTemplateMixin):
    """Apple MLX LLM implementation.

    Attributes:
        path_or_hf_repo: the path to the model or the Hugging Face Hub repo id.
        tokenizer_config: the tokenizer configuration.
        model_config: the model configuration.
        adapter_path: the path to the adapter.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    Icon:
        `:apple:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import MlxLLM

        llm = MlxLLM(model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    path_or_hf_repo: str
    tokenizer_config: Dict[str, Any] = {}
    model_config: Dict[str, Any] = {}
    adapter_path: Optional[str] = None

    _mlx_generate: Optional[Callable] = PrivateAttr(default=None)
    _model: Optional["nn.Module"] = PrivateAttr(...)
    _tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the model and tokenizer and creates the text generation pipeline. In addition,
        it will configure the tokenizer chat template."""
        try:
            import mlx  # noqa
            from mlx_lm import generate, load
        except ImportError as ie:
            raise ImportError(
                "MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
            ) from ie

        self._model, self._tokenizer = load(
            self.path_or_hf_repo,
            tokenizer_config=self.tokenizer_config,
            model_config=self.model_config,
            adapter_path=self.adapter_path,
        )

        if self._tokenizer.pad_token is None:
            self._tokenizer.pad_token = self._tokenizer.eos_token

        self._mlx_generate = generate

        super().load()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.path_or_hf_repo

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if self._tokenizer.chat_template is None:
            return input[0]["content"]

        prompt: str = (
            self._tokenizer.apply_chat_template(
                input,
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    def generate(
        self,
        inputs: List[StandardInput],
        num_generations: int = 1,
        max_tokens: int = 256,
        sampler: Optional[Callable] = None,
        logits_processors: Optional[List[Callable]] = None,
        max_kv_size: Optional[int] = None,
        prompt_cache: Optional[Any] = None,
        prefill_step_size: int = 512,
        kv_bits: Optional[int] = None,
        kv_group_size: int = 64,
        quantized_kv_start: int = 0,
        prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
        temp: Optional[float] = None,
        repetition_penalty: Optional[float] = None,
        repetition_context_size: Optional[int] = None,
        top_p: Optional[float] = None,
        min_p: Optional[float] = None,
        min_tokens_to_keep: Optional[int] = None,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: the inputs to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            sampler: the sampler to use for the generation. Defaults to `None`.
            logits_processors: the logits processors to use for the generation. Defaults to
                `None`.
            max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
            prompt_cache: the prompt cache to use for the generation. Defaults to `None`.
            prefill_step_size: the prefill step size. Defaults to `512`.
            kv_bits: the number of bits to use for the key-value cache. Defaults to `None`.
            kv_group_size: the group size for the key-value cache. Defaults to `64`.
            quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
            prompt_progress_callback: the callback to use for the generation. Defaults to
                `None`.
            temp: the temperature to use for the generation. Defaults to `None`.
            repetition_penalty: the repetition penalty to use for the generation. Defaults to
                `None`.
            repetition_context_size: the context size for the repetition penalty. Defaults to
                `None`.
            top_p: the top-p value to use for the generation. Defaults to `None`.
            min_p: the minimum p value to use for the generation. Defaults to `None`.
            min_tokens_to_keep: the minimum number of tokens to keep. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        structured_output = None
        result = []
        for input in inputs:
            if isinstance(input, tuple):
                input, structured_output = input

            output: List[str] = []
            for _ in range(num_generations):
                if structured_output:  # will raise a NotImplementedError
                    self._prepare_structured_output(structured_output)
                prompt = self.prepare_input(input)
                generation = self._mlx_generate(
                    prompt=prompt,
                    model=self._model,
                    tokenizer=self._tokenizer,
                    logits_processors=logits_processors,
                    max_tokens=max_tokens,
                    sampler=sampler,
                    max_kv_size=max_kv_size,
                    prompt_cache=prompt_cache,
                    prefill_step_size=prefill_step_size,
                    kv_bits=kv_bits,
                    kv_group_size=kv_group_size,
                    quantized_kv_start=quantized_kv_start,
                    prompt_progress_callback=prompt_progress_callback,
                    temp=temp,
                    repetition_penalty=repetition_penalty,
                    repetition_context_size=repetition_context_size,
                    top_p=top_p,
                    min_p=min_p,
                    min_tokens_to_keep=min_tokens_to_keep,
                )

                output.append(generation)

            result.append(
                prepare_output(
                    output,
                    input_tokens=[compute_tokens(input, self._tokenizer.encode)],
                    output_tokens=[
                        compute_tokens(
                            text_or_messages=generation,
                            tokenizer=self._tokenizer.encode,
                        )
                        for generation in output
                    ],
                )
            )
        return result
model_name property

Returns the model name used for the LLM.

load()

Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.

Source code in src/distilabel/models/llms/mlx.py
def load(self) -> None:
    """Loads the model and tokenizer and creates the text generation pipeline. In addition,
    it will configure the tokenizer chat template."""
    try:
        import mlx  # noqa
        from mlx_lm import generate, load
    except ImportError as ie:
        raise ImportError(
            "MLX is not installed. Please install it using `pip install 'distilabel[mlx]'`."
        ) from ie

    self._model, self._tokenizer = load(
        self.path_or_hf_repo,
        tokenizer_config=self.tokenizer_config,
        model_config=self.model_config,
        adapter_path=self.adapter_path,
    )

    if self._tokenizer.pad_token is None:
        self._tokenizer.pad_token = self._tokenizer.eos_token

    self._mlx_generate = generate

    super().load()
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/mlx.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if self._tokenizer.chat_template is None:
        return input[0]["content"]

    prompt: str = (
        self._tokenizer.apply_chat_template(
            input,
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_tokens=256, sampler=None, logits_processors=None, max_kv_size=None, prompt_cache=None, prefill_step_size=512, kv_bits=None, kv_group_size=64, quantized_kv_start=0, prompt_progress_callback=None, temp=None, repetition_penalty=None, repetition_context_size=None, top_p=None, min_p=None, min_tokens_to_keep=None)

Generates num_generations responses for each input using the text generation pipeline.

Parameters:

Name Type Description Default
inputs List[StandardInput]

the inputs to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

256
sampler Optional[Callable]

the sampler to use for the generation. Defaults to None.

None
logits_processors Optional[List[Callable]]

the logits processors to use for the generation. Defaults to None.

None
max_kv_size Optional[int]

the maximum size of the key-value cache. Defaults to None.

None
prompt_cache Optional[Any]

the prompt cache to use for the generation. Defaults to None.

None
prefill_step_size int

the prefill step size. Defaults to 512.

512
kv_bits Optional[int]

the number of bits to use for the key-value cache. Defaults to None.

None
kv_group_size int

the group size for the key-value cache. Defaults to 64.

64
quantized_kv_start int

the start of the quantized key-value cache. Defaults to 0.

0
prompt_progress_callback Optional[Callable[[int, int], None]]

the callback to use for the generation. Defaults to None.

None
temp Optional[float]

the temperature to use for the generation. Defaults to None.

None
repetition_penalty Optional[float]

the repetition penalty to use for the generation. Defaults to None.

None
repetition_context_size Optional[int]

the context size for the repetition penalty. Defaults to None.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to None.

None
min_p Optional[float]

the minimum p value to use for the generation. Defaults to None.

None
min_tokens_to_keep Optional[int]

the minimum number of tokens to keep. Defaults to None.

None

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/mlx.py
@validate_call
def generate(
    self,
    inputs: List[StandardInput],
    num_generations: int = 1,
    max_tokens: int = 256,
    sampler: Optional[Callable] = None,
    logits_processors: Optional[List[Callable]] = None,
    max_kv_size: Optional[int] = None,
    prompt_cache: Optional[Any] = None,
    prefill_step_size: int = 512,
    kv_bits: Optional[int] = None,
    kv_group_size: int = 64,
    quantized_kv_start: int = 0,
    prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
    temp: Optional[float] = None,
    repetition_penalty: Optional[float] = None,
    repetition_context_size: Optional[int] = None,
    top_p: Optional[float] = None,
    min_p: Optional[float] = None,
    min_tokens_to_keep: Optional[int] = None,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: the inputs to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        sampler: the sampler to use for the generation. Defaults to `None`.
        logits_processors: the logits processors to use for the generation. Defaults to
            `None`.
        max_kv_size: the maximum size of the key-value cache. Defaults to `None`.
        prompt_cache: the prompt cache to use for the generation. Defaults to `None`.
        prefill_step_size: the prefill step size. Defaults to `512`.
        kv_bits: the number of bits to use for the key-value cache. Defaults to `None`.
        kv_group_size: the group size for the key-value cache. Defaults to `64`.
        quantized_kv_start: the start of the quantized key-value cache. Defaults to `0`.
        prompt_progress_callback: the callback to use for the generation. Defaults to
            `None`.
        temp: the temperature to use for the generation. Defaults to `None`.
        repetition_penalty: the repetition penalty to use for the generation. Defaults to
            `None`.
        repetition_context_size: the context size for the repetition penalty. Defaults to
            `None`.
        top_p: the top-p value to use for the generation. Defaults to `None`.
        min_p: the minimum p value to use for the generation. Defaults to `None`.
        min_tokens_to_keep: the minimum number of tokens to keep. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    structured_output = None
    result = []
    for input in inputs:
        if isinstance(input, tuple):
            input, structured_output = input

        output: List[str] = []
        for _ in range(num_generations):
            if structured_output:  # will raise a NotImplementedError
                self._prepare_structured_output(structured_output)
            prompt = self.prepare_input(input)
            generation = self._mlx_generate(
                prompt=prompt,
                model=self._model,
                tokenizer=self._tokenizer,
                logits_processors=logits_processors,
                max_tokens=max_tokens,
                sampler=sampler,
                max_kv_size=max_kv_size,
                prompt_cache=prompt_cache,
                prefill_step_size=prefill_step_size,
                kv_bits=kv_bits,
                kv_group_size=kv_group_size,
                quantized_kv_start=quantized_kv_start,
                prompt_progress_callback=prompt_progress_callback,
                temp=temp,
                repetition_penalty=repetition_penalty,
                repetition_context_size=repetition_context_size,
                top_p=top_p,
                min_p=min_p,
                min_tokens_to_keep=min_tokens_to_keep,
            )

            output.append(generation)

        result.append(
            prepare_output(
                output,
                input_tokens=[compute_tokens(input, self._tokenizer.encode)],
                output_tokens=[
                    compute_tokens(
                        text_or_messages=generation,
                        tokenizer=self._tokenizer.encode,
                    )
                    for generation in output
                ],
            )
        )
    return result

MixtureOfAgentsLLM

Bases: AsyncLLM

Mixture-of-Agents implementation.

An LLM class that leverages LLMs collective strenghts to generate a response, as described in the "Mixture-of-Agents Enhances Large Language model Capabilities" paper. There is a list of LLMs proposing/generating outputs that LLMs from the next round/layer can use as auxiliary information. Finally, there is an LLM that aggregates the outputs to generate the final response.

Attributes:

Name Type Description
aggregator_llm LLM

The LLM that aggregates the outputs of the proposer LLMs.

proposers_llms List[AsyncLLM]

The list of LLMs that propose outputs to be aggregated.

rounds int

The number of layers or rounds that the proposers_llms will generate outputs. Defaults to 1.

References

Examples:

Generate text:

from distilabel.models.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM

llm = MixtureOfAgentsLLM(
    aggregator_llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3-70B-Instruct",
        tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
    ),
    proposers_llms=[
        InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
        ),
        InferenceEndpointsLLM(
            model_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
            tokenizer_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
        ),
        InferenceEndpointsLLM(
            model_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
            tokenizer_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
        ),
    ],
    rounds=2,
)

llm.load()

output = llm.generate_outputs(
    inputs=[
        [
            {
                "role": "user",
                "content": "My favorite witty review of The Rings of Power series is this: Input:",
            }
        ]
    ]
)
Source code in src/distilabel/models/llms/moa.py
class MixtureOfAgentsLLM(AsyncLLM):
    """`Mixture-of-Agents` implementation.

    An `LLM` class that leverages `LLM`s collective strenghts to generate a response,
    as described in the "Mixture-of-Agents Enhances Large Language model Capabilities"
    paper. There is a list of `LLM`s proposing/generating outputs that `LLM`s from the next
    round/layer can use as auxiliary information. Finally, there is an `LLM` that aggregates
    the outputs to generate the final response.

    Attributes:
        aggregator_llm: The `LLM` that aggregates the outputs of the proposer `LLM`s.
        proposers_llms: The list of `LLM`s that propose outputs to be aggregated.
        rounds: The number of layers or rounds that the `proposers_llms` will generate
            outputs. Defaults to `1`.

    References:
        - [Mixture-of-Agents Enhances Large Language Model Capabilities](https://arxiv.org/abs/2406.04692)

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM

        llm = MixtureOfAgentsLLM(
            aggregator_llm=InferenceEndpointsLLM(
                model_id="meta-llama/Meta-Llama-3-70B-Instruct",
                tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
            ),
            proposers_llms=[
                InferenceEndpointsLLM(
                    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
                    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
                ),
                InferenceEndpointsLLM(
                    model_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
                    tokenizer_id="NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
                ),
                InferenceEndpointsLLM(
                    model_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
                    tokenizer_id="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
                ),
            ],
            rounds=2,
        )

        llm.load()

        output = llm.generate_outputs(
            inputs=[
                [
                    {
                        "role": "user",
                        "content": "My favorite witty review of The Rings of Power series is this: Input:",
                    }
                ]
            ]
        )
        ```
    """

    aggregator_llm: LLM
    proposers_llms: List[AsyncLLM] = Field(default_factory=list)
    rounds: int = 1

    @property
    def runtime_parameters_names(self) -> "RuntimeParametersNames":
        """Returns the runtime parameters of the `LLM`, which are a combination of the
        `RuntimeParameter`s of the `LLM`, the `aggregator_llm` and the `proposers_llms`.

        Returns:
            The runtime parameters of the `LLM`.
        """
        runtime_parameters_names = super().runtime_parameters_names
        del runtime_parameters_names["generation_kwargs"]
        return runtime_parameters_names

    def load(self) -> None:
        """Loads all the `LLM`s in the `MixtureOfAgents`."""
        super().load()

        for llm in self.proposers_llms:
            self._logger.debug(f"Loading proposer LLM in MoA: {llm}")  # type: ignore
            llm.load()

        self._logger.debug(f"Loading aggregator LLM in MoA: {self.aggregator_llm}")  # type: ignore
        self.aggregator_llm.load()

    @property
    def model_name(self) -> str:
        """Returns the aggregated model name."""
        return f"moa-{self.aggregator_llm.model_name}-{'-'.join([llm.model_name for llm in self.proposers_llms])}"

    def get_generation_kwargs(self) -> Dict[str, Any]:
        """Returns the generation kwargs of the `MixtureOfAgents` as a dictionary.

        Returns:
            The generation kwargs of the `MixtureOfAgents`.
        """
        return {
            "aggregator_llm": self.aggregator_llm.get_generation_kwargs(),
            "proposers_llms": [
                llm.get_generation_kwargs() for llm in self.proposers_llms
            ],
        }

    # `abstractmethod`, had to be implemented but not used
    async def agenerate(
        self, input: "FormattedInput", num_generations: int = 1, **kwargs: Any
    ) -> List[Union[str, None]]:
        raise NotImplementedError(
            "`agenerate` method is not implemented for `MixtureOfAgents`"
        )

    def _build_moa_system_prompt(self, prev_outputs: List[str]) -> str:
        """Builds the Mixture-of-Agents system prompt.

        Args:
            prev_outputs: The list of previous outputs to use as references.

        Returns:
            The Mixture-of-Agents system prompt.
        """
        moa_system_prompt = MOA_SYSTEM_PROMPT
        for i, prev_output in enumerate(prev_outputs):
            if prev_output is not None:
                moa_system_prompt += f"\n{i + 1}. {prev_output}"
        return moa_system_prompt

    def _inject_moa_system_prompt(
        self, input: "StandardInput", prev_outputs: List[str]
    ) -> "StandardInput":
        """Injects the Mixture-of-Agents system prompt into the input.

        Args:
            input: The input to inject the system prompt into.
            prev_outputs: The list of previous outputs to use as references.

        Returns:
            The input with the Mixture-of-Agents system prompt injected.
        """
        if len(prev_outputs) == 0:
            return input

        moa_system_prompt = self._build_moa_system_prompt(prev_outputs)

        system = next((item for item in input if item["role"] == "system"), None)
        if system:
            original_system_prompt = system["content"]
            system["content"] = f"{moa_system_prompt}\n\n{original_system_prompt}"
        else:
            input.insert(0, {"role": "system", "content": moa_system_prompt})

        return input

    async def _agenerate(
        self,
        inputs: List["FormattedInput"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Internal function to concurrently generate responses for a list of inputs.

        Args:
            inputs: the list of inputs to generate responses for.
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.

        Returns:
            A list containing the generations for each input.
        """
        aggregator_llm_kwargs: Dict[str, Any] = kwargs.get("aggregator_llm", {})
        proposers_llms_kwargs: List[Dict[str, Any]] = kwargs.get(
            "proposers_llms", [{}] * len(self.proposers_llms)
        )

        prev_outputs = []
        for round in range(self.rounds):
            self._logger.debug(f"Generating round {round + 1}/{self.rounds} in MoA")  # type: ignore

            # Generate `num_generations` with each proposer LLM for each input
            tasks = [
                asyncio.create_task(
                    llm._agenerate(
                        inputs=[
                            self._inject_moa_system_prompt(
                                cast("StandardInput", input), prev_input_outputs
                            )
                            for input, prev_input_outputs in itertools.zip_longest(
                                inputs, prev_outputs, fillvalue=[]
                            )
                        ],
                        num_generations=1,
                        **generation_kwargs,
                    )
                )
                for llm, generation_kwargs in zip(
                    self.proposers_llms, proposers_llms_kwargs
                )
            ]

            # Group generations per input
            outputs: List[List["GenerateOutput"]] = await asyncio.gather(*tasks)
            prev_outputs = [
                list(itertools.chain(*input_outputs)) for input_outputs in zip(*outputs)
            ]

        self._logger.debug("Aggregating outputs in MoA")  # type: ignore
        if isinstance(self.aggregator_llm, AsyncLLM):
            return await self.aggregator_llm._agenerate(
                inputs=[
                    self._inject_moa_system_prompt(
                        cast("StandardInput", input), prev_input_outputs
                    )
                    for input, prev_input_outputs in zip(inputs, prev_outputs)
                ],
                num_generations=num_generations,
                **aggregator_llm_kwargs,
            )

        return self.aggregator_llm.generate(
            inputs=[
                self._inject_moa_system_prompt(
                    cast("StandardInput", input), prev_input_outputs
                )
                for input, prev_input_outputs in zip(inputs, prev_outputs)
            ],
            num_generations=num_generations,
            **aggregator_llm_kwargs,
        )
runtime_parameters_names property

Returns the runtime parameters of the LLM, which are a combination of the RuntimeParameters of the LLM, the aggregator_llm and the proposers_llms.

Returns:

Type Description
RuntimeParametersNames

The runtime parameters of the LLM.

model_name property

Returns the aggregated model name.

load()

Loads all the LLMs in the MixtureOfAgents.

Source code in src/distilabel/models/llms/moa.py
def load(self) -> None:
    """Loads all the `LLM`s in the `MixtureOfAgents`."""
    super().load()

    for llm in self.proposers_llms:
        self._logger.debug(f"Loading proposer LLM in MoA: {llm}")  # type: ignore
        llm.load()

    self._logger.debug(f"Loading aggregator LLM in MoA: {self.aggregator_llm}")  # type: ignore
    self.aggregator_llm.load()
get_generation_kwargs()

Returns the generation kwargs of the MixtureOfAgents as a dictionary.

Returns:

Type Description
Dict[str, Any]

The generation kwargs of the MixtureOfAgents.

Source code in src/distilabel/models/llms/moa.py
def get_generation_kwargs(self) -> Dict[str, Any]:
    """Returns the generation kwargs of the `MixtureOfAgents` as a dictionary.

    Returns:
        The generation kwargs of the `MixtureOfAgents`.
    """
    return {
        "aggregator_llm": self.aggregator_llm.get_generation_kwargs(),
        "proposers_llms": [
            llm.get_generation_kwargs() for llm in self.proposers_llms
        ],
    }
_build_moa_system_prompt(prev_outputs)

Builds the Mixture-of-Agents system prompt.

Parameters:

Name Type Description Default
prev_outputs List[str]

The list of previous outputs to use as references.

required

Returns:

Type Description
str

The Mixture-of-Agents system prompt.

Source code in src/distilabel/models/llms/moa.py
def _build_moa_system_prompt(self, prev_outputs: List[str]) -> str:
    """Builds the Mixture-of-Agents system prompt.

    Args:
        prev_outputs: The list of previous outputs to use as references.

    Returns:
        The Mixture-of-Agents system prompt.
    """
    moa_system_prompt = MOA_SYSTEM_PROMPT
    for i, prev_output in enumerate(prev_outputs):
        if prev_output is not None:
            moa_system_prompt += f"\n{i + 1}. {prev_output}"
    return moa_system_prompt
_inject_moa_system_prompt(input, prev_outputs)

Injects the Mixture-of-Agents system prompt into the input.

Parameters:

Name Type Description Default
input StandardInput

The input to inject the system prompt into.

required
prev_outputs List[str]

The list of previous outputs to use as references.

required

Returns:

Type Description
StandardInput

The input with the Mixture-of-Agents system prompt injected.

Source code in src/distilabel/models/llms/moa.py
def _inject_moa_system_prompt(
    self, input: "StandardInput", prev_outputs: List[str]
) -> "StandardInput":
    """Injects the Mixture-of-Agents system prompt into the input.

    Args:
        input: The input to inject the system prompt into.
        prev_outputs: The list of previous outputs to use as references.

    Returns:
        The input with the Mixture-of-Agents system prompt injected.
    """
    if len(prev_outputs) == 0:
        return input

    moa_system_prompt = self._build_moa_system_prompt(prev_outputs)

    system = next((item for item in input if item["role"] == "system"), None)
    if system:
        original_system_prompt = system["content"]
        system["content"] = f"{moa_system_prompt}\n\n{original_system_prompt}"
    else:
        input.insert(0, {"role": "system", "content": moa_system_prompt})

    return input
_agenerate(inputs, num_generations=1, **kwargs) async

Internal function to concurrently generate responses for a list of inputs.

Parameters:

Name Type Description Default
inputs List[FormattedInput]

the list of inputs to generate responses for.

required
num_generations int

the number of generations to generate per input.

1
**kwargs Any

the additional kwargs to be used for the generation.

{}

Returns:

Type Description
List[GenerateOutput]

A list containing the generations for each input.

Source code in src/distilabel/models/llms/moa.py
async def _agenerate(
    self,
    inputs: List["FormattedInput"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Internal function to concurrently generate responses for a list of inputs.

    Args:
        inputs: the list of inputs to generate responses for.
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.

    Returns:
        A list containing the generations for each input.
    """
    aggregator_llm_kwargs: Dict[str, Any] = kwargs.get("aggregator_llm", {})
    proposers_llms_kwargs: List[Dict[str, Any]] = kwargs.get(
        "proposers_llms", [{}] * len(self.proposers_llms)
    )

    prev_outputs = []
    for round in range(self.rounds):
        self._logger.debug(f"Generating round {round + 1}/{self.rounds} in MoA")  # type: ignore

        # Generate `num_generations` with each proposer LLM for each input
        tasks = [
            asyncio.create_task(
                llm._agenerate(
                    inputs=[
                        self._inject_moa_system_prompt(
                            cast("StandardInput", input), prev_input_outputs
                        )
                        for input, prev_input_outputs in itertools.zip_longest(
                            inputs, prev_outputs, fillvalue=[]
                        )
                    ],
                    num_generations=1,
                    **generation_kwargs,
                )
            )
            for llm, generation_kwargs in zip(
                self.proposers_llms, proposers_llms_kwargs
            )
        ]

        # Group generations per input
        outputs: List[List["GenerateOutput"]] = await asyncio.gather(*tasks)
        prev_outputs = [
            list(itertools.chain(*input_outputs)) for input_outputs in zip(*outputs)
        ]

    self._logger.debug("Aggregating outputs in MoA")  # type: ignore
    if isinstance(self.aggregator_llm, AsyncLLM):
        return await self.aggregator_llm._agenerate(
            inputs=[
                self._inject_moa_system_prompt(
                    cast("StandardInput", input), prev_input_outputs
                )
                for input, prev_input_outputs in zip(inputs, prev_outputs)
            ],
            num_generations=num_generations,
            **aggregator_llm_kwargs,
        )

    return self.aggregator_llm.generate(
        inputs=[
            self._inject_moa_system_prompt(
                cast("StandardInput", input), prev_input_outputs
            )
            for input, prev_input_outputs in zip(inputs, prev_outputs)
        ],
        num_generations=num_generations,
        **aggregator_llm_kwargs,
    )

OllamaLLM

Bases: AsyncLLM, MagpieChatTemplateMixin

Ollama LLM implementation running the Async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "notus".

host Optional[RuntimeParameter[str]]

the Ollama server host.

timeout RuntimeParameter[int]

the timeout for the LLM. Defaults to 120.

follow_redirects bool

whether to follow redirects. Defaults to True.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

tokenizer_id Optional[RuntimeParameter[str]]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer config files. If not provided, the one associated to the model will be used. Defaults to None.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

_aclient Optional[AsyncClient]

the AsyncClient to use for the Ollama API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • host: the Ollama server host.
  • timeout: the client timeout for the Ollama API. Defaults to 120.

Examples:

Generate text:

from distilabel.models.llms import OllamaLLM

llm = OllamaLLM(model="llama3")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/ollama.py
class OllamaLLM(AsyncLLM, MagpieChatTemplateMixin):
    """Ollama LLM implementation running the Async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "notus".
        host: the Ollama server host.
        timeout: the timeout for the LLM. Defaults to `120`.
        follow_redirects: whether to follow redirects. Defaults to `True`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        tokenizer_id: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        _aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `host`: the Ollama server host.
        - `timeout`: the client timeout for the Ollama API. Defaults to `120`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import OllamaLLM

        llm = OllamaLLM(model="llama3")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str
    host: Optional[RuntimeParameter[str]] = Field(
        default=None, description="The host of the Ollama API."
    )
    timeout: RuntimeParameter[int] = Field(
        default=120, description="The timeout for the Ollama API."
    )
    follow_redirects: bool = True
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )
    tokenizer_id: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The Hugging Face Hub repo id or a path to a directory containing"
        " the tokenizer config files. If not provided, the one associated to the `model`"
        " will be used.",
    )
    _num_generations_param_supported = False
    _aclient: Optional["AsyncClient"] = PrivateAttr(...)  # type: ignore

    @model_validator(mode="after")  # type: ignore
    def validate_magpie_usage(
        self,
    ) -> "OllamaLLM":
        """Validates that magpie usage is valid."""

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

    def load(self) -> None:
        """Loads the `AsyncClient` to use Ollama async API."""
        super().load()

        try:
            from ollama import AsyncClient

            self._aclient = AsyncClient(
                host=self.host,
                timeout=self.timeout,
                follow_redirects=self.follow_redirects,
            )
        except ImportError as e:
            raise ImportError(
                "Ollama Python client is not installed. Please install it using"
                " `pip install 'distilabel[ollama]'`."
            ) from e

        if self.tokenizer_id:
            try:
                from transformers import AutoTokenizer
            except ImportError as ie:
                raise ImportError(
                    "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
                ) from ie
            self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
            if self._tokenizer.chat_template is None:
                raise ValueError(
                    "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
                )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    async def _generate_chat_completion(
        self,
        input: "StandardInput",
        format: Literal["", "json"] = "",
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> "ChatResponse":
        return await self._aclient.chat(
            model=self.model,
            messages=input,
            stream=False,
            format=format,
            options=options,
            keep_alive=keep_alive,
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(
                conversation=input,
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    async def _generate_with_text_generation(
        self,
        input: "StandardInput",
        format: Literal["", "json"] = None,
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> "GenerateResponse":
        input = self.prepare_input(input)
        return await self._aclient.generate(
            model=self.model,
            prompt=input,
            format=format,
            options=options,
            keep_alive=keep_alive,
            raw=True,
        )

    @validate_call
    async def agenerate(
        self,
        input: StandardInput,
        format: Literal["", "json"] = "",
        # TODO: include relevant options from `Options` in `agenerate` method.
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> GenerateOutput:
        """
        Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).

        Args:
            input: the input to use for the generation.
            format: the format to use for the generation. Defaults to `""`.
            options: the options to use for the generation. Defaults to `None`.
            keep_alive: whether to keep the connection alive. Defaults to `None`.

        Returns:
            A list of strings as completion for the given input.
        """
        text = None
        try:
            if not format:
                format = None
            if self.tokenizer_id is None:
                completion = await self._generate_chat_completion(
                    input, format, options, keep_alive
                )
                text = completion["message"]["content"]
            else:
                completion = await self._generate_with_text_generation(
                    input, format, options, keep_alive
                )
                text = completion.response
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )

        return prepare_output([text], **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: Dict[str, Any]) -> "LLMStatistics":
        return {
            "input_tokens": [completion["prompt_eval_count"]],
            "output_tokens": [completion["eval_count"]],
        }
model_name property

Returns the model name used for the LLM.

validate_magpie_usage()

Validates that magpie usage is valid.

Source code in src/distilabel/models/llms/ollama.py
@model_validator(mode="after")  # type: ignore
def validate_magpie_usage(
    self,
) -> "OllamaLLM":
    """Validates that magpie usage is valid."""

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )
load()

Loads the AsyncClient to use Ollama async API.

Source code in src/distilabel/models/llms/ollama.py
def load(self) -> None:
    """Loads the `AsyncClient` to use Ollama async API."""
    super().load()

    try:
        from ollama import AsyncClient

        self._aclient = AsyncClient(
            host=self.host,
            timeout=self.timeout,
            follow_redirects=self.follow_redirects,
        )
    except ImportError as e:
        raise ImportError(
            "Ollama Python client is not installed. Please install it using"
            " `pip install 'distilabel[ollama]'`."
        ) from e

    if self.tokenizer_id:
        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie
        self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
        if self._tokenizer.chat_template is None:
            raise ValueError(
                "The tokenizer does not have a chat template. Please use a tokenizer with a chat template."
            )
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/ollama.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(
            conversation=input,
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
agenerate(input, format='', options=None, keep_alive=None) async

Generates a response asynchronously, using the Ollama Async API definition.

Parameters:

Name Type Description Default
input StandardInput

the input to use for the generation.

required
format Literal['', 'json']

the format to use for the generation. Defaults to "".

''
options Union[Options, None]

the options to use for the generation. Defaults to None.

None
keep_alive Union[bool, None]

whether to keep the connection alive. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of strings as completion for the given input.

Source code in src/distilabel/models/llms/ollama.py
@validate_call
async def agenerate(
    self,
    input: StandardInput,
    format: Literal["", "json"] = "",
    # TODO: include relevant options from `Options` in `agenerate` method.
    options: Union[Options, None] = None,
    keep_alive: Union[bool, None] = None,
) -> GenerateOutput:
    """
    Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).

    Args:
        input: the input to use for the generation.
        format: the format to use for the generation. Defaults to `""`.
        options: the options to use for the generation. Defaults to `None`.
        keep_alive: whether to keep the connection alive. Defaults to `None`.

    Returns:
        A list of strings as completion for the given input.
    """
    text = None
    try:
        if not format:
            format = None
        if self.tokenizer_id is None:
            completion = await self._generate_chat_completion(
                input, format, options, keep_alive
            )
            text = completion["message"]["content"]
        else:
            completion = await self._generate_with_text_generation(
                input, format, options, keep_alive
            )
            text = completion.response
    except Exception as e:
        self._logger.warning(  # type: ignore
            f"⚠️ Received no response using Ollama client (model: '{self.model_name}')."
            f" Finish reason was: {e}"
        )

    return prepare_output([text], **self._get_llm_statistics(completion))

OpenAILLM

Bases: OpenAIBaseClient, AsyncLLM

OpenAI LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gpt-3.5-turbo", "gpt-4", etc. Supported models can be found here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the OpenAI API requests. Defaults to None, which means that the value set for the environment variable OPENAI_BASE_URL will be used, or "https://api.openai.com/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the OpenAI API. Defaults to None which means that the value set for the environment variable OPENAI_API_KEY will be used, or None if not set.

default_headers Optional[RuntimeParameter[Dict[str, str]]]

the default headers to use for the OpenAI API requests.

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 6.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

Runtime parameters
  • base_url: the base URL to use for the OpenAI API requests. Defaults to None.
  • api_key: the API key to authenticate the requests to the OpenAI API. Defaults to None.
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 6.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
Icon

:simple-openai:

Examples:

Generate text:

from distilabel.models.llms import OpenAILLM

llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate text from a custom endpoint following the OpenAI API:

from distilabel.models.llms import OpenAILLM

llm = OpenAILLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    base_url=r"http://localhost:8080/v1"
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import OpenAILLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = OpenAILLM(
    model="gpt-4-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])

Generate with Batch API (offline batch generation):

from distilabel.models.llms import OpenAILLM

load = llm = OpenAILLM(
    model="gpt-3.5-turbo",
    use_offline_batch_generation=True,
    offline_batch_generation_block_until_done=5,  # poll for results every 5 seconds
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
# [['Hello! How can I assist you today?']]
Source code in src/distilabel/models/llms/openai.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
class OpenAILLM(OpenAIBaseClient, AsyncLLM):
    """OpenAI LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "gpt-3.5-turbo", "gpt-4", etc.
            Supported models can be found [here](https://platform.openai.com/docs/guides/text-generation).
        base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
            means that the value set for the environment variable `OPENAI_BASE_URL` will
            be used, or "https://api.openai.com/v1" if not set.
        api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
            `None` which means that the value set for the environment variable `OPENAI_API_KEY`
            will be used, or `None` if not set.
        default_headers: the default headers to use for the OpenAI API requests.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.

    Runtime parameters:
        - `base_url`: the base URL to use for the OpenAI API requests. Defaults to `None`.
        - `api_key`: the API key to authenticate the requests to the OpenAI API. Defaults
            to `None`.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Icon:
        `:simple-openai:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import OpenAILLM

        llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate text from a custom endpoint following the OpenAI API:

        ```python
        from distilabel.models.llms import OpenAILLM

        llm = OpenAILLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            base_url=r"http://localhost:8080/v1"
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import OpenAILLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = OpenAILLM(
            model="gpt-4-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```

        Generate with Batch API (offline batch generation):

        ```python
        from distilabel.models.llms import OpenAILLM

        load = llm = OpenAILLM(
            model="gpt-3.5-turbo",
            use_offline_batch_generation=True,
            offline_batch_generation_block_until_done=5,  # poll for results every 5 seconds
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        # [['Hello! How can I assist you today?']]
        ```
    """

    def load(self) -> None:
        AsyncLLM.load(self)
        OpenAIBaseClient.load(self)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        num_generations: int = 1,
        max_new_tokens: NonNegativeInt = 128,
        logprobs: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        echo: bool = False,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[Union[str, List[str]]] = None,
        response_format: Optional[Dict[str, str]] = None,
        extra_body: Optional[Dict[str, Any]] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the OpenAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            logprobs: whether to return the log probabilities or not. Defaults to `False`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. Defaults to `None`.
            echo: whether to echo the input in the response or not. It's only used if the
                `input` argument is an `str`. Defaults to `False`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: a string or a list of strings to use as a stop sequence for the generation.
                Defaults to `None`.
            response_format: the format of the response to return. Must be one of
                "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
                for more information on how to use the JSON model from OpenAI. Defaults to None
                which returns text. To return JSON, use {"type": "json_object"}.
            extra_body: an optional dictionary containing extra body parameters that will
                be sent to the OpenAI API endpoint. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        if isinstance(input, str):
            return await self._generate_completion(
                input=input,
                num_generations=num_generations,
                max_new_tokens=max_new_tokens,
                echo=echo,
                top_logprobs=top_logprobs,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                temperature=temperature,
                top_p=top_p,
                extra_body=extra_body,
            )

        return await self._generate_chat_completion(
            input=input,
            num_generations=num_generations,
            max_new_tokens=max_new_tokens,
            logprobs=logprobs,
            top_logprobs=top_logprobs,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            response_format=response_format,
            extra_body=extra_body,
        )

    async def _generate_completion(
        self,
        input: str,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        echo: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        extra_body: Optional[Dict[str, Any]] = None,
    ) -> GenerateOutput:
        completion = await self._aclient.completions.create(
            prompt=input,
            echo=echo,
            model=self.model,
            n=num_generations,
            max_tokens=max_new_tokens,
            logprobs=top_logprobs,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            extra_body=extra_body,
        )

        generations = []
        logprobs = []
        for choice in completion.choices:
            generations.append(choice.text)
            if choice_logprobs := self._get_logprobs_from_completion_choice(choice):
                logprobs.append(choice_logprobs)

        statistics = self._get_llm_statistics(completion)
        return prepare_output(
            generations=generations,
            input_tokens=statistics["input_tokens"],
            output_tokens=statistics["output_tokens"],
            logprobs=logprobs,
        )

    def _get_logprobs_from_completion_choice(
        self, choice: "OpenAICompletionChoice"
    ) -> Union[List[Union[List["Logprob"], None]], None]:
        if choice.logprobs is None or choice.logprobs.top_logprobs is None:
            return None

        return [
            [
                {"token": token, "logprob": token_logprob}
                for token, token_logprob in logprobs.items()
            ]
            if logprobs is not None
            else None
            for logprobs in choice.logprobs.top_logprobs
        ]

    async def _generate_chat_completion(
        self,
        input: Union["StandardInput", "StructuredInput"],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        logprobs: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[Union[str, List[str]]] = None,
        response_format: Optional[Dict[str, str]] = None,
        extra_body: Optional[Dict[str, Any]] = None,
    ) -> GenerateOutput:
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,  # type: ignore
                client=self._aclient,
                framework="openai",
            )
            self._aclient = result.get("client")  # type: ignore

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "logprobs": logprobs,
            "top_logprobs": top_logprobs,
            "max_tokens": max_new_tokens,
            "n": num_generations,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "temperature": temperature,
            "top_p": top_p,
            "stop": stop,
            "extra_body": extra_body,
        }

        # Checks if any message contains an image, in that case "stop" cannot be used or
        # raises an error in the API.
        if isinstance(
            [row for row in input if row["role"] == "user"][0]["content"], list
        ):
            kwargs.pop("stop")

        if response_format is not None:
            kwargs["response_format"] = response_format

        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore

        if structured_output:
            # NOTE: `instructor` doesn't work with `n` parameter, so it will always return
            # only 1 choice.
            statistics = self._get_llm_statistics(completion._raw_response)
            if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
                completion._raw_response.choices[0]
            ):
                output_logprobs = [choice_logprobs]
            else:
                output_logprobs = None
            return prepare_output(
                generations=[completion.model_dump_json()],
                input_tokens=statistics["input_tokens"],
                output_tokens=statistics["output_tokens"],
                logprobs=output_logprobs,
            )

        return self._generations_from_openai_completion(completion)

    def _generations_from_openai_completion(
        self, completion: "OpenAIChatCompletion"
    ) -> "GenerateOutput":
        """Get the generations from the OpenAI Chat Completion object.

        Args:
            completion: the completion object to get the generations from.

        Returns:
            A list of strings containing the generated responses for the input.
        """
        generations = []
        logprobs = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using OpenAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
            if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
                choice
            ):
                logprobs.append(choice_logprobs)

        statistics = self._get_llm_statistics(completion)
        return prepare_output(
            generations=generations,
            input_tokens=statistics["input_tokens"],
            output_tokens=statistics["output_tokens"],
            logprobs=logprobs,
        )

    def _get_logprobs_from_chat_completion_choice(
        self, choice: "OpenAIChatCompletionChoice"
    ) -> Union[List[List["Logprob"]], None]:
        if choice.logprobs is None or choice.logprobs.content is None:
            return None

        return [
            [
                {"token": top_logprob.token, "logprob": top_logprob.logprob}
                for top_logprob in token_logprobs.top_logprobs
            ]
            for token_logprobs in choice.logprobs.content
        ]

    def offline_batch_generate(
        self,
        inputs: Union[List["FormattedInput"], None] = None,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        logprobs: bool = False,
        top_logprobs: Optional[PositiveInt] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[Union[str, List[str]]] = None,
        response_format: Optional[str] = None,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Uses the OpenAI batch API to generate `num_generations` responses for the given
        inputs.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            logprobs: whether to return the log probabilities or not. Defaults to `False`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. Defaults to `None`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: a string or a list of strings to use as a stop sequence for the generation.
                Defaults to `None`.
            response_format: the format of the response to return. Must be one of
                "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
                for more information on how to use the JSON model from OpenAI. Defaults to `text`.

        Returns:
            A list of lists of strings containing the generated responses for each input
            in `inputs`.

        Raises:
            DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
                is not finished yet.
            ValueError: if no job IDs were found to retrieve the results from.
        """
        if self.jobs_ids:
            return self._check_and_get_batch_results()

        if inputs:
            self.jobs_ids = self._create_jobs(
                inputs=inputs,
                **{
                    "model": self.model,
                    "logprobs": logprobs,
                    "top_logprobs": top_logprobs,
                    "max_tokens": max_new_tokens,
                    "n": num_generations,
                    "frequency_penalty": frequency_penalty,
                    "presence_penalty": presence_penalty,
                    "temperature": temperature,
                    "top_p": top_p,
                    "stop": stop,
                    "response_format": response_format,
                },
            )
            raise DistilabelOfflineBatchGenerationNotFinishedException(
                jobs_ids=self.jobs_ids
            )

        raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")

    def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
        """Checks the status of the batch jobs and retrieves the results from the OpenAI
        Batch API.

        Returns:
            A list of lists of strings containing the generated responses for each input.

        Raises:
            ValueError: if no job IDs were found to retrieve the results from.
            DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
                is not finished yet.
            RuntimeError: if the only batch job found failed.
        """
        if not self.jobs_ids:
            raise ValueError("No job IDs were found to retrieve the results from.")

        outputs = []
        for batch_id in self.jobs_ids:
            batch = self._get_openai_batch(batch_id)

            if batch.status in ("validating", "in_progress", "finalizing"):
                raise DistilabelOfflineBatchGenerationNotFinishedException(
                    jobs_ids=self.jobs_ids
                )

            if batch.status in ("failed", "expired", "cancelled", "cancelling"):
                self._logger.error(  # type: ignore
                    f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
                )
                if len(self.jobs_ids) == 1:
                    self.jobs_ids = None
                    raise RuntimeError(
                        f"The only OpenAI API Batch that was created with ID '{batch_id}'"
                        f" failed with status '{batch.status}'."
                    )

                continue

            outputs.extend(self._retrieve_batch_results(batch))

        # sort by `custom_id` to return the results in the same order as the inputs
        outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
        return [self._parse_output(output) for output in outputs]

    def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
        """Parses the output from the OpenAI Batch API into a list of strings.

        Args:
            output: the output to parse.

        Returns:
            A list of strings containing the generated responses for the input.
        """
        from openai.types.chat import ChatCompletion as OpenAIChatCompletion

        if "response" not in output:
            return []

        if output["response"]["status_code"] != 200:
            return []

        return self._generations_from_openai_completion(
            OpenAIChatCompletion(**output["response"]["body"])
        )

    def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
        """Gets a batch from the OpenAI Batch API.

        Args:
            batch_id: the ID of the batch to retrieve.

        Returns:
            The batch retrieved from the OpenAI Batch API.

        Raises:
            openai.OpenAIError: if there was an error while retrieving the batch from the
                OpenAI Batch API.
        """
        import openai

        try:
            return self._client.batches.retrieve(batch_id)
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
            )
            raise e

    def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
        """Retrieves the results of a batch from its output file, parsing the JSONL content
        into a list of dictionaries.

        Args:
            batch: the batch to retrieve the results from.

        Returns:
            A list of dictionaries containing the results of the batch.

        Raises:
            AssertionError: if no output file ID was found in the batch.
        """
        import openai

        assert batch.output_file_id, "No output file ID was found in the batch."

        try:
            file_response = self._client.files.content(batch.output_file_id)
            return [orjson.loads(line) for line in file_response.text.splitlines()]
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
            )
            return []

    def _create_jobs(
        self, inputs: List["FormattedInput"], **kwargs: Any
    ) -> Tuple[str, ...]:
        """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            kwargs: the keyword arguments to use for the generation.

        Returns:
            A list of job IDs created in the OpenAI Batch API.
        """
        batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
        jobs = []
        for batch_input_file in batch_input_files:
            if batch := self._create_batch_api_job(batch_input_file):
                jobs.append(batch.id)
        return tuple(jobs)

    def _create_batch_api_job(
        self, batch_input_file: "OpenAIFileObject"
    ) -> Union["OpenAIBatch", None]:
        """Creates a job in the OpenAI Batch API to generate responses for the given input
        file.

        Args:
            batch_input_file: the input file to generate responses for.

        Returns:
            The batch job created in the OpenAI Batch API.
        """
        import openai

        metadata = {"description": "distilabel"}

        if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
            metadata["distilabel_pipeline_name"] = distilabel_pipeline_name

        if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
            metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id

        batch = None
        try:
            batch = self._client.batches.create(
                completion_window="24h",
                endpoint="/v1/chat/completions",
                input_file_id=batch_input_file.id,
                metadata=metadata,
            )
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while creating OpenAI Batch API job for file with ID"
                f" '{batch_input_file.id}': {e}."
            )
            raise e
        return batch

    def _create_batch_files(
        self, inputs: List["FormattedInput"], **kwargs: Any
    ) -> List["OpenAIFileObject"]:
        """Creates the necessary input files for the batch API to generate responses. The
        maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
        need to split the inputs into multiple files if necessary.

        More information: https://platform.openai.com/docs/api-reference/files/create

        Args:
            inputs: a list of inputs in chat format to generate responses for, optionally
                including structured output.
            kwargs: the keyword arguments to use for the generation.

        Returns:
            The list of file objects created for the OpenAI Batch API.

        Raises:
            openai.OpenAIError: if there was an error while creating the batch input file
                in the OpenAI Batch API.
        """
        import openai

        files = []
        for file_no, buffer in enumerate(
            self._create_jsonl_buffers(inputs=inputs, **kwargs)
        ):
            try:
                # TODO: add distilabel pipeline name and id
                batch_input_file = self._client.files.create(
                    file=(self._name_for_openai_files(file_no), buffer),
                    purpose="batch",
                )
                files.append(batch_input_file)
            except openai.OpenAIError as e:
                self._logger.error(  # type: ignore
                    f"Error while creating OpenAI batch input file: {e}"
                )
                raise e
        return files

    def _create_jsonl_buffers(
        self, inputs: List["FormattedInput"], **kwargs: Any
    ) -> Generator[io.BytesIO, None, None]:
        """Creates a generator of buffers containing the JSONL formatted inputs to be
        used by the OpenAI Batch API. The buffers created are of size 100MB or less.

        Args:
            inputs: a list of inputs in chat format to generate responses for, optionally
                including structured output.
            kwargs: the keyword arguments to use for the generation.

        Yields:
            A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
            API.
        """
        buffer = io.BytesIO()
        buffer_current_size = 0
        for i, input in enumerate(inputs):
            # We create the smallest `custom_id` so we don't  increase the size of the file
            # to much, but we can still sort the results with the order of the inputs.
            row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
            row_size = len(row)
            if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
                buffer.seek(0)
                yield buffer
                buffer = io.BytesIO()
                buffer_current_size = 0
            buffer.write(row)
            buffer_current_size += row_size

        if buffer_current_size > 0:
            buffer.seek(0)
            yield buffer

    def _create_jsonl_row(
        self, input: "FormattedInput", custom_id: str, **kwargs: Any
    ) -> bytes:
        """Creates a JSONL formatted row to be used by the OpenAI Batch API.

        Args:
            input: a list of inputs in chat format to generate responses for, optionally
                including structured output.
            custom_id: a custom ID to use for the row.
            kwargs: the keyword arguments to use for the generation.

        Returns:
            A JSONL formatted row to be used by the OpenAI Batch API.
        """
        # TODO: depending on the format of the input, add `response_format` to the kwargs
        row = {
            "custom_id": custom_id,
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {"messages": input, **kwargs},
        }
        json_row = orjson.dumps(row)
        return json_row + b"\n"

    def _name_for_openai_files(self, file_no: int) -> str:
        if (
            envs.DISTILABEL_PIPELINE_NAME is None
            or envs.DISTILABEL_PIPELINE_CACHE_ID is None
        ):
            return f"distilabel-pipeline-fileno-{file_no}.jsonl"

        return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl"

    @staticmethod
    def _get_llm_statistics(
        completion: Union["OpenAIChatCompletion", "OpenAICompletion"],
    ) -> "LLMStatistics":
        return {
            "output_tokens": [
                completion.usage.completion_tokens if completion.usage else 0
            ],
            "input_tokens": [completion.usage.prompt_tokens if completion.usage else 0],
        }
agenerate(input, num_generations=1, max_new_tokens=128, logprobs=False, top_logprobs=None, echo=False, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, stop=None, response_format=None, extra_body=None) async

Generates num_generations responses for the given input using the OpenAI async client.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens NonNegativeInt

the maximum number of new tokens that the model will generate. Defaults to 128.

128
logprobs bool

whether to return the log probabilities or not. Defaults to False.

False
top_logprobs Optional[PositiveInt]

the number of top log probabilities to return per output token generated. Defaults to None.

None
echo bool

whether to echo the input in the response or not. It's only used if the input argument is an str. Defaults to False.

False
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[Union[str, List[str]]]

a string or a list of strings to use as a stop sequence for the generation. Defaults to None.

None
response_format Optional[Dict[str, str]]

the format of the response to return. Must be one of "text" or "json". Read the documentation here for more information on how to use the JSON model from OpenAI. Defaults to None which returns text. To return JSON, use {"type": "json_object"}.

None
extra_body Optional[Dict[str, Any]]

an optional dictionary containing extra body parameters that will be sent to the OpenAI API endpoint. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/openai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    num_generations: int = 1,
    max_new_tokens: NonNegativeInt = 128,
    logprobs: bool = False,
    top_logprobs: Optional[PositiveInt] = None,
    echo: bool = False,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[Union[str, List[str]]] = None,
    response_format: Optional[Dict[str, str]] = None,
    extra_body: Optional[Dict[str, Any]] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the OpenAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        logprobs: whether to return the log probabilities or not. Defaults to `False`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. Defaults to `None`.
        echo: whether to echo the input in the response or not. It's only used if the
            `input` argument is an `str`. Defaults to `False`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: a string or a list of strings to use as a stop sequence for the generation.
            Defaults to `None`.
        response_format: the format of the response to return. Must be one of
            "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
            for more information on how to use the JSON model from OpenAI. Defaults to None
            which returns text. To return JSON, use {"type": "json_object"}.
        extra_body: an optional dictionary containing extra body parameters that will
            be sent to the OpenAI API endpoint. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    if isinstance(input, str):
        return await self._generate_completion(
            input=input,
            num_generations=num_generations,
            max_new_tokens=max_new_tokens,
            echo=echo,
            top_logprobs=top_logprobs,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            extra_body=extra_body,
        )

    return await self._generate_chat_completion(
        input=input,
        num_generations=num_generations,
        max_new_tokens=max_new_tokens,
        logprobs=logprobs,
        top_logprobs=top_logprobs,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty,
        temperature=temperature,
        top_p=top_p,
        stop=stop,
        response_format=response_format,
        extra_body=extra_body,
    )
_generations_from_openai_completion(completion)

Get the generations from the OpenAI Chat Completion object.

Parameters:

Name Type Description Default
completion ChatCompletion

the completion object to get the generations from.

required

Returns:

Type Description
GenerateOutput

A list of strings containing the generated responses for the input.

Source code in src/distilabel/models/llms/openai.py
def _generations_from_openai_completion(
    self, completion: "OpenAIChatCompletion"
) -> "GenerateOutput":
    """Get the generations from the OpenAI Chat Completion object.

    Args:
        completion: the completion object to get the generations from.

    Returns:
        A list of strings containing the generated responses for the input.
    """
    generations = []
    logprobs = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using OpenAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
        if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
            choice
        ):
            logprobs.append(choice_logprobs)

    statistics = self._get_llm_statistics(completion)
    return prepare_output(
        generations=generations,
        input_tokens=statistics["input_tokens"],
        output_tokens=statistics["output_tokens"],
        logprobs=logprobs,
    )
offline_batch_generate(inputs=None, num_generations=1, max_new_tokens=128, logprobs=False, top_logprobs=None, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, stop=None, response_format=None, **kwargs)

Uses the OpenAI batch API to generate num_generations responses for the given inputs.

Parameters:

Name Type Description Default
inputs Union[List[FormattedInput], None]

a list of inputs in chat format to generate responses for.

None
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
logprobs bool

whether to return the log probabilities or not. Defaults to False.

False
top_logprobs Optional[PositiveInt]

the number of top log probabilities to return per output token generated. Defaults to None.

None
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[Union[str, List[str]]]

a string or a list of strings to use as a stop sequence for the generation. Defaults to None.

None
response_format Optional[str]

the format of the response to return. Must be one of "text" or "json". Read the documentation here for more information on how to use the JSON model from OpenAI. Defaults to text.

None

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input

List[GenerateOutput]

in inputs.

Raises:

Type Description
DistilabelOfflineBatchGenerationNotFinishedException

if the batch generation is not finished yet.

ValueError

if no job IDs were found to retrieve the results from.

Source code in src/distilabel/models/llms/openai.py
def offline_batch_generate(
    self,
    inputs: Union[List["FormattedInput"], None] = None,
    num_generations: int = 1,
    max_new_tokens: int = 128,
    logprobs: bool = False,
    top_logprobs: Optional[PositiveInt] = None,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[Union[str, List[str]]] = None,
    response_format: Optional[str] = None,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Uses the OpenAI batch API to generate `num_generations` responses for the given
    inputs.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        logprobs: whether to return the log probabilities or not. Defaults to `False`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. Defaults to `None`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: a string or a list of strings to use as a stop sequence for the generation.
            Defaults to `None`.
        response_format: the format of the response to return. Must be one of
            "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
            for more information on how to use the JSON model from OpenAI. Defaults to `text`.

    Returns:
        A list of lists of strings containing the generated responses for each input
        in `inputs`.

    Raises:
        DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
            is not finished yet.
        ValueError: if no job IDs were found to retrieve the results from.
    """
    if self.jobs_ids:
        return self._check_and_get_batch_results()

    if inputs:
        self.jobs_ids = self._create_jobs(
            inputs=inputs,
            **{
                "model": self.model,
                "logprobs": logprobs,
                "top_logprobs": top_logprobs,
                "max_tokens": max_new_tokens,
                "n": num_generations,
                "frequency_penalty": frequency_penalty,
                "presence_penalty": presence_penalty,
                "temperature": temperature,
                "top_p": top_p,
                "stop": stop,
                "response_format": response_format,
            },
        )
        raise DistilabelOfflineBatchGenerationNotFinishedException(
            jobs_ids=self.jobs_ids
        )

    raise ValueError("No `inputs` were provided and no `jobs_ids` were found.")
_check_and_get_batch_results()

Checks the status of the batch jobs and retrieves the results from the OpenAI Batch API.

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Raises:

Type Description
ValueError

if no job IDs were found to retrieve the results from.

DistilabelOfflineBatchGenerationNotFinishedException

if the batch generation is not finished yet.

RuntimeError

if the only batch job found failed.

Source code in src/distilabel/models/llms/openai.py
def _check_and_get_batch_results(self) -> List["GenerateOutput"]:
    """Checks the status of the batch jobs and retrieves the results from the OpenAI
    Batch API.

    Returns:
        A list of lists of strings containing the generated responses for each input.

    Raises:
        ValueError: if no job IDs were found to retrieve the results from.
        DistilabelOfflineBatchGenerationNotFinishedException: if the batch generation
            is not finished yet.
        RuntimeError: if the only batch job found failed.
    """
    if not self.jobs_ids:
        raise ValueError("No job IDs were found to retrieve the results from.")

    outputs = []
    for batch_id in self.jobs_ids:
        batch = self._get_openai_batch(batch_id)

        if batch.status in ("validating", "in_progress", "finalizing"):
            raise DistilabelOfflineBatchGenerationNotFinishedException(
                jobs_ids=self.jobs_ids
            )

        if batch.status in ("failed", "expired", "cancelled", "cancelling"):
            self._logger.error(  # type: ignore
                f"OpenAI API batch with ID '{batch_id}' failed with status '{batch.status}'."
            )
            if len(self.jobs_ids) == 1:
                self.jobs_ids = None
                raise RuntimeError(
                    f"The only OpenAI API Batch that was created with ID '{batch_id}'"
                    f" failed with status '{batch.status}'."
                )

            continue

        outputs.extend(self._retrieve_batch_results(batch))

    # sort by `custom_id` to return the results in the same order as the inputs
    outputs = sorted(outputs, key=lambda x: int(x["custom_id"]))
    return [self._parse_output(output) for output in outputs]
_parse_output(output)

Parses the output from the OpenAI Batch API into a list of strings.

Parameters:

Name Type Description Default
output Dict[str, Any]

the output to parse.

required

Returns:

Type Description
GenerateOutput

A list of strings containing the generated responses for the input.

Source code in src/distilabel/models/llms/openai.py
def _parse_output(self, output: Dict[str, Any]) -> "GenerateOutput":
    """Parses the output from the OpenAI Batch API into a list of strings.

    Args:
        output: the output to parse.

    Returns:
        A list of strings containing the generated responses for the input.
    """
    from openai.types.chat import ChatCompletion as OpenAIChatCompletion

    if "response" not in output:
        return []

    if output["response"]["status_code"] != 200:
        return []

    return self._generations_from_openai_completion(
        OpenAIChatCompletion(**output["response"]["body"])
    )
_get_openai_batch(batch_id)

Gets a batch from the OpenAI Batch API.

Parameters:

Name Type Description Default
batch_id str

the ID of the batch to retrieve.

required

Returns:

Type Description
Batch

The batch retrieved from the OpenAI Batch API.

Raises:

Type Description
OpenAIError

if there was an error while retrieving the batch from the OpenAI Batch API.

Source code in src/distilabel/models/llms/openai.py
def _get_openai_batch(self, batch_id: str) -> "OpenAIBatch":
    """Gets a batch from the OpenAI Batch API.

    Args:
        batch_id: the ID of the batch to retrieve.

    Returns:
        The batch retrieved from the OpenAI Batch API.

    Raises:
        openai.OpenAIError: if there was an error while retrieving the batch from the
            OpenAI Batch API.
    """
    import openai

    try:
        return self._client.batches.retrieve(batch_id)
    except openai.OpenAIError as e:
        self._logger.error(  # type: ignore
            f"Error while retrieving batch '{batch_id}' from OpenAI: {e}"
        )
        raise e
_retrieve_batch_results(batch)

Retrieves the results of a batch from its output file, parsing the JSONL content into a list of dictionaries.

Parameters:

Name Type Description Default
batch Batch

the batch to retrieve the results from.

required

Returns:

Type Description
List[Dict[str, Any]]

A list of dictionaries containing the results of the batch.

Raises:

Type Description
AssertionError

if no output file ID was found in the batch.

Source code in src/distilabel/models/llms/openai.py
def _retrieve_batch_results(self, batch: "OpenAIBatch") -> List[Dict[str, Any]]:
    """Retrieves the results of a batch from its output file, parsing the JSONL content
    into a list of dictionaries.

    Args:
        batch: the batch to retrieve the results from.

    Returns:
        A list of dictionaries containing the results of the batch.

    Raises:
        AssertionError: if no output file ID was found in the batch.
    """
    import openai

    assert batch.output_file_id, "No output file ID was found in the batch."

    try:
        file_response = self._client.files.content(batch.output_file_id)
        return [orjson.loads(line) for line in file_response.text.splitlines()]
    except openai.OpenAIError as e:
        self._logger.error(  # type: ignore
            f"Error while retrieving batch results from file '{batch.output_file_id}': {e}"
        )
        return []
_create_jobs(inputs, **kwargs)

Creates jobs in the OpenAI Batch API to generate responses for the given inputs.

Parameters:

Name Type Description Default
inputs List[FormattedInput]

a list of inputs in chat format to generate responses for.

required
kwargs Any

the keyword arguments to use for the generation.

{}

Returns:

Type Description
Tuple[str, ...]

A list of job IDs created in the OpenAI Batch API.

Source code in src/distilabel/models/llms/openai.py
def _create_jobs(
    self, inputs: List["FormattedInput"], **kwargs: Any
) -> Tuple[str, ...]:
    """Creates jobs in the OpenAI Batch API to generate responses for the given inputs.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        kwargs: the keyword arguments to use for the generation.

    Returns:
        A list of job IDs created in the OpenAI Batch API.
    """
    batch_input_files = self._create_batch_files(inputs=inputs, **kwargs)
    jobs = []
    for batch_input_file in batch_input_files:
        if batch := self._create_batch_api_job(batch_input_file):
            jobs.append(batch.id)
    return tuple(jobs)
_create_batch_api_job(batch_input_file)

Creates a job in the OpenAI Batch API to generate responses for the given input file.

Parameters:

Name Type Description Default
batch_input_file FileObject

the input file to generate responses for.

required

Returns:

Type Description
Union[Batch, None]

The batch job created in the OpenAI Batch API.

Source code in src/distilabel/models/llms/openai.py
def _create_batch_api_job(
    self, batch_input_file: "OpenAIFileObject"
) -> Union["OpenAIBatch", None]:
    """Creates a job in the OpenAI Batch API to generate responses for the given input
    file.

    Args:
        batch_input_file: the input file to generate responses for.

    Returns:
        The batch job created in the OpenAI Batch API.
    """
    import openai

    metadata = {"description": "distilabel"}

    if distilabel_pipeline_name := envs.DISTILABEL_PIPELINE_NAME:
        metadata["distilabel_pipeline_name"] = distilabel_pipeline_name

    if distilabel_pipeline_cache_id := envs.DISTILABEL_PIPELINE_CACHE_ID:
        metadata["distilabel_pipeline_cache_id"] = distilabel_pipeline_cache_id

    batch = None
    try:
        batch = self._client.batches.create(
            completion_window="24h",
            endpoint="/v1/chat/completions",
            input_file_id=batch_input_file.id,
            metadata=metadata,
        )
    except openai.OpenAIError as e:
        self._logger.error(  # type: ignore
            f"Error while creating OpenAI Batch API job for file with ID"
            f" '{batch_input_file.id}': {e}."
        )
        raise e
    return batch
_create_batch_files(inputs, **kwargs)

Creates the necessary input files for the batch API to generate responses. The maximum size of each file so the OpenAI Batch API can process it is 100MB, so we need to split the inputs into multiple files if necessary.

More information: https://platform.openai.com/docs/api-reference/files/create

Parameters:

Name Type Description Default
inputs List[FormattedInput]

a list of inputs in chat format to generate responses for, optionally including structured output.

required
kwargs Any

the keyword arguments to use for the generation.

{}

Returns:

Type Description
List[FileObject]

The list of file objects created for the OpenAI Batch API.

Raises:

Type Description
OpenAIError

if there was an error while creating the batch input file in the OpenAI Batch API.

Source code in src/distilabel/models/llms/openai.py
def _create_batch_files(
    self, inputs: List["FormattedInput"], **kwargs: Any
) -> List["OpenAIFileObject"]:
    """Creates the necessary input files for the batch API to generate responses. The
    maximum size of each file so the OpenAI Batch API can process it is 100MB, so we
    need to split the inputs into multiple files if necessary.

    More information: https://platform.openai.com/docs/api-reference/files/create

    Args:
        inputs: a list of inputs in chat format to generate responses for, optionally
            including structured output.
        kwargs: the keyword arguments to use for the generation.

    Returns:
        The list of file objects created for the OpenAI Batch API.

    Raises:
        openai.OpenAIError: if there was an error while creating the batch input file
            in the OpenAI Batch API.
    """
    import openai

    files = []
    for file_no, buffer in enumerate(
        self._create_jsonl_buffers(inputs=inputs, **kwargs)
    ):
        try:
            # TODO: add distilabel pipeline name and id
            batch_input_file = self._client.files.create(
                file=(self._name_for_openai_files(file_no), buffer),
                purpose="batch",
            )
            files.append(batch_input_file)
        except openai.OpenAIError as e:
            self._logger.error(  # type: ignore
                f"Error while creating OpenAI batch input file: {e}"
            )
            raise e
    return files
_create_jsonl_buffers(inputs, **kwargs)

Creates a generator of buffers containing the JSONL formatted inputs to be used by the OpenAI Batch API. The buffers created are of size 100MB or less.

Parameters:

Name Type Description Default
inputs List[FormattedInput]

a list of inputs in chat format to generate responses for, optionally including structured output.

required
kwargs Any

the keyword arguments to use for the generation.

{}

Yields:

Type Description
BytesIO

A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch

BytesIO

API.

Source code in src/distilabel/models/llms/openai.py
def _create_jsonl_buffers(
    self, inputs: List["FormattedInput"], **kwargs: Any
) -> Generator[io.BytesIO, None, None]:
    """Creates a generator of buffers containing the JSONL formatted inputs to be
    used by the OpenAI Batch API. The buffers created are of size 100MB or less.

    Args:
        inputs: a list of inputs in chat format to generate responses for, optionally
            including structured output.
        kwargs: the keyword arguments to use for the generation.

    Yields:
        A buffer containing the JSONL formatted inputs to be used by the OpenAI Batch
        API.
    """
    buffer = io.BytesIO()
    buffer_current_size = 0
    for i, input in enumerate(inputs):
        # We create the smallest `custom_id` so we don't  increase the size of the file
        # to much, but we can still sort the results with the order of the inputs.
        row = self._create_jsonl_row(input=input, custom_id=str(i), **kwargs)
        row_size = len(row)
        if row_size + buffer_current_size > _OPENAI_BATCH_API_MAX_FILE_SIZE:
            buffer.seek(0)
            yield buffer
            buffer = io.BytesIO()
            buffer_current_size = 0
        buffer.write(row)
        buffer_current_size += row_size

    if buffer_current_size > 0:
        buffer.seek(0)
        yield buffer
_create_jsonl_row(input, custom_id, **kwargs)

Creates a JSONL formatted row to be used by the OpenAI Batch API.

Parameters:

Name Type Description Default
input FormattedInput

a list of inputs in chat format to generate responses for, optionally including structured output.

required
custom_id str

a custom ID to use for the row.

required
kwargs Any

the keyword arguments to use for the generation.

{}

Returns:

Type Description
bytes

A JSONL formatted row to be used by the OpenAI Batch API.

Source code in src/distilabel/models/llms/openai.py
def _create_jsonl_row(
    self, input: "FormattedInput", custom_id: str, **kwargs: Any
) -> bytes:
    """Creates a JSONL formatted row to be used by the OpenAI Batch API.

    Args:
        input: a list of inputs in chat format to generate responses for, optionally
            including structured output.
        custom_id: a custom ID to use for the row.
        kwargs: the keyword arguments to use for the generation.

    Returns:
        A JSONL formatted row to be used by the OpenAI Batch API.
    """
    # TODO: depending on the format of the input, add `response_format` to the kwargs
    row = {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {"messages": input, **kwargs},
    }
    json_row = orjson.dumps(row)
    return json_row + b"\n"

TogetherLLM

Bases: OpenAILLM

TogetherLLM LLM implementation running the async API client of OpenAI.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "mistralai/Mixtral-8x7B-Instruct-v0.1". Supported models can be found here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Together API can be set with TOGETHER_BASE_URL. Defaults to None which means that the value set for the environment variable TOGETHER_BASE_URL will be used, or "https://api.together.xyz/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Together API. Defaults to None which means that the value set for the environment variable TOGETHER_API_KEY will be used, or None if not set.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

Examples:

Generate text:

from distilabel.models.llms import AnyscaleLLM

llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/together.py
class TogetherLLM(OpenAILLM):
    """TogetherLLM LLM implementation running the async API client of OpenAI.

    Attributes:
        model: the model name to use for the LLM e.g. "mistralai/Mixtral-8x7B-Instruct-v0.1".
            Supported models can be found [here](https://api.together.xyz/models).
        base_url: the base URL to use for the Together API can be set with `TOGETHER_BASE_URL`.
            Defaults to `None` which means that the value set for the environment variable
            `TOGETHER_BASE_URL` will be used, or "https://api.together.xyz/v1" if not set.
        api_key: the API key to authenticate the requests to the Together API. Defaults to `None`
            which means that the value set for the environment variable `TOGETHER_API_KEY` will be
            used, or `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnyscaleLLM

        llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "TOGETHER_BASE_URL", "https://api.together.xyz/v1"
        ),
        description="The base URL to use for the Together API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_TOGETHER_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Together API.",
    )

    _api_key_env_var: str = PrivateAttr(_TOGETHER_API_KEY_ENV_VAR_NAME)

VertexAILLM

Bases: AsyncLLM

VertexAI LLM implementation running the async API clients for Gemini.

  • Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

To use the VertexAILLM is necessary to have configured the Google Cloud authentication using one of these methods:

  • Setting GOOGLE_CLOUD_CREDENTIALS environment variable
  • Using gcloud auth application-default login command
  • Using vertexai.init function from the google-cloud-aiplatform library

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gemini-1.0-pro". Supported models.

_aclient Optional[GenerativeModel]

the GenerativeModel to use for the Vertex AI Gemini API. It is meant to be used internally. Set in the load method.

Icon

:simple-googlecloud:

Examples:

Generate text:

from distilabel.models.llms import VertexAILLM

llm = VertexAILLM(model="gemini-1.5-pro")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/vertexai.py
class VertexAILLM(AsyncLLM):
    """VertexAI LLM implementation running the async API clients for Gemini.

    - Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

    To use the `VertexAILLM` is necessary to have configured the Google Cloud authentication
    using one of these methods:

    - Setting `GOOGLE_CLOUD_CREDENTIALS` environment variable
    - Using `gcloud auth application-default login` command
    - Using `vertexai.init` function from the `google-cloud-aiplatform` library

    Attributes:
        model: the model name to use for the LLM e.g. "gemini-1.0-pro". [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models).
        _aclient: the `GenerativeModel` to use for the Vertex AI Gemini API. It is meant
            to be used internally. Set in the `load` method.

    Icon:
        `:simple-googlecloud:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import VertexAILLM

        llm = VertexAILLM(model="gemini-1.5-pro")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str

    _num_generations_param_supported = False

    _aclient: Optional["GenerativeModel"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
        super().load()

        try:
            from vertexai.generative_models import GenerationConfig, GenerativeModel

            self._generation_config_class = GenerationConfig
        except ImportError as e:
            raise ImportError(
                "vertexai is not installed. Please install it using"
                " `pip install 'distilabel[vertexai]'`."
            ) from e

        if _is_gemini_model(self.model):
            self._aclient = GenerativeModel(model_name=self.model)
        else:
            raise NotImplementedError(
                "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
            )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
        """Converts a chat type to a list of content items expected by the API.

        Args:
            input: the chat type to be converted.

        Returns:
            List[str]: a list of content items expected by the API.
        """
        from vertexai.generative_models import Content, Part

        contents = []
        for message in input:
            if message["role"] not in ["user", "model"]:
                raise ValueError(
                    "`VertexAILLM only supports the roles 'user' or 'model'."
                )
            contents.append(
                Content(
                    role=message["role"], parts=[Part.from_text(message["content"])]
                )
            )
        return contents

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: VertexChatType,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_output_tokens: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        safety_settings: Optional[Dict[str, Any]] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

        Args:
            input: a single input in chat format to generate responses for.
            temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
            top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
            top_k: If specified, top-k sampling will be used. Defaults to `None`.
            max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
            stop_sequences: A list of stop sequences. Defaults to `None`.
            safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
            tools: A potential list of tools that can be used by the API. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vertexai.generative_models import GenerationConfig

        content: "GenerationResponse" = await self._aclient.generate_content_async(  # type: ignore
            contents=self._chattype_to_content(input),
            generation_config=GenerationConfig(
                candidate_count=1,  # only one candidate allowed per call
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_output_tokens=max_output_tokens,
                stop_sequences=stop_sequences,
            ),
            safety_settings=safety_settings,  # type: ignore
            tools=tools,  # type: ignore
            stream=False,
        )

        text = None
        try:
            text = content.candidates[0].text
        except ValueError:
            self._logger.warning(  # type: ignore
                f"Received no response using VertexAI client (model: '{self.model}')."
                f" Finish reason was: '{content.candidates[0].finish_reason}'."
            )
        return prepare_output([text], **self._get_llm_statistics(content))

    @staticmethod
    def _get_llm_statistics(content: "GenerationResponse") -> "LLMStatistics":
        return {
            "input_tokens": [content.usage_metadata.prompt_token_count],
            "output_tokens": [content.usage_metadata.candidates_token_count],
        }
model_name property

Returns the model name used for the LLM.

load()

Loads the GenerativeModel class which has access to generate_content_async to benefit from async requests.

Source code in src/distilabel/models/llms/vertexai.py
def load(self) -> None:
    """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
    super().load()

    try:
        from vertexai.generative_models import GenerationConfig, GenerativeModel

        self._generation_config_class = GenerationConfig
    except ImportError as e:
        raise ImportError(
            "vertexai is not installed. Please install it using"
            " `pip install 'distilabel[vertexai]'`."
        ) from e

    if _is_gemini_model(self.model):
        self._aclient = GenerativeModel(model_name=self.model)
    else:
        raise NotImplementedError(
            "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
        )
_chattype_to_content(input)

Converts a chat type to a list of content items expected by the API.

Parameters:

Name Type Description Default
input StandardInput

the chat type to be converted.

required

Returns:

Type Description
List[Content]

List[str]: a list of content items expected by the API.

Source code in src/distilabel/models/llms/vertexai.py
def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
    """Converts a chat type to a list of content items expected by the API.

    Args:
        input: the chat type to be converted.

    Returns:
        List[str]: a list of content items expected by the API.
    """
    from vertexai.generative_models import Content, Part

    contents = []
    for message in input:
        if message["role"] not in ["user", "model"]:
            raise ValueError(
                "`VertexAILLM only supports the roles 'user' or 'model'."
            )
        contents.append(
            Content(
                role=message["role"], parts=[Part.from_text(message["content"])]
            )
        )
    return contents
agenerate(input, temperature=None, top_p=None, top_k=None, max_output_tokens=None, stop_sequences=None, safety_settings=None, tools=None) async

Generates num_generations responses for the given input using the VertexAI async client definition.

Parameters:

Name Type Description Default
input VertexChatType

a single input in chat format to generate responses for.

required
temperature Optional[float]

Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to None.

None
top_p Optional[float]

If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to None.

None
top_k Optional[int]

If specified, top-k sampling will be used. Defaults to None.

None
max_output_tokens Optional[int]

The maximum number of output tokens to generate per message. Defaults to None.

None
stop_sequences Optional[List[str]]

A list of stop sequences. Defaults to None.

None
safety_settings Optional[Dict[str, Any]]

Safety configuration for returned content from the API. Defaults to None.

None
tools Optional[List[Dict[str, Any]]]

A potential list of tools that can be used by the API. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/vertexai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: VertexChatType,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    max_output_tokens: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    safety_settings: Optional[Dict[str, Any]] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

    Args:
        input: a single input in chat format to generate responses for.
        temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
        top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
        top_k: If specified, top-k sampling will be used. Defaults to `None`.
        max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
        stop_sequences: A list of stop sequences. Defaults to `None`.
        safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
        tools: A potential list of tools that can be used by the API. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vertexai.generative_models import GenerationConfig

    content: "GenerationResponse" = await self._aclient.generate_content_async(  # type: ignore
        contents=self._chattype_to_content(input),
        generation_config=GenerationConfig(
            candidate_count=1,  # only one candidate allowed per call
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_output_tokens=max_output_tokens,
            stop_sequences=stop_sequences,
        ),
        safety_settings=safety_settings,  # type: ignore
        tools=tools,  # type: ignore
        stream=False,
    )

    text = None
    try:
        text = content.candidates[0].text
    except ValueError:
        self._logger.warning(  # type: ignore
            f"Received no response using VertexAI client (model: '{self.model}')."
            f" Finish reason was: '{content.candidates[0].finish_reason}'."
        )
    return prepare_output([text], **self._get_llm_statistics(content))

ClientvLLM

Bases: OpenAILLM, MagpieChatTemplateMixin

A client for the vLLM server implementing the OpenAI API specification.

Attributes:

Name Type Description
base_url Optional[RuntimeParameter[str]]

the base URL of the vLLM server. Defaults to "http://localhost:8000".

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 6.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

httpx_client_kwargs RuntimeParameter[int]

extra kwargs that will be passed to the httpx.AsyncClient created to comunicate with the vLLM server. Defaults to None.

tokenizer Optional[str]

the Hugging Face Hub repo id or path of the tokenizer that will be used to apply the chat template and tokenize the inputs before sending it to the server. Defaults to None.

tokenizer_revision Optional[str]

the revision of the tokenizer to load. Defaults to None.

_aclient AsyncOpenAI

the httpx.AsyncClient used to comunicate with the vLLM server. Defaults to None.

Runtime parameters
  • base_url: the base url of the vLLM server. Defaults to "http://localhost:8000".
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 6.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
  • httpx_client_kwargs: extra kwargs that will be passed to the httpx.AsyncClient created to comunicate with the vLLM server. Defaults to None.

Examples:

Generate text:

from distilabel.models.llms import ClientvLLM

llm = ClientvLLM(
    base_url="http://localhost:8000/v1",
    tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
)

llm.load()

results = llm.generate_outputs(
    inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
    temperature=0.7,
    top_p=1.0,
    max_new_tokens=256,
)
# [
#     [
#         "I'm functioning properly, thank you for asking. How can I assist you today?",
#         "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
#         "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
#     ]
# ]
Source code in src/distilabel/models/llms/vllm.py
class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
    """A client for the `vLLM` server implementing the OpenAI API specification.

    Attributes:
        base_url: the base URL of the `vLLM` server. Defaults to `"http://localhost:8000"`.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        httpx_client_kwargs: extra kwargs that will be passed to the `httpx.AsyncClient`
            created to comunicate with the `vLLM` server. Defaults to `None`.
        tokenizer: the Hugging Face Hub repo id or path of the tokenizer that will be used
            to apply the chat template and tokenize the inputs before sending it to the
            server. Defaults to `None`.
        tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
        _aclient: the `httpx.AsyncClient` used to comunicate with the `vLLM` server. Defaults
            to `None`.

    Runtime parameters:
        - `base_url`: the base url of the `vLLM` server. Defaults to `"http://localhost:8000"`.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `httpx_client_kwargs`: extra kwargs that will be passed to the `httpx.AsyncClient`
            created to comunicate with the `vLLM` server. Defaults to `None`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import ClientvLLM

        llm = ClientvLLM(
            base_url="http://localhost:8000/v1",
            tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
        )

        llm.load()

        results = llm.generate_outputs(
            inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
            temperature=0.7,
            top_p=1.0,
            max_new_tokens=256,
        )
        # [
        #     [
        #         "I'm functioning properly, thank you for asking. How can I assist you today?",
        #         "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
        #         "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
        #     ]
        # ]
        ```
    """

    model: str = ""  # Default value so it's not needed to `ClientvLLM(model="...")`
    tokenizer: Optional[str] = None
    tokenizer_revision: Optional[str] = None

    # We need the sync client to get the list of models
    _client: "OpenAI" = PrivateAttr(None)
    _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)

    def load(self) -> None:
        """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
        optionally."""

        self.api_key = SecretStr("EMPTY")

        # We need to first create the sync client to get the model name that will be used
        # in the `super().load()` when creating the logger.
        try:
            from openai import OpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[openai]'`."
            ) from ie

        self._client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),  # type: ignore
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        super().load()

        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "To use `ClientvLLM` you need to install `transformers`."
                "Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie

        self._tokenizer = AutoTokenizer.from_pretrained(
            self.tokenizer, revision=self.tokenizer_revision
        )

    @cached_property
    def model_name(self) -> str:  # type: ignore
        """Returns the name of the model served with vLLM server."""
        models = self._client.models.list()
        return models.data[0].id

    def _prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,  # type: ignore
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        logit_bias: Optional[Dict[str, int]] = None,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for each input.

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            logit_bias: modify the likelihood of specified tokens appearing in the completion.
                Defaults to ``
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: nucleus sampling. The value refers to the top-p tokens that should be
                considered for sampling. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        completion = await self._aclient.completions.create(
            model=self.model_name,
            prompt=self._prepare_input(input),  # type: ignore
            n=num_generations,
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
        )

        generations = []
        for choice in completion.choices:
            text = choice.text
            if text == "":
                self._logger.warning(  # type: ignore
                    f"Received no response from vLLM server (model: '{self.model_name}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(text)

        return prepare_output(generations, **self._get_llm_statistics(completion))
model_name cached property

Returns the name of the model served with vLLM server.

load()

Creates an httpx.AsyncClient to connect to the vLLM server and a tokenizer optionally.

Source code in src/distilabel/models/llms/vllm.py
def load(self) -> None:
    """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
    optionally."""

    self.api_key = SecretStr("EMPTY")

    # We need to first create the sync client to get the model name that will be used
    # in the `super().load()` when creating the logger.
    try:
        from openai import OpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[openai]'`."
        ) from ie

    self._client = OpenAI(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),  # type: ignore
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    super().load()

    try:
        from transformers import AutoTokenizer
    except ImportError as ie:
        raise ImportError(
            "To use `ClientvLLM` you need to install `transformers`."
            "Please install it using `pip install 'distilabel[hf-transformers]'`."
        ) from ie

    self._tokenizer = AutoTokenizer.from_pretrained(
        self.tokenizer, revision=self.tokenizer_revision
    )
_prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/vllm.py
def _prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,  # type: ignore
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
agenerate(input, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, logit_bias=None, presence_penalty=0.0, temperature=1.0, top_p=1.0) async

Generates num_generations responses for each input.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
logit_bias Optional[Dict[str, int]]

modify the likelihood of specified tokens appearing in the completion. Defaults to ``

None
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

nucleus sampling. The value refers to the top-p tokens that should be considered for sampling. Defaults to 1.0.

1.0

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/vllm.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    logit_bias: Optional[Dict[str, int]] = None,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> GenerateOutput:
    """Generates `num_generations` responses for each input.

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        logit_bias: modify the likelihood of specified tokens appearing in the completion.
            Defaults to ``
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: nucleus sampling. The value refers to the top-p tokens that should be
            considered for sampling. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    completion = await self._aclient.completions.create(
        model=self.model_name,
        prompt=self._prepare_input(input),  # type: ignore
        n=num_generations,
        max_tokens=max_new_tokens,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        presence_penalty=presence_penalty,
        temperature=temperature,
        top_p=top_p,
    )

    generations = []
    for choice in completion.choices:
        text = choice.text
        if text == "":
            self._logger.warning(  # type: ignore
                f"Received no response from vLLM server (model: '{self.model_name}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(text)

    return prepare_output(generations, **self._get_llm_statistics(completion))

vLLM

Bases: LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin

vLLM library LLM implementation.

Attributes:

Name Type Description
model str

the model Hugging Face Hub repo id or a path to a directory containing the model weights and configuration files.

dtype str

the data type to use for the model. Defaults to auto.

trust_remote_code bool

whether to trust the remote code when loading the model. Defaults to False.

quantization Optional[str]

the quantization mode to use for the model. Defaults to None.

revision Optional[str]

the revision of the model to load. Defaults to None.

tokenizer Optional[str]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer files. If not provided, the tokenizer will be loaded from the model directory. Defaults to None.

tokenizer_mode Literal['auto', 'slow']

the mode to use for the tokenizer. Defaults to auto.

tokenizer_revision Optional[str]

the revision of the tokenizer to load. Defaults to None.

skip_tokenizer_init bool

whether to skip the initialization of the tokenizer. Defaults to False.

chat_template Optional[str]

a chat template that will be used to build the prompts before sending them to the model. If not provided, the chat template defined in the tokenizer config will be used. If not provided and the tokenizer doesn't have a chat template, then ChatML template will be used. Defaults to None.

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

seed int

the seed to use for the random number generator. Defaults to 0.

extra_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

additional dictionary of keyword arguments that will be passed to the LLM class of vllm library. Defaults to {}.

_model LLM

the vLLM model instance. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

_tokenizer PreTrainedTokenizer

the tokenizer instance used to format the prompt before passing it to the LLM. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

References
  • https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
Runtime parameters
  • extra_kwargs: additional dictionary of keyword arguments that will be passed to the LLM class of vllm library.

Examples:

Generate text:

from distilabel.models.llms import vLLM

# You can pass a custom chat_template to the model
llm = vLLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    chat_template="[INST] {{ messages[0]"content" }}\n{{ messages[1]"content" }}[/INST]",
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pathlib import Path
from distilabel.models.llms import vLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = vLLM(
    model="prometheus-eval/prometheus-7b-v2.0"
    structured_output={"format": "json", "schema": Character},
)

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/vllm.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
    """`vLLM` library LLM implementation.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        dtype: the data type to use for the model. Defaults to `auto`.
        trust_remote_code: whether to trust the remote code when loading the model. Defaults
            to `False`.
        quantization: the quantization mode to use for the model. Defaults to `None`.
        revision: the revision of the model to load. Defaults to `None`.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer files. If not provided, the tokenizer will be loaded from the
            model directory. Defaults to `None`.
        tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`.
        tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
        skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults
            to `False`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        seed: the seed to use for the random number generator. Defaults to `0`.
        extra_kwargs: additional dictionary of keyword arguments that will be passed to the
            `LLM` class of `vllm` library. Defaults to `{}`.
        _model: the `vLLM` model instance. This attribute is meant to be used internally
            and should not be accessed directly. It will be set in the `load` method.
        _tokenizer: the tokenizer instance used to format the prompt before passing it to
            the `LLM`. This attribute is meant to be used internally and should not be
            accessed directly. It will be set in the `load` method.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    References:
        - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py

    Runtime parameters:
        - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to
            the `LLM` class of `vllm` library.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import vLLM

        # You can pass a custom chat_template to the model
        llm = vLLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pathlib import Path
        from distilabel.models.llms import vLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = vLLM(
            model="prometheus-eval/prometheus-7b-v2.0"
            structured_output={"format": "json", "schema": Character},
        )

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    dtype: str = "auto"
    trust_remote_code: bool = False
    quantization: Optional[str] = None
    revision: Optional[str] = None

    tokenizer: Optional[str] = None
    tokenizer_mode: Literal["auto", "slow"] = "auto"
    tokenizer_revision: Optional[str] = None
    skip_tokenizer_init: bool = False
    chat_template: Optional[str] = None

    seed: int = 0

    extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="Additional dictionary of keyword arguments that will be passed to the"
        " `vLLM` class of `vllm` library. See all the supported arguments at: "
        "https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py",
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )

    _model: "_vLLM" = PrivateAttr(None)
    _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)
    _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None)

    def load(self) -> None:
        """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
        Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
        parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
        default value is ChatML format, unless explicitly provided.
        """
        super().load()

        CudaDevicePlacementMixin.load(self)

        try:
            from vllm import LLM as _vLLM
        except ImportError as ie:
            raise ImportError(
                "vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`."
            ) from ie

        self._model = _vLLM(
            self.model,
            dtype=self.dtype,
            trust_remote_code=self.trust_remote_code,
            quantization=self.quantization,
            revision=self.revision,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            tokenizer_revision=self.tokenizer_revision,
            skip_tokenizer_init=self.skip_tokenizer_init,
            seed=self.seed,
            **self.extra_kwargs,  # type: ignore
        )

        self._tokenizer = self._model.get_tokenizer()  # type: ignore
        if self.chat_template is not None:
            self._tokenizer.chat_template = self.chat_template  # type: ignore

        if self.structured_output:
            self._structured_output_logits_processor = self._prepare_structured_output(
                self.structured_output
            )

    def unload(self) -> None:
        """Unloads the `vLLM` model."""
        self._cleanup_vllm_model()
        self._model = None  # type: ignore
        self._tokenizer = None  # type: ignore
        CudaDevicePlacementMixin.unload(self)
        super().unload()

    def _cleanup_vllm_model(self) -> None:
        if self._model is None:
            return

        import torch  # noqa
        from vllm.distributed.parallel_state import (
            destroy_distributed_environment,
            destroy_model_parallel,
        )

        destroy_model_parallel()
        destroy_distributed_environment()
        del self._model.llm_engine.model_executor
        del self._model
        with contextlib.suppress(AssertionError):
            torch.distributed.destroy_process_group()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: Union["StandardInput", str]) -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if isinstance(input, str):
            return input

        prompt: str = (
            self._tokenizer.apply_chat_template(
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,  # type: ignore
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _prepare_batches(
        self, inputs: List["StructuredInput"]
    ) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]:
        """Prepares the inputs by grouping them by the structured output.

        When we generate structured outputs with schemas obtained from a dataset, we need to
        prepare the data to try to send batches of inputs instead of single inputs to the model
        to take advante of the engine. So we group the inputs by the structured output to be
        passed in the `generate` method.

        Args:
            inputs: The batch of inputs passed to the generate method. As we expect to be generating
                structured outputs, each element will be a tuple containing the instruction and the
                structured output.

        Returns:
            The prepared batches (sub-batches let's say) to be passed to the `generate` method.
            Each new tuple will contain instead of the single instruction, a list of instructions
        """
        instruction_order = {}
        batches: Dict[str, List[str]] = {}
        for i, (instruction, structured_output) in enumerate(inputs):
            instruction = self.prepare_input(instruction)
            instruction_order[instruction] = i

            structured_output = json.dumps(structured_output)
            if structured_output not in batches:
                batches[structured_output] = [instruction]
            else:
                batches[structured_output].append(instruction)

        # Built a list with instructions sorted by structured output
        flat_instructions = [
            instruction for _, group in batches.items() for instruction in group
        ]

        # Generate the list of indices based on the original order
        sorted_indices = [
            instruction_order[instruction] for instruction in flat_instructions
        ]

        return [
            (batch, json.loads(schema)) for schema, batch in batches.items()
        ], sorted_indices

    @validate_call
    def generate(  # noqa: C901 # type: ignore
        self,
        inputs: List[FormattedInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        presence_penalty: float = 0.0,
        frequency_penalty: float = 0.0,
        repetition_penalty: float = 1.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        logprobs: Optional[PositiveInt] = None,
        stop: Optional[List[str]] = None,
        stop_token_ids: Optional[List[int]] = None,
        include_stop_str_in_output: bool = False,
        skip_special_tokens: bool = True,
        logits_processors: Optional[LogitsProcessors] = None,
        extra_sampling_params: Optional[Dict[str, Any]] = None,
        echo: bool = False,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            repetition_penalty: the repetition penalty to use for the generation Defaults to
                `1.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            min_p: the minimum probability to use for the generation. Defaults to `0.0`.
            logprobs: number of log probabilities to return per output token. If `None`,
                then no log probability won't be returned. Defaults to `None`.
            stop: a list of strings that will be used to stop the generation when found.
                Defaults to `None`.
            stop_token_ids: a list of token ids that will be used to stop the generation
                when found. Defaults to `None`.
            include_stop_str_in_output: whether to include the stop string in the output.
                Defaults to `False`.
            skip_special_tokens: whether to exclude special tokens from the output. Defaults
                to `False`.
            logits_processors: a list of functions to process the logits before sampling.
                Defaults to `None`.
            extra_sampling_params: dictionary with additional arguments to be passed to
                the `SamplingParams` class from `vllm`.
            echo: whether to echo the include the prompt in the response or not. Defaults
                to `False`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vllm import SamplingParams

        if not logits_processors:
            logits_processors = []

        if extra_sampling_params is None:
            extra_sampling_params = {}

        structured_output = None

        if isinstance(inputs[0], tuple):
            # Prepare the batches for structured generation
            prepared_batches, sorted_indices = self._prepare_batches(inputs)  # type: ignore
        else:
            # Simulate a batch without the structured output content
            prepared_batches = [([self.prepare_input(input) for input in inputs], None)]  # type: ignore
            sorted_indices = None

        # Case in which we have a single structured output for the dataset
        if self._structured_output_logits_processor:
            logits_processors.append(self._structured_output_logits_processor)

        batched_outputs: List["LLMOutput"] = []
        generations = []

        for prepared_inputs, structured_output in prepared_batches:
            if self.structured_output is not None and structured_output is not None:
                self._logger.warning(
                    "An `structured_output` was provided in the model configuration, but"
                    " one was also provided in the input. The input structured output will"
                    " be used."
                )

            if structured_output is not None:
                logits_processors.append(
                    self._prepare_structured_output(structured_output)  # type: ignore
                )

            sampling_params = SamplingParams(  # type: ignore
                n=num_generations,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                repetition_penalty=repetition_penalty,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                min_p=min_p,
                max_tokens=max_new_tokens,
                prompt_logprobs=logprobs if echo else None,
                logprobs=logprobs,
                stop=stop,
                stop_token_ids=stop_token_ids,
                include_stop_str_in_output=include_stop_str_in_output,
                skip_special_tokens=skip_special_tokens,
                logits_processors=logits_processors,
                **extra_sampling_params,
            )

            batch_outputs: List["RequestOutput"] = self._model.generate(
                prompts=prepared_inputs,
                sampling_params=sampling_params,
                use_tqdm=False,
            )

            # Remove structured output logit processor to avoid stacking structured output
            # logits processors that leads to non-sense generations
            if structured_output is not None:
                logits_processors.pop(-1)

            for input, outputs in zip(prepared_inputs, batch_outputs):
                processed_prompt_logprobs = []
                if outputs.prompt_logprobs is not None:
                    processed_prompt_logprobs = self._get_llm_logprobs(
                        outputs.prompt_logprobs
                    )
                texts, statistics, outputs_logprobs = self._process_outputs(
                    input=input,
                    outputs=outputs,
                    echo=echo,
                    prompt_logprobs=processed_prompt_logprobs,
                )
                batched_outputs.append(texts)
                generation = prepare_output(
                    generations=texts,
                    input_tokens=statistics["input_tokens"],
                    output_tokens=statistics["output_tokens"],
                    logprobs=outputs_logprobs,
                )

                generations.append(generation)

        if sorted_indices is not None:
            pairs = list(enumerate(sorted_indices))
            pairs.sort(key=lambda x: x[1])
            generations = [generations[original_idx] for original_idx, _ in pairs]

        return generations

    def _process_outputs(
        self,
        input: str,
        outputs: "RequestOutput",
        prompt_logprobs: List[List["Logprob"]],
        echo: bool = False,
    ) -> Tuple["LLMOutput", "LLMStatistics", "LLMLogprobs"]:
        texts = []
        outputs_logprobs = []
        statistics = {
            "input_tokens": [compute_tokens(input, self._tokenizer.encode)]
            * len(outputs.outputs),
            "output_tokens": [],
        }
        for output in outputs.outputs:
            text = output.text
            if echo:
                text = input + text
            texts.append(text)
            statistics["output_tokens"].append(len(output.token_ids))
            if output.logprobs is not None:
                processed_output_logprobs = self._get_llm_logprobs(output.logprobs)
                outputs_logprobs.append(prompt_logprobs + processed_output_logprobs)
        return texts, statistics, outputs_logprobs

    def _prepare_structured_output(  # type: ignore
        self, structured_output: "OutlinesStructuredOutputType"
    ) -> Union[Callable, None]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        assert structured_output is not None, "`structured_output` cannot be `None`"

        result = prepare_guided_output(structured_output, "vllm", self._model)
        if (schema := result.get("schema")) and self.structured_output:
            self.structured_output["schema"] = schema
        return result["processor"]

    def _get_llm_logprobs(
        self, logprobs: Union["PromptLogprobs", "SampleLogprobs"]
    ) -> List[List["Logprob"]]:
        processed_logprobs = []
        for token_logprob in logprobs:  # type: ignore
            token_logprobs = []
            if token_logprob is None:
                processed_logprobs.append(None)
                continue
            for logprob in token_logprob.values():
                token_logprobs.append(
                    {"token": logprob.decoded_token, "logprob": logprob.logprob}
                )
            processed_logprobs.append(token_logprobs)
        return processed_logprobs
model_name property

Returns the model name used for the LLM.

load()

Loads the vLLM model using either the path or the Hugging Face Hub repository id. Additionally, this method also sets the chat_template for the tokenizer, so as to properly parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the default value is ChatML format, unless explicitly provided.

Source code in src/distilabel/models/llms/vllm.py
def load(self) -> None:
    """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
    Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
    parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
    default value is ChatML format, unless explicitly provided.
    """
    super().load()

    CudaDevicePlacementMixin.load(self)

    try:
        from vllm import LLM as _vLLM
    except ImportError as ie:
        raise ImportError(
            "vLLM is not installed. Please install it using `pip install 'distilabel[vllm]'`."
        ) from ie

    self._model = _vLLM(
        self.model,
        dtype=self.dtype,
        trust_remote_code=self.trust_remote_code,
        quantization=self.quantization,
        revision=self.revision,
        tokenizer=self.tokenizer,
        tokenizer_mode=self.tokenizer_mode,
        tokenizer_revision=self.tokenizer_revision,
        skip_tokenizer_init=self.skip_tokenizer_init,
        seed=self.seed,
        **self.extra_kwargs,  # type: ignore
    )

    self._tokenizer = self._model.get_tokenizer()  # type: ignore
    if self.chat_template is not None:
        self._tokenizer.chat_template = self.chat_template  # type: ignore

    if self.structured_output:
        self._structured_output_logits_processor = self._prepare_structured_output(
            self.structured_output
        )
unload()

Unloads the vLLM model.

Source code in src/distilabel/models/llms/vllm.py
def unload(self) -> None:
    """Unloads the `vLLM` model."""
    self._cleanup_vllm_model()
    self._model = None  # type: ignore
    self._tokenizer = None  # type: ignore
    CudaDevicePlacementMixin.unload(self)
    super().unload()
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input Union[StandardInput, str]

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/vllm.py
def prepare_input(self, input: Union["StandardInput", str]) -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if isinstance(input, str):
        return input

    prompt: str = (
        self._tokenizer.apply_chat_template(
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,  # type: ignore
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
_prepare_batches(inputs)

Prepares the inputs by grouping them by the structured output.

When we generate structured outputs with schemas obtained from a dataset, we need to prepare the data to try to send batches of inputs instead of single inputs to the model to take advante of the engine. So we group the inputs by the structured output to be passed in the generate method.

Parameters:

Name Type Description Default
inputs List[StructuredInput]

The batch of inputs passed to the generate method. As we expect to be generating structured outputs, each element will be a tuple containing the instruction and the structured output.

required

Returns:

Type Description
List[Tuple[List[str], OutlinesStructuredOutputType]]

The prepared batches (sub-batches let's say) to be passed to the generate method.

List[int]

Each new tuple will contain instead of the single instruction, a list of instructions

Source code in src/distilabel/models/llms/vllm.py
def _prepare_batches(
    self, inputs: List["StructuredInput"]
) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]:
    """Prepares the inputs by grouping them by the structured output.

    When we generate structured outputs with schemas obtained from a dataset, we need to
    prepare the data to try to send batches of inputs instead of single inputs to the model
    to take advante of the engine. So we group the inputs by the structured output to be
    passed in the `generate` method.

    Args:
        inputs: The batch of inputs passed to the generate method. As we expect to be generating
            structured outputs, each element will be a tuple containing the instruction and the
            structured output.

    Returns:
        The prepared batches (sub-batches let's say) to be passed to the `generate` method.
        Each new tuple will contain instead of the single instruction, a list of instructions
    """
    instruction_order = {}
    batches: Dict[str, List[str]] = {}
    for i, (instruction, structured_output) in enumerate(inputs):
        instruction = self.prepare_input(instruction)
        instruction_order[instruction] = i

        structured_output = json.dumps(structured_output)
        if structured_output not in batches:
            batches[structured_output] = [instruction]
        else:
            batches[structured_output].append(instruction)

    # Built a list with instructions sorted by structured output
    flat_instructions = [
        instruction for _, group in batches.items() for instruction in group
    ]

    # Generate the list of indices based on the original order
    sorted_indices = [
        instruction_order[instruction] for instruction in flat_instructions
    ]

    return [
        (batch, json.loads(schema)) for schema, batch in batches.items()
    ], sorted_indices
generate(inputs, num_generations=1, max_new_tokens=128, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, logprobs=None, stop=None, stop_token_ids=None, include_stop_str_in_output=False, skip_special_tokens=True, logits_processors=None, extra_sampling_params=None, echo=False)

Generates num_generations responses for each input.

Parameters:

Name Type Description Default
inputs List[FormattedInput]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
repetition_penalty float

the repetition penalty to use for the generation Defaults to 1.0.

1.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
top_k int

the top-k value to use for the generation. Defaults to 0.

-1
min_p float

the minimum probability to use for the generation. Defaults to 0.0.

0.0
logprobs Optional[PositiveInt]

number of log probabilities to return per output token. If None, then no log probability won't be returned. Defaults to None.

None
stop Optional[List[str]]

a list of strings that will be used to stop the generation when found. Defaults to None.

None
stop_token_ids Optional[List[int]]

a list of token ids that will be used to stop the generation when found. Defaults to None.

None
include_stop_str_in_output bool

whether to include the stop string in the output. Defaults to False.

False
skip_special_tokens bool

whether to exclude special tokens from the output. Defaults to False.

True
logits_processors Optional[LogitsProcessors]

a list of functions to process the logits before sampling. Defaults to None.

None
extra_sampling_params Optional[Dict[str, Any]]

dictionary with additional arguments to be passed to the SamplingParams class from vllm.

None
echo bool

whether to echo the include the prompt in the response or not. Defaults to False.

False

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/vllm.py
@validate_call
def generate(  # noqa: C901 # type: ignore
    self,
    inputs: List[FormattedInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    presence_penalty: float = 0.0,
    frequency_penalty: float = 0.0,
    repetition_penalty: float = 1.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    top_k: int = -1,
    min_p: float = 0.0,
    logprobs: Optional[PositiveInt] = None,
    stop: Optional[List[str]] = None,
    stop_token_ids: Optional[List[int]] = None,
    include_stop_str_in_output: bool = False,
    skip_special_tokens: bool = True,
    logits_processors: Optional[LogitsProcessors] = None,
    extra_sampling_params: Optional[Dict[str, Any]] = None,
    echo: bool = False,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        repetition_penalty: the repetition penalty to use for the generation Defaults to
            `1.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        min_p: the minimum probability to use for the generation. Defaults to `0.0`.
        logprobs: number of log probabilities to return per output token. If `None`,
            then no log probability won't be returned. Defaults to `None`.
        stop: a list of strings that will be used to stop the generation when found.
            Defaults to `None`.
        stop_token_ids: a list of token ids that will be used to stop the generation
            when found. Defaults to `None`.
        include_stop_str_in_output: whether to include the stop string in the output.
            Defaults to `False`.
        skip_special_tokens: whether to exclude special tokens from the output. Defaults
            to `False`.
        logits_processors: a list of functions to process the logits before sampling.
            Defaults to `None`.
        extra_sampling_params: dictionary with additional arguments to be passed to
            the `SamplingParams` class from `vllm`.
        echo: whether to echo the include the prompt in the response or not. Defaults
            to `False`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vllm import SamplingParams

    if not logits_processors:
        logits_processors = []

    if extra_sampling_params is None:
        extra_sampling_params = {}

    structured_output = None

    if isinstance(inputs[0], tuple):
        # Prepare the batches for structured generation
        prepared_batches, sorted_indices = self._prepare_batches(inputs)  # type: ignore
    else:
        # Simulate a batch without the structured output content
        prepared_batches = [([self.prepare_input(input) for input in inputs], None)]  # type: ignore
        sorted_indices = None

    # Case in which we have a single structured output for the dataset
    if self._structured_output_logits_processor:
        logits_processors.append(self._structured_output_logits_processor)

    batched_outputs: List["LLMOutput"] = []
    generations = []

    for prepared_inputs, structured_output in prepared_batches:
        if self.structured_output is not None and structured_output is not None:
            self._logger.warning(
                "An `structured_output` was provided in the model configuration, but"
                " one was also provided in the input. The input structured output will"
                " be used."
            )

        if structured_output is not None:
            logits_processors.append(
                self._prepare_structured_output(structured_output)  # type: ignore
            )

        sampling_params = SamplingParams(  # type: ignore
            n=num_generations,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            max_tokens=max_new_tokens,
            prompt_logprobs=logprobs if echo else None,
            logprobs=logprobs,
            stop=stop,
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_stop_str_in_output,
            skip_special_tokens=skip_special_tokens,
            logits_processors=logits_processors,
            **extra_sampling_params,
        )

        batch_outputs: List["RequestOutput"] = self._model.generate(
            prompts=prepared_inputs,
            sampling_params=sampling_params,
            use_tqdm=False,
        )

        # Remove structured output logit processor to avoid stacking structured output
        # logits processors that leads to non-sense generations
        if structured_output is not None:
            logits_processors.pop(-1)

        for input, outputs in zip(prepared_inputs, batch_outputs):
            processed_prompt_logprobs = []
            if outputs.prompt_logprobs is not None:
                processed_prompt_logprobs = self._get_llm_logprobs(
                    outputs.prompt_logprobs
                )
            texts, statistics, outputs_logprobs = self._process_outputs(
                input=input,
                outputs=outputs,
                echo=echo,
                prompt_logprobs=processed_prompt_logprobs,
            )
            batched_outputs.append(texts)
            generation = prepare_output(
                generations=texts,
                input_tokens=statistics["input_tokens"],
                output_tokens=statistics["output_tokens"],
                logprobs=outputs_logprobs,
            )

            generations.append(generation)

    if sorted_indices is not None:
        pairs = list(enumerate(sorted_indices))
        pairs.sort(key=lambda x: x[1])
        generations = [generations[original_idx] for original_idx, _ in pairs]

    return generations
_prepare_structured_output(structured_output)

Creates the appropriate function to filter tokens to generate structured outputs.

Parameters:

Name Type Description Default
structured_output OutlinesStructuredOutputType

the configuration dict to prepare the structured output.

required

Returns:

Type Description
Union[Callable, None]

The callable that will be used to guide the generation of the model.

Source code in src/distilabel/models/llms/vllm.py
def _prepare_structured_output(  # type: ignore
    self, structured_output: "OutlinesStructuredOutputType"
) -> Union[Callable, None]:
    """Creates the appropriate function to filter tokens to generate structured outputs.

    Args:
        structured_output: the configuration dict to prepare the structured output.

    Returns:
        The callable that will be used to guide the generation of the model.
    """
    from distilabel.steps.tasks.structured_outputs.outlines import (
        prepare_guided_output,
    )

    assert structured_output is not None, "`structured_output` cannot be `None`"

    result = prepare_guided_output(structured_output, "vllm", self._model)
    if (schema := result.get("schema")) and self.structured_output:
        self.structured_output["schema"] = schema
    return result["processor"]

CudaDevicePlacementMixin

Bases: BaseModel

Mixin class to assign CUDA devices to the LLM based on the cuda_devices attribute and the device placement information provided in _device_llm_placement_map. Providing the device placement information is optional, but if it is provided, it will be used to assign CUDA devices to the LLMs, trying to avoid using the same device for different LLMs.

Attributes:

Name Type Description
cuda_devices RuntimeParameter[Union[List[int], Literal['auto']]]

a list with the ID of the CUDA devices to be used by the LLM. If set to "auto", the devices will be automatically assigned based on the device placement information provided in _device_llm_placement_map. If set to a list of devices, it will be checked if the devices are available to be used by the LLM. If not, a warning will be logged.

disable_cuda_device_placement RuntimeParameter[bool]

Whether to disable the CUDA device placement logic or not. Defaults to False.

_llm_identifier Union[str, None]

the identifier of the LLM to be used as key in _device_llm_placement_map.

_device_llm_placement_map Generator[Dict[str, List[int]], None, None]

a dictionary with the device placement information for each LLM.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
class CudaDevicePlacementMixin(BaseModel):
    """Mixin class to assign CUDA devices to the `LLM` based on the `cuda_devices` attribute
    and the device placement information provided in `_device_llm_placement_map`. Providing
    the device placement information is optional, but if it is provided, it will be used to
    assign CUDA devices to the `LLM`s, trying to avoid using the same device for different
    `LLM`s.

    Attributes:
        cuda_devices: a list with the ID of the CUDA devices to be used by the `LLM`. If set
            to "auto", the devices will be automatically assigned based on the device
            placement information provided in `_device_llm_placement_map`. If set to a list
            of devices, it will be checked if the devices are available to be used by the
            `LLM`. If not, a warning will be logged.
        disable_cuda_device_placement: Whether to disable the CUDA device placement logic
            or not. Defaults to `False`.
        _llm_identifier: the identifier of the `LLM` to be used as key in `_device_llm_placement_map`.
        _device_llm_placement_map: a dictionary with the device placement information for each
            `LLM`.
    """

    cuda_devices: RuntimeParameter[Union[List[int], Literal["auto"]]] = Field(
        default="auto", description="A list with the ID of the CUDA devices to be used."
    )
    disable_cuda_device_placement: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether to disable the CUDA device placement logic or not.",
    )

    _llm_identifier: Union[str, None] = PrivateAttr(default=None)
    _desired_num_gpus: PositiveInt = PrivateAttr(default=1)
    _available_cuda_devices: List[int] = PrivateAttr(default_factory=list)
    _can_check_cuda_devices: bool = PrivateAttr(default=False)

    _logger: "Logger" = PrivateAttr(None)

    def load(self) -> None:
        """Assign CUDA devices to the LLM based on the device placement information provided
        in `_device_llm_placement_map`."""

        if self.disable_cuda_device_placement:
            return

        try:
            import pynvml

            pynvml.nvmlInit()
            device_count = pynvml.nvmlDeviceGetCount()
            self._available_cuda_devices = list(range(device_count))
            self._can_check_cuda_devices = True
        except ImportError as ie:
            if self.cuda_devices == "auto":
                raise ImportError(
                    "The 'pynvml' library is not installed. It is required to automatically"
                    " assign CUDA devices to the `LLM`s. Please, install it and try again."
                ) from ie

            if self.cuda_devices:
                self._logger.warning(  # type: ignore
                    "The 'pynvml' library is not installed. It is recommended to install it"
                    " to check if the CUDA devices assigned to the LLM are available."
                )

        self._assign_cuda_devices()

    def unload(self) -> None:
        """Unloads the LLM and removes the CUDA devices assigned to it from the device
        placement information provided in `_device_llm_placement_map`."""
        if self.disable_cuda_device_placement:
            return

        with self._device_llm_placement_map() as device_map:
            if self._llm_identifier in device_map:
                self._logger.debug(  # type: ignore
                    f"Removing '{self._llm_identifier}' from the CUDA device map file"
                    f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'."
                )
                del device_map[self._llm_identifier]

    @contextmanager
    def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]:
        """Reads the content of the device placement file of the node with a lock, yields
        the content, and writes the content back to the file after the context manager is
        closed. If the file doesn't exist, an empty dictionary will be yielded.

        Yields:
            The content of the device placement file.
        """
        _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.parent.mkdir(parents=True, exist_ok=True)
        _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch()
        with portalocker.Lock(
            _CUDA_DEVICE_PLACEMENT_MIXIN_FILE,
            "r+",
            flags=portalocker.LockFlags.EXCLUSIVE,
        ) as f:
            try:
                content = json.load(f)
            except json.JSONDecodeError:
                content = {}
            yield content
            f.seek(0)
            f.truncate()
            f.write(json.dumps(content))

    def _assign_cuda_devices(self) -> None:
        """Assigns CUDA devices to the LLM based on the device placement information provided
        in `_device_llm_placement_map`. If the `cuda_devices` attribute is set to "auto", it
        will be set to the first available CUDA device that is not going to be used by any
        other LLM. If the `cuda_devices` attribute is set to a list of devices, it will be
        checked if the devices are available to be used by the LLM. If not, a warning will be
        logged."""

        # Take the lock and read the device placement information for each LLM.
        with self._device_llm_placement_map() as device_map:
            if self.cuda_devices == "auto":
                self.cuda_devices = []
                for _ in range(self._desired_num_gpus):
                    if (device_id := self._get_cuda_device(device_map)) is not None:
                        self.cuda_devices.append(device_id)
                        device_map[self._llm_identifier] = self.cuda_devices  # type: ignore
                if len(self.cuda_devices) != self._desired_num_gpus:
                    self._logger.warning(  # type: ignore
                        f"Could not assign the desired number of GPUs {self._desired_num_gpus}"
                        f" for LLM with identifier '{self._llm_identifier}'."
                    )
            else:
                self._check_cuda_devices(device_map)

            device_map[self._llm_identifier] = self.cuda_devices  # type: ignore

        # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices`
        # attribute. In this case, the `cuda_devices` attribute will be set to an empty list.
        if self.cuda_devices == "auto":
            self.cuda_devices = []

        self._set_cuda_visible_devices()

    def _check_cuda_devices(self, device_map: Dict[str, List[int]]) -> None:
        """Checks if the CUDA devices assigned to the LLM are also assigned to other LLMs.

        Args:
            device_map: a dictionary with the device placement information for each LLM.
        """
        for device in self.cuda_devices:  # type: ignore
            for llm, devices in device_map.items():
                if device in devices:
                    self._logger.warning(  # type: ignore
                        f"LLM with identifier '{llm}' is also going to use CUDA device "
                        f"'{device}'. This may lead to performance issues or running out"
                        " of memory depending on the device capabilities and the loaded"
                        " models."
                    )

    def _get_cuda_device(self, device_map: Dict[str, List[int]]) -> Union[int, None]:
        """Returns the first available CUDA device to be used by the LLM that is not going
        to be used by any other LLM.

        Args:
            device_map: a dictionary with the device placement information for each LLM.

        Returns:
            The first available CUDA device to be used by the LLM.

        Raises:
            RuntimeError: if there is no available CUDA device to be used by the LLM.
        """
        for device in self._available_cuda_devices:
            if all(device not in devices for devices in device_map.values()):
                return device

        return None

    def _set_cuda_visible_devices(self) -> None:
        """Sets the `CUDA_VISIBLE_DEVICES` environment variable to the list of CUDA devices
        to be used by the LLM.
        """
        if not self.cuda_devices:
            return

        if self._can_check_cuda_devices and not all(
            device in self._available_cuda_devices for device in self.cuda_devices
        ):
            raise RuntimeError(
                f"Invalid CUDA devices for LLM '{self._llm_identifier}': {self.cuda_devices}."
                f" The available devices are: {self._available_cuda_devices}. Please, review"
                " the 'cuda_devices' attribute and try again."
            )

        cuda_devices = ",".join([str(device) for device in self.cuda_devices])
        self._logger.info(  # type: ignore
            f"🎮 LLM '{self._llm_identifier}' is going to use the following CUDA devices:"
            f" {self.cuda_devices}."
        )
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices
load()

Assign CUDA devices to the LLM based on the device placement information provided in _device_llm_placement_map.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
def load(self) -> None:
    """Assign CUDA devices to the LLM based on the device placement information provided
    in `_device_llm_placement_map`."""

    if self.disable_cuda_device_placement:
        return

    try:
        import pynvml

        pynvml.nvmlInit()
        device_count = pynvml.nvmlDeviceGetCount()
        self._available_cuda_devices = list(range(device_count))
        self._can_check_cuda_devices = True
    except ImportError as ie:
        if self.cuda_devices == "auto":
            raise ImportError(
                "The 'pynvml' library is not installed. It is required to automatically"
                " assign CUDA devices to the `LLM`s. Please, install it and try again."
            ) from ie

        if self.cuda_devices:
            self._logger.warning(  # type: ignore
                "The 'pynvml' library is not installed. It is recommended to install it"
                " to check if the CUDA devices assigned to the LLM are available."
            )

    self._assign_cuda_devices()
unload()

Unloads the LLM and removes the CUDA devices assigned to it from the device placement information provided in _device_llm_placement_map.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
def unload(self) -> None:
    """Unloads the LLM and removes the CUDA devices assigned to it from the device
    placement information provided in `_device_llm_placement_map`."""
    if self.disable_cuda_device_placement:
        return

    with self._device_llm_placement_map() as device_map:
        if self._llm_identifier in device_map:
            self._logger.debug(  # type: ignore
                f"Removing '{self._llm_identifier}' from the CUDA device map file"
                f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'."
            )
            del device_map[self._llm_identifier]
_device_llm_placement_map()

Reads the content of the device placement file of the node with a lock, yields the content, and writes the content back to the file after the context manager is closed. If the file doesn't exist, an empty dictionary will be yielded.

Yields:

Type Description
Dict[str, List[int]]

The content of the device placement file.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
@contextmanager
def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]:
    """Reads the content of the device placement file of the node with a lock, yields
    the content, and writes the content back to the file after the context manager is
    closed. If the file doesn't exist, an empty dictionary will be yielded.

    Yields:
        The content of the device placement file.
    """
    _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.parent.mkdir(parents=True, exist_ok=True)
    _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch()
    with portalocker.Lock(
        _CUDA_DEVICE_PLACEMENT_MIXIN_FILE,
        "r+",
        flags=portalocker.LockFlags.EXCLUSIVE,
    ) as f:
        try:
            content = json.load(f)
        except json.JSONDecodeError:
            content = {}
        yield content
        f.seek(0)
        f.truncate()
        f.write(json.dumps(content))
_assign_cuda_devices()

Assigns CUDA devices to the LLM based on the device placement information provided in _device_llm_placement_map. If the cuda_devices attribute is set to "auto", it will be set to the first available CUDA device that is not going to be used by any other LLM. If the cuda_devices attribute is set to a list of devices, it will be checked if the devices are available to be used by the LLM. If not, a warning will be logged.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
def _assign_cuda_devices(self) -> None:
    """Assigns CUDA devices to the LLM based on the device placement information provided
    in `_device_llm_placement_map`. If the `cuda_devices` attribute is set to "auto", it
    will be set to the first available CUDA device that is not going to be used by any
    other LLM. If the `cuda_devices` attribute is set to a list of devices, it will be
    checked if the devices are available to be used by the LLM. If not, a warning will be
    logged."""

    # Take the lock and read the device placement information for each LLM.
    with self._device_llm_placement_map() as device_map:
        if self.cuda_devices == "auto":
            self.cuda_devices = []
            for _ in range(self._desired_num_gpus):
                if (device_id := self._get_cuda_device(device_map)) is not None:
                    self.cuda_devices.append(device_id)
                    device_map[self._llm_identifier] = self.cuda_devices  # type: ignore
            if len(self.cuda_devices) != self._desired_num_gpus:
                self._logger.warning(  # type: ignore
                    f"Could not assign the desired number of GPUs {self._desired_num_gpus}"
                    f" for LLM with identifier '{self._llm_identifier}'."
                )
        else:
            self._check_cuda_devices(device_map)

        device_map[self._llm_identifier] = self.cuda_devices  # type: ignore

    # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices`
    # attribute. In this case, the `cuda_devices` attribute will be set to an empty list.
    if self.cuda_devices == "auto":
        self.cuda_devices = []

    self._set_cuda_visible_devices()
_check_cuda_devices(device_map)

Checks if the CUDA devices assigned to the LLM are also assigned to other LLMs.

Parameters:

Name Type Description Default
device_map Dict[str, List[int]]

a dictionary with the device placement information for each LLM.

required
Source code in src/distilabel/models/mixins/cuda_device_placement.py
def _check_cuda_devices(self, device_map: Dict[str, List[int]]) -> None:
    """Checks if the CUDA devices assigned to the LLM are also assigned to other LLMs.

    Args:
        device_map: a dictionary with the device placement information for each LLM.
    """
    for device in self.cuda_devices:  # type: ignore
        for llm, devices in device_map.items():
            if device in devices:
                self._logger.warning(  # type: ignore
                    f"LLM with identifier '{llm}' is also going to use CUDA device "
                    f"'{device}'. This may lead to performance issues or running out"
                    " of memory depending on the device capabilities and the loaded"
                    " models."
                )
_get_cuda_device(device_map)

Returns the first available CUDA device to be used by the LLM that is not going to be used by any other LLM.

Parameters:

Name Type Description Default
device_map Dict[str, List[int]]

a dictionary with the device placement information for each LLM.

required

Returns:

Type Description
Union[int, None]

The first available CUDA device to be used by the LLM.

Raises:

Type Description
RuntimeError

if there is no available CUDA device to be used by the LLM.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
def _get_cuda_device(self, device_map: Dict[str, List[int]]) -> Union[int, None]:
    """Returns the first available CUDA device to be used by the LLM that is not going
    to be used by any other LLM.

    Args:
        device_map: a dictionary with the device placement information for each LLM.

    Returns:
        The first available CUDA device to be used by the LLM.

    Raises:
        RuntimeError: if there is no available CUDA device to be used by the LLM.
    """
    for device in self._available_cuda_devices:
        if all(device not in devices for devices in device_map.values()):
            return device

    return None
_set_cuda_visible_devices()

Sets the CUDA_VISIBLE_DEVICES environment variable to the list of CUDA devices to be used by the LLM.

Source code in src/distilabel/models/mixins/cuda_device_placement.py
def _set_cuda_visible_devices(self) -> None:
    """Sets the `CUDA_VISIBLE_DEVICES` environment variable to the list of CUDA devices
    to be used by the LLM.
    """
    if not self.cuda_devices:
        return

    if self._can_check_cuda_devices and not all(
        device in self._available_cuda_devices for device in self.cuda_devices
    ):
        raise RuntimeError(
            f"Invalid CUDA devices for LLM '{self._llm_identifier}': {self.cuda_devices}."
            f" The available devices are: {self._available_cuda_devices}. Please, review"
            " the 'cuda_devices' attribute and try again."
        )

    cuda_devices = ",".join([str(device) for device in self.cuda_devices])
    self._logger.info(  # type: ignore
        f"🎮 LLM '{self._llm_identifier}' is going to use the following CUDA devices:"
        f" {self.cuda_devices}."
    )
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices