Skip to content

MistralLLM

MistralLLM

Bases: AsyncLLM

Mistral LLM implementation running the async API client.

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.

endpoint str

the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Mistral API. Defaults to None which means that the value set for the environment variable OPENAI_API_KEY will be used, or None if not set.

max_retries RuntimeParameter[int]

the maximum number of retries to attempt when a request fails. Defaults to 5.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response. Defaults to 120.

max_concurrent_requests RuntimeParameter[int]

the maximum number of concurrent requests to send. Defaults to 64.

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_api_key_env_var str

the name of the environment variable to use for the API key. It is meant to be used internally.

_aclient Optional[MistralAsyncClient]

the MistralAsyncClient to use for the Mistral API. It is meant to be used internally. Set in the load method.

Runtime parameters
  • api_key: the API key to authenticate the requests to the Mistral API.
  • max_retries: the maximum number of retries to attempt when a request fails. Defaults to 5.
  • timeout: the maximum time in seconds to wait for a response. Defaults to 120.
  • max_concurrent_requests: the maximum number of concurrent requests to send. Defaults to 64.

Examples:

Generate text:

```python
from distilabel.llms import MistralLLM

llm = MistralLLM(model="open-mixtral-8x22b")

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

```python
from pydantic import BaseModel
from distilabel.llms import MistralLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = MistralLLM(
    model="open-mixtral-8x22b",
    api_key="api.key",
    structured_output={"schema": User}
)

llm.load()

output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
Source code in src/distilabel/llms/mistral.py
class MistralLLM(AsyncLLM):
    """Mistral LLM implementation running the async API client.

    Attributes:
        model: the model name to use for the LLM e.g. "mistral-tiny", "mistral-large", etc.
        endpoint: the endpoint to use for the Mistral API. Defaults to "https://api.mistral.ai".
        api_key: the API key to authenticate the requests to the Mistral API. Defaults to `None` which
            means that the value set for the environment variable `OPENAI_API_KEY` will be used, or
            `None` if not set.
        max_retries: the maximum number of retries to attempt when a request fails. Defaults to `5`.
        timeout: the maximum time in seconds to wait for a response. Defaults to `120`.
        max_concurrent_requests: the maximum number of concurrent requests to send. Defaults
            to `64`.
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _api_key_env_var: the name of the environment variable to use for the API key. It is meant to
            be used internally.
        _aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally.
            Set in the `load` method.

    Runtime parameters:
        - `api_key`: the API key to authenticate the requests to the Mistral API.
        - `max_retries`: the maximum number of retries to attempt when a request fails.
            Defaults to `5`.
        - `timeout`: the maximum time in seconds to wait for a response. Defaults to `120`.
        - `max_concurrent_requests`: the maximum number of concurrent requests to send.
            Defaults to `64`.

    Examples:

        Generate text:

        ```python
        from distilabel.llms import MistralLLM

        llm = MistralLLM(model="open-mixtral-8x22b")

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        ```python
        from pydantic import BaseModel
        from distilabel.llms import MistralLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = MistralLLM(
            model="open-mixtral-8x22b",
            api_key="api.key",
            structured_output={"schema": User}
        )

        llm.load()

        output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    endpoint: str = "https://api.mistral.ai"
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_MISTRALAI_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Mistral API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=6,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )
    max_concurrent_requests: RuntimeParameter[int] = Field(
        default=64, description="The maximum number of concurrent requests to send."
    )
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
        Field(
            default=None,
            description="The structured output format to use across all the generations.",
        )
    )

    _num_generations_param_supported = False

    _api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["MistralAsyncClient"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `MistralAsyncClient` client to benefit from async requests."""
        super().load()

        try:
            from mistralai.async_client import MistralAsyncClient
        except ImportError as ie:
            raise ImportError(
                "MistralAI Python client is not installed. Please install it using"
                " `pip install mistralai`."
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = MistralAsyncClient(
            api_key=self.api_key.get_secret_value(),
            endpoint=self.endpoint,
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,  # type: ignore
            max_concurrent_requests=self.max_concurrent_requests,  # type: ignore
        )

        if self.structured_output:
            result = self._prepare_structured_output(
                structured_output=self.structured_output,
                client=self._aclient,
                framework="mistral",
            )
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    # TODO: add `num_generations` parameter once Mistral client allows `n` parameter
    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        max_new_tokens: Optional[int] = None,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the MistralAI async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,
                client=self._aclient,
                framework="mistral",
            )
            self._aclient = result.get("client")

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        kwargs = {
            "messages": input,  # type: ignore
            "model": self.model,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }
        generations = []
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)
            # TODO: This should work just with the _aclient.chat method, but it's not working.
            # We need to check instructor and see if we can create a PR.
            completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
        else:
            completion = await self._aclient.chat(**kwargs)  # type: ignore

        if structured_output:
            generations.append(completion.model_dump_json())
            return generations

        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using MistralAI client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return generations

model_name: str property

Returns the model name used for the LLM.

agenerate(input, max_new_tokens=None, temperature=None, top_p=None) async

Generates num_generations responses for the given input using the MistralAI async client.

Parameters:

Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

required
max_new_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to 128.

None
temperature Optional[float]

the temperature to use for the generation. Defaults to 0.1.

None
top_p Optional[float]

the top-p value to use for the generation. Defaults to 1.0.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/mistral.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    max_new_tokens: Optional[int] = None,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the MistralAI async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,
            client=self._aclient,
            framework="mistral",
        )
        self._aclient = result.get("client")

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    kwargs = {
        "messages": input,  # type: ignore
        "model": self.model,
        "max_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
    }
    generations = []
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)
        # TODO: This should work just with the _aclient.chat method, but it's not working.
        # We need to check instructor and see if we can create a PR.
        completion = await self._aclient.chat.completions.create(**kwargs)  # type: ignore
    else:
        completion = await self._aclient.chat(**kwargs)  # type: ignore

    if structured_output:
        generations.append(completion.model_dump_json())
        return generations

    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using MistralAI client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return generations

load()

Loads the MistralAsyncClient client to benefit from async requests.

Source code in src/distilabel/llms/mistral.py
def load(self) -> None:
    """Loads the `MistralAsyncClient` client to benefit from async requests."""
    super().load()

    try:
        from mistralai.async_client import MistralAsyncClient
    except ImportError as ie:
        raise ImportError(
            "MistralAI Python client is not installed. Please install it using"
            " `pip install mistralai`."
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = MistralAsyncClient(
        api_key=self.api_key.get_secret_value(),
        endpoint=self.endpoint,
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,  # type: ignore
        max_concurrent_requests=self.max_concurrent_requests,  # type: ignore
    )

    if self.structured_output:
        result = self._prepare_structured_output(
            structured_output=self.structured_output,
            client=self._aclient,
            framework="mistral",
        )
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output