Skip to content

Embeddings

GenerateEmbeddings

Bases: Step

Generate embeddings for a text input using the last hidden state of an LLM, as described in the paper 'What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.

Attributes:

Name Type Description
llm LLM

The LLM to use to generate the embeddings.

Input columns
  • text (str, List[Dict[str, str]]): The input text or conversation to generate embeddings for.
Output columns
  • embedding (List[float]): The embedding of the input text or conversation.
References
Source code in src/distilabel/steps/tasks/generate_embeddings.py
class GenerateEmbeddings(Step):
    """Generate embeddings for a text input using the last hidden state of an `LLM`, as
    described in the paper 'What Makes Good Data for Alignment? A Comprehensive Study of
    Automatic Data Selection in Instruction Tuning'.

    Attributes:
        llm: The `LLM` to use to generate the embeddings.

    Input columns:
        - text (`str`, `List[Dict[str, str]]`): The input text or conversation to generate
            embeddings for.

    Output columns:
        - embedding (`List[float]`): The embedding of the input text or conversation.

    References:
        - [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685)
    """

    llm: LLM

    def load(self) -> None:
        """Loads the `LLM` used to generate the embeddings."""
        super().load()
        self.llm.load()

    @property
    def inputs(self) -> List[str]:
        """The inputs for the task is a `text` column containing either a string or a
        list of dictionaries in OpenAI chat-like format."""
        return ["text"]

    @property
    def outputs(self) -> List[str]:
        """The outputs for the task is an `embedding` column containing the embedding of
        the `text` input."""
        return ["embedding"]

    def format_input(self, input: Dict[str, Any]) -> "ChatType":
        """Formats the input to be used by the LLM to generate the embeddings. The input
        can be in `ChatType` format or a string. If a string, it will be converted to a
        list of dictionaries in OpenAI chat-like format.

        Args:
            input: The input to format.

        Returns:
            The OpenAI chat-like format of the input.
        """
        text = input["text"] = input["text"]

        # input is in `ChatType` format
        if isinstance(text, str):
            return [{"role": "user", "content": text}]

        if is_openai_format(text):
            return text

        raise ValueError(
            f"Couldn't format input for step {self.name}. The `text` input column has to"
            " be a string or a list of dictionaries in OpenAI chat-like format."
        )

    def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
        """Generates an embedding for each input using the last hidden state of the `LLM`.

        Args:
            inputs: A list of Python dictionaries with the inputs of the task.

        Yields:
            A list of Python dictionaries with the outputs of the task.
        """
        formatted_inputs = [self.format_input(input) for input in inputs]
        last_hidden_states = self.llm.get_last_hidden_states(formatted_inputs)
        for input, hidden_state in zip(inputs, last_hidden_states):
            input["embedding"] = hidden_state[-1].tolist()
        yield inputs

inputs: List[str] property

The inputs for the task is a text column containing either a string or a list of dictionaries in OpenAI chat-like format.

outputs: List[str] property

The outputs for the task is an embedding column containing the embedding of the text input.

format_input(input)

Formats the input to be used by the LLM to generate the embeddings. The input can be in ChatType format or a string. If a string, it will be converted to a list of dictionaries in OpenAI chat-like format.

Parameters:

Name Type Description Default
input Dict[str, Any]

The input to format.

required

Returns:

Type Description
ChatType

The OpenAI chat-like format of the input.

Source code in src/distilabel/steps/tasks/generate_embeddings.py
def format_input(self, input: Dict[str, Any]) -> "ChatType":
    """Formats the input to be used by the LLM to generate the embeddings. The input
    can be in `ChatType` format or a string. If a string, it will be converted to a
    list of dictionaries in OpenAI chat-like format.

    Args:
        input: The input to format.

    Returns:
        The OpenAI chat-like format of the input.
    """
    text = input["text"] = input["text"]

    # input is in `ChatType` format
    if isinstance(text, str):
        return [{"role": "user", "content": text}]

    if is_openai_format(text):
        return text

    raise ValueError(
        f"Couldn't format input for step {self.name}. The `text` input column has to"
        " be a string or a list of dictionaries in OpenAI chat-like format."
    )

load()

Loads the LLM used to generate the embeddings.

Source code in src/distilabel/steps/tasks/generate_embeddings.py
def load(self) -> None:
    """Loads the `LLM` used to generate the embeddings."""
    super().load()
    self.llm.load()

process(inputs)

Generates an embedding for each input using the last hidden state of the LLM.

Parameters:

Name Type Description Default
inputs StepInput

A list of Python dictionaries with the inputs of the task.

required

Yields:

Type Description
StepOutput

A list of Python dictionaries with the outputs of the task.

Source code in src/distilabel/steps/tasks/generate_embeddings.py
def process(self, inputs: StepInput) -> "StepOutput":  # type: ignore
    """Generates an embedding for each input using the last hidden state of the `LLM`.

    Args:
        inputs: A list of Python dictionaries with the inputs of the task.

    Yields:
        A list of Python dictionaries with the outputs of the task.
    """
    formatted_inputs = [self.format_input(input) for input in inputs]
    last_hidden_states = self.llm.get_last_hidden_states(formatted_inputs)
    for input, hidden_state in zip(inputs, last_hidden_states):
        input["embedding"] = hidden_state[-1].tolist()
    yield inputs