Bases: Task
This task is used to rank a list of instructions based on their complexity. It's
an implementation of the complexity score task from the paper 'What Makes Good Data
for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.
Attributes:
Name |
Type |
Description |
_template |
Union[Template, None]
|
The Jinja2 template used to format the input data.
|
Input columns
- instructions (
List[str]
): The list of instructions to be scored.
Output columns
- complexity_score (
List[float]
): The complexity score for each instruction.
References
Source code in src/distilabel/steps/tasks/complexity_scorer.py
| class ComplexityScorer(Task):
"""This task is used to rank a list of instructions based on their complexity. It's
an implementation of the complexity score task from the paper 'What Makes Good Data
for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning'.
Attributes:
_template: The Jinja2 template used to format the input data.
Input columns:
- instructions (`List[str]`): The list of instructions to be scored.
Output columns:
- complexity_score (`List[float]`): The complexity score for each instruction.
References:
- [`What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning`](https://arxiv.org/abs/2312.15685)
"""
_template: Union[Template, None] = PrivateAttr(...)
def load(self) -> None:
super().load()
self._template = Template(_COMPLEXITY_SCORER_TEMPLATE)
@property
def inputs(self) -> List[str]:
return ["instructions"]
@property
def outputs(self) -> List[str]:
return ["scores"]
def format_input(self, input: Dict[str, Any]) -> "ChatType":
return [{"role": "user", "content": self._template.render(**input)}] # type: ignore
def format_output(
self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
if output is None:
return {"scores": [None] * len(input["instructions"])}
scores = []
score_lines = output.split("\n")
for i, line in enumerate(score_lines):
match = _PARSE_SCORE_LINE_REGEX.match(line)
score = float(match.group(1)) if match else None
scores.append(score)
if i == len(input["instructions"]) - 1:
break
return {"scores": scores}
|