Skip to content

prometheus

PrometheusTask dataclass

Bases: CritiqueTask

Source code in src/distilabel/tasks/critique/prometheus.py
@dataclass
class PrometheusTask(CritiqueTask):
    scoring_criteria: str
    score_descriptions: Dict[int, str]

    system_prompt: str = "You are a fair evaluator language model."

    __jinja2_template__: ClassVar[str] = _PROMETHEUS_TEMPLATE

    @property
    def input_args_names(self) -> List[str]:
        return super().input_args_names + ["ref_completion"]

    def generate_prompt(
        self, input: str, generations: str, ref_completion: str, **_: Any
    ) -> Prompt:
        render_kwargs = {
            "instruction": input,
            "completion": generations,
            "ref_completion": ref_completion,
            "scoring_criteria": self.scoring_criteria,
            "score_descriptions": self.score_descriptions,
        }
        return Prompt(
            system_prompt=self.system_prompt,
            formatted_prompt=self.template.render(**render_kwargs),
        )

    def parse_output(self, output: str) -> CritiqueTaskOutput:  # type: ignore
        """Parses the output of the model into the desired format."""
        # We use a regex instead of splitting by the delimiter because the
        # critique may contain the delimiter, and using the regex is safer.
        pattern = r"(.+?)\. \[RESULT\] (\d+)"
        match = re.match(pattern, output)
        if match:
            return CritiqueTaskOutput(
                score=float(match.group(2)),
                critique=match.group(1).strip(),
            )

parse_output(output)

Parses the output of the model into the desired format.

Source code in src/distilabel/tasks/critique/prometheus.py
def parse_output(self, output: str) -> CritiqueTaskOutput:  # type: ignore
    """Parses the output of the model into the desired format."""
    # We use a regex instead of splitting by the delimiter because the
    # critique may contain the delimiter, and using the regex is safer.
    pattern = r"(.+?)\. \[RESULT\] (\d+)"
    match = re.match(pattern, output)
    if match:
        return CritiqueTaskOutput(
            score=float(match.group(2)),
            critique=match.group(1).strip(),
        )