Skip to content

LLM Gallery

This section contains the existing LLM subclasses implemented in distilabel.

llms

AnthropicLLM

Bases: AsyncLLM

Anthropic LLM implementation running the Async API client.

Attributes:

Name Type Description
model str

the name of the model to use for the LLM e.g. "claude-3-opus-20240229", "claude-3-sonnet-20240229", etc. Available models can be checked here: Anthropic: Models overview.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Anthropic API. If not provided, it will be read from ANTHROPIC_API_KEY environment variable.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Anthropic API. Defaults to None which means that https://api.anthropic.com will be used internally.

timeout RuntimeParameter[float]

the maximum time in seconds to wait for a response. Defaults to 600.0.

max_retries RuntimeParameter[int]

The maximum number of times to retry the request before failing. Defaults to 6.

http_client Optional[AsyncClient]

if provided, an alternative HTTP client to use for calling Anthropic API. Defaults to None.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

_aclient Optional[AsyncAnthropic]

the AsyncAnthropic client to use for the Anthropic API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • api_key: the API key to authenticate the requests to the Anthropic API. If not provided, it will be read from ANTHROPIC_API_KEY environment variable.
  • base_url: the base URL to use for the Anthropic API. Defaults to "https://api.anthropic.com".
  • timeout: the maximum time in seconds to wait for a response. Defaults to 600.0.
  • max_retries: the maximum number of times to retry the request before failing. Defaults to 6.

Examples:

Generate text:

from distilabel.models.llms import AnthropicLLM

llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import AnthropicLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = AnthropicLLM(
    model="claude-3-opus-20240229",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/anthropic.py
class AnthropicLLM(AsyncLLM):
    """Anthropic LLM implementation running the Async API client.

    Attributes:
        model: the name of the model to use for the LLM e.g. "claude-3-opus-20240229",
            "claude-3-sonnet-20240229", etc. Available models can be checked here:
            [Anthropic: Models overview](https://docs.anthropic.com/claude/docs/models-overview).
        api_key: the API key to authenticate the requests to the Anthropic API. If not provided,
            it will be read from `ANTHROPIC_API_KEY` environment variable.
        base_url: the base URL to use for the Anthropic API. Defaults to `None` which means
            that `https://api.anthropic.com` will be used internally.
        timeout: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        max_retries: The maximum number of times to retry the request before failing. Defaults
            to `6`.
        http_client: if provided, an alternative HTTP client to use for calling Anthropic
            API. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.
        _aclient: the `AsyncAnthropic` client to use for the Anthropic API. It is meant
            to be used internally. Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Anthropic API. If not
            provided, it will be read from `ANTHROPIC_API_KEY` environment variable.
        - `base_url`: the base URL to use for the Anthropic API. Defaults to `"https://api.anthropic.com"`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        - `max_retries`: the maximum number of times to retry the request before failing.
            Defaults to `6`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnthropicLLM

        llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import AnthropicLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = AnthropicLLM(
            model="claude-3-opus-20240229",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANTHROPIC_BASE_URL", "https://api.anthropic.com"
        ),
        description="The base URL to use for the Anthropic API.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANTHROPIC_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anthropic API.",
    )
    timeout: RuntimeParameter[float] = Field(
        default=600.0,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    http_client: Optional[AsyncClient] = Field(default=None, exclude=True)
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)

    def _check_model_exists(self) -> None:
        """Checks if the specified model exists in the available models."""
        from anthropic import AsyncAnthropic

        annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
        models = [
            value
            for type_ in get_args(annotation)
            if get_origin(type_) is Literal
            for value in get_args(type_)
        ]

        if self.model not in models:
            raise ValueError(
                f"Model {self.model} does not exist among available models. "
                f"The available models are {', '.join(models)}"
            )

    def load(self) -> None:
        """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
        super().load()

        try:
            from anthropic import AsyncAnthropic
        except ImportError as ie:
            raise ImportError(
                "Anthropic Python client is not installed. Please install it using"
                " `pip install 'distilabel[anthropic]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._check_model_exists()

        self._aclient = AsyncAnthropic(
            api_key=self.api_key.get_secret_value(),
            base_url=self.base_url,
            timeout=self.timeout,
            http_client=self.http_client,
            max_retries=self.max_retries,
        )
        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="anthropic",
            )
            self._aclient = result.get("client")
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_tokens: int = 128,
        stop_sequences: Union[List[str], None] = None,
        temperature: float = 1.0,
        top_p: Union[float, None] = None,
        top_k: Union[int, None] = None,
    ) -> GenerateOutput:
        """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

        Args:
            input: a single input in chat format to generate responses for.
            max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
            stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
            temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
            top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from anthropic._types import NOT_GIVEN

        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="anthropic",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "system": (
                input.pop(0)["content"]
                if input and input[0]["role"] == "system"
                else NOT_GIVEN
            ),
            "max_tokens": max_tokens,
            "stream": False,
            "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences,
            "temperature": temperature,
            "top_p": NOT_GIVEN if top_p is None else top_p,
            "top_k": NOT_GIVEN if top_k is None else top_k,
        }

        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
            **kwargs
        )  # type: ignore
        if structured_output:
            # raw_response = completion._raw_response
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        if (content := completion.content[0].text) is None:
            self._logger.warning(
                f"Received no response using Anthropic client (model: '{self.model}')."
                f" Finish reason was: {completion.stop_reason}"
            )
        return prepare_output([content], **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "Message") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.input_tokens],
            "output_tokens": [completion.usage.output_tokens],
        }
model_name property

Returns the model name used for the LLM.

_check_model_exists()

Checks if the specified model exists in the available models.

Source code in src/distilabel/models/llms/anthropic.py
def _check_model_exists(self) -> None:
    """Checks if the specified model exists in the available models."""
    from anthropic import AsyncAnthropic

    annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
    models = [
        value
        for type_ in get_args(annotation)
        if get_origin(type_) is Literal
        for value in get_args(type_)
    ]

    if self.model not in models:
        raise ValueError(
            f"Model {self.model} does not exist among available models. "
            f"The available models are {', '.join(models)}"
        )
load()

Loads the AsyncAnthropic client to use the Anthropic async API.

Source code in src/distilabel/models/llms/anthropic.py
def load(self) -> None:
    """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
    super().load()

    try:
        from anthropic import AsyncAnthropic
    except ImportError as ie:
        raise ImportError(
            "Anthropic Python client is not installed. Please install it using"
            " `pip install 'distilabel[anthropic]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._check_model_exists()

    self._aclient = AsyncAnthropic(
        api_key=self.api_key.get_secret_value(),
        base_url=self.base_url,
        timeout=self.timeout,
        http_client=self.http_client,
        max_retries=self.max_retries,
    )
    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="anthropic",
        )
        self._aclient = result.get("client")
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, max_tokens=128, stop_sequences=None, temperature=1.0, top_p=None, top_k=None) async

Generates a response asynchronously, using the Anthropic Async API definition.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
max_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
stop_sequences Union[List[str], None]

custom text sequences that will cause the model to stop generating. Defaults to NOT_GIVEN.

None
temperature float

the temperature to use for the generation. Set only if top_p is None. Defaults to 1.0.

1.0
top_p Union[float, None]

the top-p value to use for the generation. Defaults to NOT_GIVEN.

None
top_k Union[int, None]

the top-k value to use for the generation. Defaults to NOT_GIVEN.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/anthropic.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_tokens: int = 128,
    stop_sequences: Union[List[str], None] = None,
    temperature: float = 1.0,
    top_p: Union[float, None] = None,
    top_k: Union[int, None] = None,
) -> GenerateOutput:
    """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

    Args:
        input: a single input in chat format to generate responses for.
        max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
        stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
        temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
        top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from anthropic._types import NOT_GIVEN

    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="anthropic",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "system": (
            input.pop(0)["content"]
            if input and input[0]["role"] == "system"
            else NOT_GIVEN
        ),
        "max_tokens": max_tokens,
        "stream": False,
        "stop_sequences": NOT_GIVEN if stop_sequences is None else stop_sequences,
        "temperature": temperature,
        "top_p": NOT_GIVEN if top_p is None else top_p,
        "top_k": NOT_GIVEN if top_k is None else top_k,
    }

    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
        **kwargs
    )  # type: ignore
    if structured_output:
        # raw_response = completion._raw_response
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    if (content := completion.content[0].text) is None:
        self._logger.warning(
            f"Received no response using Anthropic client (model: '{self.model}')."
            f" Finish reason was: {completion.stop_reason}"
        )
    return prepare_output([content], **self._get_llm_statistics(completion))

AnyscaleLLM

Bases: OpenAILLM

Anyscale LLM implementation running the async API client of OpenAI.

Attributes:

Name Type Description
model str

the model name to use for the LLM, e.g., google/gemma-7b-it. See the supported models under the "Text Generation -> Supported Models" section here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Anyscale API requests. Defaults to None, which means that the value set for the environment variable ANYSCALE_BASE_URL will be used, or "https://api.endpoints.anyscale.com/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Anyscale API. Defaults to None which means that the value set for the environment variable ANYSCALE_API_KEY will be used, or None if not set.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

Examples:

Generate text:

from distilabel.models.llms import AnyscaleLLM

llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/anyscale.py
class AnyscaleLLM(OpenAILLM):
    """Anyscale LLM implementation running the async API client of OpenAI.

    Attributes:
        model: the model name to use for the LLM, e.g., `google/gemma-7b-it`. See the
            supported models under the "Text Generation -> Supported Models" section
            [here](https://docs.endpoints.anyscale.com/).
        base_url: the base URL to use for the Anyscale API requests. Defaults to `None`, which
            means that the value set for the environment variable `ANYSCALE_BASE_URL` will be used, or
            "https://api.endpoints.anyscale.com/v1" if not set.
        api_key: the API key to authenticate the requests to the Anyscale API. Defaults to `None` which
            means that the value set for the environment variable `ANYSCALE_API_KEY` will be used, or
            `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key.
            It is meant to be used internally.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AnyscaleLLM

        llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANYSCALE_BASE_URL", "https://api.endpoints.anyscale.com/v1"
        ),
        description="The base URL to use for the Anyscale API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANYSCALE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anyscale API.",
    )

    _api_key_env_var: str = PrivateAttr(_ANYSCALE_API_KEY_ENV_VAR_NAME)

AzureOpenAILLM

Bases: OpenAILLM

Azure OpenAI LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM i.e. the name of the Azure deployment.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Azure OpenAI API can be set with AZURE_OPENAI_ENDPOINT. Defaults to None which means that the value set for the environment variable AZURE_OPENAI_ENDPOINT will be used, or None if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Azure OpenAI API. Defaults to None which means that the value set for the environment variable AZURE_OPENAI_API_KEY will be used, or None if not set.

api_version Optional[RuntimeParameter[str]]

the API version to use for the Azure OpenAI API. Defaults to None which means that the value set for the environment variable OPENAI_API_VERSION will be used, or None if not set.

Icon

:material-microsoft-azure:

Examples:

Generate text:

from distilabel.models.llms import AzureOpenAILLM

llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate text from a custom endpoint following the OpenAI API:

from distilabel.models.llms import AzureOpenAILLM

llm = AzureOpenAILLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    base_url=r"http://localhost:8080/v1"
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import AzureOpenAILLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = AzureOpenAILLM(
    model="gpt-4-turbo",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/azure.py
class AzureOpenAILLM(OpenAILLM):
    """Azure OpenAI LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM i.e. the name of the Azure deployment.
        base_url: the base URL to use for the Azure OpenAI API can be set with `AZURE_OPENAI_ENDPOINT`.
            Defaults to `None` which means that the value set for the environment variable
            `AZURE_OPENAI_ENDPOINT` will be used, or `None` if not set.
        api_key: the API key to authenticate the requests to the Azure OpenAI API. Defaults to `None`
            which means that the value set for the environment variable `AZURE_OPENAI_API_KEY` will be
            used, or `None` if not set.
        api_version: the API version to use for the Azure OpenAI API. Defaults to `None` which means
            that the value set for the environment variable `OPENAI_API_VERSION` will be used, or
            `None` if not set.

    Icon:
        `:material-microsoft-azure:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import AzureOpenAILLM

        llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate text from a custom endpoint following the OpenAI API:

        ```python
        from distilabel.models.llms import AzureOpenAILLM

        llm = AzureOpenAILLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            base_url=r"http://localhost:8080/v1"
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import AzureOpenAILLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = AzureOpenAILLM(
            model="gpt-4-turbo",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME),
        description="The base URL to use for the Azure OpenAI API requests i.e. the Azure OpenAI endpoint.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Azure OpenAI API.",
    )

    api_version: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv("OPENAI_API_VERSION"),
        description="The API version to use for the Azure OpenAI API.",
    )

    _base_url_env_var: str = PrivateAttr(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME)
    _api_key_env_var: str = PrivateAttr(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAzureOpenAI"] = PrivateAttr(...)  # type: ignore

    @override
    def load(self) -> None:
        """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
        # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
        # in the load method before we have the proper client.
        with patch(
            "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
        ):
            super().load()

        try:
            from openai import AsyncAzureOpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install 'distilabel[openai]'`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        # TODO: May be worth adding the AD auth too? Also the `organization`?
        self._aclient = AsyncAzureOpenAI(  # type: ignore
            azure_endpoint=self.base_url,  # type: ignore
            azure_deployment=self.model,
            api_version=self.api_version,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        if self.structured_output:
            self._prepare_structured_output(self.structured_output)
load()

Loads the AsyncAzureOpenAI client to benefit from async requests.

Source code in src/distilabel/models/llms/azure.py
@override
def load(self) -> None:
    """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
    # This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
    # in the load method before we have the proper client.
    with patch(
        "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
    ):
        super().load()

    try:
        from openai import AsyncAzureOpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install 'distilabel[openai]'`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    # TODO: May be worth adding the AD auth too? Also the `organization`?
    self._aclient = AsyncAzureOpenAI(  # type: ignore
        azure_endpoint=self.base_url,  # type: ignore
        azure_deployment=self.model,
        api_version=self.api_version,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    if self.structured_output:
        self._prepare_structured_output(self.structured_output)

CohereLLM

Bases: AsyncLLM

Cohere API implementation using the async client for concurrent text generation.

Attributes:

Name Type Description
model str

the name of the model from the Cohere API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Cohere API requests. Defaults to "https://api.cohere.ai/v1".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

client_name RuntimeParameter[str]

the name of the client to use for the API requests. Defaults to "distilabel".

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_ChatMessage Type[ChatMessage]

the ChatMessage class from the cohere package.

_aclient AsyncClient

the AsyncClient client from the cohere package.

Runtime parameters
  • base_url: the base URL to use for the Cohere API requests. Defaults to "https://api.cohere.ai/v1".
  • api_key: the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
  • client_name: the name of the client to use for the API requests. Defaults to "distilabel".

Examples:

Generate text:

from distilabel.models.llms import CohereLLM

llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import CohereLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = CohereLLM(
    model="CohereForAI/c4ai-command-r-plus",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/cohere.py
class CohereLLM(AsyncLLM):
    """Cohere API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Cohere API to use for the generation.
        base_url: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        api_key: the API key to authenticate the requests to the Cohere API. Defaults to
            the value of the `COHERE_API_KEY` environment variable.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        client_name: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _ChatMessage: the `ChatMessage` class from the `cohere` package.
        _aclient: the `AsyncClient` client from the `cohere` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        - `api_key`: the API key to authenticate the requests to the Cohere API. Defaults
            to the value of the `COHERE_API_KEY` environment variable.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `client_name`: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import CohereLLM

        llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import CohereLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = CohereLLM(
            model="CohereForAI/c4ai-command-r-plus",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "COHERE_BASE_URL", "https://api.cohere.ai/v1"
        ),
        description="The base URL to use for the Cohere API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_COHERE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Cohere API.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    client_name: RuntimeParameter[str] = Field(
        default="distilabel",
        description="The name of the client to use for the API requests.",
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
    _aclient: "AsyncClient" = PrivateAttr(...)
    _tokenizer: "Tokenizer" = PrivateAttr(...)

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def load(self) -> None:
        """Loads the `AsyncClient` client from the `cohere` package."""

        super().load()

        try:
            from cohere import AsyncClient, ChatMessage
        except ImportError as ie:
            raise ImportError(
                "The `cohere` package is required to use the `CohereLLM` class."
            ) from ie

        self._ChatMessage = ChatMessage

        self._aclient = AsyncClient(
            api_key=self.api_key.get_secret_value(),  # type: ignore
            client_name=self.client_name,
            base_url=self.base_url,
            timeout=self.timeout,
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="cohere",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

        from cohere.manually_maintained.tokenizers import get_hf_tokenizer

        self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)

    def _format_chat_to_cohere(
        self, input: "FormattedInput"
    ) -> Tuple[Union[str, None], List["ChatMessage"], str]:
        """Formats the chat input to the Cohere Chat API conversational format.

        Args:
            input: The chat input to format.

        Returns:
            A tuple containing the system, chat history, and message.
        """
        system = None
        message = None
        chat_history = []
        for item in input:
            role = item["role"]
            content = item["content"]
            if role == "system":
                system = content
            elif role == "user":
                message = content
            elif role == "assistant":
                if message is None:
                    raise ValueError(
                        "An assistant message but be preceded by a user message."
                    )
                chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
                chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
                message = None

        if message is None:
            raise ValueError("The chat input must end with a user message.")

        return system, chat_history, message

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        k: Optional[int] = None,
        p: Optional[float] = None,
        seed: Optional[float] = None,
        stop_sequences: Optional[Sequence[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        raw_prompting: Optional[bool] = None,
    ) -> GenerateOutput:
        """Generates a response from the LLM given an input.

        Args:
            input: a single input in chat format to generate responses for.
            temperature: the temperature to use for the generation. Defaults to `None`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `None`.
            k: the number of highest probability vocabulary tokens to keep for the generation.
                Defaults to `None`.
            p: the nucleus sampling probability to use for the generation. Defaults to
                `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: a list of sequences to use as stopping criteria for the generation.
                Defaults to `None`.
            frequency_penalty: the frequency penalty to use for the generation. Defaults
                to `None`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `None`.
            raw_prompting: a flag to use raw prompting for the generation. Defaults to
                `None`.

        Returns:
            The generated response from the Cohere API model.
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,  # type: ignore
                client=self._aclient,
                framework="cohere",
            )
            self._aclient = result.get("client")  # type: ignore

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        system, chat_history, message = self._format_chat_to_cohere(input)

        kwargs = {
            "message": message,
            "model": self.model,
            "preamble": system,
            "chat_history": chat_history,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "k": k,
            "p": p,
            "seed": seed,
            "stop_sequences": stop_sequences,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "raw_prompting": raw_prompting,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

        response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs)  # type: ignore

        if structured_output:
            return prepare_output(
                [response.model_dump_json()],
                **self._get_llm_statistics(
                    input, orjson.dumps(response.model_dump_json()).decode("utf-8")
                ),  # type: ignore
            )

        if (text := response.text) == "":
            self._logger.warning(  # type: ignore
                f"Received no response using Cohere client (model: '{self.model}')."
                f" Finish reason was: {response.finish_reason}"
            )
            return prepare_output(
                [None],
                **self._get_llm_statistics(input, ""),
            )

        return prepare_output(
            [text],
            **self._get_llm_statistics(input, text),
        )

    def _get_llm_statistics(
        self, input: FormattedInput, output: str
    ) -> "LLMStatistics":
        return {
            "input_tokens": [compute_tokens(input, self._tokenizer.encode)],
            "output_tokens": [compute_tokens(output, self._tokenizer.encode)],
        }
model_name property

Returns the model name used for the LLM.

load()

Loads the AsyncClient client from the cohere package.

Source code in src/distilabel/models/llms/cohere.py
def load(self) -> None:
    """Loads the `AsyncClient` client from the `cohere` package."""

    super().load()

    try:
        from cohere import AsyncClient, ChatMessage
    except ImportError as ie:
        raise ImportError(
            "The `cohere` package is required to use the `CohereLLM` class."
        ) from ie

    self._ChatMessage = ChatMessage

    self._aclient = AsyncClient(
        api_key=self.api_key.get_secret_value(),  # type: ignore
        client_name=self.client_name,
        base_url=self.base_url,
        timeout=self.timeout,
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="cohere",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output

    from cohere.manually_maintained.tokenizers import get_hf_tokenizer

    self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model)
_format_chat_to_cohere(input)

Formats the chat input to the Cohere Chat API conversational format.

Parameters:

Name Type Description Default
input FormattedInput

The chat input to format.

required

Returns:

Type Description
Tuple[Union[str, None], List[ChatMessage], str]

A tuple containing the system, chat history, and message.

Source code in src/distilabel/models/llms/cohere.py
def _format_chat_to_cohere(
    self, input: "FormattedInput"
) -> Tuple[Union[str, None], List["ChatMessage"], str]:
    """Formats the chat input to the Cohere Chat API conversational format.

    Args:
        input: The chat input to format.

    Returns:
        A tuple containing the system, chat history, and message.
    """
    system = None
    message = None
    chat_history = []
    for item in input:
        role = item["role"]
        content = item["content"]
        if role == "system":
            system = content
        elif role == "user":
            message = content
        elif role == "assistant":
            if message is None:
                raise ValueError(
                    "An assistant message but be preceded by a user message."
                )
            chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
            chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
            message = None

    if message is None:
        raise ValueError("The chat input must end with a user message.")

    return system, chat_history, message
agenerate(input, temperature=None, max_tokens=None, k=None, p=None, seed=None, stop_sequences=None, frequency_penalty=None, presence_penalty=None, raw_prompting=None) async

Generates a response from the LLM given an input.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
temperature Optional[float]

the temperature to use for the generation. Defaults to None.

None
max_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to None.

None
k Optional[int]

the number of highest probability vocabulary tokens to keep for the generation. Defaults to None.

None
p Optional[float]

the nucleus sampling probability to use for the generation. Defaults to None.

None
seed Optional[float]

the seed to use for the generation. Defaults to None.

None
stop_sequences Optional[Sequence[str]]

a list of sequences to use as stopping criteria for the generation. Defaults to None.

None
frequency_penalty Optional[float]

the frequency penalty to use for the generation. Defaults to None.

None
presence_penalty Optional[float]

the presence penalty to use for the generation. Defaults to None.

None
raw_prompting Optional[bool]

a flag to use raw prompting for the generation. Defaults to None.

None

Returns:

Type Description
GenerateOutput

The generated response from the Cohere API model.

Source code in src/distilabel/models/llms/cohere.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    k: Optional[int] = None,
    p: Optional[float] = None,
    seed: Optional[float] = None,
    stop_sequences: Optional[Sequence[str]] = None,
    frequency_penalty: Optional[float] = None,
    presence_penalty: Optional[float] = None,
    raw_prompting: Optional[bool] = None,
) -> GenerateOutput:
    """Generates a response from the LLM given an input.

    Args:
        input: a single input in chat format to generate responses for.
        temperature: the temperature to use for the generation. Defaults to `None`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `None`.
        k: the number of highest probability vocabulary tokens to keep for the generation.
            Defaults to `None`.
        p: the nucleus sampling probability to use for the generation. Defaults to
            `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: a list of sequences to use as stopping criteria for the generation.
            Defaults to `None`.
        frequency_penalty: the frequency penalty to use for the generation. Defaults
            to `None`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `None`.
        raw_prompting: a flag to use raw prompting for the generation. Defaults to
            `None`.

    Returns:
        The generated response from the Cohere API model.
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,  # type: ignore
            client=self._aclient,
            framework="cohere",
        )
        self._aclient = result.get("client")  # type: ignore

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    system, chat_history, message = self._format_chat_to_cohere(input)

    kwargs = {
        "message": message,
        "model": self.model,
        "preamble": system,
        "chat_history": chat_history,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "k": k,
        "p": p,
        "seed": seed,
        "stop_sequences": stop_sequences,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "raw_prompting": raw_prompting,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

    response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs)  # type: ignore

    if structured_output:
        return prepare_output(
            [response.model_dump_json()],
            **self._get_llm_statistics(
                input, orjson.dumps(response.model_dump_json()).decode("utf-8")
            ),  # type: ignore
        )

    if (text := response.text) == "":
        self._logger.warning(  # type: ignore
            f"Received no response using Cohere client (model: '{self.model}')."
            f" Finish reason was: {response.finish_reason}"
        )
        return prepare_output(
            [None],
            **self._get_llm_statistics(input, ""),
        )

    return prepare_output(
        [text],
        **self._get_llm_statistics(input, text),
    )

GroqLLM

Bases: AsyncLLM

Groq API implementation using the async client for concurrent text generation.

Attributes:

Name Type Description
model str

the name of the model from the Groq API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Groq API requests. Defaults to "https://api.groq.com".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Groq API. Defaults to the value of the GROQ_API_KEY environment variable.

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 2.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_api_key_env_var str

the name of the environment variable to use for the API key.

_aclient Optional[AsyncGroq]

the AsyncGroq client from the groq package.

Runtime parameters
  • base_url: the base URL to use for the Groq API requests. Defaults to "https://api.groq.com".
  • api_key: the API key to authenticate the requests to the Groq API. Defaults to the value of the GROQ_API_KEY environment variable.
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 2.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.

Examples:

Generate text:

from distilabel.models.llms import GroqLLM

llm = GroqLLM(model="llama3-70b-8192")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.models.llms import GroqLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = GroqLLM(
    model="llama3-70b-8192",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/models/llms/groq.py
class GroqLLM(AsyncLLM):
    """Groq API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Groq API to use for the generation.
        base_url: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        api_key: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key.
        _aclient: the `AsyncGroq` client from the `groq` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        - `api_key`: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import GroqLLM

        llm = GroqLLM(model="llama3-70b-8192")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import GroqLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = GroqLLM(
            model="llama3-70b-8192",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            _GROQ_API_BASE_URL_ENV_VAR_NAME, "https://api.groq.com"
        ),
        description="The base URL to use for the Groq API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_GROQ_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Groq API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=2,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(_GROQ_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncGroq"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `AsyncGroq` client to benefit from async requests."""
        super().load()

        try:
            from groq import AsyncGroq
        except ImportError as ie:
            raise ImportError(
                "Groq Python client is not installed. Please install it using"
                ' `pip install "distilabel[groq]"`.'
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = AsyncGroq(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="groq",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        seed: Optional[int] = None,
        max_new_tokens: int = 128,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[str] = None,
    ) -> "GenerateOutput":
        """Generates `num_generations` responses for the given input using the Groq async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            seed: the seed to use for the generation. Defaults to `None`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: the stop sequence to use for the generation. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.

        References:
            - https://console.groq.com/docs/text-chat
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="groq",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "seed": seed,
            "temperature": temperature,
            "max_tokens": max_new_tokens,
            "top_p": top_p,
            "stream": False,
            "stop": stop,
        }
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)

        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
        if structured_output:
            return prepare_output(
                [completion.model_dump_json()],
                **self._get_llm_statistics(completion._raw_response),
            )

        generations = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using the Groq client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return prepare_output(generations, **self._get_llm_statistics(completion))

    @staticmethod
    def _get_llm_statistics(completion: "ChatCompletion") -> "LLMStatistics":
        return {
            "input_tokens": [completion.usage.prompt_tokens if completion else 0],
            "output_tokens": [completion.usage.completion_tokens if completion else 0],
        }
model_name property

Returns the model name used for the LLM.

load()

Loads the AsyncGroq client to benefit from async requests.

Source code in src/distilabel/models/llms/groq.py
def load(self) -> None:
    """Loads the `AsyncGroq` client to benefit from async requests."""
    super().load()

    try:
        from groq import AsyncGroq
    except ImportError as ie:
        raise ImportError(
            "Groq Python client is not installed. Please install it using"
            ' `pip install "distilabel[groq]"`.'
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = AsyncGroq(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="groq",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output
agenerate(input, seed=None, max_new_tokens=128, temperature=1.0, top_p=1.0, stop=None) async

Generates num_generations responses for the given input using the Groq async client.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
seed Optional[int]

the seed to use for the generation. Defaults to None.

None
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[str]

the stop sequence to use for the generation. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

References
  • https://console.groq.com/docs/text-chat
Source code in src/distilabel/models/llms/groq.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    seed: Optional[int] = None,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[str] = None,
) -> "GenerateOutput":
    """Generates `num_generations` responses for the given input using the Groq async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        seed: the seed to use for the generation. Defaults to `None`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: the stop sequence to use for the generation. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.

    References:
        - https://console.groq.com/docs/text-chat
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="groq",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "seed": seed,
        "temperature": temperature,
        "max_tokens": max_new_tokens,
        "top_p": top_p,
        "stream": False,
        "stop": stop,
    }
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)

    completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
    if structured_output:
        return prepare_output(
            [completion.model_dump_json()],
            **self._get_llm_statistics(completion._raw_response),
        )

    generations = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using the Groq client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return prepare_output(generations, **self._get_llm_statistics(completion))

InferenceEndpointsLLM

Bases: InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin

InferenceEndpoints LLM implementation running the async API client.

This LLM will internally use huggingface_hub.AsyncInferenceClient.

Attributes:

Name Type Description
model_id Optional[str]

the model ID to use for the LLM as available in the Hugging Face Hub, which will be used to resolve the base URL for the serverless Inference Endpoints API requests. Defaults to None.

endpoint_name Optional[RuntimeParameter[str]]

the name of the Inference Endpoint to use for the LLM. Defaults to None.

endpoint_namespace Optional[RuntimeParameter[str]]

the namespace of the Inference Endpoint to use for the LLM. Defaults to None.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Inference Endpoints API requests.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Inference Endpoints API.

tokenizer_id Optional[str]

the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to None, but defining one is recommended to properly format the prompt.

model_display_name Optional[str]

the model display name to use for the LLM. Defaults to None.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

structured_output Optional[RuntimeParameter[StructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

Icon

:hugging:

Examples:

Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Dedicated Inference Endpoints:

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    endpoint_name="<ENDPOINT_NAME>",
    api_key="<HF_API_KEY>",
    endpoint_namespace="<USER|ORG>",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Dedicated Inference Endpoints or TGI:

from distilabel.models.llms.huggingface import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(
    api_key="<HF_API_KEY>",
    base_url="<BASE_URL>",
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.models.llms import InferenceEndpointsLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3-70B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
    api_key="api.key",
    structured_output={"format": "json", "schema": User.model_json_schema()}
)

llm.load()

output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
class InferenceEndpointsLLM(
    InferenceEndpointsBaseClient, AsyncLLM, MagpieChatTemplateMixin
):
    """InferenceEndpoints LLM implementation running the async API client.

    This LLM will internally use `huggingface_hub.AsyncInferenceClient`.

    Attributes:
        model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which
            will be used to resolve the base URL for the serverless Inference Endpoints API requests.
            Defaults to `None`.
        endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
        endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
        base_url: the base URL to use for the Inference Endpoints API requests.
        api_key: the API key to authenticate the requests to the Inference Endpoints API.
        tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub.
            Defaults to `None`, but defining one is recommended to properly format the prompt.
        model_display_name: the model display name to use for the LLM. Defaults to `None`.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.
        structured_output: a dictionary containing the structured output configuration or
            if more fine-grained control is needed, an instance of `OutlinesStructuredOutput`.
            Defaults to None.

    Icon:
        `:hugging:`

    Examples:
        Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Dedicated Inference Endpoints:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            endpoint_name="<ENDPOINT_NAME>",
            api_key="<HF_API_KEY>",
            endpoint_namespace="<USER|ORG>",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Dedicated Inference Endpoints or TGI:

        ```python
        from distilabel.models.llms.huggingface import InferenceEndpointsLLM

        llm = InferenceEndpointsLLM(
            api_key="<HF_API_KEY>",
            base_url="<BASE_URL>",
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.models.llms import InferenceEndpointsLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3-70B-Instruct",
            api_key="api.key",
            structured_output={"format": "json", "schema": User.model_json_schema()}
        )

        llm.load()

        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the Tour De France"}]])
        ```
    """

    def load(self) -> None:
        # Sets the logger and calls the load method of the BaseClient
        self._num_generations_param_supported = False
        AsyncLLM.load(self)
        InferenceEndpointsBaseClient.load(self)

    @model_validator(mode="after")  # type: ignore
    def only_one_of_model_id_endpoint_name_or_base_url_provided(
        self,
    ) -> "InferenceEndpointsLLM":
        """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
        provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
        favour of the dynamically calculated one.."""

        if self.base_url and (self.model_id or self.endpoint_name):
            self._logger.warning(  # type: ignore
                f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
                " or `endpoint_name` is also provided, the `base_url` will either be ignored"
                " or overwritten with the one generated from either of those args, for serverless"
                " or dedicated inference endpoints, respectively."
            )

        if self.use_magpie_template and self.tokenizer_id is None:
            raise ValueError(
                "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
                " set a `tokenizer_id` and try again."
            )

        if (
            self.model_id
            and self.tokenizer_id is None
            and self.structured_output is not None
        ):
            self.tokenizer_id = self.model_id

        if self.base_url and not (self.model_id or self.endpoint_name):
            return self

        if self.model_id and not self.endpoint_name:
            return self

        if self.endpoint_name and not self.model_id:
            return self

        raise ValidationError(
            f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
            f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
            f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
        )

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                conversation=input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _get_structured_output(
        self, input: FormattedInput
    ) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
        """Gets the structured output (if any) for the given input.

        Args:
            input: a single input in chat format to generate responses for.

        Returns:
            The input and the structured output that will be passed as `grammar` to the
            inference endpoint or `None` if not required.
        """
        structured_output = None

        # Specific structured output per input
        if isinstance(input, tuple):
            input, structured_output = input
            structured_output = {
                "type": structured_output["format"],  # type: ignore
                "value": structured_output["schema"],  # type: ignore
            }

        # Same structured output for all the inputs
        if structured_output is None and self.structured_output is not None:
            try:
                structured_output = {
                    "type": self.structured_output["format"],  # type: ignore
                    "value": self.structured_output["schema"],  # type: ignore
                }
            except KeyError as e:
                raise ValueError(
                    "To use the structured output you have to inform the `format` and `schema` in "
                    "the `structured_output` attribute."
                ) from e

        if structured_output:
            if isinstance(structured_output["value"], ModelMetaclass):
                structured_output["value"] = structured_output[
                    "value"
                ].model_json_schema()

        return input, structured_output

    async def _generate_with_text_generation(
        self,
        input: str,
        max_new_tokens: int = 128,
        repetition_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        temperature: float = 1.0,
        do_sample: bool = False,
        top_n_tokens: Optional[int] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        typical_p: Optional[float] = None,
        stop_sequences: Union[List[str], None] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        watermark: bool = False,
        structured_output: Union[Dict[str, Any], None] = None,
    ) -> GenerateOutput:
        generation: Union["TextGenerationOutput", None] = None
        try:
            generation = await self._aclient.text_generation(  # type: ignore
                prompt=input,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_n_tokens=top_n_tokens,
                top_p=top_p,
                top_k=top_k,
                stop_sequences=stop_sequences,
                return_full_text=return_full_text,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=seed or random.randint(0, sys.maxsize),
                watermark=watermark,
                grammar=structured_output,  # type: ignore
                details=True,
            )
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
        return prepare_output(
            generations=[generation.generated_text] if generation else [None],
            input_tokens=[
                compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1
            ],
            output_tokens=[
                generation.details.generated_tokens
                if generation and generation.details
                else 0
            ],
            logprobs=self._get_logprobs_from_text_generation(generation)
            if generation
            else None,  # type: ignore
        )

    def _get_logprobs_from_text_generation(
        self, generation: "TextGenerationOutput"
    ) -> Union[List[List[List["Logprob"]]], None]:
        if generation.details is None or generation.details.top_tokens is None:
            return None

        return [
            [
                [
                    {"token": top_logprob["text"], "logprob": top_logprob["logprob"]}
                    for top_logprob in token_logprobs
                ]
                for token_logprobs in generation.details.top_tokens
            ]
        ]

    async def _generate_with_chat_completion(
        self,
        input: "StandardInput",
        max_new_tokens: int = 128,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[List[float]] = None,
        logprobs: bool = False,
        presence_penalty: Optional[float] = None,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: float = 1.0,
        tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
        tool_prompt: Optional[str] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
        top_logprobs: Optional[PositiveInt] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        message = None
        completion: Union["ChatCompletionOutput", None] = None
        output_logprobs = None
        try:
            completion = await self._aclient.chat_completion(  # type: ignore
                messages=input,  # type: ignore
                max_tokens=max_new_tokens,
                frequency_penalty=frequency_penalty,
                logit_bias=logit_bias,
                logprobs=logprobs,
                presence_penalty=presence_penalty,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=seed or random.randint(0, sys.maxsize),
                stop=stop_sequences,
                temperature=temperature,
                tool_choice=tool_choice,  # type: ignore
                tool_prompt=tool_prompt,
                tools=tools,  # type: ignore
                top_logprobs=top_logprobs,
                top_p=top_p,
            )
            choice = completion.choices[0]  # type: ignore
            if (message := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            if choice_logprobs := self._get_logprobs_from_choice(choice):
                output_logprobs = [choice_logprobs]
        except Exception as e:
            self._logger.warning(  # type: ignore
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
        return prepare_output(
            generations=[message],
            input_tokens=[completion.usage.prompt_tokens] if completion else None,
            output_tokens=[completion.usage.completion_tokens] if completion else None,
            logprobs=output_logprobs,
        )

    def _get_logprobs_from_choice(
        self, choice: "ChatCompletionOutputComplete"
    ) -> Union[List[List["Logprob"]], None]:
        if choice.logprobs is None:
            return None

        return [
            [
                {"token": top_logprob.token, "logprob": top_logprob.logprob}
                for top_logprob in token_logprobs.top_logprobs
            ]
            for token_logprobs in choice.logprobs.content
        ]

    def _check_stop_sequences(
        self,
        stop_sequences: Optional[Union[str, List[str]]] = None,
    ) -> Union[List[str], None]:
        """Checks that no more than 4 stop sequences are provided.

        Args:
            stop_sequences: the stop sequences to be checked.

        Returns:
            The stop sequences.
        """
        if stop_sequences is not None:
            if isinstance(stop_sequences, str):
                stop_sequences = [stop_sequences]
            if len(stop_sequences) > 4:
                warnings.warn(
                    "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
                    UserWarning,
                    stacklevel=2,
                )
                stop_sequences = stop_sequences[:4]
        return stop_sequences

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_new_tokens: int = 128,
        frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
        logit_bias: Optional[List[float]] = None,
        logprobs: bool = False,
        presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: float = 1.0,
        tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
        tool_prompt: Optional[str] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
        top_logprobs: Optional[PositiveInt] = None,
        top_n_tokens: Optional[PositiveInt] = None,
        top_p: Optional[float] = None,
        do_sample: bool = False,
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        top_k: Optional[int] = None,
        typical_p: Optional[float] = None,
        watermark: bool = False,
        num_generations: int = 1,
    ) -> GenerateOutput:
        """Generates completions for the given input using the async client. This method
        uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
        `chat_completion` method will be used only if no `tokenizer_id` has been specified.
        Some arguments of this function are specific to the `text_generation` method, while
        some others are specific to the `chat_completion` method.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
                new tokens based on their existing frequency in the text so far, decreasing
                model's likelihood to repeat the same line verbatim. Defauls to `None`.
            logit_bias: modify the likelihood of specified tokens appearing in the completion.
                This argument is exclusive to the `chat_completion` method and will be used
                only if `tokenizer_id` is `None`.
                Defaults to `None`.
            logprobs: whether to return the log probabilities or not. This argument is exclusive
                to the `chat_completion` method and will be used only if `tokenizer_id`
                is `None`. Defaults to `False`.
            presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
                new tokens based on whether they appear in the text so far, increasing the
                model likelihood to talk about new topics. This argument is exclusive to
                the `chat_completion` method and will be used only if `tokenizer_id` is
                `None`. Defauls to `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: either a single string or a list of strings containing the sequences
                to stop the generation at. Defaults to `None`, but will be set to the
                `tokenizer.eos_token` if available.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            tool_choice: the name of the tool the model should call. It can be a dictionary
                like `{"function_name": "my_tool"}` or "auto". If not provided, then the
                model won't use any tool. This argument is exclusive to the `chat_completion`
                method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
            tool_prompt: A prompt to be appended before the tools. This argument is exclusive
                to the `chat_completion` method and will be used only if `tokenizer_id`
                is `None`. Defauls to `None`.
            tools: a list of tools definitions that the LLM can use.
                This argument is exclusive to the `chat_completion` method and will be used
                only if `tokenizer_id` is `None`. Defaults to `None`.
            top_logprobs: the number of top log probabilities to return per output token
                generated. This argument is exclusive to the `chat_completion` method and
                will be used only if `tokenizer_id` is `None`. Defaults to `None`.
            top_n_tokens: the number of top log probabilities to return per output token
                generated. This argument is exclusive of the `text_generation` method and
                will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            do_sample: whether to use sampling for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id` is not
                `None`. Defaults to `False`.
            repetition_penalty: the repetition penalty to use for the generation. This argument
                is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            return_full_text: whether to return the full text of the completion or just
                the generated text. Defaults to `False`, meaning that only the generated
                text will be returned. This argument is exclusive of the `text_generation`
                method and will be only used if `tokenizer_id` is not `None`.
            top_k: the top-k value to use for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
                values in TGI.
            typical_p: the typical-p value to use for the generation. This argument is exclusive
                of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            watermark: whether to add the watermark to the generated text. This argument
                is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
                is not `None`. Defaults to `None`.
            num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
                the validation succeds.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        stop_sequences = self._check_stop_sequences(stop_sequences)

        if isinstance(input, str) or self.tokenizer_id is not None:
            structured_output = None
            if not isinstance(input, str):
                input, structured_output = self._get_structured_output(input)
                input = self.prepare_input(input)

            return await self._generate_with_text_generation(
                input=input,
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_n_tokens=top_n_tokens,
                top_p=top_p,
                top_k=top_k,
                stop_sequences=stop_sequences,
                return_full_text=return_full_text,
                seed=seed,
                watermark=watermark,
                structured_output=structured_output,
            )

        return await self._generate_with_chat_completion(
            input=input,  # type: ignore
            max_new_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            logprobs=logprobs,
            presence_penalty=presence_penalty,
            seed=seed,
            stop_sequences=stop_sequences,
            temperature=temperature,
            tool_choice=tool_choice,
            tool_prompt=tool_prompt,
            tools=tools,
            top_logprobs=top_logprobs,
            top_p=top_p,
        )
only_one_of_model_id_endpoint_name_or_base_url_provided()

Validates that only one of model_id or endpoint_name is provided; and if base_url is also provided, a warning will be shown informing the user that the provided base_url will be ignored in favour of the dynamically calculated one..

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
@model_validator(mode="after")  # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
    self,
) -> "InferenceEndpointsLLM":
    """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
    provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
    favour of the dynamically calculated one.."""

    if self.base_url and (self.model_id or self.endpoint_name):
        self._logger.warning(  # type: ignore
            f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
            " or `endpoint_name` is also provided, the `base_url` will either be ignored"
            " or overwritten with the one generated from either of those args, for serverless"
            " or dedicated inference endpoints, respectively."
        )

    if self.use_magpie_template and self.tokenizer_id is None:
        raise ValueError(
            "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
            " set a `tokenizer_id` and try again."
        )

    if (
        self.model_id
        and self.tokenizer_id is None
        and self.structured_output is not None
    ):
        self.tokenizer_id = self.model_id

    if self.base_url and not (self.model_id or self.endpoint_name):
        return self

    if self.model_id and not self.endpoint_name:
        return self

    if self.endpoint_name and not self.model_id:
        return self

    raise ValidationError(
        f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
        f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
        f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
    )
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    prompt: str = (
        self._tokenizer.apply_chat_template(  # type: ignore
            conversation=input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
_get_structured_output(input)

Gets the structured output (if any) for the given input.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required

Returns:

Type Description
StandardInput

The input and the structured output that will be passed as grammar to the

Union[Dict[str, Any], None]

inference endpoint or None if not required.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
def _get_structured_output(
    self, input: FormattedInput
) -> Tuple["StandardInput", Union[Dict[str, Any], None]]:
    """Gets the structured output (if any) for the given input.

    Args:
        input: a single input in chat format to generate responses for.

    Returns:
        The input and the structured output that will be passed as `grammar` to the
        inference endpoint or `None` if not required.
    """
    structured_output = None

    # Specific structured output per input
    if isinstance(input, tuple):
        input, structured_output = input
        structured_output = {
            "type": structured_output["format"],  # type: ignore
            "value": structured_output["schema"],  # type: ignore
        }

    # Same structured output for all the inputs
    if structured_output is None and self.structured_output is not None:
        try:
            structured_output = {
                "type": self.structured_output["format"],  # type: ignore
                "value": self.structured_output["schema"],  # type: ignore
            }
        except KeyError as e:
            raise ValueError(
                "To use the structured output you have to inform the `format` and `schema` in "
                "the `structured_output` attribute."
            ) from e

    if structured_output:
        if isinstance(structured_output["value"], ModelMetaclass):
            structured_output["value"] = structured_output[
                "value"
            ].model_json_schema()

    return input, structured_output
_check_stop_sequences(stop_sequences=None)

Checks that no more than 4 stop sequences are provided.

Parameters:

Name Type Description Default
stop_sequences Optional[Union[str, List[str]]]

the stop sequences to be checked.

None

Returns:

Type Description
Union[List[str], None]

The stop sequences.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
def _check_stop_sequences(
    self,
    stop_sequences: Optional[Union[str, List[str]]] = None,
) -> Union[List[str], None]:
    """Checks that no more than 4 stop sequences are provided.

    Args:
        stop_sequences: the stop sequences to be checked.

    Returns:
        The stop sequences.
    """
    if stop_sequences is not None:
        if isinstance(stop_sequences, str):
            stop_sequences = [stop_sequences]
        if len(stop_sequences) > 4:
            warnings.warn(
                "Only up to 4 stop sequences are allowed, so keeping the first 4 items only.",
                UserWarning,
                stacklevel=2,
            )
            stop_sequences = stop_sequences[:4]
    return stop_sequences
agenerate(input, max_new_tokens=128, frequency_penalty=None, logit_bias=None, logprobs=False, presence_penalty=None, seed=None, stop_sequences=None, temperature=1.0, tool_choice=None, tool_prompt=None, tools=None, top_logprobs=None, top_n_tokens=None, top_p=None, do_sample=False, repetition_penalty=None, return_full_text=False, top_k=None, typical_p=None, watermark=False, num_generations=1) async

Generates completions for the given input using the async client. This method uses two methods of the huggingface_hub.AsyncClient: chat_completion and text_generation. chat_completion method will be used only if no tokenizer_id has been specified. Some arguments of this function are specific to the text_generation method, while some others are specific to the chat_completion method.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty Optional[Annotated[float, Field(ge=-2.0, le=2.0)]]

a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing model's likelihood to repeat the same line verbatim. Defauls to None.

None
logit_bias Optional[List[float]]

modify the likelihood of specified tokens appearing in the completion. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
logprobs bool

whether to return the log probabilities or not. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to False.

False
presence_penalty Optional[Annotated[float, Field(ge=-2.0, le=2.0)]]

a value between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model likelihood to talk about new topics. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defauls to None.

None
seed Optional[int]

the seed to use for the generation. Defaults to None.

None
stop_sequences Optional[List[str]]

either a single string or a list of strings containing the sequences to stop the generation at. Defaults to None, but will be set to the tokenizer.eos_token if available.

None
temperature float

the temperature to use for the generation. Defaults to 1.0.

1.0
tool_choice Optional[Union[Dict[str, str], Literal['auto']]]

the name of the tool the model should call. It can be a dictionary like {"function_name": "my_tool"} or "auto". If not provided, then the model won't use any tool. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
tool_prompt Optional[str]

A prompt to be appended before the tools. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defauls to None.

None
tools Optional[List[Dict[str, Any]]]

a list of tools definitions that the LLM can use. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
top_logprobs Optional[PositiveInt]

the number of top log probabilities to return per output token generated. This argument is exclusive to the chat_completion method and will be used only if tokenizer_id is None. Defaults to None.

None
top_n_tokens Optional[PositiveInt]

the number of top log probabilities to return per output token generated. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

None
do_sample bool

whether to use sampling for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to False.

False
repetition_penalty Optional[float]

the repetition penalty to use for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

None
return_full_text bool

whether to return the full text of the completion or just the generated text. Defaults to False, meaning that only the generated text will be returned. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None.

False
top_k Optional[int]

the top-k value to use for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to 0.8, since neither 0.0 nor 1.0 are valid values in TGI.

None
typical_p Optional[float]

the typical-p value to use for the generation. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

None
watermark bool

whether to add the watermark to the generated text. This argument is exclusive of the text_generation method and will be only used if tokenizer_id is not None. Defaults to None.

False
num_generations int

the number of generations to generate. Defaults to 1. It's here to ensure the validation succeds.

1

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/huggingface/inference_endpoints.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_new_tokens: int = 128,
    frequency_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
    logit_bias: Optional[List[float]] = None,
    logprobs: bool = False,
    presence_penalty: Optional[Annotated[float, Field(ge=-2.0, le=2.0)]] = None,
    seed: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    temperature: float = 1.0,
    tool_choice: Optional[Union[Dict[str, str], Literal["auto"]]] = None,
    tool_prompt: Optional[str] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
    top_logprobs: Optional[PositiveInt] = None,
    top_n_tokens: Optional[PositiveInt] = None,
    top_p: Optional[float] = None,
    do_sample: bool = False,
    repetition_penalty: Optional[float] = None,
    return_full_text: bool = False,
    top_k: Optional[int] = None,
    typical_p: Optional[float] = None,
    watermark: bool = False,
    num_generations: int = 1,
) -> GenerateOutput:
    """Generates completions for the given input using the async client. This method
    uses two methods of the `huggingface_hub.AsyncClient`: `chat_completion` and `text_generation`.
    `chat_completion` method will be used only if no `tokenizer_id` has been specified.
    Some arguments of this function are specific to the `text_generation` method, while
    some others are specific to the `chat_completion` method.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
            new tokens based on their existing frequency in the text so far, decreasing
            model's likelihood to repeat the same line verbatim. Defauls to `None`.
        logit_bias: modify the likelihood of specified tokens appearing in the completion.
            This argument is exclusive to the `chat_completion` method and will be used
            only if `tokenizer_id` is `None`.
            Defaults to `None`.
        logprobs: whether to return the log probabilities or not. This argument is exclusive
            to the `chat_completion` method and will be used only if `tokenizer_id`
            is `None`. Defaults to `False`.
        presence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
            new tokens based on whether they appear in the text so far, increasing the
            model likelihood to talk about new topics. This argument is exclusive to
            the `chat_completion` method and will be used only if `tokenizer_id` is
            `None`. Defauls to `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: either a single string or a list of strings containing the sequences
            to stop the generation at. Defaults to `None`, but will be set to the
            `tokenizer.eos_token` if available.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        tool_choice: the name of the tool the model should call. It can be a dictionary
            like `{"function_name": "my_tool"}` or "auto". If not provided, then the
            model won't use any tool. This argument is exclusive to the `chat_completion`
            method and will be used only if `tokenizer_id` is `None`. Defaults to `None`.
        tool_prompt: A prompt to be appended before the tools. This argument is exclusive
            to the `chat_completion` method and will be used only if `tokenizer_id`
            is `None`. Defauls to `None`.
        tools: a list of tools definitions that the LLM can use.
            This argument is exclusive to the `chat_completion` method and will be used
            only if `tokenizer_id` is `None`. Defaults to `None`.
        top_logprobs: the number of top log probabilities to return per output token
            generated. This argument is exclusive to the `chat_completion` method and
            will be used only if `tokenizer_id` is `None`. Defaults to `None`.
        top_n_tokens: the number of top log probabilities to return per output token
            generated. This argument is exclusive of the `text_generation` method and
            will be only used if `tokenizer_id` is not `None`. Defaults to `None`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        do_sample: whether to use sampling for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id` is not
            `None`. Defaults to `False`.
        repetition_penalty: the repetition penalty to use for the generation. This argument
            is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        return_full_text: whether to return the full text of the completion or just
            the generated text. Defaults to `False`, meaning that only the generated
            text will be returned. This argument is exclusive of the `text_generation`
            method and will be only used if `tokenizer_id` is not `None`.
        top_k: the top-k value to use for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `0.8`, since neither `0.0` nor `1.0` are valid
            values in TGI.
        typical_p: the typical-p value to use for the generation. This argument is exclusive
            of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        watermark: whether to add the watermark to the generated text. This argument
            is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
            is not `None`. Defaults to `None`.
        num_generations: the number of generations to generate. Defaults to `1`. It's here to ensure
            the validation succeds.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    stop_sequences = self._check_stop_sequences(stop_sequences)

    if isinstance(input, str) or self.tokenizer_id is not None:
        structured_output = None
        if not isinstance(input, str):
            input, structured_output = self._get_structured_output(input)
            input = self.prepare_input(input)

        return await self._generate_with_text_generation(
            input=input,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            typical_p=typical_p,
            repetition_penalty=repetition_penalty,
            frequency_penalty=frequency_penalty,
            temperature=temperature,
            top_n_tokens=top_n_tokens,
            top_p=top_p,
            top_k=top_k,
            stop_sequences=stop_sequences,
            return_full_text=return_full_text,
            seed=seed,
            watermark=watermark,
            structured_output=structured_output,
        )

    return await self._generate_with_chat_completion(
        input=input,  # type: ignore
        max_new_tokens=max_new_tokens,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        logprobs=logprobs,
        presence_penalty=presence_penalty,
        seed=seed,
        stop_sequences=stop_sequences,
        temperature=temperature,
        tool_choice=tool_choice,
        tool_prompt=tool_prompt,
        tools=tools,
        top_logprobs=top_logprobs,
        top_p=top_p,
    )

TransformersLLM

Bases: LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin

Hugging Face transformers library LLM implementation using the text generation pipeline.

Attributes:

Name Type Description
model str

the model Hugging Face Hub repo id or a path to a directory containing the model weights and configuration files.

revision str

if model refers to a Hugging Face Hub repository, then the revision (e.g. a branch name or a commit id) to use. Defaults to "main".

torch_dtype str

the torch dtype to use for the model e.g. "float16", "float32", etc. Defaults to "auto".

trust_remote_code bool

whether to allow fetching and executing remote code fetched from the repository in the Hub. Defaults to False.

model_kwargs Optional[Dict[str, Any]]

additional dictionary of keyword arguments that will be passed to the from_pretrained method of the model.

tokenizer Optional[str]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer config files. If not provided, the one associated to the model will be used. Defaults to None.

use_fast bool

whether to use a fast tokenizer or not. Defaults to True.

chat_template Optional[str]

a chat template that will be used to build the prompts before sending them to the model. If not provided, the chat template defined in the tokenizer config will be used. If not provided and the tokenizer doesn't have a chat template, then ChatML template will be used. Defaults to None.

device Optional[Union[str, int]]

the name or index of the device where the model will be loaded. Defaults to None.

device_map Optional[Union[str, Dict[str, Any]]]

a dictionary mapping each layer of the model to a device, or a mode like "sequential" or "auto". Defaults to None.

token Optional[SecretStr]

the Hugging Face Hub token that will be used to authenticate to the Hugging Face Hub. If not provided, the HF_TOKEN environment or huggingface_hub package local configuration will be used. Defaults to None.

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

use_magpie_template bool

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template Union[MagpieAvailablePreQueryTemplates, str, None]

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

Icon

:hugging:

Examples:

Generate text:

from distilabel.models.llms import TransformersLLM

llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")

llm.load()

# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
Source code in src/distilabel/models/llms/huggingface/transformers.py
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
    """Hugging Face `transformers` library LLM implementation using the text generation
    pipeline.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        revision: if `model` refers to a Hugging Face Hub repository, then the revision
            (e.g. a branch name or a commit id) to use. Defaults to `"main"`.
        torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc.
            Defaults to `"auto"`.
        trust_remote_code: whether to allow fetching and executing remote code fetched
            from the repository in the Hub. Defaults to `False`.
        model_kwargs: additional dictionary of keyword arguments that will be passed to
            the `from_pretrained` method of the model.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_fast: whether to use a fast tokenizer or not. Defaults to `True`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        device: the name or index of the device where the model will be loaded. Defaults
            to `None`.
        device_map: a dictionary mapping each layer of the model to a device, or a mode
            like `"sequential"` or `"auto"`. Defaults to `None`.
        token: the Hugging Face Hub token that will be used to authenticate to the Hugging
            Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
            local configuration will be used. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    Icon:
        `:hugging:`

    Examples:
        Generate text:

        ```python
        from distilabel.models.llms import TransformersLLM

        llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")

        llm.load()

        # Call the model
        output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```
    """

    model: str
    revision: str = "main"
    torch_dtype: str = "auto"
    trust_remote_code: bool = False
    model_kwargs: Optional[Dict[str, Any]] = None
    tokenizer: Optional[str] = None
    use_fast: bool = True
    chat_template: Optional[str] = None
    device: Optional[Union[str, int]] = None
    device_map: Optional[Union[str, Dict[str, Any]]] = None
    token: Optional[SecretStr] = Field(
        default_factory=lambda: os.getenv(HF_TOKEN_ENV_VAR)
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )

    _pipeline: Optional["Pipeline"] = PrivateAttr(...)
    _prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
    _logits_processor: Union[Callable, None] = PrivateAttr(default=None)

    def load(self) -> None:
        """Loads the model and tokenizer and creates the text generation pipeline. In addition,
        it will configure the tokenizer chat template."""
        if self.device == "cuda":
            CudaDevicePlacementMixin.load(self)

        try:
            from transformers import pipeline
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
            ) from ie

        token = self.token.get_secret_value() if self.token is not None else self.token

        self._pipeline = pipeline(
            "text-generation",
            model=self.model,
            revision=self.revision,
            torch_dtype=self.torch_dtype,
            trust_remote_code=self.trust_remote_code,
            model_kwargs=self.model_kwargs or {},
            tokenizer=self.tokenizer or self.model,
            use_fast=self.use_fast,
            device=self.device,
            device_map=self.device_map,
            token=token,
            return_full_text=False,
        )

        if self.chat_template is not None:
            self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore

        if self._pipeline.tokenizer.pad_token is None:  # type: ignore
            self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token  # type: ignore

        if self.structured_output:
            processor = self._prepare_structured_output(self.structured_output)
            if _is_outlines_version_below_0_1_0():
                self._prefix_allowed_tokens_fn = processor
            else:
                self._logits_processor = [processor]

        super().load()

    def unload(self) -> None:
        """Unloads the `vLLM` model."""
        CudaDevicePlacementMixin.unload(self)
        super().unload()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if self._pipeline.tokenizer.chat_template is None:  # type: ignore
            return input[0]["content"]

        prompt: str = (
            self._pipeline.tokenizer.apply_chat_template(  # type: ignore
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[StandardInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        temperature: float = 0.1,
        repetition_penalty: float = 1.1,
        top_p: float = 1.0,
        top_k: int = 0,
        do_sample: bool = True,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            repetition_penalty: the repetition penalty to use for the generation. Defaults
                to `1.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            do_sample: whether to use sampling or not. Defaults to `True`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        prepared_inputs = [self.prepare_input(input=input) for input in inputs]

        outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
            prepared_inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            top_p=top_p,
            top_k=top_k,
            do_sample=do_sample,
            num_return_sequences=num_generations,
            prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
            pad_token_id=self._pipeline.tokenizer.eos_token_id,
            logits_processor=self._logits_processor,
        )
        llm_output = [
            [generation["generated_text"] for generation in output]
            for output in outputs
        ]

        result = []
        for input, output in zip(inputs, llm_output):
            result.append(
                prepare_output(
                    output,
                    input_tokens=[
                        compute_tokens(input, self._pipeline.tokenizer.encode)
                    ],
                    output_tokens=[
                        compute_tokens(row, self._pipeline.tokenizer.encode)
                        for row in output
                    ],
                )
            )

        return result

    def get_last_hidden_states(
        self, inputs: List["StandardInput"]
    ) -> List["HiddenState"]:
        """Gets the last `hidden_states` of the model for the given inputs. It doesn't
        execute the task head.

        Args:
            inputs: a list of inputs in chat format to generate the embeddings for.

        Returns:
            A list containing the last hidden state for each sequence using a NumPy array
            with shape [num_tokens, hidden_size].
        """
        model: "PreTrainedModel" = (
            self._pipeline.model.model  # type: ignore
            if hasattr(self._pipeline.model, "model")  # type: ignore
            else next(self._pipeline.model.children())  # type: ignore
        )
        tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
        input_ids = tokenizer(
            [self.prepare_input(input) for input in inputs],  # type: ignore
            return_tensors="pt",
            padding=True,
        ).to(model.device)
        last_hidden_states = model(**input_ids)["last_hidden_state"]

        return [
            seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
            for seq_last_hidden_state, attention_mask in zip(
                last_hidden_states,
                input_ids["attention_mask"],  # type: ignore
            )
        ]

    def _prepare_structured_output(
        self, structured_output: Optional[OutlinesStructuredOutputType] = None
    ) -> Union[Callable, List[Callable]]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        result = prepare_guided_output(
            structured_output, "transformers", self._pipeline
        )
        if schema := result.get("schema"):
            self.structured_output["schema"] = schema
        return result["processor"]
model_name property

Returns the model name used for the LLM.

load()

Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def load(self) -> None:
    """Loads the model and tokenizer and creates the text generation pipeline. In addition,
    it will configure the tokenizer chat template."""
    if self.device == "cuda":
        CudaDevicePlacementMixin.load(self)

    try:
        from transformers import pipeline
    except ImportError as ie:
        raise ImportError(
            "Transformers is not installed. Please install it using `pip install 'distilabel[hf-transformers]'`."
        ) from ie

    token = self.token.get_secret_value() if self.token is not None else self.token

    self._pipeline = pipeline(
        "text-generation",
        model=self.model,
        revision=self.revision,
        torch_dtype=self.torch_dtype,
        trust_remote_code=self.trust_remote_code,
        model_kwargs=self.model_kwargs or {},
        tokenizer=self.tokenizer or self.model,
        use_fast=self.use_fast,
        device=self.device,
        device_map=self.device_map,
        token=token,
        return_full_text=False,
    )

    if self.chat_template is not None:
        self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore

    if self._pipeline.tokenizer.pad_token is None:  # type: ignore
        self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token  # type: ignore

    if self.structured_output:
        processor = self._prepare_structured_output(self.structured_output)
        if _is_outlines_version_below_0_1_0():
            self._prefix_allowed_tokens_fn = processor
        else:
            self._logits_processor = [processor]

    super().load()
unload()

Unloads the vLLM model.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def unload(self) -> None:
    """Unloads the `vLLM` model."""
    CudaDevicePlacementMixin.unload(self)
    super().unload()
prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/models/llms/huggingface/transformers.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if self._pipeline.tokenizer.chat_template is None:  # type: ignore
        return input[0]["content"]

    prompt: str = (
        self._pipeline.tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)
generate(inputs, num_generations=1, max_new_tokens=128, temperature=0.1, repetition_penalty=1.1, top_p=1.0, top_k=0, do_sample=True)

Generates num_generations responses for each input using the text generation pipeline.

Parameters:

Name Type Description Default
inputs List[StandardInput]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
temperature float

the temperature to use for the generation. Defaults to 0.1.

0.1
repetition_penalty float

the repetition penalty to use for the generation. Defaults to 1.1.

1.1
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
top_k int

the top-k value to use for the generation. Defaults to 0.

0
do_sample bool

whether to use sampling or not. Defaults to True.

True

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/models/llms/huggingface/transformers.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[StandardInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    temperature: float = 0.1,
    repetition_penalty: float = 1.1,
    top_p: float = 1.0,
    top_k: int = 0,
    do_sample: bool = True,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        repetition_penalty: the repetition penalty to use for the generation. Defaults
            to `1.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        do_sample: whether to use sampling or not. Defaults to `True`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    prepared_inputs = [self.prepare_input(input=input) for input in inputs]

    outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
        prepared_inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        top_p=top_p,
        top_k=top_k,
        do_sample=do_sample,
        num_return_sequences=num_generations,
        prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
        pad_token_id=self._pipeline.tokenizer.eos_token_id,
        logits_processor=self._logits_processor,
    )
    llm_output = [
        [generation["generated_text"] for generation in output]
        for output in outputs
    ]

    result = []
    for input, output in zip(inputs, llm_output):
        result.append(
            prepare_output(
                output,
                input_tokens=[
                    compute_tokens(input, self._pipeline.tokenizer.encode)
                ],
                output_tokens=[
                    compute_tokens(row, self._pipeline.tokenizer.encode)
                    for row in output
                ],
            )
        )

    return result
get_last_hidden_states(inputs)

Gets the last hidden_states of the model for the given inputs. It doesn't execute the task head.

Parameters:

Name Type Description Default
inputs List[StandardInput]

a list of inputs in chat format to generate the embeddings for.

required

Returns:

Type Description
List[HiddenState]

A list containing the last hidden state for each sequence using a NumPy array

List[HiddenState]

with shape [num_tokens, hidden_size].

Source code in src/distilabel/models/llms/huggingface/transformers.py
def get_last_hidden_states(
    self, inputs: List["StandardInput"]
) -> List["HiddenState"]:
    """Gets the last `hidden_states` of the model for the given inputs. It doesn't
    execute the task head.

    Args:
        inputs: a list of inputs in chat format to generate the embeddings for.

    Returns:
        A list containing the last hidden state for each sequence using a NumPy array
        with shape [num_tokens, hidden_size].
    """
    model: "PreTrainedModel" = (
        self._pipeline.model.model  # type: ignore
        if hasattr(self._pipeline.model, "model")  # type: ignore
        else next(self._pipeline.model.children())  # type: ignore
    )
    tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
    input_ids = tokenizer(
        [self.prepare_input(input) for input in inputs],  # type: ignore
        return_tensors="pt",
        padding=True,
    ).to(model.device)
    last_hidden_states = model(**input_ids)["last_hidden_state"]

    return [
        seq_last_hidden_state[attention_mask.bool<