Skip to content

vLLM

ClientvLLM

Bases: OpenAILLM, MagpieChatTemplateMixin

A client for the vLLM server implementing the OpenAI API specification.

Attributes:

Name Type Description
base_url

the base URL of the vLLM server. Defaults to "http://localhost:8000".

max_retries

the maximum number of times to retry the request to the API before failing. Defaults to 6.

timeout

the maximum time in seconds to wait for a response from the API. Defaults to 120.

httpx_client_kwargs

extra kwargs that will be passed to the httpx.AsyncClient created to comunicate with the vLLM server. Defaults to None.

tokenizer Optional[str]

the Hugging Face Hub repo id or path of the tokenizer that will be used to apply the chat template and tokenize the inputs before sending it to the server. Defaults to None.

tokenizer_revision Optional[str]

the revision of the tokenizer to load. Defaults to None.

_aclient Optional[str]

the httpx.AsyncClient used to comunicate with the vLLM server. Defaults to None.

Runtime parameters
  • base_url: the base url of the vLLM server. Defaults to "http://localhost:8000".
  • max_retries: the maximum number of times to retry the request to the API before failing. Defaults to 6.
  • timeout: the maximum time in seconds to wait for a response from the API. Defaults to 120.
  • httpx_client_kwargs: extra kwargs that will be passed to the httpx.AsyncClient created to comunicate with the vLLM server. Defaults to None.

Examples:

Generate text:

```python
from distilabel.llms import ClientvLLM

llm = ClientvLLM(
    base_url="http://localhost:8000/v1",
    tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
)

llm.load()

results = llm.generate(
    inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
    temperature=0.7,
    top_p=1.0,
    max_new_tokens=256,
)
# [
#     [
#         "I'm functioning properly, thank you for asking. How can I assist you today?",
#         "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
#         "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
#     ]
# ]
```
Source code in src/distilabel/llms/vllm.py
class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
    """A client for the `vLLM` server implementing the OpenAI API specification.

    Attributes:
        base_url: the base URL of the `vLLM` server. Defaults to `"http://localhost:8000"`.
        max_retries: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        timeout: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        httpx_client_kwargs: extra kwargs that will be passed to the `httpx.AsyncClient`
            created to comunicate with the `vLLM` server. Defaults to `None`.
        tokenizer: the Hugging Face Hub repo id or path of the tokenizer that will be used
            to apply the chat template and tokenize the inputs before sending it to the
            server. Defaults to `None`.
        tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
        _aclient: the `httpx.AsyncClient` used to comunicate with the `vLLM` server. Defaults
            to `None`.

    Runtime parameters:
        - `base_url`: the base url of the `vLLM` server. Defaults to `"http://localhost:8000"`.
        - `max_retries`: the maximum number of times to retry the request to the API before
            failing. Defaults to `6`.
        - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults
            to `120`.
        - `httpx_client_kwargs`: extra kwargs that will be passed to the `httpx.AsyncClient`
            created to comunicate with the `vLLM` server. Defaults to `None`.

    Examples:

        Generate text:

        ```python
        from distilabel.llms import ClientvLLM

        llm = ClientvLLM(
            base_url="http://localhost:8000/v1",
            tokenizer="meta-llama/Meta-Llama-3.1-8B-Instruct"
        )

        llm.load()

        results = llm.generate(
            inputs=[[{"role": "user", "content": "Hello, how are you?"}]],
            temperature=0.7,
            top_p=1.0,
            max_new_tokens=256,
        )
        # [
        #     [
        #         "I'm functioning properly, thank you for asking. How can I assist you today?",
        #         "I'm doing well, thank you for asking. I'm a large language model, so I don't have feelings or emotions like humans do, but I'm here to help answer any questions or provide information you might need. How can I assist you today?",
        #         "I'm just a computer program, so I don't have feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have. What's on your mind?"
        #     ]
        # ]
        ```
    """

    model: str = ""  # Default value so it's not needed to `ClientvLLM(model="...")`
    tokenizer: Optional[str] = None
    tokenizer_revision: Optional[str] = None

    # We need the sync client to get the list of models
    _client: "OpenAI" = PrivateAttr(None)
    _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)

    def load(self) -> None:
        """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
        optionally."""

        self.api_key = SecretStr("EMPTY")

        # We need to first create the sync client to get the model name that will be used
        # in the `super().load()` when creating the logger.
        try:
            from openai import OpenAI
        except ImportError as ie:
            raise ImportError(
                "OpenAI Python client is not installed. Please install it using"
                " `pip install openai`."
            ) from ie

        self._client = OpenAI(
            base_url=self.base_url,
            api_key=self.api_key.get_secret_value(),  # type: ignore
            max_retries=self.max_retries,  # type: ignore
            timeout=self.timeout,
        )

        super().load()

        try:
            from transformers import AutoTokenizer
        except ImportError as ie:
            raise ImportError(
                "To use `ClientvLLM` you need to install `transformers`."
                "Please install it using `pip install transformers`."
            ) from ie

        self._tokenizer = AutoTokenizer.from_pretrained(
            self.tokenizer, revision=self.tokenizer_revision
        )

    @property
    def model_name(self) -> str:
        """Returns the name of the model served with vLLM server."""
        models = self._client.models.list()
        return models.data[0].id

    def _prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        prompt: str = (
            self._tokenizer.apply_chat_template(  # type: ignore
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,  # type: ignore
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    @validate_call
    async def agenerate(  # type: ignore
        self,
        input: FormattedInput,
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        logit_bias: Optional[Dict[str, int]] = None,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
    ) -> GenerateOutput:
        """Generates `num_generations` responses for each input.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            logit_bias: modify the likelihood of specified tokens appearing in the completion.
                Defaults to ``
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: nucleus sampling. The value refers to the top-p tokens that should be
                considered for sampling. Defaults to `1.0`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """

        completion = await self._aclient.completions.create(
            model=self.model_name,
            prompt=self._prepare_input(input),  # type: ignore
            n=num_generations,
            max_tokens=max_new_tokens,
            frequency_penalty=frequency_penalty,
            logit_bias=logit_bias,
            presence_penalty=presence_penalty,
            temperature=temperature,
            top_p=top_p,
        )

        generations = []
        for choice in completion.choices:
            if (text := choice.text) == "":
                self._logger.warning(  # type: ignore
                    f"Received no response from vLLM server (model: '{self.model_name}')."
                    f" Finish reason was: {choice.finish_reason}"
                )
            generations.append(text)
        return generations

model_name: str property

Returns the name of the model served with vLLM server.

agenerate(input, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, logit_bias=None, presence_penalty=0.0, temperature=1.0, top_p=1.0) async

Generates num_generations responses for each input.

Parameters:

Name Type Description Default
inputs

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
logit_bias Optional[Dict[str, int]]

modify the likelihood of specified tokens appearing in the completion. Defaults to ``

None
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

nucleus sampling. The value refers to the top-p tokens that should be considered for sampling. Defaults to 1.0.

1.0

Returns:

Type Description
GenerateOutput

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/vllm.py
@validate_call
async def agenerate(  # type: ignore
    self,
    input: FormattedInput,
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    logit_bias: Optional[Dict[str, int]] = None,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
) -> GenerateOutput:
    """Generates `num_generations` responses for each input.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        logit_bias: modify the likelihood of specified tokens appearing in the completion.
            Defaults to ``
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: nucleus sampling. The value refers to the top-p tokens that should be
            considered for sampling. Defaults to `1.0`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """

    completion = await self._aclient.completions.create(
        model=self.model_name,
        prompt=self._prepare_input(input),  # type: ignore
        n=num_generations,
        max_tokens=max_new_tokens,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        presence_penalty=presence_penalty,
        temperature=temperature,
        top_p=top_p,
    )

    generations = []
    for choice in completion.choices:
        if (text := choice.text) == "":
            self._logger.warning(  # type: ignore
                f"Received no response from vLLM server (model: '{self.model_name}')."
                f" Finish reason was: {choice.finish_reason}"
            )
        generations.append(text)
    return generations

load()

Creates an httpx.AsyncClient to connect to the vLLM server and a tokenizer optionally.

Source code in src/distilabel/llms/vllm.py
def load(self) -> None:
    """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer
    optionally."""

    self.api_key = SecretStr("EMPTY")

    # We need to first create the sync client to get the model name that will be used
    # in the `super().load()` when creating the logger.
    try:
        from openai import OpenAI
    except ImportError as ie:
        raise ImportError(
            "OpenAI Python client is not installed. Please install it using"
            " `pip install openai`."
        ) from ie

    self._client = OpenAI(
        base_url=self.base_url,
        api_key=self.api_key.get_secret_value(),  # type: ignore
        max_retries=self.max_retries,  # type: ignore
        timeout=self.timeout,
    )

    super().load()

    try:
        from transformers import AutoTokenizer
    except ImportError as ie:
        raise ImportError(
            "To use `ClientvLLM` you need to install `transformers`."
            "Please install it using `pip install transformers`."
        ) from ie

    self._tokenizer = AutoTokenizer.from_pretrained(
        self.tokenizer, revision=self.tokenizer_revision
    )

vLLM

Bases: LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin

vLLM library LLM implementation.

Attributes:

Name Type Description
model str

the model Hugging Face Hub repo id or a path to a directory containing the model weights and configuration files.

dtype str

the data type to use for the model. Defaults to auto.

trust_remote_code bool

whether to trust the remote code when loading the model. Defaults to False.

quantization Optional[str]

the quantization mode to use for the model. Defaults to None.

revision Optional[str]

the revision of the model to load. Defaults to None.

tokenizer Optional[str]

the tokenizer Hugging Face Hub repo id or a path to a directory containing the tokenizer files. If not provided, the tokenizer will be loaded from the model directory. Defaults to None.

tokenizer_mode Literal['auto', 'slow']

the mode to use for the tokenizer. Defaults to auto.

tokenizer_revision Optional[str]

the revision of the tokenizer to load. Defaults to None.

skip_tokenizer_init bool

whether to skip the initialization of the tokenizer. Defaults to False.

chat_template Optional[str]

a chat template that will be used to build the prompts before sending them to the model. If not provided, the chat template defined in the tokenizer config will be used. If not provided and the tokenizer doesn't have a chat template, then ChatML template will be used. Defaults to None.

structured_output Optional[RuntimeParameter[OutlinesStructuredOutputType]]

a dictionary containing the structured output configuration or if more fine-grained control is needed, an instance of OutlinesStructuredOutput. Defaults to None.

seed int

the seed to use for the random number generator. Defaults to 0.

extra_kwargs Optional[RuntimeParameter[Dict[str, Any]]]

additional dictionary of keyword arguments that will be passed to the LLM class of vllm library. Defaults to {}.

_model LLM

the vLLM model instance. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

_tokenizer PreTrainedTokenizer

the tokenizer instance used to format the prompt before passing it to the LLM. This attribute is meant to be used internally and should not be accessed directly. It will be set in the load method.

use_magpie_template PreTrainedTokenizer

a flag used to enable/disable applying the Magpie pre-query template. Defaults to False.

magpie_pre_query_template PreTrainedTokenizer

the pre-query template to be applied to the prompt or sent to the LLM to generate an instruction or a follow up user message. Valid values are "llama3", "qwen2" or another pre-query template provided. Defaults to None.

References
  • https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
Runtime parameters
  • extra_kwargs: additional dictionary of keyword arguments that will be passed to the LLM class of vllm library.

Examples:

Generate text:

```python
from distilabel.llms import vLLM

# You can pass a custom chat_template to the model
llm = vLLM(
    model="prometheus-eval/prometheus-7b-v2.0",
    chat_template="[INST] {{ messages[0]"content" }}\n{{ messages[1]"content" }}[/INST]",
)

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
```

Generate structured data:

```python
from pathlib import Path
from distilabel.llms import vLLM

class User(BaseModel):
    name: str
    last_name: str
    id: int

llm = vLLM(
    model="prometheus-eval/prometheus-7b-v2.0"
    structured_output={"format": "json", "schema": Character},
)

llm.load()

# Call the model
output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
```
Source code in src/distilabel/llms/vllm.py
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
    """`vLLM` library LLM implementation.

    Attributes:
        model: the model Hugging Face Hub repo id or a path to a directory containing the
            model weights and configuration files.
        dtype: the data type to use for the model. Defaults to `auto`.
        trust_remote_code: whether to trust the remote code when loading the model. Defaults
            to `False`.
        quantization: the quantization mode to use for the model. Defaults to `None`.
        revision: the revision of the model to load. Defaults to `None`.
        tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing
            the tokenizer files. If not provided, the tokenizer will be loaded from the
            model directory. Defaults to `None`.
        tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`.
        tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`.
        skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults
            to `False`.
        chat_template: a chat template that will be used to build the prompts before
            sending them to the model. If not provided, the chat template defined in the
            tokenizer config will be used. If not provided and the tokenizer doesn't have
            a chat template, then ChatML template will be used. Defaults to `None`.
        structured_output: a dictionary containing the structured output configuration or if more
            fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
        seed: the seed to use for the random number generator. Defaults to `0`.
        extra_kwargs: additional dictionary of keyword arguments that will be passed to the
            `LLM` class of `vllm` library. Defaults to `{}`.
        _model: the `vLLM` model instance. This attribute is meant to be used internally
            and should not be accessed directly. It will be set in the `load` method.
        _tokenizer: the tokenizer instance used to format the prompt before passing it to
            the `LLM`. This attribute is meant to be used internally and should not be
            accessed directly. It will be set in the `load` method.
        use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
            template. Defaults to `False`.
        magpie_pre_query_template: the pre-query template to be applied to the prompt or
            sent to the LLM to generate an instruction or a follow up user message. Valid
            values are "llama3", "qwen2" or another pre-query template provided. Defaults
            to `None`.

    References:
        - https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py

    Runtime parameters:
        - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to
            the `LLM` class of `vllm` library.

    Examples:

        Generate text:

        ```python
        from distilabel.llms import vLLM

        # You can pass a custom chat_template to the model
        llm = vLLM(
            model="prometheus-eval/prometheus-7b-v2.0",
            chat_template="[INST] {{ messages[0]\"content\" }}\\n{{ messages[1]\"content\" }}[/INST]",
        )

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
        ```

        Generate structured data:

        ```python
        from pathlib import Path
        from distilabel.llms import vLLM

        class User(BaseModel):
            name: str
            last_name: str
            id: int

        llm = vLLM(
            model="prometheus-eval/prometheus-7b-v2.0"
            structured_output={"format": "json", "schema": Character},
        )

        llm.load()

        # Call the model
        output = llm.generate(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]])
        ```
    """

    model: str
    dtype: str = "auto"
    trust_remote_code: bool = False
    quantization: Optional[str] = None
    revision: Optional[str] = None

    tokenizer: Optional[str] = None
    tokenizer_mode: Literal["auto", "slow"] = "auto"
    tokenizer_revision: Optional[str] = None
    skip_tokenizer_init: bool = False
    chat_template: Optional[str] = None

    seed: int = 0

    extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
        default_factory=dict,
        description="Additional dictionary of keyword arguments that will be passed to the"
        " `vLLM` class of `vllm` library. See all the supported arguments at: "
        "https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py",
    )
    structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
        default=None,
        description="The structured output format to use across all the generations.",
    )

    _model: "_vLLM" = PrivateAttr(None)
    _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None)
    _logits_processor: Optional[Callable] = PrivateAttr(default=None)

    def load(self) -> None:
        """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
        Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
        parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
        default value is ChatML format, unless explicitly provided.
        """
        super().load()

        CudaDevicePlacementMixin.load(self)

        try:
            from vllm import LLM as _vLLM
        except ImportError as ie:
            raise ImportError(
                "vLLM is not installed. Please install it using `pip install vllm`."
            ) from ie

        self._model = _vLLM(
            self.model,
            dtype=self.dtype,
            trust_remote_code=self.trust_remote_code,
            quantization=self.quantization,
            revision=self.revision,
            tokenizer=self.tokenizer,
            tokenizer_mode=self.tokenizer_mode,
            tokenizer_revision=self.tokenizer_revision,
            skip_tokenizer_init=self.skip_tokenizer_init,
            seed=self.seed,
            **self.extra_kwargs,  # type: ignore
        )

        self._tokenizer = self._model.get_tokenizer()  # type: ignore
        if self.chat_template is not None:
            self._tokenizer.chat_template = self.chat_template  # type: ignore

        if self.structured_output:
            self._logits_processor = self._prepare_structured_output(
                self.structured_output
            )

    def unload(self) -> None:
        """Unloads the `vLLM` model."""
        CudaDevicePlacementMixin.unload(self)
        super().unload()

    @property
    def model_name(self) -> str:
        """Returns the model name used for the LLM."""
        return self.model

    def prepare_input(self, input: "StandardInput") -> str:
        """Prepares the input (applying the chat template and tokenization) for the provided
        input.

        Args:
            input: the input list containing chat items.

        Returns:
            The prompt to send to the LLM.
        """
        if self._tokenizer.chat_template is None:
            return input[0]["content"]

        prompt: str = (
            self._tokenizer.apply_chat_template(
                input,  # type: ignore
                tokenize=False,
                add_generation_prompt=True,  # type: ignore
            )
            if input
            else ""
        )
        return super().apply_magpie_pre_query_template(prompt, input)

    def _prepare_batches(
        self, inputs: List[FormattedInput]
    ) -> Tuple[List[List[FormattedInput]], List[int]]:
        """Prepares the inputs by grouping them by the structured output.

        When we generate structured outputs with schemas obtained from a dataset, we need to
        prepare the data to try to send batches of inputs instead of single inputs to the model
        to take advante of the engine. So we group the inputs by the structured output to be
        passed in the `generate` method.

        Args:
            inputs: The batch of inputs passed to the generate method. As we expect to be generating
                structured outputs, each element will be a tuple containing the instruction and the
                structured output.

        Returns:
            The prepared batches (sub-batches let's say) to be passed to the `generate` method.
            Each new tuple will contain instead of the single instruction, a list of instructions
        """
        instruction_order = {}
        batches = {}
        for i, (instruction, structured_output) in enumerate(inputs):
            instruction = self.prepare_input(instruction)
            instruction_order[instruction] = i
            structured_output = json.dumps(structured_output)
            if structured_output not in batches:
                batches[structured_output] = [instruction]
            else:
                batches[structured_output].append(instruction)

        # Flatten the instructions in prepared_data
        flat_instructions = [
            instruction for _, group in batches.items() for instruction in group
        ]
        # Generate the list of indices based on the original order
        sorted_indices = [
            instruction_order[instruction] for instruction in flat_instructions
        ]
        return [
            (batch, json.loads(schema)) for schema, batch in batches.items()
        ], sorted_indices

    @validate_call
    def generate(  # type: ignore
        self,
        inputs: List[FormattedInput],
        num_generations: int = 1,
        max_new_tokens: int = 128,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        extra_sampling_params: Optional[Dict[str, Any]] = None,
    ) -> List[GenerateOutput]:
        """Generates `num_generations` responses for each input.

        Args:
            inputs: a list of inputs in chat format to generate responses for.
            num_generations: the number of generations to create per input. Defaults to
                `1`.
            max_new_tokens: the maximum number of new tokens that the model will generate.
                Defaults to `128`.
            frequency_penalty: the repetition penalty to use for the generation. Defaults
                to `0.0`.
            presence_penalty: the presence penalty to use for the generation. Defaults to
                `0.0`.
            temperature: the temperature to use for the generation. Defaults to `0.1`.
            top_p: the top-p value to use for the generation. Defaults to `1.0`.
            top_k: the top-k value to use for the generation. Defaults to `0`.
            extra_sampling_params: dictionary with additional arguments to be passed to
                the `SamplingParams` class from `vllm`.

        Returns:
            A list of lists of strings containing the generated responses for each input.
        """
        from vllm import SamplingParams

        if extra_sampling_params is None:
            extra_sampling_params = {}
        structured_output = None

        if isinstance(inputs[0], tuple):
            prepared_batches, sorted_indices = self._prepare_batches(inputs)
        else:
            # Simulate a batch without the structured output content
            prepared_batches = [([self.prepare_input(input) for input in inputs], None)]
            sorted_indices = None

        # In case we have a single structured output for the dataset, we can
        logits_processors = None
        if self._logits_processor:
            logits_processors = [self._logits_processor]

        batched_outputs = []

        for prepared_inputs, structured_output in prepared_batches:
            if structured_output:
                logits_processors = [self._prepare_structured_output(structured_output)]

            sampling_params = SamplingParams(  # type: ignore
                n=num_generations,
                presence_penalty=presence_penalty,
                frequency_penalty=frequency_penalty,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
                max_tokens=max_new_tokens,
                logits_processors=logits_processors,
                **extra_sampling_params,
            )

            batch_outputs = self._model.generate(
                prepared_inputs,
                sampling_params,
                use_tqdm=False,  # type: ignore
            )

            batched_outputs += [
                [output.text for output in outputs.outputs] for outputs in batch_outputs
            ]

        # If logits_processor is set, we need to sort the outputs back to the original order
        # (would be needed only if we have multiple structured outputs in the dataset)
        if sorted_indices is not None:
            batched_outputs = _sort_batches(
                batched_outputs, sorted_indices, num_generations=num_generations
            )
        return batched_outputs

    def _prepare_structured_output(
        self, structured_output: Optional[OutlinesStructuredOutputType] = None
    ) -> Union[Callable, None]:
        """Creates the appropriate function to filter tokens to generate structured outputs.

        Args:
            structured_output: the configuration dict to prepare the structured output.

        Returns:
            The callable that will be used to guide the generation of the model.
        """
        from distilabel.steps.tasks.structured_outputs.outlines import (
            prepare_guided_output,
        )

        result = prepare_guided_output(structured_output, "vllm", self._model)
        if (schema := result.get("schema")) and self.structured_output:
            self.structured_output["schema"] = schema
        return result["processor"]

model_name: str property

Returns the model name used for the LLM.

generate(inputs, num_generations=1, max_new_tokens=128, frequency_penalty=0.0, presence_penalty=0.0, temperature=1.0, top_p=1.0, top_k=-1, extra_sampling_params=None)

Generates num_generations responses for each input.

Parameters:

Name Type Description Default
inputs List[FormattedInput]

a list of inputs in chat format to generate responses for.

required
num_generations int

the number of generations to create per input. Defaults to 1.

1
max_new_tokens int

the maximum number of new tokens that the model will generate. Defaults to 128.

128
frequency_penalty float

the repetition penalty to use for the generation. Defaults to 0.0.

0.0
presence_penalty float

the presence penalty to use for the generation. Defaults to 0.0.

0.0
temperature float

the temperature to use for the generation. Defaults to 0.1.

1.0
top_p float

the top-p value to use for the generation. Defaults to 1.0.

1.0
top_k int

the top-k value to use for the generation. Defaults to 0.

-1
extra_sampling_params Optional[Dict[str, Any]]

dictionary with additional arguments to be passed to the SamplingParams class from vllm.

None

Returns:

Type Description
List[GenerateOutput]

A list of lists of strings containing the generated responses for each input.

Source code in src/distilabel/llms/vllm.py
@validate_call
def generate(  # type: ignore
    self,
    inputs: List[FormattedInput],
    num_generations: int = 1,
    max_new_tokens: int = 128,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
    temperature: float = 1.0,
    top_p: float = 1.0,
    top_k: int = -1,
    extra_sampling_params: Optional[Dict[str, Any]] = None,
) -> List[GenerateOutput]:
    """Generates `num_generations` responses for each input.

    Args:
        inputs: a list of inputs in chat format to generate responses for.
        num_generations: the number of generations to create per input. Defaults to
            `1`.
        max_new_tokens: the maximum number of new tokens that the model will generate.
            Defaults to `128`.
        frequency_penalty: the repetition penalty to use for the generation. Defaults
            to `0.0`.
        presence_penalty: the presence penalty to use for the generation. Defaults to
            `0.0`.
        temperature: the temperature to use for the generation. Defaults to `0.1`.
        top_p: the top-p value to use for the generation. Defaults to `1.0`.
        top_k: the top-k value to use for the generation. Defaults to `0`.
        extra_sampling_params: dictionary with additional arguments to be passed to
            the `SamplingParams` class from `vllm`.

    Returns:
        A list of lists of strings containing the generated responses for each input.
    """
    from vllm import SamplingParams

    if extra_sampling_params is None:
        extra_sampling_params = {}
    structured_output = None

    if isinstance(inputs[0], tuple):
        prepared_batches, sorted_indices = self._prepare_batches(inputs)
    else:
        # Simulate a batch without the structured output content
        prepared_batches = [([self.prepare_input(input) for input in inputs], None)]
        sorted_indices = None

    # In case we have a single structured output for the dataset, we can
    logits_processors = None
    if self._logits_processor:
        logits_processors = [self._logits_processor]

    batched_outputs = []

    for prepared_inputs, structured_output in prepared_batches:
        if structured_output:
            logits_processors = [self._prepare_structured_output(structured_output)]

        sampling_params = SamplingParams(  # type: ignore
            n=num_generations,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_new_tokens,
            logits_processors=logits_processors,
            **extra_sampling_params,
        )

        batch_outputs = self._model.generate(
            prepared_inputs,
            sampling_params,
            use_tqdm=False,  # type: ignore
        )

        batched_outputs += [
            [output.text for output in outputs.outputs] for outputs in batch_outputs
        ]

    # If logits_processor is set, we need to sort the outputs back to the original order
    # (would be needed only if we have multiple structured outputs in the dataset)
    if sorted_indices is not None:
        batched_outputs = _sort_batches(
            batched_outputs, sorted_indices, num_generations=num_generations
        )
    return batched_outputs

load()

Loads the vLLM model using either the path or the Hugging Face Hub repository id. Additionally, this method also sets the chat_template for the tokenizer, so as to properly parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the default value is ChatML format, unless explicitly provided.

Source code in src/distilabel/llms/vllm.py
def load(self) -> None:
    """Loads the `vLLM` model using either the path or the Hugging Face Hub repository id.
    Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly
    parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the
    default value is ChatML format, unless explicitly provided.
    """
    super().load()

    CudaDevicePlacementMixin.load(self)

    try:
        from vllm import LLM as _vLLM
    except ImportError as ie:
        raise ImportError(
            "vLLM is not installed. Please install it using `pip install vllm`."
        ) from ie

    self._model = _vLLM(
        self.model,
        dtype=self.dtype,
        trust_remote_code=self.trust_remote_code,
        quantization=self.quantization,
        revision=self.revision,
        tokenizer=self.tokenizer,
        tokenizer_mode=self.tokenizer_mode,
        tokenizer_revision=self.tokenizer_revision,
        skip_tokenizer_init=self.skip_tokenizer_init,
        seed=self.seed,
        **self.extra_kwargs,  # type: ignore
    )

    self._tokenizer = self._model.get_tokenizer()  # type: ignore
    if self.chat_template is not None:
        self._tokenizer.chat_template = self.chat_template  # type: ignore

    if self.structured_output:
        self._logits_processor = self._prepare_structured_output(
            self.structured_output
        )

prepare_input(input)

Prepares the input (applying the chat template and tokenization) for the provided input.

Parameters:

Name Type Description Default
input StandardInput

the input list containing chat items.

required

Returns:

Type Description
str

The prompt to send to the LLM.

Source code in src/distilabel/llms/vllm.py
def prepare_input(self, input: "StandardInput") -> str:
    """Prepares the input (applying the chat template and tokenization) for the provided
    input.

    Args:
        input: the input list containing chat items.

    Returns:
        The prompt to send to the LLM.
    """
    if self._tokenizer.chat_template is None:
        return input[0]["content"]

    prompt: str = (
        self._tokenizer.apply_chat_template(
            input,  # type: ignore
            tokenize=False,
            add_generation_prompt=True,  # type: ignore
        )
        if input
        else ""
    )
    return super().apply_magpie_pre_query_template(prompt, input)

unload()

Unloads the vLLM model.

Source code in src/distilabel/llms/vllm.py
def unload(self) -> None:
    """Unloads the `vLLM` model."""
    CudaDevicePlacementMixin.unload(self)
    super().unload()