Skip to content

Complexity scorer

ComplexityScorer

Bases: Task

This task is used to rank a list of instructions based on their complexity. It's an implementation of the complexity score task from the paper 'What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.

Attributes:

Name Type Description
_template Union[Template, None]

The Jinja2 template used to format the input data.

Input columns
  • instructions (List[str]): The list of instructions to be scored.
Output columns
  • complexity_score (List[float]): The complexity score for each instruction.
References
Source code in src/distilabel/steps/tasks/complexity_scorer.py
class ComplexityScorer(Task):
    """This task is used to rank a list of instructions based on their complexity. It's
    an implementation of the complexity score task from the paper 'What Makes Good Data
    for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.

    Attributes:
        _template: The Jinja2 template used to format the input data.

    Input columns:
        - instructions (`List[str]`): The list of instructions to be scored.

    Output columns:
        - complexity_score (`List[float]`): The complexity score for each instruction.

    References:
        - [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
    """

    _template: Union[Template, None] = PrivateAttr(...)

    def load(self) -> None:
        super().load()
        self._template = Template(_COMPLEXITY_SCORER_TEMPLATE)

    @property
    def inputs(self) -> List[str]:
        return ["instructions"]

    @property
    def outputs(self) -> List[str]:
        return ["scores"]

    def format_input(self, input: Dict[str, Any]) -> "ChatType":
        return [{"role": "user", "content": self._template.render(**input)}]  # type: ignore

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        if output is None:
            return {"scores": [None] * len(input["instructions"])}

        scores = []
        score_lines = output.split("\n")
        for i, line in enumerate(score_lines):
            match = _PARSE_SCORE_LINE_REGEX.match(line)
            score = float(match.group(1)) if match else None
            scores.append(score)
            if i == len(input["instructions"]) - 1:
                break

        return {"scores": scores}