Skip to content



Bases: AsyncLLM

Cohere API implementation using the async client for concurrent text generation.


Name Type Description
model str

the name of the model from the Cohere API to use for the generation.

base_url Optional[RuntimeParameter[str]]

the base URL to use for the Cohere API requests. Defaults to "".

api_key Optional[RuntimeParameter[SecretStr]]

the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.

timeout RuntimeParameter[int]

the maximum time in seconds to wait for a response from the API. Defaults to 120.

client_name RuntimeParameter[str]

the name of the client to use for the API requests. Defaults to "distilabel".

structured_output Optional[RuntimeParameter[InstructorStructuredOutputType]]

a dictionary containing the structured output configuration configuration using instructor. You can take a look at the dictionary structure in InstructorStructuredOutputType from distilabel.steps.tasks.structured_outputs.instructor.

_ChatMessage Type[ChatMessage]

the ChatMessage class from the cohere package.

_aclient AsyncClient

the AsyncClient client from the cohere package.

Runtime parameters
  • base_url: the base URL to use for the Cohere API requests. Defaults to "".
  • api_key: the API key to authenticate the requests to the Cohere API. Defaults to the value of the COHERE_API_KEY environment variable.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
  • client_name: the name of the client to use for the API requests. Defaults to "distilabel".


Generate text:

from distilabel.llms import CohereLLM

llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")


# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

Generate structured data:

from pydantic import BaseModel
from distilabel.llms import CohereLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = CohereLLM(
    structured_output={"schema": User}


output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
Source code in src/distilabel/llms/
class CohereLLM(AsyncLLM):
    """Cohere API implementation using the async client for concurrent text generation.

        model: the name of the model from the Cohere API to use for the generation.
        base_url: the base URL to use for the Cohere API requests. Defaults to
        api_key: the API key to authenticate the requests to the Cohere API. Defaults to
            the value of the `COHERE_API_KEY` environment variable.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        client_name: the name of the client to use for the API requests. Defaults to
        structured_output: a dictionary containing the structured output configuration configuration
            using `instructor`. You can take a look at the dictionary structure in
            `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
        _ChatMessage: the `ChatMessage` class from the `cohere` package.
        _aclient: the `AsyncClient` client from the `cohere` package.

    Runtime parameters:
        - `base_url`: the base URL to use for the Cohere API requests. Defaults to
        - `api_key`: the API key to authenticate the requests to the Cohere API. Defaults
            to the value of the `COHERE_API_KEY` environment variable.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `client_name`: the name of the client to use for the API requests. Defaults to


        Generate text:

        from distilabel.llms import CohereLLM

        llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")


        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])

        Generate structured data:

        from pydantic import BaseModel
        from distilabel.llms import CohereLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = CohereLLM(
            structured_output={"schema": User}


        output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])

    model: str
    base_url: Optional[RuntimeParameter[str]] = Field(
        default_factory=lambda: os.getenv(
            "COHERE_BASE_URL", ""
        description="The base URL to use for the Cohere API requests.",
    api_key: Optional[RuntimeParameter[SecretStr]] = Field(
        default_factory=lambda: os.getenv(_COHERE_API_KEY_ENV_VAR_NAME),
        description="The API key to authenticate the requests to the Cohere API.",
    timeout: RuntimeParameter[int] = Field(
        description="The maximum time in seconds to wait for a response from the API.",
    client_name: RuntimeParameter[str] = Field(
        description="The name of the client to use for the API requests.",
    structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
            description="The structured output format to use across all the generations.",

    _num_generations_param_supported = False

    _ChatMessage: Type["ChatMessage"] = PrivateAttr(...)
    _aclient: "AsyncClient" = PrivateAttr(...)

    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def load(self) -> None:
        """Loads the `AsyncClient` client from the `cohere` package."""


            from cohere import AsyncClient, ChatMessage
        except ImportError as ie:
            raise ImportError(
                "The `cohere` package is required to use the `CohereLLM` class."
            ) from ie

        self._ChatMessage = ChatMessage

        self._aclient = AsyncClient(
            api_key=self.api_key.get_secret_value(),  # type: ignore

        if self.structured_output:
            result = self._prepare_structured_output(
            self._aclient = result.get("client")  # type: ignore
            if structured_output := result.get("structured_output"):
                self.structured_output = structured_output

    def _format_chat_to_cohere(
        self, input: "FormattedInput"
    ) -> Tuple[Union[str, None], List["ChatMessage"], str]:
        """Formats the chat input to the Cohere Chat API conversational format.

            input: The chat input to format.

            A tuple containing the system, chat history, and message.
        system = None
        message = None
        chat_history = []
        for item in input:
            role = item["role"]
            content = item["content"]
            if role == "system":
                system = content
            elif role == "user":
                message = content
            elif role == "assistant":
                if message is None:
                    raise ValueError(
                        "An assistant message but be preceded by a user message."
                chat_history.append(self._ChatMessage(role="USER", message=message))  # type: ignore
                chat_history.append(self._ChatMessage(role="CHATBOT", message=content))  # type: ignore
                message = None

        if message is None:
            raise ValueError("The chat input must end with a user message.")

        return system, chat_history, message

    async def agenerate(  # type: ignore
        input: FormattedInput,
        temperature: Optional[float] = None,
        max_tokens: Optional[int] = None,
        k: Optional[int] = None,
        p: Optional[float] = None,
        seed: Optional[float] = None,
        stop_sequences: Optional[Sequence[str]] = None,
        frequency_penalty: Optional[float] = None,
        presence_penalty: Optional[float] = None,
        raw_prompting: Optional[bool] = None,
    ) -> GenerateOutput:
        """Generates a response from the LLM given an input.

            input: a single input in chat format to generate responses for.
            temperature: the temperature to use for the generation. Defaults to `None`.
            max_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `None`.
            k: the number of highest probability vocabulary tokens to keep for the generation.
                Defaults to `None`.
            p: the nucleus sampling probability to use for the generation. Defaults to
            seed: the seed to use for the generation. Defaults to `None`.
            stop_sequences: a list of sequences to use as stopping criteria for the generation.
                Defaults to `None`.
            frequency_penalty: the frequency penalty to use for the generation. Defaults
                to `None`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
            raw_prompting: a flag to use raw prompting for the generation. Defaults to

            The generated response from the Cohere API model.
        structured_output = None
        if isinstance(input, tuple):
            input, structured_output = input
            result = self._prepare_structured_output(
                structured_output=structured_output,  # type: ignore
            self._aclient = result.get("client")  # type: ignore

        if structured_output is None and self.structured_output is not None:
            structured_output = self.structured_output

        system, chat_history, message = self._format_chat_to_cohere(input)

        kwargs = {
            "message": message,
            "model": self.model,
            "preamble": system,
            "chat_history": chat_history,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "k": k,
            "p": p,
            "seed": seed,
            "stop_sequences": stop_sequences,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "raw_prompting": raw_prompting,
        if structured_output:
            kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

        response = await**kwargs)  # type: ignore

        if structured_output:
            return [response.model_dump_json()]

        if (text := response.text) == "":
            self._logger.warning(  # type: ignore
                f"Received no response using Cohere client (model: '{self.model}')."
                f" Finish reason was: {response.finish_reason}"
            return [None]

        return [text]

model_name: str property

Returns the model name used for the LLM.

agenerate(input, temperature=None, max_tokens=None, k=None, p=None, seed=None, stop_sequences=None, frequency_penalty=None, presence_penalty=None, raw_prompting=None) async

Generates a response from the LLM given an input.


Name Type Description Default
input FormattedInput

a single input in chat format to generate responses for.

temperature Optional[float]

the temperature to use for the generation. Defaults to None.

max_tokens Optional[int]

the maximum number of new tokens that the model will generate. Defaults to None.

k Optional[int]

the number of highest probability vocabulary tokens to keep for the generation. Defaults to None.

p Optional[float]

the nucleus sampling probability to use for the generation. Defaults to None.

seed Optional[float]

the seed to use for the generation. Defaults to None.

stop_sequences Optional[Sequence[str]]

a list of sequences to use as stopping criteria for the generation. Defaults to None.

frequency_penalty Optional[float]

the frequency penalty to use for the generation. Defaults to None.

presence_penalty Optional[float]

the presence penalty to use for the generation. Defaults to None.

raw_prompting Optional[bool]

a flag to use raw prompting for the generation. Defaults to None.



Type Description

The generated response from the Cohere API model.

Source code in src/distilabel/llms/
async def agenerate(  # type: ignore
    input: FormattedInput,
    temperature: Optional[float] = None,
    max_tokens: Optional[int] = None,
    k: Optional[int] = None,
    p: Optional[float] = None,
    seed: Optional[float] = None,
    stop_sequences: Optional[Sequence[str]] = None,
    frequency_penalty: Optional[float] = None,
    presence_penalty: Optional[float] = None,
    raw_prompting: Optional[bool] = None,
) -> GenerateOutput:
    """Generates a response from the LLM given an input.

        input: a single input in chat format to generate responses for.
        temperature: the temperature to use for the generation. Defaults to `None`.
        max_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `None`.
        k: the number of highest probability vocabulary tokens to keep for the generation.
            Defaults to `None`.
        p: the nucleus sampling probability to use for the generation. Defaults to
        seed: the seed to use for the generation. Defaults to `None`.
        stop_sequences: a list of sequences to use as stopping criteria for the generation.
            Defaults to `None`.
        frequency_penalty: the frequency penalty to use for the generation. Defaults
            to `None`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
        raw_prompting: a flag to use raw prompting for the generation. Defaults to

        The generated response from the Cohere API model.
    structured_output = None
    if isinstance(input, tuple):
        input, structured_output = input
        result = self._prepare_structured_output(
            structured_output=structured_output,  # type: ignore
        self._aclient = result.get("client")  # type: ignore

    if structured_output is None and self.structured_output is not None:
        structured_output = self.structured_output

    system, chat_history, message = self._format_chat_to_cohere(input)

    kwargs = {
        "message": message,
        "model": self.model,
        "preamble": system,
        "chat_history": chat_history,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "k": k,
        "p": p,
        "seed": seed,
        "stop_sequences": stop_sequences,
        "frequency_penalty": frequency_penalty,
        "presence_penalty": presence_penalty,
        "raw_prompting": raw_prompting,
    if structured_output:
        kwargs = self._prepare_kwargs(kwargs, structured_output)  # type: ignore

    response = await**kwargs)  # type: ignore

    if structured_output:
        return [response.model_dump_json()]

    if (text := response.text) == "":
        self._logger.warning(  # type: ignore
            f"Received no response using Cohere client (model: '{self.model}')."
            f" Finish reason was: {response.finish_reason}"
        return [None]

    return [text]


Loads the AsyncClient client from the cohere package.

Source code in src/distilabel/llms/
def load(self) -> None:
    """Loads the `AsyncClient` client from the `cohere` package."""


        from cohere import AsyncClient, ChatMessage
    except ImportError as ie:
        raise ImportError(
            "The `cohere` package is required to use the `CohereLLM` class."
        ) from ie

    self._ChatMessage = ChatMessage

    self._aclient = AsyncClient(
        api_key=self.api_key.get_secret_value(),  # type: ignore

    if self.structured_output:
        result = self._prepare_structured_output(
        self._aclient = result.get("client")  # type: ignore
        if structured_output := result.get("structured_output"):
            self.structured_output = structured_output