Skip to content

VertexAILLM

VertexAILLM

Bases: AsyncLLM

VertexAI LLM implementation running the async API clients for Gemini.

  • Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

To use the VertexAILLM is necessary to have configured the Google Cloud authentication using one of these methods:

  • Setting GOOGLE_CLOUD_CREDENTIALS environment variable
  • Using gcloud auth application-default login command
  • Using vertexai.init function from the google-cloud-aiplatform library

Attributes:

Name Type Description
model str

the model name to use for the LLM e.g. "gemini-1.0-pro". Supported models.

_aclient Optional[GenerativeModel]

the GenerativeModel to use for the Vertex AI Gemini API. It is meant to be used internally. Set in the load method.

Icon

:simple-googlecloud:

Source code in src/distilabel/llms/vertexai.py
class VertexAILLM(AsyncLLM):
    """VertexAI LLM implementation running the async API clients for Gemini.

    - Gemini API: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini

    To use the `VertexAILLM` is necessary to have configured the Google Cloud authentication
    using one of these methods:

    - Setting `GOOGLE_CLOUD_CREDENTIALS` environment variable
    - Using `gcloud auth application-default login` command
    - Using `vertexai.init` function from the `google-cloud-aiplatform` library

    Attributes:
        model: the model name to use for the LLM e.g. "gemini-1.0-pro". [Supported models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models).
        _aclient: the `GenerativeModel` to use for the Vertex AI Gemini API. It is meant
            to be used internally. Set in the `load` method.

    Icon:
        `:simple-googlecloud:`
    """

    model: str

    _num_generations_param_supported = False

    _aclient: Optional["GenerativeModel"] = PrivateAttr(...)

    def load(self) -> None:
        """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
        super().load()

        try:
            from vertexai.generative_models import GenerationConfig, GenerativeModel

            self._generation_config_class = GenerationConfig
        except ImportError as e:
            raise ImportError(
                "vertexai is not installed. Please install it using"
                " `pip install google-cloud-aiplatform`."
            ) from e

        if _is_gemini_model(self.model):
            self._aclient = GenerativeModel(model_name=self.model)
        else:
            raise NotImplementedError(
                "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
            )

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def _chattype_to_content(self, input: "StandardInput") -> List["Content"]:
        """Converts a chat type to a list of content items expected by the API.

        Args:
            input: the chat type to be converted.

        Returns:
            List[str]: a list of content items expected by the API.
        """
        from vertexai.generative_models import Content, Part

        contents = []
        for message in input:
            if message["role"] not in ["user", "model"]:
                raise ValueError(
                    "`VertexAILLM only supports the roles 'user' or 'model'."
                )
            contents.append(
                Content(
                    role=message["role"], parts=[Part.from_text(message["content"])]
                )
            )
        return contents

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: StandardInput,
        temperature: Optional[float] = None,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        max_output_tokens: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        safety_settings: Optional[Dict[str, Any]] = None,
        tools: Optional[List[Dict[str, Any]]] = None,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

        Args:
            input: a single input in chat format to generate responses for.
            temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
            top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
            top_k: If specified, top-k sampling will be used. Defaults to `None`.
            max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
            stop_sequences: A list of stop sequences. Defaults to `None`.
            safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
            tools: A potential list of tools that can be used by the API. Defaults to `None`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vertexai.generative_models import GenerationConfig

        content: "GenerationResponse" = await self._aclient.generate_content_async(  # type: ignore
            contents=self._chattype_to_content(input),
            generation_config=GenerationConfig(
                candidate_count=1,  # only one candidate allowed per call
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_output_tokens=max_output_tokens,
                stop_sequences=stop_sequences,
            ),
            safety_settings=safety_settings,  # type: ignore
            tools=tools,  # type: ignore
            stream=False,
        )

        text = None
        try:
            text = content.candidates[0].text
        except ValueError:
            self._logger.warning(  # type: ignore
                f"Received no response using VertexAI client (model: '{self.model}')."
                f" Finish reason was: '{content.candidates[0].finish_reason}'."
            )

        return [text]

model_name: str property

Returns the model name used for the LLM.

agenerate(input, temperature=None, top_p=None, top_k=None, max_output_tokens=None, stop_sequences=None, safety_settings=None, tools=None) async

Generates num_generations responses for the given input using the VertexAI async client definition.

Parameters:

Name Type Description Default
input StandardInput

a single input in chat format to generate responses for.

required
temperature Optional[float]

Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to None.

None
top_p Optional[float]

If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to None.

None
top_k Optional[int]

If specified, top-k sampling will be used. Defaults to None.

None
max_output_tokens Optional[int]

The maximum number of output tokens to generate per message. Defaults to None.

None
stop_sequences Optional[List[str]]

A list of stop sequences. Defaults to None.

None
safety_settings Optional[Dict[str, Any]]

Safety configuration for returned content from the API. Defaults to None.

None
tools Optional[List[Dict[str, Any]]]

A potential list of tools that can be used by the API. Defaults to None.

None

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/vertexai.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: StandardInput,
    temperature: Optional[float] = None,
    top_p: Optional[float] = None,
    top_k: Optional[int] = None,
    max_output_tokens: Optional[int] = None,
    stop_sequences: Optional[List[str]] = None,
    safety_settings: Optional[Dict[str, Any]] = None,
    tools: Optional[List[Dict[str, Any]]] = None,
) -> GenerateOutput:
    """Generates `num_generations` responses for the given input using the [VertexAI async client definition](https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini).

    Args:
        input: a single input in chat format to generate responses for.
        temperature: Controls the randomness of predictions. Range: [0.0, 1.0]. Defaults to `None`.
        top_p: If specified, nucleus sampling will be used. Range: (0.0, 1.0]. Defaults to `None`.
        top_k: If specified, top-k sampling will be used. Defaults to `None`.
        max_output_tokens: The maximum number of output tokens to generate per message. Defaults to `None`.
        stop_sequences: A list of stop sequences. Defaults to `None`.
        safety_settings: Safety configuration for returned content from the API. Defaults to `None`.
        tools: A potential list of tools that can be used by the API. Defaults to `None`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vertexai.generative_models import GenerationConfig

    content: "GenerationResponse" = await self._aclient.generate_content_async(  # type: ignore
        contents=self._chattype_to_content(input),
        generation_config=GenerationConfig(
            candidate_count=1,  # only one candidate allowed per call
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_output_tokens=max_output_tokens,
            stop_sequences=stop_sequences,
        ),
        safety_settings=safety_settings,  # type: ignore
        tools=tools,  # type: ignore
        stream=False,
    )

    text = None
    try:
        text = content.candidates[0].text
    except ValueError:
        self._logger.warning(  # type: ignore
            f"Received no response using VertexAI client (model: '{self.model}')."
            f" Finish reason was: '{content.candidates[0].finish_reason}'."
        )

    return [text]

load()

Loads the GenerativeModel class which has access to generate_content_async to benefit from async requests.

Source code in src/distilabel/llms/vertexai.py
def load(self) -> None:
    """Loads the `GenerativeModel` class which has access to `generate_content_async` to benefit from async requests."""
    super().load()

    try:
        from vertexai.generative_models import GenerationConfig, GenerativeModel

        self._generation_config_class = GenerationConfig
    except ImportError as e:
        raise ImportError(
            "vertexai is not installed. Please install it using"
            " `pip install google-cloud-aiplatform`."
        ) from e

    if _is_gemini_model(self.model):
        self._aclient = GenerativeModel(model_name=self.model)
    else:
        raise NotImplementedError(
            "`VertexAILLM` is only implemented for `gemini` models that allow for `ChatType` data."
        )