Skip to content

ultracm

UltraCMTask dataclass

Bases: CritiqueTask

Source code in src/distilabel/tasks/critique/ultracm.py
@dataclass
class UltraCMTask(CritiqueTask):
    __jinja2_template__: ClassVar[str] = _ULTRACM_TEMPLATE

    system_prompt: str = (
        "User: A one-turn chat between a curious user and an artificial intelligence"
        " assistant. The assistant gives helpful, very detailed, and polite answers to"
        " the user's questions.</s>"
    )

    def generate_prompt(self, input: str, generations: str, **_: Any) -> Prompt:
        render_kwargs = {
            "instruction": input,
            "completion": generations,
        }
        return Prompt(
            system_prompt=self.system_prompt,
            formatted_prompt=f"User: {self.template.render(**render_kwargs)}</s>\nAssistant: ### Feedback\nOverall Score: ",
        )

    def parse_output(self, output: str) -> CritiqueTaskOutput:  # type: ignore
        """Parses the output of the model into the desired format."""
        pattern = r"(\d+(?:\.\d+)?)\s*(.*)"
        match = re.match(pattern, output)
        if match:
            return CritiqueTaskOutput(
                score=float(match.group(1)),
                critique=match.group(2).strip(),
            )

    def to_argilla_dataset(
        self,
        dataset_row: Dict[str, Any],
        generations_column: str = "generations",
        score_column: str = "score",
        critique_column: str = "critique",
        score_values: Optional[List[int]] = None,
    ) -> "FeedbackDataset":
        return super().to_argilla_dataset(
            dataset_row=dataset_row,
            generations_column=generations_column,
            score_column=score_column,
            critique_column=critique_column,
            score_values=score_values or [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        )

parse_output(output)

Parses the output of the model into the desired format.

Source code in src/distilabel/tasks/critique/ultracm.py
def parse_output(self, output: str) -> CritiqueTaskOutput:  # type: ignore
    """Parses the output of the model into the desired format."""
    pattern = r"(\d+(?:\.\d+)?)\s*(.*)"
    match = re.match(pattern, output)
    if match:
        return CritiqueTaskOutput(
            score=float(match.group(1)),
            critique=match.group(2).strip(),
        )