Skip to content

GroqLLM

GroqLLM

Bases: AsyncLLM

Groq API implementation using the async client for concurrent text generation.

Attributes:

Name Type Description
model str

the name of the model from the Groq API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Groq API requests. Defaults to "https://api.groq.com".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Groq API. Defaults to the value of the GROQ_API_KEY environment variable.

max_retries RuntimeParameter[int]

the maximum number of times to retry the request to the API before failing. Defaults to 2.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

_api_key_env_var str

the name of the environment variable to use for the API key.

_aclient Optional[AsyncGroq]

the AsyncGroq client from the groq package.

Runtime parameters
  • base_url: the base URL to use for the Groq API requests. Defaults to "https://api.groq.com".
  • api_key: the API key to authenticate the requests to the Groq API. Defaults to the value of the GROQ_API_KEY environment variable.
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 2.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
Source code in src/distilabel/llms/groq.py
class GroqLLM(AsyncLLM):
    """Groq API implementation using the async client for concurrent text generation.

    Attributes:
        model: the name of the model from the Groq API to use for the generation.
        base_url: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        api_key: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        _api_key_env_var: the name of the environment variable to use for the API key.
        _aclient: the `AsyncGroq` client from the `groq` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Groq API requests. Defaults to
            `"https://api.groq.com"`.
        - `api_key`: the API key to authenticate the requests to the Groq API. Defaults to
            the value of the `GROQ_API_KEY` environment variable.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `2`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
    """

    model: str

    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            _GROQ_API_BASE_URL_ENV_VAR_NAME, "https://api.groq.com"
        ),
        description="The base URL to use for the Groq API requests.",
    )
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_GROQ_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Groq API.",
    )
    max_retries: RuntimeParameter[int] = Field(
        default=2,
        description="The maximum number of times to retry the request to the API before"
        " failing.",
    )
    timeout: RuntimeParameter[int] = Field(
        default=120,
        description="The maximum time in seconds to wait for a response from the API.",
    )

    _api_key_env_var: str = PrivateAttr(_GROQ_API_KEY_ENV_VAR_NAME)
    _aclient: Optional["AsyncGroq"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `AsyncGroq` client to benefit from async requests."""
        super().load()

        try:
            from groq import AsyncGroq
        except ImportError as ie:
            raise ImportError(
                "Groq Python client is not installed. Please install it using"
                ' `pip install groq` or from the extras as `pip install "distilabel[groq]"`.'
            ) from ie

        if self.api_key is None:
            raise ValueError(
                f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
                f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
            )

        self._aclient = AsyncGroq(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: ChatType,
        seed: Optional[int] = None,
        max_new_tokens: int = 128,
        temperature: float = 1.0,
        top_p: float = 1.0,
        stop: Optional[str] = None,
    ) -> "GenerateOutput":
        """Generates `num_generations` responses for the given input using the Groq async
        client.

        Args:
            input: a single input in chat format to generate responses for.
            seed: the seed to use for the generation. Defaults to `None`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            stop: the stop sequence to use for the generation. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.

        References:
            - https://console.groq.com/docs/text-chat
        """
        completion = await self._aclient.chat.completions.create(  # type: ignore
            messages=input,  # type: ignore
            model=self.model,
            seed=seed,  # type: ignore
            temperature=temperature,
            max_tokens=max_new_tokens,
            top_p=top_p,
            stream=False,
            stop=stop,
        )
        generations = []
        for choice in completion.choices:
            if (content := choice.message.content) is None:
                self._logger.warning(  # type: ignore
                    f"Received no response using the Groq client (model: '{self.model}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(content)
        return generations

    # TODO: remove this function once Groq client allows `n` parameter
    @override
    def generate(
        self,
        inputs: List["ChatType"],
        num_generations: int = 1,
        **kwargs: Any,
    ) -> List["GenerateOutput"]:
        """Method to generate a list of responses asynchronously, returning the output
        synchronously awaiting for the response of each input sent to `agenerate`.
        """

        async def agenerate(
            inputs: List["ChatType"], **kwargs: Any
        ) -> "GenerateOutput":
            """Internal function to parallelize the asynchronous generation of responses."""
            tasks = [
                asyncio.create_task(self.agenerate(input=input, **kwargs))
                for input in inputs
                for _ in range(num_generations)
            ]
            return [outputs[0] for outputs in await asyncio.gather(*tasks)]

        outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
        return list(grouper(outputs, n=num_generations, incomplete="ignore"))

model_name: str property

Returns the model name used for the LLM.

agenerate(input, seed=None, max_new_tokens=128, temperature=1.0, top_p=1.0, stop=None) async

Generates num_generations responses for the given input using the Groq async client.

Parameters:

Name Type Description Default
input ChatType

a single input in chat format to generate responses for.

required
seed Optional[int]

the seed to use for the generation. Defaults to None.

None
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
stop Optional[str]

the stop sequence to use for the generation. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

References
  • https://console.groq.com/docs/text-chat
Source code in src/distilabel/llms/groq.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: ChatType,
    seed: Optional[int] = None,
    max_new_tokens: int = 128,
    temperature: float = 1.0,
    top_p: float = 1.0,
    stop: Optional[str] = None,
) -> "GenerateOutput":
    """Generates `num_generations` responses for the given input using the Groq async
    client.

    Args:
        input: a single input in chat format to generate responses for.
        seed: the seed to use for the generation. Defaults to `None`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        stop: the stop sequence to use for the generation. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.

    References:
        - https://console.groq.com/docs/text-chat
    """
    completion = await self._aclient.chat.completions.create(  # type: ignore
        messages=input,  # type: ignore
        model=self.model,
        seed=seed,  # type: ignore
        temperature=temperature,
        max_tokens=max_new_tokens,
        top_p=top_p,
        stream=False,
        stop=stop,
    )
    generations = []
    for choice in completion.choices:
        if (content := choice.message.content) is None:
            self._logger.warning(  # type: ignore
                f"Received no response using the Groq client (model: '{self.model}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(content)
    return generations

generate(inputs, num_generations=1, **kwargs)

Method to generate a list of responses asynchronously, returning the output synchronously awaiting for the response of each input sent to agenerate.

Source code in src/distilabel/llms/groq.py
@override
def generate(
    self,
    inputs: List["ChatType"],
    num_generations: int = 1,
    **kwargs: Any,
) -> List["GenerateOutput"]:
    """Method to generate a list of responses asynchronously, returning the output
    synchronously awaiting for the response of each input sent to `agenerate`.
    """

    async def agenerate(
        inputs: List["ChatType"], **kwargs: Any
    ) -> "GenerateOutput":
        """Internal function to parallelize the asynchronous generation of responses."""
        tasks = [
            asyncio.create_task(self.agenerate(input=input, **kwargs))
            for input in inputs
            for _ in range(num_generations)
        ]
        return [outputs[0] for outputs in await asyncio.gather(*tasks)]

    outputs = self.event_loop.run_until_complete(agenerate(inputs, **kwargs))
    return list(grouper(outputs, n=num_generations, incomplete="ignore"))

load()

Loads the AsyncGroq client to benefit from async requests.

Source code in src/distilabel/llms/groq.py
def load(self) -> None:
    """Loads the `AsyncGroq` client to benefit from async requests."""
    super().load()

    try:
        from groq import AsyncGroq
    except ImportError as ie:
        raise ImportError(
            "Groq Python client is not installed. Please install it using"
            ' `pip install groq` or from the extras as `pip install "distilabel[groq]"`.'
        ) from ie

    if self.api_key is None:
        raise ValueError(
            f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
            f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
        )

    self._aclient = AsyncGroq(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )