Skip to content

Index

AnthropicLLM

Bases: AsyncLLM

Anthropic LLM implementation running the Async API client.

Attributes:

Name Type Description
model str

the name of the model to use for the LLM e.g. "claude-3-opus-20240229", "claude-3-sonnet-20240229", etc. Available models can be checked here: Anthropic: Models overview.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Anthropic API. If not provided, it will be read from ANTHROPIC_API_KEY environment variable.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Anthropic API. Defaults to None which means that https://api.anthropic.com will be used internally.

timeout RuntimeParameter[float]

the maximum time in seconds to wait for a response. Defaults to 600.0.

max_retries RuntimeParameter[int]

The maximum number of times to retry the request before failing. Defaults to 6.

http_client Optional[AsyncClient]

if provided, an alternative HTTP client to use for calling Anthropic API. Defaults to None.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

_aclient Optional[AsyncAnthropic]

the AsyncAnthropic client to use for the Anthropic API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • api_key: the API key to authenticate the requests to the Anthropic API. If not provided, it will be read from ANTHROPIC_API_KEY environment variable.
  • base_url: the base URL to use for the Anthropic API. Defaults to "https://api.anthropic.com".
  • timeout: the maximum time in seconds to wait for a response. Defaults to 600.0.
  • max_retries: the maximum number of times to retry the request before failing. Defaults to 6.
Source code in src/distilabel/llms/anthropic.py
class AnthropicLLM(AsyncLLM):
    """Anthropic LLM implementation running the Async API client.

    Attributes:
        model: the name of the model to use for the LLM e.g. "claude-3-opus-20240229",
            "claude-3-sonnet-20240229", etc. Available models can be checked here:
            [Anthropic: Models overview](https://docs.anthropic.com/claude/docs/models-overview).
        api_key: the API key to authenticate the requests to the Anthropic API. If not provided,
            it will be read from `ANTHROPIC_API_KEY` environment variable.
        base_url: the base URL to use for the Anthropic API. Defaults to `None` which means
            that `https://api.anthropic.com` will be used internally.
        timeout: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        max_retries: The maximum number of times to retry the request before failing. Defaults
            to `6`.
        http_client: if provided, an alternative HTTP client to use for calling Anthropic
            API. Defaults to `None`.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.
        _aclient: the `AsyncAnthropic` client to use for the Anthropic API. It is meant
            to be used internally. Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Anthropic API. If not
            provided, it will be read from `ANTHROPIC_API_KEY` environment variable.
        - `base_url`: the base URL to use for the Anthropic API. Defaults to `"https://api.anthropic.com"`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `600.0`.
        - `max_retries`: the maximum number of times to retry the request before failing.
            Defaults to `6`.
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANTHROPIC_BASE_URL", "https://api.anthropic.com"
        ),
        description="The base URL to use for the Anthropic API.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANTHROPIC_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anthropic API.",
    )
    timeout: RuntimeParameter[float] = Field(
        default=600.0,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    http_client: Optional[AsyncClient] = Field(default=None, exclude=True)

    _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...)

    def _check_model_exists(self) -> None:
        """Checks if the specified model exists in the available models."""
        from anthropic import AsyncAnthropic

        annotation = get_type_hints(AsyncAnthropic().messages.create).get("model", None)
        models = [
            value
            for type_ in get_args(annotation)
            if get_origin(type_) is Literal
            for value in get_args(type_)
        ]

        if self.model not in models:
            raise ValueError(
                f"Model {self.model} does not exist among available models. "
                f"The available models are {', '.join(models)}"
            )

    def load(self) -> None:
        """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
        super().load()

        try:
            from anthropic import AsyncAnthropic
        except ImportError as ie:
            raise ImportError(
                "Anthropic Python client is not installed. Please install it using"
                " `pip install anthropic`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._check_model_exists()

        self._aclient = AsyncAnthropic(
            api_key=self.api_key.get_secret_value(),
            base_url=self.base_url,
            timeout=self.timeout,
            http_client=self.http_client,
            max_retries=self.max_retries,
        )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        max_tokens: int = 128,
        stop_sequences: Union[List[str], None] = None,
        temperature: float = 1.0,
        top_p: Union[float, None] = None,
        top_k: Union[int, None] = None,
    ) -> GenerateOutput:
        """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

        Args:
            input: a single input in chat format to generate responses for.
            max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
            stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
            temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
            top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from anthropic._types import NOT_GIVEN

        completion = await self._aclient.messages.create(  # type: ignore
            model=self.model,
            system=(
                input.pop(0)["content"]
                if input and input[0]["role"] == "system"
                else NOT_GIVEN
            ),
            messages=input,  # type: ignore
            max_tokens=max_tokens,
            stream=False,
            stop_sequences=NOT_GIVEN if stop_sequences is None else stop_sequences,
            temperature=temperature,
            top_p=NOT_GIVEN if top_p is None else top_p,
            top_k=NOT_GIVEN if top_k is None else top_k,
        )
        generations = []
        if (content := completion.content[0].text) is None:
            self._logger.warning(
                f"Received no response using Anthropic client (model: '{self.model}')."
                f" Finish reason was: {completion.stop_reason}"
            )
        generations.append(content)
        return generations

    # TODO: remove this function once Anthropic client allows `n` parameter
    @override
    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`.
        """

        async def agenerate(
            inputs: List["ChatType"], **kwargs: Any
        ) -> "GenerateOutput":
            """Internal function to parallelize the asynchronous generation of responses."""
            tasks = [
                asyncio.create_task(self.agenerate(input=input, **kwargs))
                for input in inputs
                for _ in range(num_generations)
            ]
            return [outputs[0] for outputs in await asyncio.gather(*tasks)]

        outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
        return list(grouper(outputs, n=num_generations, incomplete="ignore"))

model_name: str property

Returns the model name used for the LLM.

agenerate(input, max_tokens=128, stop_sequences=None, temperature=1.0, top_p=None, top_k=None) async

Generates a response asynchronously, using the Anthropic Async API definition.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
max_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
stop_sequences Union[List[str], None]

custom text sequences that will cause the model to stop generating. Defaults to NOT_GIVEN.

None
temperature float

the temperature to use for the generation. Set only if top_p is None. Defaults to 1.0.

1.0
top_p Union[float, None]

the top-p value to use for the generation. Defaults to NOT_GIVEN.

None
top_k Union[int, None]

the top-k value to use for the generation. Defaults to NOT_GIVEN.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/anthropic.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    max_tokens: int = 128,
    stop_sequences: Union[List[str], None] = None,
    temperature: float = 1.0,
    top_p: Union[float, None] = None,
    top_k: Union[int, None] = None,
) -> GenerateOutput:
    """Generates a response asynchronously, using the [Anthropic Async API definition](https://github.com/anthropics/anthropic-sdk-python).

    Args:
        input: a single input in chat format to generate responses for.
        max_tokens: the maximum number of new tokens that the model will generate. Defaults to `128`.
        stop_sequences: custom text sequences that will cause the model to stop generating. Defaults to `NOT_GIVEN`.
        temperature: the temperature to use for the generation. Set only if top_p is None. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `NOT_GIVEN`.
        top_k: the top-k value to use for the generation. Defaults to `NOT_GIVEN`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from anthropic._types import NOT_GIVEN

    completion = await self._aclient.messages.create(  # type: ignore
        model=self.model,
        system=(
            input.pop(0)["content"]
            if input and input[0]["role"] == "system"
            else NOT_GIVEN
        ),
        messages=input,  # type: ignore
        max_tokens=max_tokens,
        stream=False,
        stop_sequences=NOT_GIVEN if stop_sequences is None else stop_sequences,
        temperature=temperature,
        top_p=NOT_GIVEN if top_p is None else top_p,
        top_k=NOT_GIVEN if top_k is None else top_k,
    )
    generations = []
    if (content := completion.content[0].text) is None:
        self._logger.warning(
            f"Received no response using Anthropic client (model: '{self.model}')."
            f" Finish reason was: {completion.stop_reason}"
        )
    generations.append(content)
    return generations

generate(inputs, num_generations=1, **kwargs)

Method to generate a list of responses asynchronously, returning the output synchronously awaiting for the response of each input sent to agenerate.

Source code in src/distilabel/llms/anthropic.py
@override
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`.
    """

    async def agenerate(
        inputs: List["ChatType"], **kwargs: Any
    ) -> "GenerateOutput":
        """Internal function to parallelize the asynchronous generation of responses."""
        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        return [outputs[0] for outputs in await asyncio.gather(*tasks)]

    outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
    return list(grouper(outputs, n=num_generations, incomplete="ignore"))

load()

Loads the AsyncAnthropic client to use the Anthropic async API.

Source code in src/distilabel/llms/anthropic.py
def load(self) -> None:
    """Loads the `AsyncAnthropic` client to use the Anthropic async API."""
    super().load()

    try:
        from anthropic import AsyncAnthropic
    except ImportError as ie:
        raise ImportError(
            "Anthropic Python client is not installed. Please install it using"
            " `pip install anthropic`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._check_model_exists()

    self._aclient = AsyncAnthropic(
        api_key=self.api_key.get_secret_value(),
        base_url=self.base_url,
        timeout=self.timeout,
        http_client=self.http_client,
        max_retries=self.max_retries,
    )

AnyscaleLLM

Bases: OpenAILLM

Anyscale LLM implementation running the async API client of OpenAI because of duplicate API behavior.

Attributes:

Name Type Description
model

the model name to use for the LLM, e.g., google/gemma-7b-it. See the supported models under the "Text Generation -> Supported Models" section here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Anyscale API requests. Defaults to None, which means that the value set for the environment variable ANYSCALE_BASE_URL will be used, or "https://api.endpoints.anyscale.com/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Anyscale API. Defaults to None which means that the value set for the environment variable ANYSCALE_API_KEY will be used, or None if not set.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

Source code in src/distilabel/llms/anyscale.py
class AnyscaleLLM(OpenAILLM):
    """Anyscale LLM implementation running the async API client of OpenAI because of
    duplicate API behavior.

    Attributes:
        model: the model name to use for the LLM, e.g., `google/gemma-7b-it`. See the
            supported models under the "Text Generation -> Supported Models" section
            [here](https://docs.endpoints.anyscale.com/).
        base_url: the base URL to use for the Anyscale API requests. Defaults to `None`, which
            means that the value set for the environment variable `ANYSCALE_BASE_URL` will be used, or
            "https://api.endpoints.anyscale.com/v1" if not set.
        api_key: the API key to authenticate the requests to the Anyscale API. Defaults to `None` which
            means that the value set for the environment variable `ANYSCALE_API_KEY` will be used, or
            `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key.
            It is meant to be used internally.
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "ANYSCALE_BASE_URL", "https://api.endpoints.anyscale.com/v1"
        ),
        description="The base URL to use for the Anyscale API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_ANYSCALE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Anyscale API.",
    )

    _api_key_env_var: str = PrivateAttr(_ANYSCALE_API_KEY_ENV_VAR_NAME)

AsyncLLM

Bases: LLM

Abstract class for asynchronous LLMs, so as to benefit from the async capabilities of each LLM implementation. This class is meant to be subclassed by each LLM, and the method agenerate needs to be implemented to provide the asynchronous generation of responses.

Attributes:

Name Type Description
_event_loop AbstractEventLoop

the event loop to be used for the asynchronous generation of responses.

Source code in src/distilabel/llms/base.py
class AsyncLLM(LLM):
    """Abstract class for asynchronous LLMs, so as to benefit from the async capabilities
    of each LLM implementation. This class is meant to be subclassed by each LLM, and the
    method `agenerate` needs to be implemented to provide the asynchronous generation of
    responses.

    Attributes:
        _event_loop: the event loop to be used for the asynchronous generation of responses.
    """

    _event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)

    @property
    def generate_parameters(self) -> List[inspect.Parameter]:
        """Returns the parameters of the `agenerate` method.

        Returns:
            A list containing the parameters of the `agenerate` method.
        """
        return list(inspect.signature(self.agenerate).parameters.values())

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `agenerate` method.

        Returns:
            The parsed docstring of the `agenerate` method.
        """
        return parse_google_docstring(self.agenerate)

    @property
    def event_loop(self) -> "asyncio.AbstractEventLoop":
        if self._event_loop is None:
            try:
                self._event_loop = asyncio.get_running_loop()
                if self._event_loop.is_closed():
                    self._event_loop = asyncio.new_event_loop()  # type: ignore
            except RuntimeError:
                self._event_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self._event_loop)
        return self._event_loop

    @abstractmethod
    async def agenerate(
        self, input: "ChatType", num_generations: int = 1, **kwargs: Any
    ) -> List[Union[str, None]]:
        """Method to generate a `num_generations` responses for a given input asynchronously,
        and executed concurrently in `generate` method.
        """
        pass

    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`.
        """

        async def agenerate(
            inputs: List["ChatType"], **kwargs: Any
        ) -> List[List[Union[str, None]]]:
            """Internal function to parallelize the asynchronous generation of responses."""
            tasks = [
                asyncio.create_task(
                    self.agenerate(
                        input=input, num_generations=num_generations, **kwargs
                    )
                )
                for input in inputs
            ]
            return await asyncio.gather(*tasks)

        return self.event_loop.run_until_complete(agenerate(inputs, **kwargs))

    def __del__(self) -> None:
        """Closes the event loop when the object is deleted."""
        if sys.meta_path is None:
            return
        if self.event_loop is not None:
            self.event_loop.close()

generate_parameters: List[inspect.Parameter] property

Returns the parameters of the agenerate method.

Returns:

Type Description
List[Parameter]

A list containing the parameters of the agenerate method.

generate_parsed_docstring: Docstring cached property

Returns the parsed docstring of the agenerate method.

Returns:

Type Description
Docstring

The parsed docstring of the agenerate method.

__del__()

Closes the event loop when the object is deleted.

Source code in src/distilabel/llms/base.py
def __del__(self) -> None:
    """Closes the event loop when the object is deleted."""
    if sys.meta_path is None:
        return
    if self.event_loop is not None:
        self.event_loop.close()

agenerate(input, num_generations=1, **kwargs) abstractmethod async

Method to generate a num_generations responses for a given input asynchronously, and executed concurrently in generate method.

Source code in src/distilabel/llms/base.py
@abstractmethod
async def agenerate(
    self, input: "ChatType", num_generations: int = 1, **kwargs: Any
) -> List[Union[str, None]]:
    """Method to generate a `num_generations` responses for a given input asynchronously,
    and executed concurrently in `generate` method.
    """
    pass

generate(inputs, num_generations=1, **kwargs)

Method to generate a list of responses asynchronously, returning the output synchronously awaiting for the response of each input sent to agenerate.

Source code in src/distilabel/llms/base.py
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`.
    """

    async def agenerate(
        inputs: List["ChatType"], **kwargs: Any
    ) -> List[List[Union[str, None]]]:
        """Internal function to parallelize the asynchronous generation of responses."""
        tasks = [
            asyncio.create_task(
                self.agenerate(
                    input=input, num_generations=num_generations, **kwargs
                )
            )
            for input in inputs
        ]
        return await asyncio.gather(*tasks)

    return self.event_loop.run_until_complete(agenerate(inputs, **kwargs))

AzureOpenAILLM

Bases: OpenAILLM

Azure OpenAI LLM implementation running the async API client of OpenAI because of duplicate API behavior, but with Azure-specific parameters.

Attributes:

Name Type Description
model

the model name to use for the LLM i.e. the name of the Azure deployment.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Azure OpenAI API can be set with AZURE_OPENAI_ENDPOINT. Defaults to None which means that the value set for the environment variable AZURE_OPENAI_ENDPOINT will be used, or None if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Azure OpenAI API. Defaults to None which means that the value set for the environment variable AZURE_OPENAI_API_KEY will be used, or None if not set.

api_version Optional[RuntimeParameter[str]]

the API version to use for the Azure OpenAI API. Defaults to None which means that the value set for the environment variable OPENAI_API_VERSION will be used, or None if not set.

Source code in src/distilabel/llms/azure.py
class AzureOpenAILLM(OpenAILLM):
    """Azure OpenAI LLM implementation running the async API client of OpenAI because of
    duplicate API behavior, but with Azure-specific parameters.

    Attributes:
        model: the model name to use for the LLM i.e. the name of the Azure deployment.
        base_url: the base URL to use for the Azure OpenAI API can be set with `AZURE_OPENAI_ENDPOINT`.
            Defaults to `None` which means that the value set for the environment variable
            `AZURE_OPENAI_ENDPOINT` will be used, or `None` if not set.
        api_key: the API key to authenticate the requests to the Azure OpenAI API. Defaults to `None`
            which means that the value set for the environment variable `AZURE_OPENAI_API_KEY` will be
            used, or `None` if not set.
        api_version: the API version to use for the Azure OpenAI API. Defaults to `None` which means
            that the value set for the environment variable `OPENAI_API_VERSION` will be used, or
            `None` if not set.
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME),
        description="The base URL to use for the Azure OpenAI API requests i.e. the Azure OpenAI endpoint.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Azure OpenAI API.",
    )

    api_version: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv("OPENAI_API_VERSION"),
        description="The API version to use for the Azure OpenAI API.",
    )

    _base_url_env_var: str = PrivateAttr(_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME)
    _api_key_env_var: str = PrivateAttr(_AZURE_OPENAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncAzureOpenAI"] = PrivateAttr(...)  # type: ignore

    @override
    def load(self) -> None:
        """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
        super().load()

        try:
            from openai import AsyncAzureOpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install openai`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        # TODO: May be worth adding the AD auth too? Also the `organization`?
        self._aclient = AsyncAzureOpenAI(  # type: ignore
            azure_endpoint=self.base_url,  # type: ignore
            azure_deployment=self.model,
            api_version=self.api_version,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

load()

Loads the AsyncAzureOpenAI client to benefit from async requests.

Source code in src/distilabel/llms/azure.py
@override
def load(self) -> None:
    """Loads the `AsyncAzureOpenAI` client to benefit from async requests."""
    super().load()

    try:
        from openai import AsyncAzureOpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install openai`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    # TODO: May be worth adding the AD auth too? Also the `organization`?
    self._aclient = AsyncAzureOpenAI(  # type: ignore
        azure_endpoint=self.base_url,  # type: ignore
        azure_deployment=self.model,
        api_version=self.api_version,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

CohereLLM

Bases: AsyncLLM

Cohere API implementation using the async client for concurrent text generation.

Attributes:

Name Type Description
model str

the name of the model from the Cohere API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Cohere API requests. Defaults to "https://api.cohere.ai/v1".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

client_name RuntimeParameter[str]

the name of the client to use for the API requests. Defaults to "distilabel".

_ChatMessage Type[ChatMessage]

the ChatMessage class from the cohere package.

_aclient AsyncClient

the AsyncClient client from the cohere package.

Runtime parameters
  • base_url: the base URL to use for the Cohere API requests. Defaults to "https://api.cohere.ai/v1".
  • api_key: the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
  • client_name: the name of the client to use for the API requests. Defaults to "distilabel".
Source code in src/distilabel/llms/cohere.py
class CohereLLM(AsyncLLM):
    """Cohere API implementation using the async client for concurrent text generation.


    Attributes:
        model: the name of the model from the Cohere API to use for the generation.
        base_url: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        api_key: the API key to authenticate the requests to the Cohere API. Defaults to
            the value of the `COHERE_API_KEY` environment variable.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        client_name: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.
        _ChatMessage: the `ChatMessage` class from the `cohere` package.
        _aclient: the `AsyncClient` client from the `cohere` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Cohere API requests. Defaults to
            `"https://api.cohere.ai/v1"`.
        - `api_key`: the API key to authenticate the requests to the Cohere API. Defaults
            to the value of the `COHERE_API_KEY` environment variable.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `client_name`: the name of the client to use for the API requests. Defaults to
            `"distilabel"`.
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "COHERE_BASE_URL", "https://api.cohere.ai/v1"
        ),
        description="The base URL to use for the Cohere API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_COHERE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Cohere API.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    client_name: RuntimeParameter[str] = Field(
        default="distilabel",
        description="The name of the client to use for the API requests.",
    )

    _ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
    _aclient: "AsyncClient" = PrivateAttr(...)

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def load(self) -> None:
        """Loads the `AsyncClient` client from the `cohere` package."""

        super().load()

        try:
            from cohere import AsyncClient, ChatMessage
        except ImportError as ie:
            raise ImportError(
                "The `cohere` package is required to use the `CohereLLM` class."
            ) from ie

        self._ChatMessage = ChatMessage

        self._aclient = AsyncClient(
            api_key=self.api_key.get_secret_value(),  # type: ignore
            client_name=self.client_name,
            base_url=self.base_url,
            timeout=self.timeout,
        )

    def _format_chat_to_cohere(
        self, input: "ChatType"
    ) -> Tuple[Union[str, None], List["ChatMessage"], str]:
        """Formats the chat input to the Cohere Chat API conversational format.

        Args:
            input: The chat input to format.

        Returns:
            A tuple containing the system, chat history, and message.
        """
        system = None
        message = None
        chat_history = []
        for item in input:
            role = item["role"]
            content = item["content"]
            if role == "system":
                system = content
            elif role == "user":
                message = content
            elif role == "assistant":
                if message is None:
                    raise ValueError(
                        "An assistant message but be preceded by a user message."
                    )
                chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
                chat_history.append(self._ChatMessage(role="CHATBOT", message=content))
                message = None

        if message is None:
            raise ValueError("The chat input must end with a user message.")

        return system, chat_history, message

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        k: Optional[int] = None,
        p: Optional[float] = None,
        seed: Optional[float] = None,
        stop_sequences: Optional[Sequence[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        raw_prompting: Optional[bool] = None,
    ) -> Union[str, None]:
        """Generates a response from the LLM given an input.

        Args:
            input: a single input in chat format to generate responses for.
            temperature: the temperature to use for the generation. Defaults to `None`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `None`.
            k: the number of highest probability vocabulary tokens to keep for the generation.
                Defaults to `None`.
            p: the nucleus sampling probability to use for the generation. Defaults to
                `None`.
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: a list of sequences to use as stopping criteria for the generation.
                Defaults to `None`.
            frequency_penalty: the frequency penalty to use for the generation. Defaults
                to `None`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `None`.
            raw_prompting: a flag to use raw prompting for the generation. Defaults to
                `None`.

        Returns:
            The generated response from the Cohere API model.
        """
        system, chat_history, message = self._format_chat_to_cohere(input)

        response = await self._aclient.chat(  # type: ignore
            message=message,
            model=self.model,
            preamble=system,
            chat_history=chat_history,
            temperature=temperature,
            max_tokens=max_tokens,
            k=k,
            p=p,
            seed=seed,
            stop_sequences=stop_sequences,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            raw_prompting=raw_prompting,
        )

        if (text := response.text) == "":
            self._logger.warning(
                f"Received no response using Cohere client (model: '{self.model}')."
                f" Finish reason was: {response.finish_reason}"
            )
            return None

        return text

    @override
    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`."""

        async def agenerate(
            inputs: List["ChatType"], **kwargs: Any
        ) -> "GenerateOutput":
            """Internal function to parallelize the asynchronous generation of responses."""
            tasks = [
                asyncio.create_task(self.agenerate(input=input, **kwargs))
                for input in inputs
                for _ in range(num_generations)
            ]
            return await asyncio.gather(*tasks)

        outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
        return list(grouper(outputs, n=num_generations, incomplete="ignore"))

model_name: str property

Returns the model name used for the LLM.

agenerate(input, temperature=None, max_tokens=None, k=None, p=None, seed=None, stop_sequences=None, frequency_penalty=None, presence_penalty=None, raw_prompting=None) async

Generates a response from the LLM given an input.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
temperature Optional[float]

the temperature to use for the generation. Defaults to None.

None
max_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to None.

None
k Optional[int]

the number of highest probability vocabulary tokens to keep for the generation. Defaults to None.

None
p Optional[float]

the nucleus sampling probability to use for the generation. Defaults to None.

None
seed Optional[float]

the seed to use for the generation. Defaults to None.

None
stop_sequences Optional[Sequence[str]]

a list of sequences to use as stopping criteria for the generation. Defaults to None.

None
frequency_penalty Optional[float]

the frequency penalty to use for the generation. Defaults to None.

None
presence_penalty Optional[float]

the presence penalty to use for the generation. Defaults to None.

None
raw_prompting Optional[bool]

a flag to use raw prompting for the generation. Defaults to None.

None

Returns:

Type Description
Union[str, None]

The generated response from the Cohere API model.

Source code in src/distilabel/llms/cohere.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    k: Optional[int] = None,
    p: Optional[float] = None,
    seed: Optional[float] = None,
    stop_sequences: Optional[Sequence[str]] = None,
    frequency_penalty: Optional[float] = None,
    presence_penalty: Optional[float] = None,
    raw_prompting: Optional[bool] = None,
) -> Union[str, None]:
    """Generates a response from the LLM given an input.

    Args:
        input: a single input in chat format to generate responses for.
        temperature: the temperature to use for the generation. Defaults to `None`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `None`.
        k: the number of highest probability vocabulary tokens to keep for the generation.
            Defaults to `None`.
        p: the nucleus sampling probability to use for the generation. Defaults to
            `None`.
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: a list of sequences to use as stopping criteria for the generation.
            Defaults to `None`.
        frequency_penalty: the frequency penalty to use for the generation. Defaults
            to `None`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `None`.
        raw_prompting: a flag to use raw prompting for the generation. Defaults to
            `None`.

    Returns:
        The generated response from the Cohere API model.
    """
    system, chat_history, message = self._format_chat_to_cohere(input)

    response = await self._aclient.chat(  # type: ignore
        message=message,
        model=self.model,
        preamble=system,
        chat_history=chat_history,
        temperature=temperature,
        max_tokens=max_tokens,
        k=k,
        p=p,
        seed=seed,
        stop_sequences=stop_sequences,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty,
        raw_prompting=raw_prompting,
    )

    if (text := response.text) == "":
        self._logger.warning(
            f"Received no response using Cohere client (model: '{self.model}')."
            f" Finish reason was: {response.finish_reason}"
        )
        return None

    return text

generate(inputs, num_generations=1, **kwargs)

Method to generate a list of responses asynchronously, returning the output synchronously awaiting for the response of each input sent to agenerate.

Source code in src/distilabel/llms/cohere.py
@override
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`."""

    async def agenerate(
        inputs: List["ChatType"], **kwargs: Any
    ) -> "GenerateOutput":
        """Internal function to parallelize the asynchronous generation of responses."""
        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        return await asyncio.gather(*tasks)

    outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
    return list(grouper(outputs, n=num_generations, incomplete="ignore"))

load()

Loads the AsyncClient client from the cohere package.

Source code in src/distilabel/llms/cohere.py
def load(self) -> None:
    """Loads the `AsyncClient` client from the `cohere` package."""

    super().load()

    try:
        from cohere import AsyncClient, ChatMessage
    except ImportError as ie:
        raise ImportError(
            "The `cohere` package is required to use the `CohereLLM` class."
        ) from ie

    self._ChatMessage = ChatMessage

    self._aclient = AsyncClient(
        api_key=self.api_key.get_secret_value(),  # type: ignore
        client_name=self.client_name,
        base_url=self.base_url,
        timeout=self.timeout,
    )

CudaDevicePlacementMixin

Bases: BaseModel

Mixin class to assign CUDA devices to the LLM based on the cuda_devices attribute and the device placement information provided in _device_llm_placement_map. Providing the device placement information is optional, but if it is provided, it will be used to assign CUDA devices to the LLMs, trying to avoid using the same device for different LLMs.

Attributes:

Name Type Description
cuda_devices Union[List[int], Literal['auto']]

a list with the ID of the CUDA devices to be used by the LLM. If set to "auto", the devices will be automatically assigned based on the device placement information provided in _device_llm_placement_map. If set to a list of devices, it will be checked if the devices are available to be used by the LLM. If not, a warning will be logged.

_llm_identifier Union[str, None]

the identifier of the LLM to be used as key in _device_llm_placement_map.

_device_llm_placement_map Union[DictProxy[str, Any], None]

a dictionary with the device placement information for each LLM.

Source code in src/distilabel/llms/mixins.py
class CudaDevicePlacementMixin(BaseModel):
    """Mixin class to assign CUDA devices to the `LLM` based on the `cuda_devices` attribute
    and the device placement information provided in `_device_llm_placement_map`. Providing
    the device placement information is optional, but if it is provided, it will be used to
    assign CUDA devices to the `LLM`s, trying to avoid using the same device for different
    `LLM`s.

    Attributes:
        cuda_devices: a list with the ID of the CUDA devices to be used by the `LLM`. If set
            to "auto", the devices will be automatically assigned based on the device
            placement information provided in `_device_llm_placement_map`. If set to a list
            of devices, it will be checked if the devices are available to be used by the
            `LLM`. If not, a warning will be logged.
        _llm_identifier: the identifier of the `LLM` to be used as key in `_device_llm_placement_map`.
        _device_llm_placement_map: a dictionary with the device placement information for each
            `LLM`.
    """

    # TODO: this should be a runtime parameter
    cuda_devices: Union[List[int], Literal["auto"]] = Field(default="auto")

    _llm_identifier: Union[str, None] = PrivateAttr(default=None)
    _device_llm_placement_map: Union["DictProxy[str, Any]", None] = PrivateAttr(
        default=None
    )
    _device_llm_placement_lock: Union["Lock", None] = PrivateAttr(default=None)
    _available_cuda_devices: Union[List[int], None] = PrivateAttr(default=None)
    _can_check_cuda_devices: bool = PrivateAttr(default=False)

    def load(self) -> None:
        """Assign CUDA devices to the LLM based on the device placement information provided
        in `_device_llm_placement_map`."""

        try:
            import pynvml

            pynvml.nvmlInit()
            device_count = pynvml.nvmlDeviceGetCount()
            self._available_cuda_devices = list(range(device_count))
            self._can_check_cuda_devices = True
        except ImportError as ie:
            if self.cuda_devices == "auto":
                raise ImportError(
                    "The 'pynvml' library is not installed. It is required to automatically"
                    " assign CUDA devices to the `LLM`s. Please, install it and try again."
                ) from ie

            if self.cuda_devices:
                self._logger.warning(  # type: ignore
                    "The 'pynvml' library is not installed. It is recommended to install it"
                    " to check if the CUDA devices assigned to the LLM are available."
                )

        self._assign_cuda_devices()

    def set_device_placement_info(
        self,
        llm_identifier: str,
        device_llm_placement_map: "DictProxy[str, Any]",
        device_llm_placement_lock: "Lock",
    ) -> None:
        """Sets the value of `_device_llm_placement_map` to be used to assign CUDA devices
        to the LLM.

        Args:
            llm_identifier: the identifier of the LLM to be used as key in the device
                placement information.
            device_llm_placement_map: a dictionary with the device placement information for
                each LLM. It should have two keys. The first key is "lock" and its value is
                a lock object to be used to synchronize the access to the device placement
                information. The second key is "value" and its value is a dictionary with the
                device placement information for each LLM.
            device_llm_placement_lock: a lock object to be used to synchronize the access to
                `_device_llm_placement_map`.
        """
        self._llm_identifier = llm_identifier
        self._device_llm_placement_map = device_llm_placement_map
        self._device_llm_placement_lock = device_llm_placement_lock

    def _assign_cuda_devices(self) -> None:
        """Assigns CUDA devices to the LLM based on the device placement information provided
        in `_device_llm_placement_map`. If the `cuda_devices` attribute is set to "auto", it
        will be set to the first available CUDA device that is not going to be used by any
        other LLM. If the `cuda_devices` attribute is set to a list of devices, it will be
        checked if the devices are available to be used by the LLM. If not, a warning will be
        logged."""

        if self._device_llm_placement_map is not None:
            with self._device_llm_placement_lock:  # type: ignore
                if self.cuda_devices == "auto":
                    self.cuda_devices = [
                        self._get_cuda_device(self._device_llm_placement_map)
                    ]
                else:
                    self._check_cuda_devices(self._device_llm_placement_map)

                self._device_llm_placement_map[self._llm_identifier] = self.cuda_devices  # type: ignore

        # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices`
        # attribute. In this case, the `cuda_devices` attribute will be set to an empty list.
        if self.cuda_devices == "auto":
            self.cuda_devices = []

        self._set_cuda_visible_devices()

    def _check_cuda_devices(self, device_map: Dict[str, List[int]]) -> None:
        """Checks if the CUDA devices assigned to the LLM are also assigned to other LLMs.

        Args:
            device_map: a dictionary with the device placement information for each LLM.
        """
        for device in self.cuda_devices:
            for llm, devices in device_map.items():
                if device in devices:
                    self._logger.warning(
                        f"LLM with identifier '{llm}' is also going to use CUDA device "
                        f"'{device}'. This may lead to performance issues or running out"
                        " of memory depending on the device capabilities and the loaded"
                        " models."
                    )

    def _get_cuda_device(self, device_map: Dict[str, List[int]]) -> int:
        """Returns the first available CUDA device to be used by the LLM that is not going
        to be used by any other LLM.

        Args:
            device_map: a dictionary with the device placement information for each LLM.

        Returns:
            The first available CUDA device to be used by the LLM.

        Raises:
            RuntimeError: if there is no available CUDA device to be used by the LLM.
        """
        for device in self._available_cuda_devices:
            if all(device not in devices for devices in device_map.values()):
                return device

        raise RuntimeError(
            "Couldn't find an available CUDA device automatically to be used by the LLM"
            f" '{self._llm_identifier}'. For forcing the use of a specific device, set the"
            " `cuda_devices` attribute to a list with the desired device(s)."
        )

    def _set_cuda_visible_devices(self) -> None:
        """Sets the `CUDA_VISIBLE_DEVICES` environment variable to the list of CUDA devices
        to be used by the LLM.
        """
        if not self.cuda_devices:
            return

        if self._can_check_cuda_devices and not all(
            device in self._available_cuda_devices for device in self.cuda_devices
        ):
            raise RuntimeError(
                f"Invalid CUDA devices for LLM '{self._llm_identifier}': {self.cuda_devices}."
                f" The available devices are: {self._available_cuda_devices}. Please, review"
                " the 'cuda_devices' attribute and try again."
            )

        cuda_devices = ",".join([str(device) for device in self.cuda_devices])
        self._logger.info(
            f"🎮 LLM '{self._llm_identifier}' is going to use the following CUDA devices:"
            f" {self.cuda_devices}."
        )
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda_devices

load()

Assign CUDA devices to the LLM based on the device placement information provided in _device_llm_placement_map.

Source code in src/distilabel/llms/mixins.py
def load(self) -> None:
    """Assign CUDA devices to the LLM based on the device placement information provided
    in `_device_llm_placement_map`."""

    try:
        import pynvml

        pynvml.nvmlInit()
        device_count = pynvml.nvmlDeviceGetCount()
        self._available_cuda_devices = list(range(device_count))
        self._can_check_cuda_devices = True
    except ImportError as ie:
        if self.cuda_devices == "auto":
            raise ImportError(
                "The 'pynvml' library is not installed. It is required to automatically"
                " assign CUDA devices to the `LLM`s. Please, install it and try again."
            ) from ie

        if self.cuda_devices:
            self._logger.warning(  # type: ignore
                "The 'pynvml' library is not installed. It is recommended to install it"
                " to check if the CUDA devices assigned to the LLM are available."
            )

    self._assign_cuda_devices()

set_device_placement_info(llm_identifier, device_llm_placement_map, device_llm_placement_lock)

Sets the value of _device_llm_placement_map to be used to assign CUDA devices to the LLM.

Parameters:

Name Type Description Default
llm_identifier str

the identifier of the LLM to be used as key in the device placement information.

required
device_llm_placement_map DictProxy[str, Any]

a dictionary with the device placement information for each LLM. It should have two keys. The first key is "lock" and its value is a lock object to be used to synchronize the access to the device placement information. The second key is "value" and its value is a dictionary with the device placement information for each LLM.

required
device_llm_placement_lock Lock

a lock object to be used to synchronize the access to _device_llm_placement_map.

required
Source code in src/distilabel/llms/mixins.py
def set_device_placement_info(
    self,
    llm_identifier: str,
    device_llm_placement_map: "DictProxy[str, Any]",
    device_llm_placement_lock: "Lock",
) -> None:
    """Sets the value of `_device_llm_placement_map` to be used to assign CUDA devices
    to the LLM.

    Args:
        llm_identifier: the identifier of the LLM to be used as key in the device
            placement information.
        device_llm_placement_map: a dictionary with the device placement information for
            each LLM. It should have two keys. The first key is "lock" and its value is
            a lock object to be used to synchronize the access to the device placement
            information. The second key is "value" and its value is a dictionary with the
            device placement information for each LLM.
        device_llm_placement_lock: a lock object to be used to synchronize the access to
            `_device_llm_placement_map`.
    """
    self._llm_identifier = llm_identifier
    self._device_llm_placement_map = device_llm_placement_map
    self._device_llm_placement_lock = device_llm_placement_lock

InferenceEndpointsLLM

Bases: AsyncLLM

InferenceEndpoints LLM implementation running the async API client via either the huggingface_hub.AsyncInferenceClient or via openai.AsyncOpenAI.

Attributes:

Name Type Description
model_id Optional[str]

the model ID to use for the LLM as available in the Hugging Face Hub, which will be used to resolve the base URL for the serverless Inference Endpoints API requests. Defaults to None.

endpoint_name Optional[RuntimeParameter[str]]

the name of the Inference Endpoint to use for the LLM. Defaults to None.

endpoint_namespace Optional[RuntimeParameter[str]]

the namespace of the Inference Endpoint to use for the LLM. Defaults to None.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Inference Endpoints API requests.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Inference Endpoints API.

tokenizer_id Optional[str]

the tokenizer ID to use for the LLM as available in the Hugging Face Hub. Defaults to None, but defining one is recommended to properly format the prompt.

model_display_name Optional[str]

the model display name to use for the LLM. Defaults to None.

use_openai_client bool

whether to use the OpenAI client instead of the Hugging Face client.

Examples:

from distilabel.llms.huggingface import InferenceEndpointsLLM

# Free serverless Inference API
llm = InferenceEndpointsLLM(
    model_id="mistralai/Mistral-7B-Instruct-v0.2",
)

# Dedicated Inference Endpoints
llm = InferenceEndpointsLLM(
    endpoint_name="<ENDPOINT_NAME>",
    api_key="<HF_API_KEY>",
    endpoint_namespace="<USER|ORG>",
)

# Dedicated Inference Endpoints or TGI
llm = InferenceEndpointsLLM(
    api_key="<HF_API_KEY>",
    base_url="<BASE_URL>",
)

llm.load()

# Synchrounous request
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

# Asynchronous request
output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
Source code in src/distilabel/llms/huggingface/inference_endpoints.py
class InferenceEndpointsLLM(AsyncLLM):
    """InferenceEndpoints LLM implementation running the async API client via either
    the `huggingface_hub.AsyncInferenceClient` or via `openai.AsyncOpenAI`.

    Attributes:
        model_id: the model ID to use for the LLM as available in the Hugging Face Hub, which
            will be used to resolve the base URL for the serverless Inference Endpoints API requests.
            Defaults to `None`.
        endpoint_name: the name of the Inference Endpoint to use for the LLM. Defaults to `None`.
        endpoint_namespace: the namespace of the Inference Endpoint to use for the LLM. Defaults to `None`.
        base_url: the base URL to use for the Inference Endpoints API requests.
        api_key: the API key to authenticate the requests to the Inference Endpoints API.
        tokenizer_id: the tokenizer ID to use for the LLM as available in the Hugging Face Hub.
            Defaults to `None`, but defining one is recommended to properly format the prompt.
        model_display_name: the model display name to use for the LLM. Defaults to `None`.
        use_openai_client: whether to use the OpenAI client instead of the Hugging Face client.

    Examples:
        ```python
        from distilabel.llms.huggingface import InferenceEndpointsLLM

        # Free serverless Inference API
        llm = InferenceEndpointsLLM(
            model_id="mistralai/Mistral-7B-Instruct-v0.2",
        )

        # Dedicated Inference Endpoints
        llm = InferenceEndpointsLLM(
            endpoint_name="<ENDPOINT_NAME>",
            api_key="<HF_API_KEY>",
            endpoint_namespace="<USER|ORG>",
        )

        # Dedicated Inference Endpoints or TGI
        llm = InferenceEndpointsLLM(
            api_key="<HF_API_KEY>",
            base_url="<BASE_URL>",
        )

        llm.load()

        # Synchrounous request
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        # Asynchronous request
        output = await llm.agenerate(input=[{"role": "user", "content": "Hello world!"}])
        ```
    """

    model_id: Optional[str] = None

    endpoint_name: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The name of the Inference Endpoint to use for the LLM.",
    )
    endpoint_namespace: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The namespace of the Inference Endpoint to use for the LLM.",
    )
    base_url: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The base URL to use for the Inference Endpoints API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default=os.getenv(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Inference Endpoints API.",
    )

    tokenizer_id: Optional[str] = None
    model_display_name: Optional[str] = None
    use_openai_client: bool = False

    _model_name: Optional[str] = PrivateAttr(default=None)
    _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(default=None)
    _api_key_env_var: str = PrivateAttr(_INFERENCE_ENDPOINTS_API_KEY_ENV_VAR_NAME)
    _aclient: Optional[Union["AsyncInferenceClient", "AsyncOpenAI"]] = PrivateAttr(...)

    @model_validator(mode="after")  # type: ignore
    def only_one_of_model_id_endpoint_name_or_base_url_provided(
        self,
    ) -> "InferenceEndpointsLLM":
        """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
        provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
        favour of the dynamically calculated one.."""

        if self.base_url and (self.model_id or self.endpoint_name):
            self._logger.warning(  # type: ignore
                f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`"
                " is also provided, the `base_url` will either be ignored or overwritten with the one generated"
                " from either of those args, for serverless or dedicated inference endpoints, respectively."
            )

        if self.base_url and not (self.model_id or self.endpoint_name):
            return self

        if self.model_id and not self.endpoint_name:
            return self

        if self.endpoint_name and not self.model_id:
            return self

        raise ValidationError(
            "Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too,"
            " it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name},"
            f" and `base_url`={self.base_url}."
        )

    def load(self) -> None:  # noqa: C901
        """Loads the either the `AsyncInferenceClient` or the `AsyncOpenAI` client to benefit
        from async requests, running the Hugging Face Inference Endpoint underneath via the
        `/v1/chat/completions` endpoint, exposed for the models running on TGI using the
        `text-generation` task.

        Raises:
            ImportError: if the `openai` Python client is not installed.
            ImportError: if the `huggingface-hub` Python client is not installed.
            ValueError: if the model is not currently deployed or is not running the TGI framework.
            ImportError: if the `transformers` Python client is not installed.
        """
        super().load()

        try:
            from huggingface_hub import (
                AsyncInferenceClient,
                InferenceClient,
                get_inference_endpoint,
            )
        except ImportError as ie:
            raise ImportError(
                "Hugging Face Hub Python client is not installed. Please install it using"
                " `pip install huggingface-hub`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        if self.model_id is not None:
            client = InferenceClient()
            status = client.get_model_status(self.model_id)

            if (
                status.state not in {"Loadable", "Loaded"}
                and status.framework != "text-generation-inference"
            ):
                raise ValueError(
                    f"Model {self.model_id} is not currently deployed or is not running the TGI framework"
                )

            self.base_url = client._resolve_url(
                model=self.model_id, task="text-generation"
            )

        if self.endpoint_name is not None:
            client = get_inference_endpoint(
                name=self.endpoint_name,
                namespace=self.endpoint_namespace,
                token=self.api_key.get_secret_value(),
            )
            if client.status in ["paused", "scaledToZero"]:
                client.resume().wait(timeout=300)
            elif client.status in ["initializing"]:
                client.wait(timeout=300)

            self.base_url = client.url
            self._model_name = client.repository

        if self.use_openai_client:
            try:
                from openai import AsyncOpenAI
            except ImportError as ie:
                raise ImportError(
                    "OpenAI Python client is not installed. Please install it using"
                    " `pip install openai`."
                ) from ie

            self._aclient = AsyncOpenAI(
                base_url=self.base_url,
                api_key=self.api_key.get_secret_value(),
                max_retries=6,
            )
        else:
            self._aclient = AsyncInferenceClient(
                model=self.base_url,
                token=self.api_key.get_secret_value(),
            )

        if self.tokenizer_id:
            try:
                from transformers import AutoTokenizer
            except ImportError as ie:
                raise ImportError(
                    "Transformers Python client is not installed. Please install it using"
                    " `pip install transformers`."
                ) from ie

            self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)

    @property
    @override
    def model_name(self) -> Union[str, None]:  # type: ignore
        """Returns the model name used for the LLM."""
        return (
            self.model_display_name
            or self._model_name
            or self.model_id
            or self.endpoint_name
            or self.base_url
        )

    async def _openai_agenerate(
        self,
        input: "ChatType",
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        """Generates completions for the given input using the OpenAI async client."""
        completion = await self._aclient.chat.completions.create(  # type: ignore
            messages=input,  # type: ignore
            model="tgi",
            max_tokens=max_new_tokens,
            n=1,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            timeout=50,
        )
        if completion.choices[0].message.content is None:
            self._logger.warning(
                f"⚠️ Received no response using OpenAI client (model: '{self.model_name}')."
                f" Finish reason was: {completion.choices[0].finish_reason}"
            )
        return [completion.choices[0].message.content]

    # TODO: add `num_generations` parameter once either TGI or `AsyncInferenceClient` allows `n` parameter
    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        repetition_penalty: Optional[float] = None,
        temperature: float = 1.0,
        do_sample: bool = False,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        typical_p: Optional[float] = None,
    ) -> "GenerateOutput":
        """Generates completions for the given input using the OpenAI async client.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`. Only applies if `use_openai_client=True`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`. Only applies if `use_openai_client=True`.
            repetition_penalty: the repetition penalty to use for the generation. Defaults
                to `None`. Only applies if `use_openai_client=False`.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            do_sample: whether to use sampling for the generation. Defaults to `False`.
                Only applies if `use_openai_client=False`.
            top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither
                `0.0` nor `1.0` are valid values in TGI.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            typical_p: the typical-p value to use for the generation. Defaults to `0.5`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        if self.use_openai_client:
            return await self._openai_agenerate(
                input=input,
                max_new_tokens=max_new_tokens,
                frequency_penalty=frequency_penalty,
                presence_penalty=presence_penalty,
                temperature=temperature,
                top_p=top_p,
            )

        if self._tokenizer is not None:
            prompt = self._tokenizer.apply_chat_template(  # type: ignore
                conversation=input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,
            )
        else:
            prompt = "\n".join([message["content"] for message in input])

        try:
            completion = await self._aclient.text_generation(  # type: ignore
                prompt=prompt,  # type: ignore
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                typical_p=typical_p,
                repetition_penalty=repetition_penalty,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                # NOTE: here to ensure that the cache is not used and a different response is
                # generated every time
                seed=random.randint(0, 2147483647),
            )
            return [completion]
        except Exception as e:
            self._logger.warning(
                f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
                f" Finish reason was: {e}"
            )
            return [None]

    # TODO: remove this function once `AsyncInferenceClient` allows `n` parameter
    @override
    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`.
        """

        async def agenerate(
            inputs: List["ChatType"], **kwargs: Any
        ) -> "GenerateOutput":
            """Internal function to parallelize the asynchronous generation of responses."""
            tasks = [
                asyncio.create_task(self.agenerate(input=input, **kwargs))
                for input in inputs
                for _ in range(num_generations)
            ]
            return [outputs[0] for outputs in await asyncio.gather(*tasks)]

        outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
        return list(grouper(outputs, n=num_generations, incomplete="ignore"))

model_name: Union[str, None] property

Returns the model name used for the LLM.

agenerate(input, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, repetition_penalty=None, temperature=1.0, do_sample=False, top_k=None, top_p=None, typical_p=None) async

Generates completions for the given input using the OpenAI async client.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0. Only applies if use_openai_client=True.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0. Only applies if use_openai_client=True.

0.0
repetition_penalty Optional[float]

the repetition penalty to use for the generation. Defaults to None. Only applies if use_openai_client=False.

None
temperature float

the temperature to use for the generation. Defaults to 1.0.

1.0
do_sample bool

whether to use sampling for the generation. Defaults to False. Only applies if use_openai_client=False.

False
top_k Optional[int]

the top-k value to use for the generation. Defaults to 0.8, since neither 0.0 nor 1.0 are valid values in TGI.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

None
typical_p Optional[float]

the typical-p value to use for the generation. Defaults to 0.5.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/huggingface/inference_endpoints.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    repetition_penalty: Optional[float] = None,
    temperature: float = 1.0,
    do_sample: bool = False,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    typical_p: Optional[float] = None,
) -> "GenerateOutput":
    """Generates completions for the given input using the OpenAI async client.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`. Only applies if `use_openai_client=True`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`. Only applies if `use_openai_client=True`.
        repetition_penalty: the repetition penalty to use for the generation. Defaults
            to `None`. Only applies if `use_openai_client=False`.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        do_sample: whether to use sampling for the generation. Defaults to `False`.
            Only applies if `use_openai_client=False`.
        top_k: the top-k value to use for the generation. Defaults to `0.8`, since neither
            `0.0` nor `1.0` are valid values in TGI.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        typical_p: the typical-p value to use for the generation. Defaults to `0.5`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    if self.use_openai_client:
        return await self._openai_agenerate(
            input=input,
            max_new_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
        )

    if self._tokenizer is not None:
        prompt = self._tokenizer.apply_chat_template(  # type: ignore
            conversation=input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )
    else:
        prompt = "\n".join([message["content"] for message in input])

    try:
        completion = await self._aclient.text_generation(  # type: ignore
            prompt=prompt,  # type: ignore
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            typical_p=typical_p,
            repetition_penalty=repetition_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            # NOTE: here to ensure that the cache is not used and a different response is
            # generated every time
            seed=random.randint(0, 2147483647),
        )
        return [completion]
    except Exception as e:
        self._logger.warning(
            f"⚠️ Received no response using Inference Client (model: '{self.model_name}')."
            f" Finish reason was: {e}"
        )
        return [None]

generate(inputs, num_generations=1, **kwargs)

Method to generate a list of responses asynchronously, returning the output synchronously awaiting for the response of each input sent to agenerate.

Source code in src/distilabel/llms/huggingface/inference_endpoints.py
@override
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`.
    """

    async def agenerate(
        inputs: List["ChatType"], **kwargs: Any
    ) -> "GenerateOutput":
        """Internal function to parallelize the asynchronous generation of responses."""
        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        return [outputs[0] for outputs in await asyncio.gather(*tasks)]

    outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
    return list(grouper(outputs, n=num_generations, incomplete="ignore"))

load()

Loads the either the AsyncInferenceClient or the AsyncOpenAI client to benefit from async requests, running the Hugging Face Inference Endpoint underneath via the /v1/chat/completions endpoint, exposed for the models running on TGI using the text-generation task.

Raises:

Type Description
ImportError

if the openai Python client is not installed.

ImportError

if the huggingface-hub Python client is not installed.

ValueError

if the model is not currently deployed or is not running the TGI framework.

ImportError

if the transformers Python client is not installed.

Source code in src/distilabel/llms/huggingface/inference_endpoints.py
def load(self) -> None:  # noqa: C901
    """Loads the either the `AsyncInferenceClient` or the `AsyncOpenAI` client to benefit
    from async requests, running the Hugging Face Inference Endpoint underneath via the
    `/v1/chat/completions` endpoint, exposed for the models running on TGI using the
    `text-generation` task.

    Raises:
        ImportError: if the `openai` Python client is not installed.
        ImportError: if the `huggingface-hub` Python client is not installed.
        ValueError: if the model is not currently deployed or is not running the TGI framework.
        ImportError: if the `transformers` Python client is not installed.
    """
    super().load()

    try:
        from huggingface_hub import (
            AsyncInferenceClient,
            InferenceClient,
            get_inference_endpoint,
        )
    except ImportError as ie:
        raise ImportError(
            "Hugging Face Hub Python client is not installed. Please install it using"
            " `pip install huggingface-hub`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    if self.model_id is not None:
        client = InferenceClient()
        status = client.get_model_status(self.model_id)

        if (
            status.state not in {"Loadable", "Loaded"}
            and status.framework != "text-generation-inference"
        ):
            raise ValueError(
                f"Model {self.model_id} is not currently deployed or is not running the TGI framework"
            )

        self.base_url = client._resolve_url(
            model=self.model_id, task="text-generation"
        )

    if self.endpoint_name is not None:
        client = get_inference_endpoint(
            name=self.endpoint_name,
            namespace=self.endpoint_namespace,
            token=self.api_key.get_secret_value(),
        )
        if client.status in ["paused", "scaledToZero"]:
            client.resume().wait(timeout=300)
        elif client.status in ["initializing"]:
            client.wait(timeout=300)

        self.base_url = client.url
        self._model_name = client.repository

    if self.use_openai_client:
        try:
            from openai import AsyncOpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install openai`."
            ) from ie

        self._aclient = AsyncOpenAI(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),
            max_retries=6,
        )
    else:
        self._aclient = AsyncInferenceClient(
            model=self.base_url,
            token=self.api_key.get_secret_value(),
        )

    if self.tokenizer_id:
        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "Transformers Python client is not installed. Please install it using"
                " `pip install transformers`."
            ) from ie

        self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)

only_one_of_model_id_endpoint_name_or_base_url_provided()

Validates that only one of model_id or endpoint_name is provided; and if base_url is also provided, a warning will be shown informing the user that the provided base_url will be ignored in favour of the dynamically calculated one..

Source code in src/distilabel/llms/huggingface/inference_endpoints.py
@model_validator(mode="after")  # type: ignore
def only_one_of_model_id_endpoint_name_or_base_url_provided(
    self,
) -> "InferenceEndpointsLLM":
    """Validates that only one of `model_id` or `endpoint_name` is provided; and if `base_url` is also
    provided, a warning will be shown informing the user that the provided `base_url` will be ignored in
    favour of the dynamically calculated one.."""

    if self.base_url and (self.model_id or self.endpoint_name):
        self._logger.warning(  # type: ignore
            f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`"
            " is also provided, the `base_url` will either be ignored or overwritten with the one generated"
            " from either of those args, for serverless or dedicated inference endpoints, respectively."
        )

    if self.base_url and not (self.model_id or self.endpoint_name):
        return self

    if self.model_id and not self.endpoint_name:
        return self

    if self.endpoint_name and not self.model_id:
        return self

    raise ValidationError(
        "Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too,"
        " it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name},"
        f" and `base_url`={self.base_url}."
    )

LLM

Bases: RuntimeParametersMixin, BaseModel, _Serializable, ABC

Base class for LLMs to be used in distilabel framework.

To implement an LLM subclass, you need to subclass this class and implement: - load method to load the LLM if needed. Don't forget to call super().load(), so the _logger attribute is initialized. - model_name property to return the model name used for the LLM. - generate method to generate num_generations per input in inputs.

Attributes:

Name Type Description
generation_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

the kwargs to be propagated to either generate or agenerate methods within each LLM.

_logger Union[Logger, None]

the logger to be used for the LLM. It will be initialized when the load method is called.

Source code in src/distilabel/llms/base.py
class LLM(RuntimeParametersMixin, BaseModel, _Serializable, ABC):
    """Base class for `LLM`s to be used in `distilabel` framework.

    To implement an `LLM` subclass, you need to subclass this class and implement:
        - `load` method to load the `LLM` if needed. Don't forget to call `super().load()`,
            so the `_logger` attribute is initialized.
        - `model_name` property to return the model name used for the LLM.
        - `generate` method to generate `num_generations` per input in `inputs`.

    Attributes:
        generation_kwargs: the kwargs to be propagated to either `generate` or `agenerate`
            methods within each `LLM`.
        _logger: the logger to be used for the `LLM`. It will be initialized when the `load`
            method is called.
    """

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        protected_namespaces=(),
        validate_default=True,
        validate_assignment=True,
    )

    generation_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="The kwargs to be propagated to either `generate` or `agenerate`"
        " methods within each `LLM`.",
    )

    _logger: Union[logging.Logger, None] = PrivateAttr(...)

    def load(self) -> None:
        """Method to be called to initialize the `LLM` and its logger."""
        self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")

    @property
    @abstractmethod
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        pass

    @abstractmethod
    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Abstract method to be implemented by each LLM to generate `num_generations`
        per input in `inputs`.

        Args:
            inputs: the list of inputs to generate responses for which follows OpenAI's
                API format:

                ```python
                [
                    {"role": "system", "content": "You're a helpful assistant..."},
                    {"role": "user", "content": "Give a template email for B2B communications..."},
                    {"role": "assistant", "content": "Sure, here's a template you can use..."},
                    {"role": "user", "content": "Modify the second paragraph..."}
                ]
                ```
            num_generations: the number of generations to generate per input.
            **kwargs: the additional kwargs to be used for the generation.
        """
        pass

    @property
    def generate_parameters(self) -> List["inspect.Parameter"]:
        """Returns the parameters of the `generate` method.

        Returns:
            A list containing the parameters of the `generate` method.
        """
        return list(inspect.signature(self.generate).parameters.values())

    @property
    def runtime_parameters_names(self) -> "RuntimeParametersNames":
        """Returns the runtime parameters of the `LLM`, which are combination of the
        attributes of the `LLM` type hinted with `RuntimeParameter` and the parameters
        of the `generate` method that are not `input` and `num_generations`.

        Returns:
            A dictionary with the name of the runtime parameters as keys and a boolean
            indicating if the parameter is optional or not.
        """
        runtime_parameters = super().runtime_parameters_names
        runtime_parameters["generation_kwargs"] = {}

        # runtime parameters from the `generate` method
        for param in self.generate_parameters:
            if param.name in ["input", "inputs", "num_generations"]:
                continue
            is_optional = param.default != inspect.Parameter.empty
            runtime_parameters["generation_kwargs"][param.name] = is_optional

        return runtime_parameters

    def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:
        """Gets the information of the runtime parameters of the `LLM` such as the name
        and the description. This function is meant to include the information of the runtime
        parameters in the serialized data of the `LLM`.

        Returns:
            A list containing the information for each runtime parameter of the `LLM`.
        """
        runtime_parameters_info = super().get_runtime_parameters_info()

        generation_kwargs_info = next(
            runtime_parameter_info
            for runtime_parameter_info in runtime_parameters_info
            if runtime_parameter_info["name"] == "generation_kwargs"
        )

        generate_docstring_args = self.generate_parsed_docstring["args"]

        generation_kwargs_info["keys"] = []
        for key, value in generation_kwargs_info["optional"].items():
            info = {"name": key, "optional": value}
            if description := generate_docstring_args.get(key):
                info["description"] = description
            generation_kwargs_info["keys"].append(info)

        generation_kwargs_info.pop("optional")

        return runtime_parameters_info

    @cached_property
    def generate_parsed_docstring(self) -> "Docstring":
        """Returns the parsed docstring of the `generate` method.

        Returns:
            The parsed docstring of the `generate` method.
        """
        return parse_google_docstring(self.generate)

    def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
        """Method to get the last hidden states of the model for a list of inputs.

        Args:
            inputs: the list of inputs to get the last hidden states from.

        Returns:
            A list containing the last hidden state for each sequence using a NumPy array
                with shape [num_tokens, hidden_size].
        """
        raise NotImplementedError(
            f"Method `get_last_hidden_states` is not implemented for `{self.__class__.__name__}`"
        )

generate_parameters: List[inspect.Parameter] property

Returns the parameters of the generate method.

Returns:

Type Description
List[Parameter]

A list containing the parameters of the generate method.

generate_parsed_docstring: Docstring cached property

Returns the parsed docstring of the generate method.

Returns:

Type Description
Docstring

The parsed docstring of the generate method.

model_name: str abstractmethod property

Returns the model name used for the LLM.

runtime_parameters_names: RuntimeParametersNames property

Returns the runtime parameters of the LLM, which are combination of the attributes of the LLM type hinted with RuntimeParameter and the parameters of the generate method that are not input and num_generations.

Returns:

Type Description
RuntimeParametersNames

A dictionary with the name of the runtime parameters as keys and a boolean

RuntimeParametersNames

indicating if the parameter is optional or not.

generate(inputs, num_generations=1, **kwargs) abstractmethod

Abstract method to be implemented by each LLM to generate num_generations per input in inputs.

Parameters:

Name Type Description Default
inputs List[ChatType]

the list of inputs to generate responses for which follows OpenAI's API format:

[
    {"role": "system", "content": "You're a helpful assistant..."},
    {"role": "user", "content": "Give a template email for B2B communications..."},
    {"role": "assistant", "content": "Sure, here's a template you can use..."},
    {"role": "user", "content": "Modify the second paragraph..."}
]
required
num_generations int

the number of generations to generate per input.

1
**kwargs Any

the additional kwargs to be used for the generation.

{}
Source code in src/distilabel/llms/base.py
@abstractmethod
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Abstract method to be implemented by each LLM to generate `num_generations`
    per input in `inputs`.

    Args:
        inputs: the list of inputs to generate responses for which follows OpenAI's
            API format:

            ```python
            [
                {"role": "system", "content": "You're a helpful assistant..."},
                {"role": "user", "content": "Give a template email for B2B communications..."},
                {"role": "assistant", "content": "Sure, here's a template you can use..."},
                {"role": "user", "content": "Modify the second paragraph..."}
            ]
            ```
        num_generations: the number of generations to generate per input.
        **kwargs: the additional kwargs to be used for the generation.
    """
    pass

get_last_hidden_states(inputs)

Method to get the last hidden states of the model for a list of inputs.

Parameters:

Name Type Description Default
inputs List[ChatType]

the list of inputs to get the last hidden states from.

required

Returns:

Type Description
List[HiddenState]

A list containing the last hidden state for each sequence using a NumPy array with shape [num_tokens, hidden_size].

Source code in src/distilabel/llms/base.py
def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
    """Method to get the last hidden states of the model for a list of inputs.

    Args:
        inputs: the list of inputs to get the last hidden states from.

    Returns:
        A list containing the last hidden state for each sequence using a NumPy array
            with shape [num_tokens, hidden_size].
    """
    raise NotImplementedError(
        f"Method `get_last_hidden_states` is not implemented for `{self.__class__.__name__}`"
    )

get_runtime_parameters_info()

Gets the information of the runtime parameters of the LLM such as the name and the description. This function is meant to include the information of the runtime parameters in the serialized data of the LLM.

Returns:

Type Description
List[Dict[str, Any]]

A list containing the information for each runtime parameter of the LLM.

Source code in src/distilabel/llms/base.py
def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:
    """Gets the information of the runtime parameters of the `LLM` such as the name
    and the description. This function is meant to include the information of the runtime
    parameters in the serialized data of the `LLM`.

    Returns:
        A list containing the information for each runtime parameter of the `LLM`.
    """
    runtime_parameters_info = super().get_runtime_parameters_info()

    generation_kwargs_info = next(
        runtime_parameter_info
        for runtime_parameter_info in runtime_parameters_info
        if runtime_parameter_info["name"] == "generation_kwargs"
    )

    generate_docstring_args = self.generate_parsed_docstring["args"]

    generation_kwargs_info["keys"] = []
    for key, value in generation_kwargs_info["optional"].items():
        info = {"name": key, "optional": value}
        if description := generate_docstring_args.get(key):
            info["description"] = description
        generation_kwargs_info["keys"].append(info)

    generation_kwargs_info.pop("optional")

    return runtime_parameters_info

load()

Method to be called to initialize the LLM and its logger.

Source code in src/distilabel/llms/base.py
def load(self) -> None:
    """Method to be called to initialize the `LLM` and its logger."""
    self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")

LiteLLM

Bases: AsyncLLM

LiteLLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large", etc.

verbose RuntimeParameter[bool]

whether to log the LiteLLM client's logs. Defaults to False.

Runtime parameters
  • verbose: whether to log the LiteLLM client's logs. Defaults to False.
Source code in src/distilabel/llms/litellm.py
class LiteLLM(AsyncLLM):
    """LiteLLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "gpt-3.5-turbo" or "mistral/mistral-large", etc.
        verbose: whether to log the LiteLLM client's logs. Defaults to `False`.

    Runtime parameters:
        - `verbose`: whether to log the LiteLLM client's logs. Defaults to `False`.
    """

    model: str
    verbose: RuntimeParameter[bool] = Field(
        default=False, description="Whether to log the LiteLLM client's logs."
    )

    _aclient: Optional[Callable] = PrivateAttr(...)

    def load(self) -> None:
        """
        Loads the `acompletion` LiteLLM client to benefit from async requests.
        """
        super().load()

        try:
            import litellm

            litellm.telemetry = False
        except ImportError as e:
            raise ImportError(
                "LiteLLM Python client is not installed. Please install it using"
                " `pip install litellm`."
            ) from e
        self._aclient = litellm.acompletion

        if not self.verbose:
            litellm.suppress_debug_info = True
            for key in logging.Logger.manager.loggerDict.keys():
                if "litellm" not in key.lower():
                    continue
                logging.getLogger(key).setLevel(logging.CRITICAL)

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        num_generations: int = 1,
        functions: Optional[List] = None,
        function_call: Optional[str] = None,
        temperature: Optional[float] = 1.0,
        top_p: Optional[float] = 1.0,
        stop: Optional[Union[str, list]] = None,
        max_tokens: Optional[int] = None,
        presence_penalty: Optional[float] = None,
        frequency_penalty: Optional[float] = None,
        logit_bias: Optional[dict] = None,
        user: Optional[str] = None,
        metadata: Optional[dict] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        api_key: Optional[str] = None,
        model_list: Optional[list] = None,
        mock_response: Optional[str] = None,
        force_timeout: Optional[int] = 600,
        custom_llm_provider: Optional[str] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            functions: a list of functions to apply to the conversation messages. Defaults to
                `None`.
            function_call: the name of the function to call within the conversation. Defaults
                to `None`.
            temperature: the temperature to use for the generation. Defaults to `1.0`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: Up to 4 sequences where the LLM API will stop generating further tokens.
                Defaults to `None`.
            max_tokens: The maximum number of tokens in the generated completion. Defaults to
                `None`.
            presence_penalty: It is used to penalize new tokens based on their existence in the
                text so far. Defaults to `None`.
            frequency_penalty: It is used to penalize new tokens based on their frequency in the
                text so far. Defaults to `None`.
            logit_bias: Used to modify the probability of specific tokens appearing in the
                completion. Defaults to `None`.
            user: A unique identifier representing your end-user. This can help the LLM provider
                to monitor and detect abuse. Defaults to `None`.
            metadata: Pass in additional metadata to tag your completion calls - eg. prompt
                version, details, etc. Defaults to `None`.
            api_base: Base URL for the API. Defaults to `None`.
            api_version: API version. Defaults to `None`.
            api_key: API key. Defaults to `None`.
            model_list: List of api base, version, keys. Defaults to `None`.
            mock_response: If provided, return a mock completion response for testing or debugging
                purposes. Defaults to `None`.
            force_timeout: The maximum execution time in seconds for the completion request.
                Defaults to `600`.
            custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
                model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
                `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        import litellm

        async def _call_aclient_until_n_choices() -> List["Choices"]:
            choices = []
            while len(choices) < num_generations:
                completion = await self._aclient(  # type: ignore
                    model=self.model,
                    messages=input,
                    n=num_generations,
                    functions=functions,
                    function_call=function_call,
                    temperature=temperature,
                    top_p=top_p,
                    stream=False,
                    stop=stop,
                    max_tokens=max_tokens,
                    presence_penalty=presence_penalty,
                    frequency_penalty=frequency_penalty,
                    logit_bias=logit_bias,
                    user=user,
                    metadata=metadata,
                    api_base=api_base,
                    api_version=api_version,
                    api_key=api_key,
                    model_list=model_list,
                    mock_response=mock_response,
                    force_timeout=force_timeout,
                    custom_llm_provider=custom_llm_provider,
                )
                choices.extend(completion.choices)
            return choices

        # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
        try:
            litellm.drop_params = False
            choices = await _call_aclient_until_n_choices()
        except litellm.exceptions.APIError as e:
            if "does not support parameters" in str(e):
                litellm.drop_params = True
                choices = await _call_aclient_until_n_choices()
            else:
                raise e

        generations = []
        for choice in choices:
            if (content := choice.message.content) is None:
                self._logger.warning(
                    f"Received no response using LiteLLM client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return generations

model_name: str property

Returns the model name used for the LLM.

agenerate(input, num_generations=1, functions=None, function_call=None, temperature=1.0, top_p=1.0, stop=None, max_tokens=None, presence_penalty=None, frequency_penalty=None, logit_bias=None, user=None, metadata=None, api_base=None, api_version=None, api_key=None, model_list=None, mock_response=None, force_timeout=600, custom_llm_provider=None) async

Generates num_generations responses for the given input using the LiteLLM async client.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
functions Optional[List]

a list of functions to apply to the conversation messages. Defaults to None.

None
function_call Optional[str]

the name of the function to call within the conversation. Defaults to None.

None
temperature Optional[float]

the temperature to use for the generation. Defaults to 1.0.

1.0
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[Union[str, list]]

Up to 4 sequences where the LLM API will stop generating further tokens. Defaults to None.

None
max_tokens Optional[int]

The maximum number of tokens in the generated completion. Defaults to None.

None
presence_penalty Optional[float]

It is used to penalize new tokens based on their existence in the text so far. Defaults to None.

None
frequency_penalty Optional[float]

It is used to penalize new tokens based on their frequency in the text so far. Defaults to None.

None
logit_bias Optional[dict]

Used to modify the probability of specific tokens appearing in the completion. Defaults to None.

None
user Optional[str]

A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. Defaults to None.

None
metadata Optional[dict]

Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. Defaults to None.

None
api_base Optional[str]

Base URL for the API. Defaults to None.

None
api_version Optional[str]

API version. Defaults to None.

None
api_key Optional[str]

API key. Defaults to None.

None
model_list Optional[list]

List of api base, version, keys. Defaults to None.

None
mock_response Optional[str]

If provided, return a mock completion response for testing or debugging purposes. Defaults to None.

None
force_timeout Optional[int]

The maximum execution time in seconds for the completion request. Defaults to 600.

600
custom_llm_provider Optional[str]

Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable) model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/litellm.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    num_generations: int = 1,
    functions: Optional[List] = None,
    function_call: Optional[str] = None,
    temperature: Optional[float] = 1.0,
    top_p: Optional[float] = 1.0,
    stop: Optional[Union[str, list]] = None,
    max_tokens: Optional[int] = None,
    presence_penalty: Optional[float] = None,
    frequency_penalty: Optional[float] = None,
    logit_bias: Optional[dict] = None,
    user: Optional[str] = None,
    metadata: Optional[dict] = None,
    api_base: Optional[str] = None,
    api_version: Optional[str] = None,
    api_key: Optional[str] = None,
    model_list: Optional[list] = None,
    mock_response: Optional[str] = None,
    force_timeout: Optional[int] = 600,
    custom_llm_provider: Optional[str] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [LiteLLM async client](https://github.com/BerriAI/litellm).

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        functions: a list of functions to apply to the conversation messages. Defaults to
            `None`.
        function_call: the name of the function to call within the conversation. Defaults
            to `None`.
        temperature: the temperature to use for the generation. Defaults to `1.0`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: Up to 4 sequences where the LLM API will stop generating further tokens.
            Defaults to `None`.
        max_tokens: The maximum number of tokens in the generated completion. Defaults to
            `None`.
        presence_penalty: It is used to penalize new tokens based on their existence in the
            text so far. Defaults to `None`.
        frequency_penalty: It is used to penalize new tokens based on their frequency in the
            text so far. Defaults to `None`.
        logit_bias: Used to modify the probability of specific tokens appearing in the
            completion. Defaults to `None`.
        user: A unique identifier representing your end-user. This can help the LLM provider
            to monitor and detect abuse. Defaults to `None`.
        metadata: Pass in additional metadata to tag your completion calls - eg. prompt
            version, details, etc. Defaults to `None`.
        api_base: Base URL for the API. Defaults to `None`.
        api_version: API version. Defaults to `None`.
        api_key: API key. Defaults to `None`.
        model_list: List of api base, version, keys. Defaults to `None`.
        mock_response: If provided, return a mock completion response for testing or debugging
            purposes. Defaults to `None`.
        force_timeout: The maximum execution time in seconds for the completion request.
            Defaults to `600`.
        custom_llm_provider: Used for Non-OpenAI LLMs, Example usage for bedrock, set(iterable)
            model="amazon.titan-tg1-large" and custom_llm_provider="bedrock". Defaults to
            `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    import litellm

    async def _call_aclient_until_n_choices() -> List["Choices"]:
        choices = []
        while len(choices) < num_generations:
            completion = await self._aclient(  # type: ignore
                model=self.model,
                messages=input,
                n=num_generations,
                functions=functions,
                function_call=function_call,
                temperature=temperature,
                top_p=top_p,
                stream=False,
                stop=stop,
                max_tokens=max_tokens,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                logit_bias=logit_bias,
                user=user,
                metadata=metadata,
                api_base=api_base,
                api_version=api_version,
                api_key=api_key,
                model_list=model_list,
                mock_response=mock_response,
                force_timeout=force_timeout,
                custom_llm_provider=custom_llm_provider,
            )
            choices.extend(completion.choices)
        return choices

    # litellm.drop_params is used to en/disable sending **kwargs parameters to the API if they cannot be used
    try:
        litellm.drop_params = False
        choices = await _call_aclient_until_n_choices()
    except litellm.exceptions.APIError as e:
        if "does not support parameters" in str(e):
            litellm.drop_params = True
            choices = await _call_aclient_until_n_choices()
        else:
            raise e

    generations = []
    for choice in choices:
        if (content := choice.message.content) is None:
            self._logger.warning(
                f"Received no response using LiteLLM client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return generations

load()

Loads the acompletion LiteLLM client to benefit from async requests.

Source code in src/distilabel/llms/litellm.py
def load(self) -> None:
    """
    Loads the `acompletion` LiteLLM client to benefit from async requests.
    """
    super().load()

    try:
        import litellm

        litellm.telemetry = False
    except ImportError as e:
        raise ImportError(
            "LiteLLM Python client is not installed. Please install it using"
            " `pip install litellm`."
        ) from e
    self._aclient = litellm.acompletion

    if not self.verbose:
        litellm.suppress_debug_info = True
        for key in logging.Logger.manager.loggerDict.keys():
            if "litellm" not in key.lower():
                continue
            logging.getLogger(key).setLevel(logging.CRITICAL)

LlamaCppLLM

Bases: LLM

llama.cpp LLM implementation running the Python bindings for the C++ code.

Attributes:

Name Type Description
chat_format str

the chat format to use for the model. Defaults to chatml.

model_path RuntimeParameter[FilePath]

contains the path to the GGUF quantized model, compatible with the installed version of the llama.cpp Python bindings.

n_gpu_layers RuntimeParameter[int]

the number of layers to use for the GPU. Defaults to -1, meaning that the available GPU device will be used.

verbose RuntimeParameter[bool]

whether to print verbose output. Defaults to False.

_model Optional[Llama]

the Llama model instance. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

Runtime parameters
  • model_path: the path to the GGUF quantized model.
  • n_gpu_layers: the number of layers to use for the GPU. Defaults to -1.
  • verbose: whether to print verbose output. Defaults to False.
Source code in src/distilabel/llms/llamacpp.py
class LlamaCppLLM(LLM):
    """llama.cpp LLM implementation running the Python bindings for the C++ code.

    Attributes:
        chat_format: the chat format to use for the model. Defaults to `chatml`.
        model_path: contains the path to the GGUF quantized model, compatible with the
            installed version of the `llama.cpp` Python bindings.
        n_gpu_layers: the number of layers to use for the GPU. Defaults to `-1`, meaning that
            the available GPU device will be used.
        verbose: whether to print verbose output. Defaults to `False`.
        _model: the Llama model instance. This attribute is meant to be used internally and
            should not be accessed directly. It will be set in the `load` method.

    Runtime parameters:
        - `model_path`: the path to the GGUF quantized model.
        - `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`.
        - `verbose`: whether to print verbose output. Defaults to `False`.
    """

    chat_format: str = "chatml"
    model_path: RuntimeParameter[FilePath] = Field(
        default=None, description="The path to the GGUF quantized model."
    )
    n_gpu_layers: RuntimeParameter[int] = Field(
        default=-1,
        description="The number of layers that will be loaded in the GPU.",
    )
    verbose: RuntimeParameter[bool] = Field(
        default=False,
        description="Whether to print verbose output from llama.cpp library.",
    )

    _model: Optional["Llama"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `Llama` model from the `model_path`."""

        try:
            from llama_cpp import Llama
        except ImportError as ie:
            raise ImportError(
                "The `llama_cpp` package is required to use the `LlamaCppLLM` class."
            ) from ie

        self._model = Llama(
            model_path=self.model_path.as_posix(),
            chat_format=self.chat_format,
            n_gpu_layers=self.n_gpu_layers,
            verbose=self.verbose,
        )

        super().load()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self._model.model_path  # type: ignore

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[ChatType],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for the given input using the Llama model.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        batch_outputs = []
        for input in inputs:
            outputs = []
            for _ in range(num_generations):
                chat_completions: "CreateChatCompletionResponse" = (
                    self._model.create_chat_completion(  # type: ignore
                        messages=input,  # type: ignore
                        max_tokens=max_new_tokens,
                        frequency_penalty=frequency_penalty,
                        presence_penalty=presence_penalty,
                        temperature=temperature,
                        top_p=top_p,
                    )
                )
                outputs.append(chat_completions["choices"][0]["message"]["content"])
            batch_outputs.append(outputs)
        return batch_outputs

model_name: str property

Returns the model name used for the LLM.

generate(inputs, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0)

Generates num_generations responses for the given input using the Llama model.

Parameters:

Name Type Description Default
inputs List[ChatType]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/llamacpp.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[ChatType],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for the given input using the Llama model.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    batch_outputs = []
    for input in inputs:
        outputs = []
        for _ in range(num_generations):
            chat_completions: "CreateChatCompletionResponse" = (
                self._model.create_chat_completion(  # type: ignore
                    messages=input,  # type: ignore
                    max_tokens=max_new_tokens,
                    frequency_penalty=frequency_penalty,
                    presence_penalty=presence_penalty,
                    temperature=temperature,
                    top_p=top_p,
                )
            )
            outputs.append(chat_completions["choices"][0]["message"]["content"])
        batch_outputs.append(outputs)
    return batch_outputs

load()

Loads the Llama model from the model_path.

Source code in src/distilabel/llms/llamacpp.py
def load(self) -> None:
    """Loads the `Llama` model from the `model_path`."""

    try:
        from llama_cpp import Llama
    except ImportError as ie:
        raise ImportError(
            "The `llama_cpp` package is required to use the `LlamaCppLLM` class."
        ) from ie

    self._model = Llama(
        model_path=self.model_path.as_posix(),
        chat_format=self.chat_format,
        n_gpu_layers=self.n_gpu_layers,
        verbose=self.verbose,
    )

    super().load()

MistralLLM

Bases: AsyncLLM

Mistral LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.

endpoint str

the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Mistral API. Defaults to None which means that the value set for the environment variable OPENAI_API_KEY will be used, or None if not set.

max_retries RuntimeParameter[int]

the maximum number of retries to attempt when a request fails. Defaults to 5.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response. Defaults to 120.

max_concurrent_requests RuntimeParameter[int]

the maximum number of concurrent requests to send. Defaults to 64.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

_aclient Optional[MistralAsyncClient]

the MistralAsyncClient to use for the Mistral API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • api_key: the API key to authenticate the requests to the Mistral API.
  • max_retries: the maximum number of retries to attempt when a request fails. Defaults to 5.
  • timeout: the maximum time in seconds to wait for a response. Defaults to 120.
  • max_concurrent_requests: the maximum number of concurrent requests to send. Defaults to 64.
Source code in src/distilabel/llms/mistral.py
class MistralLLM(AsyncLLM):
    """Mistral LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.
        endpoint: the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".
        api_key: the API key to authenticate the requests to the Mistral API. Defaults to `None` which
            means that the value set for the environment variable `OPENAI_API_KEY` will be used, or
            `None` if not set.
        max_retries: the maximum number of retries to attempt when a request fails. Defaults to `5`.
        timeout: the maximum time in seconds to wait for a response. Defaults to `120`.
        max_concurrent_requests: the maximum number of concurrent requests to send. Defaults
            to `64`.
        _api_key_env_var: the name of the environment variable to use for the API key. It is meant to
            be used internally.
        _aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Mistral API.
        - `max_retries`: the maximum number of retries to attempt when a request fails.
            Defaults to `5`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `120`.
        - `max_concurrent_requests`: the maximum number of concurrent requests to send.
            Defaults to `64`.
    """

    model: str
    endpoint: str = "https://api.mistral.ai"
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_MISTRALAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Mistral API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_concurrent_requests: RuntimeParameter[int] = Field(
        default=64, description="The maximum number of concurrent requests to send."
    )

    _api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["MistralAsyncClient"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `MistralAsyncClient` client to benefit from async requests."""
        super().load()

        try:
            from mistralai.async_client import MistralAsyncClient
        except ImportError as ie:
            raise ImportError(
                "MistralAI Python client is not installed. Please install it using"
                " `pip install mistralai`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = MistralAsyncClient(
            api_key=self.api_key.get_secret_value(),
            endpoint=self.endpoint,
            max_retries=self.max_retries,
            timeout=self.timeout,
            max_concurrent_requests=self.max_concurrent_requests,
        )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    # TODO: add `num_generations` parameter once Mistral client allows `n` parameter
    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        max_new_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the MistralAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        completion = await self._aclient.chat(  # type: ignore
            messages=input,
            model=self.model,
            temperature=temperature,
            max_tokens=max_new_tokens,
            top_p=top_p,
        )
        generations = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(
                    f"Received no response using MistralAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return generations

    # TODO: remove this function once Mistral client allows `n` parameter
    @override
    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`.
        """

        async def agenerate(
            inputs: List["ChatType"], **kwargs: Any
        ) -> "GenerateOutput":
            """Internal function to parallelize the asynchronous generation of responses."""
            tasks = [
                asyncio.create_task(self.agenerate(input=input, **kwargs))
                for input in inputs
                for _ in range(num_generations)
            ]
            return [outputs[0] for outputs in await asyncio.gather(*tasks)]

        outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
        return list(grouper(outputs, n=num_generations, incomplete="ignore"))

model_name: str property

Returns the model name used for the LLM.

agenerate(input, max_new_tokens=None, temperature=None, top_p=None) async

Generates num_generations responses for the given input using the MistralAI async client.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
max_new_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to 128.

None
temperature Optional[float]

the temperature to use for the generation. Defaults to 0.1.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/mistral.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    max_new_tokens: Optional[int] = None,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the MistralAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    completion = await self._aclient.chat(  # type: ignore
        messages=input,
        model=self.model,
        temperature=temperature,
        max_tokens=max_new_tokens,
        top_p=top_p,
    )
    generations = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(
                f"Received no response using MistralAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return generations

generate(inputs, num_generations=1, **kwargs)

Method to generate a list of responses asynchronously, returning the output synchronously awaiting for the response of each input sent to agenerate.

Source code in src/distilabel/llms/mistral.py
@override
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`.
    """

    async def agenerate(
        inputs: List["ChatType"], **kwargs: Any
    ) -> "GenerateOutput":
        """Internal function to parallelize the asynchronous generation of responses."""
        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        return [outputs[0] for outputs in await asyncio.gather(*tasks)]

    outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
    return list(grouper(outputs, n=num_generations, incomplete="ignore"))

load()

Loads the MistralAsyncClient client to benefit from async requests.

Source code in src/distilabel/llms/mistral.py
def load(self) -> None:
    """Loads the `MistralAsyncClient` client to benefit from async requests."""
    super().load()

    try:
        from mistralai.async_client import MistralAsyncClient
    except ImportError as ie:
        raise ImportError(
            "MistralAI Python client is not installed. Please install it using"
            " `pip install mistralai`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = MistralAsyncClient(
        api_key=self.api_key.get_secret_value(),
        endpoint=self.endpoint,
        max_retries=self.max_retries,
        timeout=self.timeout,
        max_concurrent_requests=self.max_concurrent_requests,
    )

OllamaLLM

Bases: AsyncLLM

Ollama LLM implementation running the Async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "notus".

host Optional[RuntimeParameter[str]]

the Ollama server host.

timeout RuntimeParameter[int]

the timeout for the LLM. Defaults to 120.

_aclient Optional[AsyncClient]

the AsyncClient to use for the Ollama API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • host: the Ollama server host.
  • timeout: the client timeout for the Ollama API. Defaults to 120.
Source code in src/distilabel/llms/ollama.py
class OllamaLLM(AsyncLLM):
    """Ollama LLM implementation running the Async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "notus".
        host: the Ollama server host.
        timeout: the timeout for the LLM. Defaults to `120`.
        _aclient: the `AsyncClient` to use for the Ollama API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `host`: the Ollama server host.
        - `timeout`: the client timeout for the Ollama API. Defaults to `120`.
    """

    model: str
    host: Optional[RuntimeParameter[str]] = Field(
        default=None, description="The host of the Ollama API."
    )
    timeout: RuntimeParameter[int] = Field(
        default=120, description="The timeout for the Ollama API."
    )
    follow_redirects: bool = True

    _aclient: Optional["AsyncClient"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `AsyncClient` to use Ollama async API."""
        super().load()

        try:
            from ollama import AsyncClient

            self._aclient = AsyncClient(
                host=self.host,
                timeout=self.timeout,
                follow_redirects=self.follow_redirects,
            )
        except ImportError as e:
            raise ImportError(
                "Ollama Python client is not installed. Please install it using"
                " `pip install ollama`."
            ) from e

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        num_generations: int = 1,
        format: Literal["", "json"] = "",
        # TODO: include relevant options from `Options` in `agenerate` method.
        options: Union[Options, None] = None,
        keep_alive: Union[bool, None] = None,
    ) -> List[str]:
        """
        Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).

        Args:
            input: the input to use for the generation.
            num_generations: the number of generations to produce. Defaults to `1`.
            format: the format to use for the generation. Defaults to `""`.
            options: the options to use for the generation. Defaults to `None`.
            keep_alive: whether to keep the connection alive. Defaults to `None`.

        Returns:
            A list of strings as completion for the given input.
        """
        generations = []
        # TODO: remove this for-loop and override the `generate` method
        for _ in range(num_generations):
            completion = await self._aclient.chat(  # type: ignore
                model=self.model,
                messages=input,  # type: ignore
                stream=False,
                format=format,
                options=options,
                keep_alive=keep_alive,
            )
            # TODO: improve error handling
            generations.append(completion["message"]["content"])

        return generations

model_name: str property

Returns the model name used for the LLM.

agenerate(input, num_generations=1, format='', options=None, keep_alive=None) async

Generates a response asynchronously, using the Ollama Async API definition.

Parameters:

Name Type Description Default
input ChatType

the input to use for the generation.

required
num_generations int

the number of generations to produce. Defaults to 1.

1
format Literal['', 'json']

the format to use for the generation. Defaults to "".

''
options Union[Options, None]

the options to use for the generation. Defaults to None.

None
keep_alive Union[bool, None]

whether to keep the connection alive. Defaults to None.

None

Returns:

Type Description
List[str]

A list of strings as completion for the given input.

Source code in src/distilabel/llms/ollama.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    num_generations: int = 1,
    format: Literal["", "json"] = "",
    # TODO: include relevant options from `Options` in `agenerate` method.
    options: Union[Options, None] = None,
    keep_alive: Union[bool, None] = None,
) -> List[str]:
    """
    Generates a response asynchronously, using the [Ollama Async API definition](https://github.com/ollama/ollama-python).

    Args:
        input: the input to use for the generation.
        num_generations: the number of generations to produce. Defaults to `1`.
        format: the format to use for the generation. Defaults to `""`.
        options: the options to use for the generation. Defaults to `None`.
        keep_alive: whether to keep the connection alive. Defaults to `None`.

    Returns:
        A list of strings as completion for the given input.
    """
    generations = []
    # TODO: remove this for-loop and override the `generate` method
    for _ in range(num_generations):
        completion = await self._aclient.chat(  # type: ignore
            model=self.model,
            messages=input,  # type: ignore
            stream=False,
            format=format,
            options=options,
            keep_alive=keep_alive,
        )
        # TODO: improve error handling
        generations.append(completion["message"]["content"])

    return generations

load()

Loads the AsyncClient to use Ollama async API.

Source code in src/distilabel/llms/ollama.py
def load(self) -> None:
    """Loads the `AsyncClient` to use Ollama async API."""
    super().load()

    try:
        from ollama import AsyncClient

        self._aclient = AsyncClient(
            host=self.host,
            timeout=self.timeout,
            follow_redirects=self.follow_redirects,
        )
    except ImportError as e:
        raise ImportError(
            "Ollama Python client is not installed. Please install it using"
            " `pip install ollama`."
        ) from e

OpenAILLM

Bases: AsyncLLM

OpenAI LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gpt-3.5-turbo", "gpt-4", etc. Supported models can be found here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the OpenAI API requests. Defaults to None, which means that the value set for the environment variable OPENAI_BASE_URL will be used, or "https://api.openai.com/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the OpenAI API. Defaults to None which means that the value set for the environment variable OPENAI_API_KEY will be used, or None if not set.

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 6.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

Runtime parameters
  • base_url: the base URL to use for the OpenAI API requests. Defaults to None.
  • api_key: the API key to authenticate the requests to the OpenAI API. Defaults to None.
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 6.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
Source code in src/distilabel/llms/openai.py
class OpenAILLM(AsyncLLM):
    """OpenAI LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "gpt-3.5-turbo", "gpt-4", etc.
            Supported models can be found [here](https://platform.openai.com/docs/guides/text-generation).
        base_url: the base URL to use for the OpenAI API requests. Defaults to `None`, which
            means that the value set for the environment variable `OPENAI_BASE_URL` will
            be used, or "https://api.openai.com/v1" if not set.
        api_key: the API key to authenticate the requests to the OpenAI API. Defaults to
            `None` which means that the value set for the environment variable `OPENAI_API_KEY`
            will be used, or `None` if not set.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.

    Runtime parameters:
        - `base_url`: the base URL to use for the OpenAI API requests. Defaults to `None`.
        - `api_key`: the API key to authenticate the requests to the OpenAI API. Defaults
            to `None`.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
    """

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "OPENAI_BASE_URL", "https://api.openai.com/v1"
        ),
        description="The base URL to use for the OpenAI API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_OPENAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the OpenAI API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )

    _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncOpenAI"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `AsyncOpenAI` client to benefit from async requests."""
        super().load()

        try:
            from openai import AsyncOpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install openai`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = AsyncOpenAI(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,
            timeout=self.timeout,
        )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the OpenAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        completion = await self._aclient.chat.completions.create(  # type: ignore
            messages=input,  # type: ignore
            model=self.model,
            max_tokens=max_new_tokens,
            n=num_generations,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
            timeout=50,
        )
        generations = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(
                    f"Received no response using OpenAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return generations

model_name: str property

Returns the model name used for the LLM.

agenerate(input, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0) async

Generates num_generations responses for the given input using the OpenAI async client.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/openai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the OpenAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    completion = await self._aclient.chat.completions.create(  # type: ignore
        messages=input,  # type: ignore
        model=self.model,
        max_tokens=max_new_tokens,
        n=num_generations,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty,
        temperature=temperature,
        top_p=top_p,
        timeout=50,
    )
    generations = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(
                f"Received no response using OpenAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return generations

load()

Loads the AsyncOpenAI client to benefit from async requests.

Source code in src/distilabel/llms/openai.py
def load(self) -> None:
    """Loads the `AsyncOpenAI` client to benefit from async requests."""
    super().load()

    try:
        from openai import AsyncOpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install openai`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = AsyncOpenAI(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,
        timeout=self.timeout,
    )

TogetherLLM

Bases: OpenAILLM

TogetherLLM LLM implementation running the async API client of OpenAI because of duplicate API behavior.

Attributes:

Name Type Description
model

the model name to use for the LLM e.g. "mistralai/Mixtral-8x7B-Instruct-v0.1". Supported models can be found here.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Together API can be set with TOGETHER_BASE_URL. Defaults to None which means that the value set for the environment variable TOGETHER_BASE_URL will be used, or "https://api.together.xyz/v1" if not set.

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Together API. Defaults to None which means that the value set for the environment variable TOGETHER_API_KEY will be used, or None if not set.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

Source code in src/distilabel/llms/together.py
class TogetherLLM(OpenAILLM):
    """TogetherLLM LLM implementation running the async API client of OpenAI because of
    duplicate API behavior.

    Attributes:
        model: the model name to use for the LLM e.g. "mistralai/Mixtral-8x7B-Instruct-v0.1".
            Supported models can be found [here](https://api.together.xyz/models).
        base_url: the base URL to use for the Together API can be set with `TOGETHER_BASE_URL`.
            Defaults to `None` which means that the value set for the environment variable
            `TOGETHER_BASE_URL` will be used, or "https://api.together.xyz/v1" if not set.
        api_key: the API key to authenticate the requests to the Together API. Defaults to `None`
            which means that the value set for the environment variable `TOGETHER_API_KEY` will be
            used, or `None` if not set.
        _api_key_env_var: the name of the environment variable to use for the API key. It
            is meant to be used internally.
    """

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "TOGETHER_BASE_URL", "https://api.together.xyz/v1"
        ),
        description="The base URL to use for the Together API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_TOGETHER_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Together API.",
    )

    _api_key_env_var: str = PrivateAttr(_TOGETHER_API_KEY_ENV_VAR_NAME)

TransformersLLM

Bases: LLM, CudaDevicePlacementMixin

Hugging Face transformers library LLM implementation using the text generation pipeline.

Attributes:

Name Type Description
model str

the model Hugging Face Hub repo id or a path to a directory containing the model weights and configuration files.

revision str

if model refers to a Hugging Face Hub repository, then the revision (e.g. a branch name or a commit id) to use. Defaults to "main".

torch_dtype str

the torch dtype to use for the model e.g. "float16", "float32", etc. Defaults to "auto".

trust_remote_code bool

whether to trust or not remote (code in the Hugging Face Hub repository) code to load the model. Defaults to False.

model_kwargs Optional[Dict[str, Any]]

additional dictionary of keyword arguments that will be passed to the from_pretrained method of the model.

tokenizer Optional[str]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer config files. If not provided, the one associated to the model will be used. Defaults to None.

use_fast bool

whether to use a fast tokenizer or not. Defaults to True.

chat_template Optional[str]

a chat template that will be used to build the prompts before sending them to the model. If not provided, the chat template defined in the tokenizer config will be used. If not provided and the tokenizer doesn't have a chat template, then ChatML template will be used. Defaults to None.

device Optional[Union[str, int]]

the name or index of the device where the model will be loaded. Defaults to None.

device_map Optional[Union[str, Dict[str, Any]]]

a dictionary mapping each layer of the model to a device, or a mode like "sequential" or "auto". Defaults to None.

token Optional[str]

the Hugging Face Hub token that will be used to authenticate to the Hugging Face Hub. If not provided, the HF_TOKEN environment or huggingface_hub package local configuration will be used. Defaults to None.

Source code in src/distilabel/llms/huggingface/transformers.py
class TransformersLLM(LLM, CudaDevicePlacementMixin):
    """Hugging Face `transformers` library LLM implementation using the text generation
    pipeline.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        revision: if `model` refers to a Hugging Face Hub repository, then the revision
            (e.g. a branch name or a commit id) to use. Defaults to `"main"`.
        torch_dtype: the torch dtype to use for the model e.g. "float16", "float32", etc.
            Defaults to `"auto"`.
        trust_remote_code: whether to trust or not remote (code in the Hugging Face Hub
            repository) code to load the model. Defaults to `False`.
        model_kwargs: additional dictionary of keyword arguments that will be passed to
            the `from_pretrained` method of the model.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer config files. If not provided, the one associated to the `model`
            will be used. Defaults to `None`.
        use_fast: whether to use a fast tokenizer or not. Defaults to `True`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        device: the name or index of the device where the model will be loaded. Defaults
            to `None`.
        device_map: a dictionary mapping each layer of the model to a device, or a mode
            like `"sequential"` or `"auto"`. Defaults to `None`.
        token: the Hugging Face Hub token that will be used to authenticate to the Hugging
            Face Hub. If not provided, the `HF_TOKEN` environment or `huggingface_hub` package
            local configuration will be used. Defaults to `None`.
    """

    model: str
    revision: str = "main"
    torch_dtype: str = "auto"
    trust_remote_code: bool = False
    model_kwargs: Optional[Dict[str, Any]] = None
    tokenizer: Optional[str] = None
    use_fast: bool = True
    chat_template: Optional[str] = None
    device: Optional[Union[str, int]] = None
    device_map: Optional[Union[str, Dict[str, Any]]] = None
    token: Optional[str] = None

    _pipeline: Optional["Pipeline"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the model and tokenizer and creates the text generation pipeline. In addition,
        it will configure the tokenizer chat template."""
        super().load()

        if self.device == "cuda":
            CudaDevicePlacementMixin.load(self)

        try:
            from transformers import pipeline
        except ImportError as ie:
            raise ImportError(
                "Transformers is not installed. Please install it using `pip install transformers`."
            ) from ie

        self._pipeline = pipeline(
            "text-generation",
            model=self.model,
            revision=self.revision,
            torch_dtype=self.torch_dtype,
            trust_remote_code=self.trust_remote_code,
            model_kwargs=self.model_kwargs or {},
            tokenizer=self.tokenizer or self.model,
            use_fast=self.use_fast,
            device=self.device,
            device_map=self.device_map,
            token=self.token or os.getenv("HF_TOKEN"),
            return_full_text=False,
        )

        if self.chat_template is not None:
            self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore
        elif (
            self._pipeline.tokenizer.chat_template is None  # type: ignore
            and self._pipeline.tokenizer.default_chat_template is None  # type: ignore
        ):
            self._pipeline.tokenizer.chat_template = CHATML_TEMPLATE  # type: ignore

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: "ChatType") -> str:
        """Prepares the input by applying the chat template to the input, which is formatted
        as an OpenAI conversation, and adding the generation prompt.
        """
        return self._pipeline.tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,
        )

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[ChatType],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        temperature: float = 0.1,
        repetition_penalty: float = 1.1,
        top_p: float = 1.0,
        top_k: int = 0,
        do_sample: bool = True,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            repetition_penalty: the repetition penalty to use for the generation. Defaults
                to `1.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            do_sample: whether to use sampling or not. Defaults to `True`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
            [self.prepare_input(input=input) for input in inputs],
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            top_p=top_p,
            top_k=top_k,
            do_sample=do_sample,
            num_return_sequences=num_generations,
        )
        return [
            [generation["generated_text"] for generation in output]
            for output in outputs
        ]

    def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
        """Gets the last `hidden_states` of the model for the given inputs. It doesn't
        execute the task head.

        Args:
            inputs: a list of inputs in chat format to generate the embeddings for.

        Returns:
            A list containing the last hidden state for each sequence using a NumPy array
            with shape [num_tokens, hidden_size].
        """
        model: "PreTrainedModel" = (
            self._pipeline.model.model  # type: ignore
            if hasattr(self._pipeline.model, "model")  # type: ignore
            else next(self._pipeline.model.children())  # type: ignore
        )
        tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
        input_ids = tokenizer(
            [self.prepare_input(input) for input in inputs],  # type: ignore
            return_tensors="pt",
            padding=True,
        ).to(model.device)
        last_hidden_states = model(**input_ids)["last_hidden_state"]

        return [
            seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
            for seq_last_hidden_state, attention_mask in zip(
                last_hidden_states,
                input_ids["attention_mask"],  # type: ignore
            )
        ]

model_name: str property

Returns the model name used for the LLM.

generate(inputs, num_generations=1, max_new_tokens=128, temperature=0.1, repetition_penalty=1.1, top_p=1.0, top_k=0, do_sample=True)

Generates num_generations responses for each input using the text generation pipeline.

Parameters:

Name Type Description Default
inputs List[ChatType]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
temperature float

the temperature to use for the generation. Defaults to 0.1.

0.1
repetition_penalty float

the repetition penalty to use for the generation. Defaults to 1.1.

1.1
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
top_k int

the top-k value to use for the generation. Defaults to 0.

0
do_sample bool

whether to use sampling or not. Defaults to True.

True

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/huggingface/transformers.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[ChatType],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    temperature: float = 0.1,
    repetition_penalty: float = 1.1,
    top_p: float = 1.0,
    top_k: int = 0,
    do_sample: bool = True,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        repetition_penalty: the repetition penalty to use for the generation. Defaults
            to `1.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        do_sample: whether to use sampling or not. Defaults to `True`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    outputs: List[List[Dict[str, str]]] = self._pipeline(  # type: ignore
        [self.prepare_input(input=input) for input in inputs],
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        top_p=top_p,
        top_k=top_k,
        do_sample=do_sample,
        num_return_sequences=num_generations,
    )
    return [
        [generation["generated_text"] for generation in output]
        for output in outputs
    ]

get_last_hidden_states(inputs)

Gets the last hidden_states of the model for the given inputs. It doesn't execute the task head.

Parameters:

Name Type Description Default
inputs List[ChatType]

a list of inputs in chat format to generate the embeddings for.

required

Returns:

Type Description
List[HiddenState]

A list containing the last hidden state for each sequence using a NumPy array

List[HiddenState]

with shape [num_tokens, hidden_size].

Source code in src/distilabel/llms/huggingface/transformers.py
def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
    """Gets the last `hidden_states` of the model for the given inputs. It doesn't
    execute the task head.

    Args:
        inputs: a list of inputs in chat format to generate the embeddings for.

    Returns:
        A list containing the last hidden state for each sequence using a NumPy array
        with shape [num_tokens, hidden_size].
    """
    model: "PreTrainedModel" = (
        self._pipeline.model.model  # type: ignore
        if hasattr(self._pipeline.model, "model")  # type: ignore
        else next(self._pipeline.model.children())  # type: ignore
    )
    tokenizer: "PreTrainedTokenizer" = self._pipeline.tokenizer  # type: ignore
    input_ids = tokenizer(
        [self.prepare_input(input) for input in inputs],  # type: ignore
        return_tensors="pt",
        padding=True,
    ).to(model.device)
    last_hidden_states = model(**input_ids)["last_hidden_state"]

    return [
        seq_last_hidden_state[attention_mask.bool(), :].detach().cpu().numpy()
        for seq_last_hidden_state, attention_mask in zip(
            last_hidden_states,
            input_ids["attention_mask"],  # type: ignore
        )
    ]

load()

Loads the model and tokenizer and creates the text generation pipeline. In addition, it will configure the tokenizer chat template.

Source code in src/distilabel/llms/huggingface/transformers.py
def load(self) -> None:
    """Loads the model and tokenizer and creates the text generation pipeline. In addition,
    it will configure the tokenizer chat template."""
    super().load()

    if self.device == "cuda":
        CudaDevicePlacementMixin.load(self)

    try:
        from transformers import pipeline
    except ImportError as ie:
        raise ImportError(
            "Transformers is not installed. Please install it using `pip install transformers`."
        ) from ie

    self._pipeline = pipeline(
        "text-generation",
        model=self.model,
        revision=self.revision,
        torch_dtype=self.torch_dtype,
        trust_remote_code=self.trust_remote_code,
        model_kwargs=self.model_kwargs or {},
        tokenizer=self.tokenizer or self.model,
        use_fast=self.use_fast,
        device=self.device,
        device_map=self.device_map,
        token=self.token or os.getenv("HF_TOKEN"),
        return_full_text=False,
    )

    if self.chat_template is not None:
        self._pipeline.tokenizer.chat_template = self.chat_template  # type: ignore
    elif (
        self._pipeline.tokenizer.chat_template is None  # type: ignore
        and self._pipeline.tokenizer.default_chat_template is None  # type: ignore
    ):
        self._pipeline.tokenizer.chat_template = CHATML_TEMPLATE  # type: ignore

prepare_input(input)

Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt.

Source code in src/distilabel/llms/huggingface/transformers.py
def prepare_input(self, input: "ChatType") -> str:
    """Prepares the input by applying the chat template to the input, which is formatted
    as an OpenAI conversation, and adding the generation prompt.
    """
    return self._pipeline.tokenizer.apply_chat_template(  # type: ignore
        input,  # type: ignore
        tokenize=False,
        add_generation_prompt=True,
    )

VertexAILLM

Bases: AsyncLLM

VertexAI LLM implementation running the async API clients for Gemini.

  • Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

To use the VertexAILLM is necessary to have configured the Google Cloud authentication using one of these methods:

  • Setting GOOGLE_CLOUD_CREDENTIALS environment variable
  • Using gcloud auth application-default login command
  • Using vertexai.init function from the google-cloud-aiplatform library

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gemini-1.0-pro". Supported models.

_aclient Optional[GenerativeModel]

the GenerativeModel to use for the Vertex AI Gemini API. It is meant to be used internally. Set in the load method.

Source code in src/distilabel/llms/vertexai.py
class VertexAILLM(AsyncLLM):
    """VertexAI LLM implementation running the async API clients for Gemini.

    - Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

    To use the `VertexAILLM` is necessary to have configured the Google Cloud authentication
    using one of these methods:

    - Setting `GOOGLE_CLOUD_CREDENTIALS` environment variable
    - Using `gcloud auth application-default login` command
    - Using `vertexai.init` function from the `google-cloud-aiplatform` library

    Attributes:
        model: the model name to use for the LLM e.g. "gemini-1.0-pro". [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models).
        _aclient: the `GenerativeModel` to use for the Vertex AI Gemini API. It is meant
            to be used internally. Set in the `load` method.
    """

    model: str
    _aclient: Optional["GenerativeModel"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
        super().load()

        try:
            from vertexai.generative_models import GenerationConfig, GenerativeModel

            self._generation_config_class = GenerationConfig
        except ImportError as e:
            raise ImportError(
                "vertexai is not installed. Please install it using"
                " `pip install google-cloud-aiplatform`."
            ) from e

        if _is_gemini_model(self.model):
            self._aclient = GenerativeModel(model_name=self.model)
        else:
            raise NotImplementedError(
                "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
            )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def _chattype_to_content(self, input: "ChatType") -> List["Content"]:
        """Converts a chat type to a list of content items expected by the API.

        Args:
            input: the chat type to be converted.

        Returns:
            List[str]: a list of content items expected by the API.
        """
        from vertexai.generative_models import Content, Part

        contents = []
        for message in input:
            if message["role"] not in ["user", "model"]:
                raise ValueError(
                    "`VertexAILLM only supports the roles 'user' or 'model'."
                )
            contents.append(
                Content(
                    role=message["role"], parts=[Part.from_text(message["content"])]
                )
            )
        return contents

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        num_generations: int = 1,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_output_tokens: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        safety_settings: Optional[Dict[str, Any]] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

        Args:
            input: a single input in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
            top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
            top_k: If specified, top-k sampling will be used. Defaults to `None`.
            max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
            stop_sequences: A list of stop sequences. Defaults to `None`.
            safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
            tools: A potential list of tools that can be used by the API. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vertexai.generative_models import GenerationConfig

        contents = self._chattype_to_content(input)
        generations = []
        # TODO: remove this for-loop and override `generate`
        for _ in range(num_generations):
            content = await self._aclient.generate_content_async(  # type: ignore
                contents=contents,
                generation_config=GenerationConfig(
                    candidate_count=1,  # only one candidate allowed per call
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                    max_output_tokens=max_output_tokens,
                    stop_sequences=stop_sequences,
                ),
                safety_settings=safety_settings,
                tools=tools,
                stream=False,
            )

            text = None
            try:
                text = content.candidates[0].text
            except ValueError:
                self._logger.warning(
                    f"Received no response using VertexAI client (model: '{self.model}')."
                    f" Finish reason was: '{content.candidates[0].finish_reason}'."
                )
            generations.append(text)

        return generations

model_name: str property

Returns the model name used for the LLM.

agenerate(input, num_generations=1, temperature=None, top_p=None, top_k=None, max_output_tokens=None, stop_sequences=None, safety_settings=None, tools=None) async

Generates num_generations responses for the given input using the VertexAI async client definition.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
temperature Optional[float]

Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to None.

None
top_p Optional[float]

If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to None.

None
top_k Optional[int]

If specified, top-k sampling will be used. Defaults to None.

None
max_output_tokens Optional[int]

The maximum number of output tokens to generate per message. Defaults to None.

None
stop_sequences Optional[List[str]]

A list of stop sequences. Defaults to None.

None
safety_settings Optional[Dict[str, Any]]

Safety configuration for returned content from the API. Defaults to None.

None
tools Optional[List[Dict[str, Any]]]

A potential list of tools that can be used by the API. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/vertexai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    num_generations: int = 1,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    max_output_tokens: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    safety_settings: Optional[Dict[str, Any]] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

    Args:
        input: a single input in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
        top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
        top_k: If specified, top-k sampling will be used. Defaults to `None`.
        max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
        stop_sequences: A list of stop sequences. Defaults to `None`.
        safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
        tools: A potential list of tools that can be used by the API. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vertexai.generative_models import GenerationConfig

    contents = self._chattype_to_content(input)
    generations = []
    # TODO: remove this for-loop and override `generate`
    for _ in range(num_generations):
        content = await self._aclient.generate_content_async(  # type: ignore
            contents=contents,
            generation_config=GenerationConfig(
                candidate_count=1,  # only one candidate allowed per call
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_output_tokens=max_output_tokens,
                stop_sequences=stop_sequences,
            ),
            safety_settings=safety_settings,
            tools=tools,
            stream=False,
        )

        text = None
        try:
            text = content.candidates[0].text
        except ValueError:
            self._logger.warning(
                f"Received no response using VertexAI client (model: '{self.model}')."
                f" Finish reason was: '{content.candidates[0].finish_reason}'."
            )
        generations.append(text)

    return generations

load()

Loads the GenerativeModel class which has access to generate_content_async to benefit from async requests.

Source code in src/distilabel/llms/vertexai.py
def load(self) -> None:
    """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
    super().load()

    try:
        from vertexai.generative_models import GenerationConfig, GenerativeModel

        self._generation_config_class = GenerationConfig
    except ImportError as e:
        raise ImportError(
            "vertexai is not installed. Please install it using"
            " `pip install google-cloud-aiplatform`."
        ) from e

    if _is_gemini_model(self.model):
        self._aclient = GenerativeModel(model_name=self.model)
    else:
        raise NotImplementedError(
            "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
        )

vLLM

Bases: LLM, CudaDevicePlacementMixin

vLLM library LLM implementation.

Attributes:

Name Type Description
model str

the model Hugging Face Hub repo id or a path to a directory containing the model weights and configuration files.

model_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

additional dictionary of keyword arguments that will be passed to the LLM class of vllm library.

chat_template Optional[str]

a chat template that will be used to build the prompts before sending them to the model. If not provided, the chat template defined in the tokenizer config will be used. If not provided and the tokenizer doesn't have a chat template, then ChatML template will be used. Defaults to None.

_model Optional[LLM]

the vLLM model instance. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

_tokenizer Optional[PreTrainedTokenizer]

the tokenizer instance used to format the prompt before passing it to the LLM. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

Runtime parameters
  • model_kwargs: additional dictionary of keyword arguments that will be passed to the LLM class of vllm library.
Source code in src/distilabel/llms/vllm.py
class vLLM(LLM, CudaDevicePlacementMixin):
    """`vLLM` library LLM implementation.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        model_kwargs: additional dictionary of keyword arguments that will be passed to
            the `LLM` class of `vllm` library.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        _model: the `vLLM` model instance. This attribute is meant to be used internally
            and should not be accessed directly. It will be set in the `load` method.
        _tokenizer: the tokenizer instance used to format the prompt before passing it to
            the `LLM`. This attribute is meant to be used internally and should not be
            accessed directly. It will be set in the `load` method.

    Runtime parameters:
        - `model_kwargs`: additional dictionary of keyword arguments that will be passed to
            the `LLM` class of `vllm` library.
    """

    model: str
    model_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="Additional dictionary of keyword arguments that will be passed to the"
        " `LLM` class of `vllm` library.",
    )
    chat_template: Optional[str] = None

    _model: Optional["_vLLM"] = PrivateAttr(...)
    _tokenizer: Optional["PreTrainedTokenizer"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
        Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
        parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
        default value is ChatML format, unless explicitly provided.
        """
        super().load()

        CudaDevicePlacementMixin.load(self)

        try:
            from vllm import LLM as _vLLM
            from vllm import SamplingParams as _SamplingParams

            global SamplingParams
            SamplingParams = _SamplingParams
        except ImportError as ie:
            raise ImportError(
                "vLLM is not installed. Please install it using `pip install vllm`."
            ) from ie

        self._model = _vLLM(self.model, **self.model_kwargs)  # type: ignore
        self._tokenizer = self._model.get_tokenizer()  # type: ignore

        if self.chat_template is not None:
            self._tokenizer.chat_template = self.chat_template  # type: ignore
        elif (
            self._tokenizer.chat_template is None  # type: ignore
            and self._tokenizer.default_chat_template is None  # type: ignore
        ):
            self._tokenizer.chat_template = CHATML_TEMPLATE

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: "ChatType") -> str:
        """Prepares the input by applying the chat template to the input, which is formatted
        as an OpenAI conversation, and adding the generation prompt.
        """
        return self._tokenizer.apply_chat_template(  # type: ignore
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,  # type: ignore
        )

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[ChatType],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        extra_sampling_params: Optional[Dict[str, Any]] = None,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input using the text generation
        pipeline.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            extra_sampling_params: dictionary with additional arguments to be passed to
                the `SamplingParams` class from `vllm`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        if extra_sampling_params is None:
            extra_sampling_params = {}
        sampling_params = SamplingParams(  # type: ignore
            n=num_generations,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_new_tokens,
            **extra_sampling_params,
        )

        prepared_inputs = [self.prepare_input(input) for input in inputs]
        batch_outputs = self._model.generate(  # type: ignore
            prepared_inputs,
            sampling_params,
            use_tqdm=False,  # type: ignore
        )
        return [
            [output.text for output in outputs.outputs] for outputs in batch_outputs
        ]

model_name: str property

Returns the model name used for the LLM.

generate(inputs, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, top_k=-1, extra_sampling_params=None)

Generates num_generations responses for each input using the text generation pipeline.

Parameters:

Name Type Description Default
inputs List[ChatType]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
top_k int

the top-k value to use for the generation. Defaults to 0.

-1
extra_sampling_params Optional[Dict[str, Any]]

dictionary with additional arguments to be passed to the SamplingParams class from vllm.

None

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/vllm.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[ChatType],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    top_k: int = -1,
    extra_sampling_params: Optional[Dict[str, Any]] = None,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input using the text generation
    pipeline.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        extra_sampling_params: dictionary with additional arguments to be passed to
            the `SamplingParams` class from `vllm`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    if extra_sampling_params is None:
        extra_sampling_params = {}
    sampling_params = SamplingParams(  # type: ignore
        n=num_generations,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        max_tokens=max_new_tokens,
        **extra_sampling_params,
    )

    prepared_inputs = [self.prepare_input(input) for input in inputs]
    batch_outputs = self._model.generate(  # type: ignore
        prepared_inputs,
        sampling_params,
        use_tqdm=False,  # type: ignore
    )
    return [
        [output.text for output in outputs.outputs] for outputs in batch_outputs
    ]

load()

Loads the vLLM model using either the path or the Hugging Face Hub repository id. Additionally, this method also sets the chat_template for the tokenizer, so as to properly parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the default value is ChatML format, unless explicitly provided.

Source code in src/distilabel/llms/vllm.py
def load(self) -> None:
    """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
    Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
    parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
    default value is ChatML format, unless explicitly provided.
    """
    super().load()

    CudaDevicePlacementMixin.load(self)

    try:
        from vllm import LLM as _vLLM
        from vllm import SamplingParams as _SamplingParams

        global SamplingParams
        SamplingParams = _SamplingParams
    except ImportError as ie:
        raise ImportError(
            "vLLM is not installed. Please install it using `pip install vllm`."
        ) from ie

    self._model = _vLLM(self.model, **self.model_kwargs)  # type: ignore
    self._tokenizer = self._model.get_tokenizer()  # type: ignore

    if self.chat_template is not None:
        self._tokenizer.chat_template = self.chat_template  # type: ignore
    elif (
        self._tokenizer.chat_template is None  # type: ignore
        and self._tokenizer.default_chat_template is None  # type: ignore
    ):
        self._tokenizer.chat_template = CHATML_TEMPLATE

prepare_input(input)

Prepares the input by applying the chat template to the input, which is formatted as an OpenAI conversation, and adding the generation prompt.

Source code in src/distilabel/llms/vllm.py
def prepare_input(self, input: "ChatType") -> str:
    """Prepares the input by applying the chat template to the input, which is formatted
    as an OpenAI conversation, and adding the generation prompt.
    """
    return self._tokenizer.apply_chat_template(  # type: ignore
        input,  # type: ignore
        tokenize=False,
        add_generation_prompt=True,  # type: ignore
    )