Skip to content

base

LLM

Bases: ABC

Source code in src/distilabel/llm/base.py
class LLM(ABC):
    def __init__(
        self,
        task: Task,
        num_threads: Union[int, None] = None,
        prompt_format: Union["SupportedFormats", None] = None,
        prompt_formatting_fn: Union[Callable[..., str], None] = None,
    ) -> None:
        """Initializes the LLM base class.

        Note:
            This class is intended to be used internally, but you anyone can still create
            a subclass, implement the `abstractmethod`s and use it.

        Args:
            task (Task): the task to be performed by the LLM.
            num_threads (Union[int, None], optional): the number of threads to be used
                for parallel generation. If `None`, no parallel generation will be performed.
                Defaults to `None`.
            prompt_format (Union["SupportedFormats", None], optional): the format to be used
                for the prompt. If `None`, the default format of the task will be used, available
                formats are `openai`, `chatml`, `llama2`, `zephyr`, and `default`. Defaults to `None`,
                but `default` (concatenation of `system_prompt` and `formatted_prompt` with a line-break)
                will be used if no `prompt_formatting_fn` is provided.
            prompt_formatting_fn (Union[Callable[..., str], None], optional): a function to be
                applied to the prompt before generation. If `None`, no formatting will be applied.
                Defaults to `None`.
        """
        self.task = task

        self.thread_pool_executor = (
            ThreadPoolExecutor(max_workers=num_threads)
            if num_threads is not None
            else None
        )

        self.prompt_format = prompt_format
        self.prompt_formatting_fn = prompt_formatting_fn

    def __del__(self) -> None:
        """Shuts down the thread pool executor if it is not `None`."""
        if self.thread_pool_executor is not None:
            self.thread_pool_executor.shutdown()

    @property
    def num_threads(self) -> Union[int, None]:
        if self.thread_pool_executor:
            return self.thread_pool_executor._max_workers

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(task={self.task.__class__.__name__}, num_threads={self.num_threads}, promp_format='{self.prompt_format}', model='{self.model_name}')"

    def __rich_repr__(self) -> Generator[Any, None, None]:
        yield "task", self.task
        yield "num_threads", self.num_threads
        yield "prompt_format", self.prompt_format
        if self.prompt_formatting_fn is not None:
            args = f"({', '.join(self.prompt_formatting_fn.__code__.co_varnames)})"
            representation = self.prompt_formatting_fn.__name__ + args
            yield "prompt_formatting_fn", representation
        yield "model", self.model_name

    @property
    @abstractmethod
    def model_name(self) -> str:
        pass

    def _generate_prompts(
        self,
        inputs: List[Dict[str, Any]],
        default_format: Union["SupportedFormats", None] = None,
    ) -> List[Any]:
        """Generates the prompts to be used for generation.

        Args:
            inputs (List[Dict[str, Any]]): the inputs to be used for generation.
            default_format (Union["SupportedFormats", None], optional): the default format to be used
                for the prompt if no `prompt_format` is specified. Defaults to `None`.

        Returns:
            List[Any]: the generated prompts.

        Raises:
            ValueError: if the generated prompt is not of the expected type.
        """
        prompts = []
        for input in inputs:
            prompt = self.task.generate_prompt(**input)
            if not isinstance(prompt, Prompt) and self.prompt_formatting_fn is not None:
                warnings.warn(
                    "The method `generate_prompt` is not returning a `Prompt` class but a prompt"
                    f" of `type={type(prompt)}`, meaning that a pre-formatting has already been"
                    " applied in the `task.generate_prompt` method, so the usage of a `prompt_formatting_fn`"
                    " is discouraged.",
                    UserWarning,
                    stacklevel=2,
                )
                prompt = self.prompt_formatting_fn(prompt)
            elif isinstance(prompt, Prompt) and self.prompt_formatting_fn is None:
                if self.prompt_format is not None or default_format is not None:
                    prompt = prompt.format_as(
                        format=self.prompt_format or default_format  # type: ignore
                    )
                else:
                    warnings.warn(
                        "No `prompt_format` has been specified and no `default_format` is set, so"
                        " the prompt will be concatenated with a line-break and no specific formatting"
                        " by default.",
                        UserWarning,
                        stacklevel=2,
                    )
                    prompt = prompt.format_as(format="default")
            prompts.append(prompt)
        return prompts

    @abstractmethod
    def _generate(
        self, inputs: List[Dict[str, Any]], num_generations: int = 1
    ) -> List[List["LLMOutput"]]:
        pass

    def _get_valid_inputs(
        self, inputs: List[Dict[str, Any]]
    ) -> Tuple[List[Dict[str, Any]], List[int]]:
        """Returns the valid inputs and the indices of the invalid inputs.

        A valid input is an input that contains all the arguments required by the task.

        Args:
            inputs (List[Dict[str, Any]]): the inputs to be used for generation.

        Returns:
            Tuple[List[Dict[str, Any]], List[int]]: a tuple containing the valid inputs and
                the indices of the invalid inputs.
        """

        valid_inputs = []
        not_valid_inputs_indices = []
        for i, input in enumerate(inputs):
            if not all(input_arg in input for input_arg in self.task.input_args_names):
                logger.warn(
                    f"Missing {self.task.__class__.__name__} input argument in batch element {i}"
                )
                not_valid_inputs_indices.append(i)
                continue

            valid_inputs.append(input)

        return valid_inputs, not_valid_inputs_indices

    def _fill_missing_inputs(
        self,
        generations: List[List[LLMOutput]],
        invalid_inputs_indices: List[int],
        num_generations: int,
    ) -> List[List[LLMOutput]]:
        """Fills the `generations` list with empty `LLMOutput`s for the inputs that were
        not valid for the associated task of this `LLM`.

        Args:
            generations (List[List[LLMOutput]]): the generations to be filled.
            invalid_inputs_indices (List[int]): the indices of the inputs that were not
                valid for the associated task of this `LLM`.
            num_generations (int): the number of generations to be performed for each input.

        Returns:
            List[List[LLMOutput]]: the filled generations.
        """

        filled_generations = generations.copy()
        for idx in invalid_inputs_indices:
            filled_generations.insert(
                idx,
                [
                    LLMOutput(
                        model_name=self.model_name,
                        prompt_used=None,
                        raw_output=None,
                        parsed_output=None,
                    )
                    for _ in range(num_generations)
                ],
            )
        return filled_generations

    def generate(
        self,
        inputs: List[Dict[str, Any]],
        num_generations: int = 1,
        progress_callback_func: Union[Callable, None] = None,
    ) -> Union[List[List["LLMOutput"]], Future[List[List["LLMOutput"]]]]:
        """Generates the outputs for the given inputs using the LLM.

        Args:
            inputs (List[Dict[str, Any]]): the inputs to be used for generation.
            num_generations (int, optional): the number of generations to be performed for each input.
                Defaults to `1`.
            progress_callback_func (Union[Callable, None], optional): a function to be called at each
                generation step. Defaults to `None`.

        Returns:
            Union[List[Future[List["LLMOutput"]]], List[List["LLMOutput"]]]: the generated outputs.
        """

        def _progress():
            if progress_callback_func is not None:
                progress_callback_func(advance=num_generations * len(inputs))

        valid_inputs, invalid_inputs_indices = self._get_valid_inputs(inputs)

        if self.thread_pool_executor is not None:
            futures = []
            for input in valid_inputs:
                future = self.thread_pool_executor.submit(
                    self._generate, [input], num_generations
                )
                futures.append(future)
            future = when_all_complete(
                futures=futures,
                callback=lambda generations: self._fill_missing_inputs(
                    generations, invalid_inputs_indices, num_generations
                ),
            )
            future.add_done_callback(lambda _: _progress())
            return future

        generations = self._generate(valid_inputs, num_generations)

        generations = self._fill_missing_inputs(
            generations, invalid_inputs_indices, num_generations
        )

        _progress()
        return generations

    @property
    def return_futures(self) -> bool:
        """Whether the `LLM` returns futures"""
        return self.thread_pool_executor is not None

    def validate_prompts(
        self,
        inputs: List[Dict[str, Any]],
        default_format: Union["SupportedFormats", None] = None,
    ) -> str:
        """Generates the prompts to be used for generation, can be used to check the prompts visually.

        Args:
            inputs (List[Dict[str, Any]]):
                The inputs to be used for generation.

        Returns:
            str: The prompts that would be used for the generation.

        Examples:
            >>> from distilabel.tasks import TextGenerationTask
            >>> llm.validate_prompts([{"input": "Your input"}])[0]
            You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
            If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
            I'm valid for text generation task
        """
        return self._generate_prompts(inputs, default_format=default_format)

return_futures: bool property

Whether the LLM returns futures

__del__()

Shuts down the thread pool executor if it is not None.

Source code in src/distilabel/llm/base.py
def __del__(self) -> None:
    """Shuts down the thread pool executor if it is not `None`."""
    if self.thread_pool_executor is not None:
        self.thread_pool_executor.shutdown()

__init__(task, num_threads=None, prompt_format=None, prompt_formatting_fn=None)

Initializes the LLM base class.

Note

This class is intended to be used internally, but you anyone can still create a subclass, implement the abstractmethods and use it.

Parameters:

Name Type Description Default
task Task

the task to be performed by the LLM.

required
num_threads Union[int, None]

the number of threads to be used for parallel generation. If None, no parallel generation will be performed. Defaults to None.

None
prompt_format Union['SupportedFormats', None]

the format to be used for the prompt. If None, the default format of the task will be used, available formats are openai, chatml, llama2, zephyr, and default. Defaults to None, but default (concatenation of system_prompt and formatted_prompt with a line-break) will be used if no prompt_formatting_fn is provided.

None
prompt_formatting_fn Union[Callable[..., str], None]

a function to be applied to the prompt before generation. If None, no formatting will be applied. Defaults to None.

None
Source code in src/distilabel/llm/base.py
def __init__(
    self,
    task: Task,
    num_threads: Union[int, None] = None,
    prompt_format: Union["SupportedFormats", None] = None,
    prompt_formatting_fn: Union[Callable[..., str], None] = None,
) -> None:
    """Initializes the LLM base class.

    Note:
        This class is intended to be used internally, but you anyone can still create
        a subclass, implement the `abstractmethod`s and use it.

    Args:
        task (Task): the task to be performed by the LLM.
        num_threads (Union[int, None], optional): the number of threads to be used
            for parallel generation. If `None`, no parallel generation will be performed.
            Defaults to `None`.
        prompt_format (Union["SupportedFormats", None], optional): the format to be used
            for the prompt. If `None`, the default format of the task will be used, available
            formats are `openai`, `chatml`, `llama2`, `zephyr`, and `default`. Defaults to `None`,
            but `default` (concatenation of `system_prompt` and `formatted_prompt` with a line-break)
            will be used if no `prompt_formatting_fn` is provided.
        prompt_formatting_fn (Union[Callable[..., str], None], optional): a function to be
            applied to the prompt before generation. If `None`, no formatting will be applied.
            Defaults to `None`.
    """
    self.task = task

    self.thread_pool_executor = (
        ThreadPoolExecutor(max_workers=num_threads)
        if num_threads is not None
        else None
    )

    self.prompt_format = prompt_format
    self.prompt_formatting_fn = prompt_formatting_fn

generate(inputs, num_generations=1, progress_callback_func=None)

Generates the outputs for the given inputs using the LLM.

Parameters:

Name Type Description Default
inputs List[Dict[str, Any]]

the inputs to be used for generation.

required
num_generations int

the number of generations to be performed for each input. Defaults to 1.

1
progress_callback_func Union[Callable, None]

a function to be called at each generation step. Defaults to None.

None

Returns:

Type Description
Union[List[List['LLMOutput']], Future[List[List['LLMOutput']]]]

Union[List[Future[List["LLMOutput"]]], List[List["LLMOutput"]]]: the generated outputs.

Source code in src/distilabel/llm/base.py
def generate(
    self,
    inputs: List[Dict[str, Any]],
    num_generations: int = 1,
    progress_callback_func: Union[Callable, None] = None,
) -> Union[List[List["LLMOutput"]], Future[List[List["LLMOutput"]]]]:
    """Generates the outputs for the given inputs using the LLM.

    Args:
        inputs (List[Dict[str, Any]]): the inputs to be used for generation.
        num_generations (int, optional): the number of generations to be performed for each input.
            Defaults to `1`.
        progress_callback_func (Union[Callable, None], optional): a function to be called at each
            generation step. Defaults to `None`.

    Returns:
        Union[List[Future[List["LLMOutput"]]], List[List["LLMOutput"]]]: the generated outputs.
    """

    def _progress():
        if progress_callback_func is not None:
            progress_callback_func(advance=num_generations * len(inputs))

    valid_inputs, invalid_inputs_indices = self._get_valid_inputs(inputs)

    if self.thread_pool_executor is not None:
        futures = []
        for input in valid_inputs:
            future = self.thread_pool_executor.submit(
                self._generate, [input], num_generations
            )
            futures.append(future)
        future = when_all_complete(
            futures=futures,
            callback=lambda generations: self._fill_missing_inputs(
                generations, invalid_inputs_indices, num_generations
            ),
        )
        future.add_done_callback(lambda _: _progress())
        return future

    generations = self._generate(valid_inputs, num_generations)

    generations = self._fill_missing_inputs(
        generations, invalid_inputs_indices, num_generations
    )

    _progress()
    return generations

validate_prompts(inputs, default_format=None)

Generates the prompts to be used for generation, can be used to check the prompts visually.

Parameters:

Name Type Description Default
inputs List[Dict[str, Any]]

The inputs to be used for generation.

required

Returns:

Name Type Description
str str

The prompts that would be used for the generation.

Examples:

>>> from distilabel.tasks import TextGenerationTask
>>> llm.validate_prompts([{"input": "Your input"}])[0]
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
I'm valid for text generation task
Source code in src/distilabel/llm/base.py
def validate_prompts(
    self,
    inputs: List[Dict[str, Any]],
    default_format: Union["SupportedFormats", None] = None,
) -> str:
    """Generates the prompts to be used for generation, can be used to check the prompts visually.

    Args:
        inputs (List[Dict[str, Any]]):
            The inputs to be used for generation.

    Returns:
        str: The prompts that would be used for the generation.

    Examples:
        >>> from distilabel.tasks import TextGenerationTask
        >>> llm.validate_prompts([{"input": "Your input"}])[0]
        You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
        If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
        I'm valid for text generation task
    """
    return self._generate_prompts(inputs, default_format=default_format)

LLMPool

LLMPool is a class that wraps multiple ProcessLLMs and performs generation in parallel using them. Depending on the number of LLMs and the parameter num_generations, the LLMPool will decide how many generations to perform for each LLM:

  • If num_generations is less than the number of LLMs, then num_generations LLMs will be chosen randomly and each of them will perform 1 generation.

  • If num_generations is equal to the number of LLMs, then each LLM will perform 1 generation.

  • If num_generations is greater than the number of LLMs, then each LLM will perform num_generations // num_llms generations, and the remaining num_generations % num_llms generations will be performed by num_generations % num_llms randomly chosen LLMs.

Attributes:

Name Type Description
llms List[ProcessLLM]

the ProcessLLMs to be used for generation.

Source code in src/distilabel/llm/base.py
class LLMPool:
    """LLMPool is a class that wraps multiple `ProcessLLM`s and performs generation in
    parallel using them. Depending on the number of `LLM`s and the parameter `num_generations`,
    the `LLMPool` will decide how many generations to perform for each `LLM`:

    - If `num_generations` is less than the number of `LLM`s, then `num_generations` LLMs
    will be chosen randomly and each of them will perform 1 generation.


    - If `num_generations` is equal to the number of `LLM`s, then each `LLM` will perform
    1 generation.

    - If `num_generations` is greater than the number of `LLM`s, then each `LLM` will
    perform `num_generations // num_llms` generations, and the remaining `num_generations % num_llms`
    generations will be performed by `num_generations % num_llms` randomly chosen `LLM`s.

    Attributes:
        llms (List[ProcessLLM]): the `ProcessLLM`s to be used for generation.
    """

    def __init__(self, llms: List[ProcessLLM]) -> None:
        """Initializes the `LLMPool` class.

        Args:
            llms: the `ProcessLLM`s to be used for generation. The list must contain at
                least 2 `ProcessLLM`s.

        Raises:
            ValueError: if the `llms` argument contains less than 2 `ProcessLLM`s, the
                `llms` argument contains `ProcessLLM`s that are not `ProcessLLM`s, or
                if the `llms` argument contains `ProcessLLM`s with different tasks.
        """
        if len(llms) < 2:
            raise ValueError(
                "The `llms` argument must contain at least 2 `ProcessLLM`s. If you want"
                " to use a single `ProcessLLM`, use the `ProcessLLM` directly instead."
            )

        if not all(isinstance(llm, ProcessLLM) for llm in llms):
            raise ValueError("The `llms` argument must contain only `ProcessLLM`s.")

        # Note: The following piece of code is used to check that all the `ProcessLLM`s
        # have the same task or a subclass of it.
        mros = [(type(llm.task), len(type(llm.task).mro())) for llm in llms]
        min_common_class = min(mros, key=lambda x: x[1])[0]
        if not all(isinstance(llm.task, min_common_class) for llm in llms):
            # This can fail for example with 3 different TextGenerationTasks
            # Task1(TextGenerationTask), Task2(TextGenerationTask), Task2(TextGenerationTask)
            # because they share the same parent class but we don't check the common one
            # TODO(plaguss): We check that they all have the same parent class, this should be simplified
            # with the previous check
            parent_classes = [type(llm.task).mro()[1] for llm in llms]
            if not len(set(parent_classes)) == 1:
                raise ValueError(
                    "All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)."
                )

        self.llms = llms
        self.num_llms = len(llms)

    def _get_num_generations_per_llm(self, num_generations: int) -> Dict[int, int]:
        """Returns the number of generations to be performed by each `LLM`.

        Args:
            num_generations: the number of generations to be performed.

        Returns:
            Dict[int, int]: a dictionary where the keys are the ids of the `LLM`s and the
            values are the number of generations to be performed by each `LLM`.
        """
        llms_ids = list(range(self.num_llms))
        generations_per_llm = {i: num_generations // self.num_llms for i in llms_ids}

        for i in random.sample(llms_ids, k=num_generations % self.num_llms):
            generations_per_llm[i] += 1

        return generations_per_llm

    def generate(
        self,
        inputs: List[Dict[str, Any]],
        num_generations: int = 1,
        progress_callback_func: Union[Callable, None] = None,
    ) -> List[List["LLMOutput"]]:
        """Generates the outputs for the given inputs using the pool of `ProcessLLM`s.

        Args:
            inputs (List[Dict[str, Any]]): the inputs to be used for generation.
            num_generations (int, optional): the number of generations to be performed for each input.
                Defaults to `1`.
            progress_callback_func (Union[Callable, None], optional): a function to be called at each
                generation step. Defaults to `None`.

        Returns:
            Future[List[List["LLMOutput"]]]: the generated outputs as a `Future`.
        """
        num_generations_per_llm = self._get_num_generations_per_llm(num_generations)

        futures = [
            llm.generate(
                inputs,
                num_generations=num_generations_per_llm[i],
                progress_callback_func=progress_callback_func,
            )
            for i, llm in enumerate(self.llms)
            if num_generations_per_llm[i] > 0
        ]
        llms_generations = [future.result() for future in futures]

        generations = []
        for llms_row_generations in zip(*llms_generations):
            row_generations = []
            for llm_row_generations in llms_row_generations:
                for generation in llm_row_generations:
                    row_generations.append(generation)
            generations.append(row_generations)

        return generations

    def teardown(self) -> None:
        """Stops the `ProcessLLM`s."""
        for llm in self.llms:
            llm.teardown()

    @property
    def task(self) -> "Task":
        """Returns the task that will be used by the `ProcessLLM`s of this pool.

        Returns:
            Task: the task that will be used by the `ProcessLLM`s of this pool.
        """
        return self.llms[0].task

    @property
    def return_futures(self) -> bool:
        """Whether the `LLM` returns futures"""
        return False

return_futures: bool property

Whether the LLM returns futures

task: 'Task' property

Returns the task that will be used by the ProcessLLMs of this pool.

Returns:

Name Type Description
Task 'Task'

the task that will be used by the ProcessLLMs of this pool.

__init__(llms)

Initializes the LLMPool class.

Parameters:

Name Type Description Default
llms List[ProcessLLM]

the ProcessLLMs to be used for generation. The list must contain at least 2 ProcessLLMs.

required

Raises:

Type Description
ValueError

if the llms argument contains less than 2 ProcessLLMs, the llms argument contains ProcessLLMs that are not ProcessLLMs, or if the llms argument contains ProcessLLMs with different tasks.

Source code in src/distilabel/llm/base.py
def __init__(self, llms: List[ProcessLLM]) -> None:
    """Initializes the `LLMPool` class.

    Args:
        llms: the `ProcessLLM`s to be used for generation. The list must contain at
            least 2 `ProcessLLM`s.

    Raises:
        ValueError: if the `llms` argument contains less than 2 `ProcessLLM`s, the
            `llms` argument contains `ProcessLLM`s that are not `ProcessLLM`s, or
            if the `llms` argument contains `ProcessLLM`s with different tasks.
    """
    if len(llms) < 2:
        raise ValueError(
            "The `llms` argument must contain at least 2 `ProcessLLM`s. If you want"
            " to use a single `ProcessLLM`, use the `ProcessLLM` directly instead."
        )

    if not all(isinstance(llm, ProcessLLM) for llm in llms):
        raise ValueError("The `llms` argument must contain only `ProcessLLM`s.")

    # Note: The following piece of code is used to check that all the `ProcessLLM`s
    # have the same task or a subclass of it.
    mros = [(type(llm.task), len(type(llm.task).mro())) for llm in llms]
    min_common_class = min(mros, key=lambda x: x[1])[0]
    if not all(isinstance(llm.task, min_common_class) for llm in llms):
        # This can fail for example with 3 different TextGenerationTasks
        # Task1(TextGenerationTask), Task2(TextGenerationTask), Task2(TextGenerationTask)
        # because they share the same parent class but we don't check the common one
        # TODO(plaguss): We check that they all have the same parent class, this should be simplified
        # with the previous check
        parent_classes = [type(llm.task).mro()[1] for llm in llms]
        if not len(set(parent_classes)) == 1:
            raise ValueError(
                "All the `ProcessLLM` in `llms` must share the same task (either as the instance or the parent class)."
            )

    self.llms = llms
    self.num_llms = len(llms)

generate(inputs, num_generations=1, progress_callback_func=None)

Generates the outputs for the given inputs using the pool of ProcessLLMs.

Parameters:

Name Type Description Default
inputs List[Dict[str, Any]]

the inputs to be used for generation.

required
num_generations int

the number of generations to be performed for each input. Defaults to 1.

1
progress_callback_func Union[Callable, None]

a function to be called at each generation step. Defaults to None.

None

Returns:

Type Description
List[List['LLMOutput']]

Future[List[List["LLMOutput"]]]: the generated outputs as a Future.

Source code in src/distilabel/llm/base.py
def generate(
    self,
    inputs: List[Dict[str, Any]],
    num_generations: int = 1,
    progress_callback_func: Union[Callable, None] = None,
) -> List[List["LLMOutput"]]:
    """Generates the outputs for the given inputs using the pool of `ProcessLLM`s.

    Args:
        inputs (List[Dict[str, Any]]): the inputs to be used for generation.
        num_generations (int, optional): the number of generations to be performed for each input.
            Defaults to `1`.
        progress_callback_func (Union[Callable, None], optional): a function to be called at each
            generation step. Defaults to `None`.

    Returns:
        Future[List[List["LLMOutput"]]]: the generated outputs as a `Future`.
    """
    num_generations_per_llm = self._get_num_generations_per_llm(num_generations)

    futures = [
        llm.generate(
            inputs,
            num_generations=num_generations_per_llm[i],
            progress_callback_func=progress_callback_func,
        )
        for i, llm in enumerate(self.llms)
        if num_generations_per_llm[i] > 0
    ]
    llms_generations = [future.result() for future in futures]

    generations = []
    for llms_row_generations in zip(*llms_generations):
        row_generations = []
        for llm_row_generations in llms_row_generations:
            for generation in llm_row_generations:
                row_generations.append(generation)
        generations.append(row_generations)

    return generations

teardown()

Stops the ProcessLLMs.

Source code in src/distilabel/llm/base.py
def teardown(self) -> None:
    """Stops the `ProcessLLM`s."""
    for llm in self.llms:
        llm.teardown()

ProcessLLM

A class that wraps an LLM and performs generation in a separate process. The result is a Future that will be set when the generation is completed.

This class creates a new child process that will load the LLM and perform the text generation. In order to communicate with this child process, a bridge thread is created in the main process. The bridge thread will send and receive the results from the child process using multiprocessing.Queues. The communication between the bridge thread and the main process is done using Futures. This architecture was inspired by the ProcessPoolExecutor from the concurrent.futures module and it's a simplified version of it.

Source code in src/distilabel/llm/base.py
class ProcessLLM:
    """A class that wraps an `LLM` and performs generation in a separate process. The
    result is a `Future` that will be set when the generation is completed.

    This class creates a new child process that will load the `LLM` and perform the
    text generation. In order to communicate with this child process, a bridge thread
    is created in the main process. The bridge thread will send and receive the results
    from the child process using `multiprocessing.Queue`s. The communication between the
    bridge thread and the main process is done using `Future`s. This architecture was
    inspired by the `ProcessPoolExecutor` from the `concurrent.futures` module and it's
    a simplified version of it.
    """

    def __init__(self, task: Task, load_llm_fn: Callable[[Task], LLM]) -> None:
        """Initializes the `ProcessLLM` class.

        Args:
            task: the task to be performed by the `LLM`. This task will be used by the
                child process when calling the `load_llm_fn`.
            load_llm_fn (Callable[[Task], LLM]): a function that will be executed in the
                child process to load the `LLM`. It must return an `LLM` instance.
        """
        self.task = task

        self._load_llm_fn = load_llm_fn

        # The bridge thread will act as a bridge between the main process and the child
        # process for communication. It will send the generation requests to the child
        # process and receive the results from the child process.
        self._bridge_thread = None

        # The child process which will load the `LLM` and perform the generation.
        self._generation_process = None

        # The `Semaphore` that will be used to synchronize the loading of the `LLM`.
        # `_BridgeThread` will be blocked until `_GenerationProcess` has called the
        # `load_llm_fn` and the `LLM` has been loaded.
        self._load_llm_sem = mp.Semaphore(0)

        # This thread will create text generation requests
        self.pending_text_generation_request: Dict[int, _TextGenerationRequest] = {}
        self.text_generation_request_count = 0
        self.text_generation_request_ids_queue: queue.Queue[int] = queue.Queue()

        # Queues for the communication between the `_BridgeThread` and the `_GenerationProcess`
        self._call_queue = mp.Queue()
        self._result_queue = mp.Queue()

        # Shared memory object for transfering the `model_name` to the main process
        # once the `LLM` is loaded
        self._model_name = mp.Array(c_char, MAX_MODEL_NAME_LENGTH)

    def _start_bridge_thread(self) -> None:
        """Starts the bridge thread and the generation process."""
        if self._bridge_thread is None:
            self._generation_process = _GenerationProcess(self)
            self._generation_process.start()
            pid = self._generation_process.pid
            logger.debug(f"Generation process with PID {pid} started!")

            self._bridge_thread = _BridgeThread(self)
            self._bridge_thread.start()
            logger.debug("Bridge thread for process with PID {pid} started!")

    def _add_text_generation_request(
        self,
        inputs: List[Dict[str, Any]],
        num_generations: int = 1,
        progress_callback_func: Union[Callable, None] = None,
    ) -> Future[List[List["LLMOutput"]]]:
        """Creates and send a new text generation request to the bridge thread. This thread
        and the bridge thread shares a dictionary used to store the text generation requests.
        This thread will add the text generation requests to the dictionary and the bridge
        thread will only read from it. In order for the bridge thread to know that a new
        text generation request has been added to the dictionary, this thread will put the
        id of the request in a queue. The bridge thread will read from this queue and get
        the text generation request from the dictionary.
        """

        def _progress():
            if progress_callback_func is not None:
                progress_callback_func(advance=num_generations * len(inputs))

        text_generation_request = _TextGenerationRequest(
            inputs=inputs, num_generations=num_generations
        )
        # Put the request information in the dictionary associated to the request id
        self.pending_text_generation_request[
            self.text_generation_request_count
        ] = text_generation_request
        # Put the request id in the queue (for the `_BridgeThread` to consume it)
        self.text_generation_request_ids_queue.put(self.text_generation_request_count)
        self.text_generation_request_count += 1
        text_generation_request.future.add_done_callback(lambda _: _progress())
        return text_generation_request.future

    def generate(
        self,
        inputs: List[Dict[str, Any]],
        num_generations: int = 1,
        progress_callback_func: Union[Callable, None] = None,
    ) -> Future[List[List["LLMOutput"]]]:
        """Generates the outputs for the given inputs using the `ProcessLLM` and its loaded
        `LLM`.

        Args:
            inputs (List[Dict[str, Any]]): the inputs to be used for generation.
            num_generations (int, optional): the number of generations to be performed for each input.
                Defaults to `1`.
            progress_callback_func (Union[Callable, None], optional): a function to be called at each
                generation step. Defaults to `None`.

        Returns:
            Future[List[List["LLMOutput"]]]: the generated outputs as a `Future`.
        """
        self._start_bridge_thread()
        return self._add_text_generation_request(
            inputs, num_generations, progress_callback_func
        )

    def teardown(self) -> None:
        """Stops the bridge thread and the generation process."""
        if self._generation_process is not None:
            self._generation_process.stop()
            self._generation_process.join()

        if self._bridge_thread is not None:
            self._bridge_thread.stop()
            self._bridge_thread.join()

    @cached_property
    def model_name(self) -> str:
        """Returns the model name of the `LLM` once it has been loaded."""
        with self._model_name:
            return "".join([c.decode() for c in self._model_name if c != b"\0"])

    @property
    def return_futures(self) -> bool:
        """Whether the `LLM` returns futures"""
        return True

model_name: str cached property

Returns the model name of the LLM once it has been loaded.

return_futures: bool property

Whether the LLM returns futures

__init__(task, load_llm_fn)

Initializes the ProcessLLM class.

Parameters:

Name Type Description Default
task Task

the task to be performed by the LLM. This task will be used by the child process when calling the load_llm_fn.

required
load_llm_fn Callable[[Task], LLM]

a function that will be executed in the child process to load the LLM. It must return an LLM instance.

required
Source code in src/distilabel/llm/base.py
def __init__(self, task: Task, load_llm_fn: Callable[[Task], LLM]) -> None:
    """Initializes the `ProcessLLM` class.

    Args:
        task: the task to be performed by the `LLM`. This task will be used by the
            child process when calling the `load_llm_fn`.
        load_llm_fn (Callable[[Task], LLM]): a function that will be executed in the
            child process to load the `LLM`. It must return an `LLM` instance.
    """
    self.task = task

    self._load_llm_fn = load_llm_fn

    # The bridge thread will act as a bridge between the main process and the child
    # process for communication. It will send the generation requests to the child
    # process and receive the results from the child process.
    self._bridge_thread = None

    # The child process which will load the `LLM` and perform the generation.
    self._generation_process = None

    # The `Semaphore` that will be used to synchronize the loading of the `LLM`.
    # `_BridgeThread` will be blocked until `_GenerationProcess` has called the
    # `load_llm_fn` and the `LLM` has been loaded.
    self._load_llm_sem = mp.Semaphore(0)

    # This thread will create text generation requests
    self.pending_text_generation_request: Dict[int, _TextGenerationRequest] = {}
    self.text_generation_request_count = 0
    self.text_generation_request_ids_queue: queue.Queue[int] = queue.Queue()

    # Queues for the communication between the `_BridgeThread` and the `_GenerationProcess`
    self._call_queue = mp.Queue()
    self._result_queue = mp.Queue()

    # Shared memory object for transfering the `model_name` to the main process
    # once the `LLM` is loaded
    self._model_name = mp.Array(c_char, MAX_MODEL_NAME_LENGTH)

generate(inputs, num_generations=1, progress_callback_func=None)

Generates the outputs for the given inputs using the ProcessLLM and its loaded LLM.

Parameters:

Name Type Description Default
inputs List[Dict[str, Any]]

the inputs to be used for generation.

required
num_generations int

the number of generations to be performed for each input. Defaults to 1.

1
progress_callback_func Union[Callable, None]

a function to be called at each generation step. Defaults to None.

None

Returns:

Type Description
Future[List[List['LLMOutput']]]

Future[List[List["LLMOutput"]]]: the generated outputs as a Future.

Source code in src/distilabel/llm/base.py
def generate(
    self,
    inputs: List[Dict[str, Any]],
    num_generations: int = 1,
    progress_callback_func: Union[Callable, None] = None,
) -> Future[List[List["LLMOutput"]]]:
    """Generates the outputs for the given inputs using the `ProcessLLM` and its loaded
    `LLM`.

    Args:
        inputs (List[Dict[str, Any]]): the inputs to be used for generation.
        num_generations (int, optional): the number of generations to be performed for each input.
            Defaults to `1`.
        progress_callback_func (Union[Callable, None], optional): a function to be called at each
            generation step. Defaults to `None`.

    Returns:
        Future[List[List["LLMOutput"]]]: the generated outputs as a `Future`.
    """
    self._start_bridge_thread()
    return self._add_text_generation_request(
        inputs, num_generations, progress_callback_func
    )

teardown()

Stops the bridge thread and the generation process.

Source code in src/distilabel/llm/base.py
def teardown(self) -> None:
    """Stops the bridge thread and the generation process."""
    if self._generation_process is not None:
        self._generation_process.stop()
        self._generation_process.join()

    if self._bridge_thread is not None:
        self._bridge_thread.stop()
        self._bridge_thread.join()