class Task(ABC, _Serializable):
"""Abstract class used to define the methods required to create a `Task`, to be used
within an `LLM`.
Args:
system_prompt (str): the system prompt to be used for generation.
task_description (Union[str, None], optional): the description of the task. Defaults to `None`.
Raises:
ValueError: if the `__jinja2_template__` attribute is not provided.
"""
system_prompt: str
task_description: Union[str, None] = None
__jinja2_template__: Union[str, None] = None
__type__: Union[Literal["generation", "labelling"], None] = None
def __rich_repr__(self) -> Generator[Any, None, None]:
yield "system_prompt", self.system_prompt
yield "task_description", self.task_description
yield "input_args_names", self.input_args_names
yield "output_args_names", self.output_args_names
@property
def template(self) -> "Template":
if self.__jinja2_template__ is None:
raise ValueError(
"You must provide a `__jinja2_template__` attribute to your Task subclass."
)
return Template(open(self.__jinja2_template__).read())
@abstractmethod
def generate_prompt(self, **kwargs: Any) -> Prompt:
pass
@abstractmethod
def parse_output(self, output: str) -> Any:
pass
@property
@abstractmethod
def input_args_names(self) -> List[str]:
pass
@property
@abstractmethod
def output_args_names(self) -> List[str]:
pass
def validate_dataset(self, columns_in_dataset: List[str]) -> None:
"""Validates that the dataset contains the required columns for the task.
Args:
columns_in_dataset (List[str]): the columns in the dataset.
Raises:
KeyError: if the dataset does not contain the required columns.
"""
for input_arg_name in self.input_args_names:
if input_arg_name not in columns_in_dataset:
raise KeyError(
f"LLM expects a column named '{input_arg_name}' in the provided"
" dataset, but it was not found."
)
def to_argilla_dataset(
self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any
) -> "FeedbackDataset":
raise NotImplementedError(
"`to_argilla_dataset` is not implemented, if you want to export your dataset as an Argilla"
" `FeedbackDataset` you will need to implement this method first."
)
def to_argilla_record(
self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any
) -> Union["FeedbackRecord", List["FeedbackRecord"]]:
raise NotImplementedError(
"`to_argilla_record` is not implemented, if you want to export your dataset as an Argilla"
" `FeedbackDataset` you will need to implement this method first."
)
# Renamed to _to_argilla_record instead of renaming `to_argilla_record` to protected, as that would
# imply more breaking changes.
def _to_argilla_record( # noqa: C901
self, dataset_row: Dict[str, Any], *args: Any, **kwargs: Any
) -> Union["FeedbackRecord", List["FeedbackRecord"]]:
column_names = list(dataset_row.keys())
if self.__type__ is None or self.__type__ == "generation":
required_column_names = self.input_args_names + self.output_args_names
elif self.__type__ == "labelling":
required_column_names = self.output_args_names
else:
raise ValueError("The task type is not supported.")
dataset_rows = [dataset_row]
if "generation_model" in dataset_row and isinstance(
dataset_row["generation_model"], list
):
generation_columns = column_names[
column_names.index("generation_model") : column_names.index(
"labelling_model"
)
if "labelling_model" in column_names
else None
]
if any(
isinstance(nested, list)
for column_name in list(
set(generation_columns)
- {
"generation_model",
"generation_prompt",
"raw_generation_response",
}
)
for nested in dataset_row[column_name]
):
if any(
generation_column in required_column_names
for generation_column in generation_columns
):
unwrapped_dataset_rows = []
for row in dataset_rows:
for idx in range(len(dataset_row["generation_model"])):
unwrapped_dataset_row = {}
for key, value in row.items():
if key in generation_columns:
unwrapped_dataset_row[key] = value[idx]
else:
unwrapped_dataset_row[key] = value
unwrapped_dataset_rows.append(unwrapped_dataset_row)
dataset_rows = unwrapped_dataset_rows
if "labelling_model" in dataset_row and isinstance(
dataset_row["labelling_model"], list
):
labelling_columns = column_names[column_names.index("labelling_model") :]
if any(
isinstance(nested, list)
for column_name in list(
set(labelling_columns)
- {
"labelling_model",
"labelling_prompt",
"raw_labelling_response",
}
)
for nested in dataset_row[column_name]
):
if any(
labelling_column in required_column_names
for labelling_column in labelling_columns
):
unwrapped_dataset_rows = []
for row in dataset_rows:
for idx in range(len(dataset_row["labelling_model"])):
unwrapped_dataset_row = {}
for key, value in row.items():
if key in labelling_columns:
unwrapped_dataset_row[key] = value[idx]
else:
unwrapped_dataset_row[key] = value
unwrapped_dataset_rows.append(unwrapped_dataset_row)
dataset_rows = unwrapped_dataset_rows
if len(dataset_rows) == 1:
return self.to_argilla_record(dataset_rows[0], *args, **kwargs)
records = []
for dataset_row in dataset_rows:
generated_records = self.to_argilla_record(dataset_row, *args, **kwargs)
if isinstance(generated_records, list):
records.extend(generated_records)
else:
records.append(generated_records)
return records