Skip to content

Routing batch function

RoutingBatchFunc = Callable[[List[str]], List[str]] module-attribute

Type alias for a routing batch function. It takes a list of all the downstream steps and returns a list with the names of the steps that should receive the batch.

RoutingBatchFunction

Bases: BaseModel, _Serializable

A thin wrapper around a routing batch function that can be used to route batches from one upstream step to specific downstream steps.

Attributes:

Name Type Description
routing_function RoutingBatchFunc

The routing function that takes a list of all the downstream steps and returns a list with the names of the steps that should receive the batch.

_step Union[_Step, None]

The upstream step that is connected to the routing batch function.

_routed_batch_registry Dict[str, Dict[int, List[str]]]

A dictionary that keeps track of the batches that have been routed to specific downstream steps.

Source code in src/distilabel/pipeline/routing_batch_function.py
class RoutingBatchFunction(BaseModel, _Serializable):
    """A thin wrapper around a routing batch function that can be used to route batches
    from one upstream step to specific downstream steps.

    Attributes:
        routing_function: The routing function that takes a list of all the downstream steps
            and returns a list with the names of the steps that should receive the batch.
        _step: The upstream step that is connected to the routing batch function.
        _routed_batch_registry: A dictionary that keeps track of the batches that have been
            routed to specific downstream steps.
    """

    routing_function: RoutingBatchFunc
    description: Optional[str] = None

    _step: Union["_Step", None] = PrivateAttr(default=None)
    _routed_batch_registry: Dict[str, Dict[int, List[str]]] = PrivateAttr(
        default_factory=dict
    )
    _factory_function_module: Union[str, None] = PrivateAttr(default=None)
    _factory_function_name: Union[str, None] = PrivateAttr(default=None)
    _factory_function_kwargs: Union[Dict[str, Any], None] = PrivateAttr(default=None)

    def route_batch(self, batch: "_Batch", steps: List[str]) -> List[str]:
        """Returns a list of selected downstream steps from `steps` to which the `batch`
        should be routed.

        Args:
            batch: The batch that should be routed.
            steps: A list of all the downstream steps that can receive the batch.

        Returns:
            A list with the names of the steps that should receive the batch.
        """
        routed_steps = self.routing_function(steps)
        self._register_routed_batch(batch, routed_steps)
        return routed_steps

    def set_factory_function(
        self,
        factory_function_module: str,
        factory_function_name: str,
        factory_function_kwargs: Dict[str, Any],
    ) -> None:
        """Sets the factory function that was used to create the `routing_batch_function`.

        Args:
            factory_function_module: The module name where the factory function is defined.
            factory_function_name: The name of the factory function that was used to create
                the `routing_batch_function`.
            factory_function_kwargs: The keyword arguments that were used when calling the
                factory function.
        """
        self._factory_function_module = factory_function_module
        self._factory_function_name = factory_function_name
        self._factory_function_kwargs = factory_function_kwargs

    def __call__(self, batch: "_Batch", steps: List[str]) -> List[str]:
        """Returns a list of selected downstream steps from `steps` to which the `batch`
        should be routed.

        Args:
            batch: The batch that should be routed.
            steps: A list of all the downstream steps that can receive the batch.

        Returns:
            A list with the names of the steps that should receive the batch.
        """
        return self.route_batch(batch, steps)

    def _register_routed_batch(self, batch: "_Batch", routed_steps: List[str]) -> None:
        """Registers a batch that has been routed to specific downstream steps.

        Args:
            batch: The batch that has been routed.
            routed_steps: The list of downstream steps that have been selected to receive
                the batch.
        """
        upstream_step = batch.step_name
        batch_seq_no = batch.seq_no
        self._routed_batch_registry.setdefault(upstream_step, {}).setdefault(
            batch_seq_no, routed_steps
        )

    def __rshift__(
        self, other: List["DownstreamConnectableSteps"]
    ) -> List["DownstreamConnectableSteps"]:
        """Connects a list of dowstream steps to the upstream step of the routing batch
        function.

        Args:
            other: A list of downstream steps that should be connected to the upstream step
                of the routing batch function.

        Returns:
            The list of downstream steps that have been connected to the upstream step of the
            routing batch function.
        """
        if not isinstance(other, list):
            raise ValueError(
                f"Can only set a `routing_batch_function` for a list of steps. Got: {other}."
                " Please, review the right-hand side of the `routing_batch_function >> other`"
                " expression. It should be"
                " `upstream_step >> routing_batch_function >> [downstream_step_1, dowstream_step_2, ...]`."
            )

        if not self._step:
            raise ValueError(
                "Routing batch function doesn't have an upstream step. Cannot connect downstream"
                " steps before connecting the upstream step. Connect this routing batch"
                " function to an upstream step using the `>>` operator. For example:"
                " `upstream_step >> routing_batch_function >> [downstream_step_1, downstream_step_2, ...]`."
            )

        for step in other:
            self._step.connect(step)
        return other

    def dump(self, **kwargs: Any) -> Dict[str, Any]:
        """Dumps the routing batch function to a dictionary, and the information of the
        factory function used to create this routing batch function.

        Args:
            **kwargs: Additional keyword arguments that should be included in the dump.

        Returns:
            A dictionary with the routing batch function information and the factory function
            information.
        """
        dump_info: Dict[str, Any] = {"step": self._step.name}  # type: ignore

        if self.description:
            dump_info["description"] = self.description

        if type_info := self._get_type_info():
            dump_info[TYPE_INFO_KEY] = type_info

        return dump_info

    def _get_type_info(self) -> Dict[str, Any]:
        """Returns the information of the factory function used to create the routing batch
        function.

        Returns:
            A dictionary with the factory function information.
        """

        type_info = {}

        if self._factory_function_module:
            type_info["module"] = self._factory_function_module

        if self._factory_function_name:
            type_info["name"] = self._factory_function_name

        if self._factory_function_kwargs:
            type_info["kwargs"] = self._factory_function_kwargs

        return type_info

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> Self:
        """Loads a routing batch function from a dictionary. It must contain the information
        of the factory function used to create the routing batch function.

        Args:
            data: A dictionary with the routing batch function information and the factory
                function information.
        """
        type_info = data.get(TYPE_INFO_KEY)
        if not type_info:
            step = data.get("step")
            raise ValueError(
                f"The routing batch function for step '{step}' was created without a factory"
                " function, and it cannot be reconstructed."
            )

        module = type_info.get("module")
        name = type_info.get("name")
        kwargs = type_info.get("kwargs")

        if not module or not name or not kwargs:
            raise ValueError(
                "The routing batch function was created with a factory function, but the"
                " information is incomplete. Cannot reconstruct the routing batch function."
            )

        routing_batch_function = _get_module_attr(module=module, name=name)(**kwargs)
        routing_batch_function.description = data.get("description")
        routing_batch_function.set_factory_function(
            factory_function_module=module,
            factory_function_name=name,
            factory_function_kwargs=kwargs,
        )

        return routing_batch_function

__call__(batch, steps)

Returns a list of selected downstream steps from steps to which the batch should be routed.

Parameters:

Name Type Description Default
batch _Batch

The batch that should be routed.

required
steps List[str]

A list of all the downstream steps that can receive the batch.

required

Returns:

Type Description
List[str]

A list with the names of the steps that should receive the batch.

Source code in src/distilabel/pipeline/routing_batch_function.py
def __call__(self, batch: "_Batch", steps: List[str]) -> List[str]:
    """Returns a list of selected downstream steps from `steps` to which the `batch`
    should be routed.

    Args:
        batch: The batch that should be routed.
        steps: A list of all the downstream steps that can receive the batch.

    Returns:
        A list with the names of the steps that should receive the batch.
    """
    return self.route_batch(batch, steps)

__rshift__(other)

Connects a list of dowstream steps to the upstream step of the routing batch function.

Parameters:

Name Type Description Default
other List[DownstreamConnectableSteps]

A list of downstream steps that should be connected to the upstream step of the routing batch function.

required

Returns:

Type Description
List[DownstreamConnectableSteps]

The list of downstream steps that have been connected to the upstream step of the

List[DownstreamConnectableSteps]

routing batch function.

Source code in src/distilabel/pipeline/routing_batch_function.py
def __rshift__(
    self, other: List["DownstreamConnectableSteps"]
) -> List["DownstreamConnectableSteps"]:
    """Connects a list of dowstream steps to the upstream step of the routing batch
    function.

    Args:
        other: A list of downstream steps that should be connected to the upstream step
            of the routing batch function.

    Returns:
        The list of downstream steps that have been connected to the upstream step of the
        routing batch function.
    """
    if not isinstance(other, list):
        raise ValueError(
            f"Can only set a `routing_batch_function` for a list of steps. Got: {other}."
            " Please, review the right-hand side of the `routing_batch_function >> other`"
            " expression. It should be"
            " `upstream_step >> routing_batch_function >> [downstream_step_1, dowstream_step_2, ...]`."
        )

    if not self._step:
        raise ValueError(
            "Routing batch function doesn't have an upstream step. Cannot connect downstream"
            " steps before connecting the upstream step. Connect this routing batch"
            " function to an upstream step using the `>>` operator. For example:"
            " `upstream_step >> routing_batch_function >> [downstream_step_1, downstream_step_2, ...]`."
        )

    for step in other:
        self._step.connect(step)
    return other

dump(**kwargs)

Dumps the routing batch function to a dictionary, and the information of the factory function used to create this routing batch function.

Parameters:

Name Type Description Default
**kwargs Any

Additional keyword arguments that should be included in the dump.

{}

Returns:

Type Description
Dict[str, Any]

A dictionary with the routing batch function information and the factory function

Dict[str, Any]

information.

Source code in src/distilabel/pipeline/routing_batch_function.py
def dump(self, **kwargs: Any) -> Dict[str, Any]:
    """Dumps the routing batch function to a dictionary, and the information of the
    factory function used to create this routing batch function.

    Args:
        **kwargs: Additional keyword arguments that should be included in the dump.

    Returns:
        A dictionary with the routing batch function information and the factory function
        information.
    """
    dump_info: Dict[str, Any] = {"step": self._step.name}  # type: ignore

    if self.description:
        dump_info["description"] = self.description

    if type_info := self._get_type_info():
        dump_info[TYPE_INFO_KEY] = type_info

    return dump_info

from_dict(data) classmethod

Loads a routing batch function from a dictionary. It must contain the information of the factory function used to create the routing batch function.

Parameters:

Name Type Description Default
data Dict[str, Any]

A dictionary with the routing batch function information and the factory function information.

required
Source code in src/distilabel/pipeline/routing_batch_function.py
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Self:
    """Loads a routing batch function from a dictionary. It must contain the information
    of the factory function used to create the routing batch function.

    Args:
        data: A dictionary with the routing batch function information and the factory
            function information.
    """
    type_info = data.get(TYPE_INFO_KEY)
    if not type_info:
        step = data.get("step")
        raise ValueError(
            f"The routing batch function for step '{step}' was created without a factory"
            " function, and it cannot be reconstructed."
        )

    module = type_info.get("module")
    name = type_info.get("name")
    kwargs = type_info.get("kwargs")

    if not module or not name or not kwargs:
        raise ValueError(
            "The routing batch function was created with a factory function, but the"
            " information is incomplete. Cannot reconstruct the routing batch function."
        )

    routing_batch_function = _get_module_attr(module=module, name=name)(**kwargs)
    routing_batch_function.description = data.get("description")
    routing_batch_function.set_factory_function(
        factory_function_module=module,
        factory_function_name=name,
        factory_function_kwargs=kwargs,
    )

    return routing_batch_function

route_batch(batch, steps)

Returns a list of selected downstream steps from steps to which the batch should be routed.

Parameters:

Name Type Description Default
batch _Batch

The batch that should be routed.

required
steps List[str]

A list of all the downstream steps that can receive the batch.

required

Returns:

Type Description
List[str]

A list with the names of the steps that should receive the batch.

Source code in src/distilabel/pipeline/routing_batch_function.py
def route_batch(self, batch: "_Batch", steps: List[str]) -> List[str]:
    """Returns a list of selected downstream steps from `steps` to which the `batch`
    should be routed.

    Args:
        batch: The batch that should be routed.
        steps: A list of all the downstream steps that can receive the batch.

    Returns:
        A list with the names of the steps that should receive the batch.
    """
    routed_steps = self.routing_function(steps)
    self._register_routed_batch(batch, routed_steps)
    return routed_steps

set_factory_function(factory_function_module, factory_function_name, factory_function_kwargs)

Sets the factory function that was used to create the routing_batch_function.

Parameters:

Name Type Description Default
factory_function_module str

The module name where the factory function is defined.

required
factory_function_name str

The name of the factory function that was used to create the routing_batch_function.

required
factory_function_kwargs Dict[str, Any]

The keyword arguments that were used when calling the factory function.

required
Source code in src/distilabel/pipeline/routing_batch_function.py
def set_factory_function(
    self,
    factory_function_module: str,
    factory_function_name: str,
    factory_function_kwargs: Dict[str, Any],
) -> None:
    """Sets the factory function that was used to create the `routing_batch_function`.

    Args:
        factory_function_module: The module name where the factory function is defined.
        factory_function_name: The name of the factory function that was used to create
            the `routing_batch_function`.
        factory_function_kwargs: The keyword arguments that were used when calling the
            factory function.
    """
    self._factory_function_module = factory_function_module
    self._factory_function_name = factory_function_name
    self._factory_function_kwargs = factory_function_kwargs

routing_batch_function(description=None)

Creates a routing batch function that can be used to route batches from one upstream step to specific downstream steps.

Parameters:

Name Type Description Default
description Optional[str]

An optional description for the routing batch function.

None

Returns:

Type Description
Callable[[RoutingBatchFunc], RoutingBatchFunction]

A RoutingBatchFunction instance that can be used with the >> operators and with

Callable[[RoutingBatchFunc], RoutingBatchFunction]

the Pipeline.connect method when defining the pipeline.

Example:

from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import LoadHubDataset, CombineColumns


@routing_batch_function
def random_routing_batch(steps: List[str]) -> List[str]:
    return random.sample(steps, 2)


with Pipeline(name="routing-batch-function") as pipeline:
    load_data = LoadHubDataset()

    generations = []
    for llm in (
        OpenAILLM(model="gpt-4-0125-preview"),
        MistralLLM(model="mistral-large-2402"),
        VertexAILLM(model="gemini-1.5-pro"),
    ):
        task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
        generations.append(task)

    combine_columns = CombineColumns(columns=["generation", "model_name"])

    load_data >> random_routing_batch >> generations >> combine_columns
Source code in src/distilabel/pipeline/routing_batch_function.py
def routing_batch_function(
    description: Optional[str] = None,
) -> Callable[[RoutingBatchFunc], RoutingBatchFunction]:
    """Creates a routing batch function that can be used to route batches from one upstream
    step to specific downstream steps.

    Args:
        description: An optional description for the routing batch function.

    Returns:
        A `RoutingBatchFunction` instance that can be used with the `>>` operators and with
        the `Pipeline.connect` method when defining the pipeline.

    Example:

    ```python
    from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
    from distilabel.pipeline import Pipeline, routing_batch_function
    from distilabel.steps import LoadHubDataset, CombineColumns


    @routing_batch_function
    def random_routing_batch(steps: List[str]) -> List[str]:
        return random.sample(steps, 2)


    with Pipeline(name="routing-batch-function") as pipeline:
        load_data = LoadHubDataset()

        generations = []
        for llm in (
            OpenAILLM(model="gpt-4-0125-preview"),
            MistralLLM(model="mistral-large-2402"),
            VertexAILLM(model="gemini-1.5-pro"),
        ):
            task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
            generations.append(task)

        combine_columns = CombineColumns(columns=["generation", "model_name"])

        load_data >> random_routing_batch >> generations >> combine_columns
    ```
    """

    def decorator(func: RoutingBatchFunc) -> RoutingBatchFunction:
        factory_function_name, factory_function_module, factory_function_kwargs = (
            None,
            None,
            None,
        )

        # Check if `routing_batch_function` was created using a factory function from an installed package
        stack = inspect.stack()
        if len(stack) > 2:
            factory_function_frame_info = stack[1]

            # Function factory path
            if factory_function_frame_info.function != "<module>":
                factory_function_name = factory_function_frame_info.function
                factory_function_module = inspect.getmodule(
                    factory_function_frame_info.frame
                ).__name__  # type: ignore

                # Function factory kwargs
                factory_function_kwargs = factory_function_frame_info.frame.f_locals

        routing_batch_function = RoutingBatchFunction(
            routing_function=func,
            description=description,
        )

        if (
            factory_function_module
            and factory_function_name
            and factory_function_kwargs
        ):
            routing_batch_function.set_factory_function(
                factory_function_module=factory_function_module,
                factory_function_name=factory_function_name,
                factory_function_kwargs=factory_function_kwargs,
            )

        return routing_batch_function

    return decorator

sample_n_steps(n)

A simple function that creates a routing batch function that samples n steps from the list of all the downstream steps.

Parameters:

Name Type Description Default
n int

The number of steps to sample from the list of all the downstream steps.

required

Returns:

Type Description
RoutingBatchFunction

A RoutingBatchFunction instance that can be used with the >> operators and with

RoutingBatchFunction

the Pipeline.connect method when defining the pipeline.

Example:

from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadHubDataset, CombineColumns


random_routing_batch = sample_n_steps(2)


with Pipeline(name="routing-batch-function") as pipeline:
    load_data = LoadHubDataset()

    generations = []
    for llm in (
        OpenAILLM(model="gpt-4-0125-preview"),
        MistralLLM(model="mistral-large-2402"),
        VertexAILLM(model="gemini-1.5-pro"),
    ):
        task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
        generations.append(task)

    combine_columns = CombineColumns(columns=["generation", "model_name"])

    load_data >> random_routing_batch >> generations >> combine_columns
Source code in src/distilabel/pipeline/routing_batch_function.py
def sample_n_steps(n: int) -> RoutingBatchFunction:
    """A simple function that creates a routing batch function that samples `n` steps from
    the list of all the downstream steps.

    Args:
        n: The number of steps to sample from the list of all the downstream steps.

    Returns:
        A `RoutingBatchFunction` instance that can be used with the `>>` operators and with
        the `Pipeline.connect` method when defining the pipeline.

    Example:

    ```python
    from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
    from distilabel.pipeline import Pipeline, sample_n_steps
    from distilabel.steps import LoadHubDataset, CombineColumns


    random_routing_batch = sample_n_steps(2)


    with Pipeline(name="routing-batch-function") as pipeline:
        load_data = LoadHubDataset()

        generations = []
        for llm in (
            OpenAILLM(model="gpt-4-0125-preview"),
            MistralLLM(model="mistral-large-2402"),
            VertexAILLM(model="gemini-1.5-pro"),
        ):
            task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
            generations.append(task)

        combine_columns = CombineColumns(columns=["generation", "model_name"])

        load_data >> random_routing_batch >> generations >> combine_columns
    ```
    """

    @routing_batch_function(
        description=f"Sample {n} steps from the list of downstream steps."
    )
    def sample_n(steps: List[str]) -> List[str]:
        return random.sample(steps, n)

    return sample_n