Skip to content

Task Gallery

This section contains the existing Task subclasses implemented in distilabel.

tasks

APIGenExecutionChecker

Bases: Step

Executes the generated function calls.

This step checks if a given answer from a model as generated by APIGenGenerator can be executed against the given library (given by libpath, which is a string pointing to a python .py file with functions).

Attributes:

Name Type Description
libpath str

The path to the library where we will retrieve the functions. It can also point to a folder with the functions. In this case, the folder layout should be a folder with .py files, each containing a single function, the name of the function being the same as the filename.

check_is_dangerous bool

Bool to exclude some potentially dangerous functions, it contains some heuristics found while testing. This functions can run subprocesses, deal with the OS, or have other potentially dangerous operations. Defaults to True.

Input columns
  • answers (str): List with arguments to be passed to the function, dumped as a string from a list of dictionaries. Should be loaded using json.loads.
Output columns
  • keep_row_after_execution_check (bool): Whether the function should be kept or not.
  • execution_result (str): The result from executing the function.
Categories
  • filtering
  • execution
References

Examples:

Execute a function from a given library with the answer from an LLM:

from distilabel.steps.tasks import APIGenExecutionChecker

# For the libpath you can use as an example the file at the tests folder:
# ../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py
task = APIGenExecutionChecker(
    libpath="../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py",
)
task.load()

res = next(
    task.process(
        [
            {
                "answers": [
                    {
                        "arguments": {
                            "initial_velocity": 0.2,
                            "acceleration": 0.1,
                            "time": 0.5,
                        },
                        "name": "final_velocity",
                    }
                ],
            }
        ]
    )
)
res
#[{'answers': [{'arguments': {'initial_velocity': 0.2, 'acceleration': 0.1, 'time': 0.5}, 'name': 'final_velocity'}], 'keep_row_after_execution_check': True, 'execution_result': ['0.25']}]
Source code in src/distilabel/steps/tasks/apigen/execution_checker.py
class APIGenExecutionChecker(Step):
    """Executes the generated function calls.

    This step checks if a given answer from a model as generated by `APIGenGenerator`
    can be executed against the given library (given by `libpath`, which is a string
    pointing to a python .py file with functions).

    Attributes:
        libpath: The path to the library where we will retrieve the functions.
            It can also point to a folder with the functions. In this case, the folder
            layout should be a folder with .py files, each containing a single function,
            the name of the function being the same as the filename.
        check_is_dangerous: Bool to exclude some potentially dangerous functions, it contains
            some heuristics found while testing. This functions can run subprocesses, deal with
            the OS, or have other potentially dangerous operations. Defaults to True.

    Input columns:
        - answers (`str`): List with arguments to be passed to the function,
            dumped as a string from a list of dictionaries. Should be loaded using
            `json.loads`.

    Output columns:
        - keep_row_after_execution_check (`bool`): Whether the function should be kept or not.
        - execution_result (`str`): The result from executing the function.

    Categories:
        - filtering
        - execution

    References:
        - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
        - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)

    Examples:
        Execute a function from a given library with the answer from an LLM:

        ```python
        from distilabel.steps.tasks import APIGenExecutionChecker

        # For the libpath you can use as an example the file at the tests folder:
        # ../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py
        task = APIGenExecutionChecker(
            libpath="../distilabel/tests/unit/steps/tasks/apigen/_sample_module.py",
        )
        task.load()

        res = next(
            task.process(
                [
                    {
                        "answers": [
                            {
                                "arguments": {
                                    "initial_velocity": 0.2,
                                    "acceleration": 0.1,
                                    "time": 0.5,
                                },
                                "name": "final_velocity",
                            }
                        ],
                    }
                ]
            )
        )
        res
        #[{'answers': [{'arguments': {'initial_velocity': 0.2, 'acceleration': 0.1, 'time': 0.5}, 'name': 'final_velocity'}], 'keep_row_after_execution_check': True, 'execution_result': ['0.25']}]
        ```
    """

    libpath: str = Field(
        default=...,
        description=(
            "The path to the library where we will retrieve the functions, "
            "or a folder with python files named the same as the functions they contain.",
        ),
    )
    check_is_dangerous: bool = Field(
        default=True,
        description=(
            "Bool to exclude some potentially dangerous functions, it contains "
            "some heuristics found while testing. This functions can run subprocesses, "
            "deal with the OS, or have other potentially dangerous operations.",
        ),
    )

    _toolbox: Union["ModuleType", None] = PrivateAttr(None)

    def load(self) -> None:
        """Loads the library where the functions will be extracted from."""
        super().load()
        if Path(self.libpath).suffix == ".py":
            self._toolbox = load_module_from_path(self.libpath)

    def unload(self) -> None:
        self._toolbox = None

    @property
    def inputs(self) -> "StepColumns":
        """The inputs for the task are those found in the original dataset."""
        return ["answers"]

    @property
    def outputs(self) -> "StepColumns":
        """The outputs are the columns required by `APIGenGenerator` task."""
        return ["keep_row_after_execution_check", "execution_result"]

    def _get_function(self, function_name: str) -> Callable:
        """Retrieves the function from the toolbox.

        Args:
            function_name: The name of the function to retrieve.

        Returns:
            Callable: The function to be executed.
        """
        if self._toolbox:
            return getattr(self._toolbox, function_name, None)
        try:
            toolbox = load_module_from_path(
                str(Path(self.libpath) / f"{function_name}.py")
            )
            return getattr(toolbox, function_name, None)
        except FileNotFoundError:
            return None
        except Exception as e:
            self._logger.warning(f"Error loading function '{function_name}': {e}")
            return None

    def _is_dangerous(self, function: Callable) -> bool:
        """Checks if a function is dangerous to remove it.
        Contains a list of heuristics to avoid executing possibly dangerous functions.
        """
        source_code = inspect.getsource(function)
        # We don't want to execute functions that use subprocess
        if (
            ("subprocess." in source_code)
            or ("os.system(" in source_code)
            or ("input(" in source_code)
            # Avoiding threading
            or ("threading.Thread(" in source_code)
            or ("exec(" in source_code)
            # Avoiding argparse (not sure why)
            or ("argparse.ArgumentParser(" in source_code)
            # Avoiding logging changing the levels to not mess with the logs
            or (".setLevel(" in source_code)
            # Don't run a test battery
            or ("unittest.main(" in source_code)
            # Avoid exiting the program
            or ("sys.exit(" in source_code)
            or ("exit(" in source_code)
            or ("raise SystemExit(" in source_code)
            or ("multiprocessing.Pool(" in source_code)
        ):
            return True
        return False

    @override
    def process(self, inputs: StepInput) -> "StepOutput":
        """Checks the answer to see if it can be executed.
        Captures the possible errors and returns them.

        If a single example is provided, it is copied to avoid raising an error.

        Args:
            inputs: A list of dictionaries with the input data.

        Yields:
            A list of dictionaries with the output data.
        """
        for input in inputs:
            output = []
            if input["answers"]:
                answers = json.loads(input["answers"])
            else:
                input.update(
                    **{
                        "keep_row_after_execution_check": False,
                        "execution_result": ["No answers were provided."],
                    }
                )
                continue
            for answer in answers:
                if answer is None:
                    output.append(
                        {
                            "keep": False,
                            "execution_result": "Nothing was generated for this answer.",
                        }
                    )
                    continue

                function_name = answer.get("name", None)
                arguments = answer.get("arguments", None)

                self._logger.debug(
                    f"Executing function '{function_name}' with arguments: {arguments}"
                )
                function = self._get_function(function_name)

                if self.check_is_dangerous:
                    if function and self._is_dangerous(function):
                        function = None

                if function is None:
                    output.append(
                        {
                            "keep": False,
                            "execution_result": f"Function '{function_name}' not found.",
                        }
                    )
                else:
                    execution = execute_from_response(function, arguments)
                    output.append(
                        {
                            "keep": execution["keep"],
                            "execution_result": execution["execution_result"],
                        }
                    )
            # We only consider a good response if all the answers were executed successfully,
            # but keep the reasons for further review if needed.
            input.update(
                **{
                    "keep_row_after_execution_check": all(
                        o["keep"] is True for o in output
                    ),
                    "execution_result": [o["execution_result"] for o in output],
                }
            )

        yield inputs
inputs: StepColumns property

The inputs for the task are those found in the original dataset.

outputs: StepColumns property

The outputs are the columns required by APIGenGenerator task.

load()

Loads the library where the functions will be extracted from.

Source code in src/distilabel/steps/tasks/apigen/execution_checker.py
def load(self) -> None:
    """Loads the library where the functions will be extracted from."""
    super().load()
    if Path(self.libpath).suffix == ".py":
        self._toolbox = load_module_from_path(self.libpath)
_get_function(function_name)

Retrieves the function from the toolbox.

Parameters:

Name Type Description Default
function_name str

The name of the function to retrieve.

required

Returns:

Name Type Description
Callable Callable

The function to be executed.

Source code in src/distilabel/steps/tasks/apigen/execution_checker.py
def _get_function(self, function_name: str) -> Callable:
    """Retrieves the function from the toolbox.

    Args:
        function_name: The name of the function to retrieve.

    Returns:
        Callable: The function to be executed.
    """
    if self._toolbox:
        return getattr(self._toolbox, function_name, None)
    try:
        toolbox = load_module_from_path(
            str(Path(self.libpath) / f"{function_name}.py")
        )
        return getattr(toolbox, function_name, None)
    except FileNotFoundError:
        return None
    except Exception as e:
        self._logger.warning(f"Error loading function '{function_name}': {e}")
        return None
_is_dangerous(function)

Checks if a function is dangerous to remove it. Contains a list of heuristics to avoid executing possibly dangerous functions.

Source code in src/distilabel/steps/tasks/apigen/execution_checker.py
def _is_dangerous(self, function: Callable) -> bool:
    """Checks if a function is dangerous to remove it.
    Contains a list of heuristics to avoid executing possibly dangerous functions.
    """
    source_code = inspect.getsource(function)
    # We don't want to execute functions that use subprocess
    if (
        ("subprocess." in source_code)
        or ("os.system(" in source_code)
        or ("input(" in source_code)
        # Avoiding threading
        or ("threading.Thread(" in source_code)
        or ("exec(" in source_code)
        # Avoiding argparse (not sure why)
        or ("argparse.ArgumentParser(" in source_code)
        # Avoiding logging changing the levels to not mess with the logs
        or (".setLevel(" in source_code)
        # Don't run a test battery
        or ("unittest.main(" in source_code)
        # Avoid exiting the program
        or ("sys.exit(" in source_code)
        or ("exit(" in source_code)
        or ("raise SystemExit(" in source_code)
        or ("multiprocessing.Pool(" in source_code)
    ):
        return True
    return False
process(inputs)

Checks the answer to see if it can be executed. Captures the possible errors and returns them.

If a single example is provided, it is copied to avoid raising an error.

Parameters:

Name Type Description Default
inputs StepInput

A list of dictionaries with the input data.

required

Yields:

Type Description
StepOutput

A list of dictionaries with the output data.

Source code in src/distilabel/steps/tasks/apigen/execution_checker.py
@override
def process(self, inputs: StepInput) -> "StepOutput":
    """Checks the answer to see if it can be executed.
    Captures the possible errors and returns them.

    If a single example is provided, it is copied to avoid raising an error.

    Args:
        inputs: A list of dictionaries with the input data.

    Yields:
        A list of dictionaries with the output data.
    """
    for input in inputs:
        output = []
        if input["answers"]:
            answers = json.loads(input["answers"])
        else:
            input.update(
                **{
                    "keep_row_after_execution_check": False,
                    "execution_result": ["No answers were provided."],
                }
            )
            continue
        for answer in answers:
            if answer is None:
                output.append(
                    {
                        "keep": False,
                        "execution_result": "Nothing was generated for this answer.",
                    }
                )
                continue

            function_name = answer.get("name", None)
            arguments = answer.get("arguments", None)

            self._logger.debug(
                f"Executing function '{function_name}' with arguments: {arguments}"
            )
            function = self._get_function(function_name)

            if self.check_is_dangerous:
                if function and self._is_dangerous(function):
                    function = None

            if function is None:
                output.append(
                    {
                        "keep": False,
                        "execution_result": f"Function '{function_name}' not found.",
                    }
                )
            else:
                execution = execute_from_response(function, arguments)
                output.append(
                    {
                        "keep": execution["keep"],
                        "execution_result": execution["execution_result"],
                    }
                )
        # We only consider a good response if all the answers were executed successfully,
        # but keep the reasons for further review if needed.
        input.update(
            **{
                "keep_row_after_execution_check": all(
                    o["keep"] is True for o in output
                ),
                "execution_result": [o["execution_result"] for o in output],
            }
        )

    yield inputs

APIGenGenerator

Bases: Task

Generate queries and answers for the given functions in JSON format.

The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
verifiable and diverse function-calling datasets. The task generates a set of diverse queries
and corresponding answers for the given functions in JSON format.

Attributes:
    system_prompt: The system prompt to guide the user in the generation of queries and answers.
    use_tools: Whether to use the tools available in the prompt to generate the queries and answers.
        In case the tools are given in the input, they will be added to the prompt.
    number: The number of queries to generate. It can be a list, where each number will be
        chosen randomly, or a dictionary with the number of queries and the probability of each.
        I.e: `number=1`, `number=[1, 2, 3]`, `number={1: 0.5, 2: 0.3, 3: 0.2}` are all valid inputs.
        It corresponds to the number of parallel queries to generate.
    use_default_structured_output: Whether to use the default structured output or not.

Input columns:
    - examples (`str`): Examples used as few shots to guide the model.
    - func_name (`str`): Name for the function to generate.
    - func_desc (`str`): Description of what the function should do.
    - tools (`str`): JSON formatted string containing the tool representation of the function.

Output columns:
    - query (`str`): The list of queries.
    - answers (`str`): JSON formatted string with the list of answers, containing the info as
        a dictionary to be passed to the functions.

Categories:
    - text-generation

References:
    - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
    - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)

Examples:
    Generate without structured output (original implementation):

    ```python
    from distilabel.steps.tasks import ApiGenGenerator
    from distilabel.llms import InferenceEndpointsLLM

    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 1024,
        },
    )
    apigen = ApiGenGenerator(
        use_default_structured_output=False,
        llm=llm
    )
    apigen.load()

    res = next(
        apigen.process(
            [
                {
                    "examples": 'QUERY:

What is the binary sum of 10010 and 11101? ANSWER: [{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', "func_name": "getrandommovie", "func_desc": "Returns a list of random movies from a database by calling an external API." } ] ) ) res # [{'examples': 'QUERY: What is the binary sum of 10010 and 11101? ANSWER: [{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', # 'number': 1, # 'func_name': 'getrandommovie', # 'func_desc': 'Returns a list of random movies from a database by calling an external API.', # 'queries': ['I want to watch a movie tonight, can you recommend a random one from your database?', # 'Give me 5 random movie suggestions from your database to plan my weekend.'], # 'answers': [[{'name': 'getrandommovie', 'arguments': {}}], # [{'name': 'getrandommovie', 'arguments': {}}, # {'name': 'getrandommovie', 'arguments': {}}, # {'name': 'getrandommovie', 'arguments': {}}, # {'name': 'getrandommovie', 'arguments': {}}, # {'name': 'getrandommovie', 'arguments': {}}]], # 'raw_input_api_gen_generator_0': [{'role': 'system', # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.

Construct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.

Ensure the query: - Is clear and concise - Demonstrates typical use cases - Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words - Across a variety level of difficulties, ranging from beginner and advanced use cases - The corresponding result's parameter types and ranges match with the function's descriptions

Ensure the answer: - Is a list of function calls in JSON format - The length of the answer list should be equal to the number of requests in the query - Can solve all the requests in the query effectively"}, # {'role': 'user', # 'content': 'Here are examples of queries and the corresponding answers for similar functions: QUERY: What is the binary sum of 10010 and 11101? ANSWER: [{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]

Note that the query could be interpreted as a combination of several independent requests. Based on these examples, generate 2 diverse query and answer pairs for the function getrandommovie The detailed function description is the following: Returns a list of random movies from a database by calling an external API.

The output MUST strictly adhere to the following JSON format, and NO other text MUST be included:

[
   {
       "query": "The generated query.",
       "answers": [
           {
               "name": "api_name",
               "arguments": {
                   "arg_name": "value"
                   ... (more arguments as required)
               }
           },
           ... (more API calls as required)
       ]
   }
]

Now please generate 2 diverse query and answer pairs following the above format.'}]}, # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] ```

    Generate with structured output:

    ```python
    from distilabel.steps.tasks import ApiGenGenerator
    from distilabel.llms import InferenceEndpointsLLM

    llm=InferenceEndpointsLLM(
        model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 1024,
        },
    )
    apigen = ApiGenGenerator(
        use_default_structured_output=True,
        llm=llm
    )
    apigen.load()

    res_struct = next(
        apigen.process(
            [
                {
                    "examples": 'QUERY:

What is the binary sum of 10010 and 11101? ANSWER: [{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', "func_name": "getrandommovie", "func_desc": "Returns a list of random movies from a database by calling an external API." } ] ) ) res_struct # [{'examples': 'QUERY: What is the binary sum of 10010 and 11101? ANSWER: [{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]', # 'number': 1, # 'func_name': 'getrandommovie', # 'func_desc': 'Returns a list of random movies from a database by calling an external API.', # 'queries': ["I'm bored and want to watch a movie. Can you suggest some movies?", # "My family and I are planning a movie night. We can't decide on what to watch. Can you suggest some random movie titles?"], # 'answers': [[{'arguments': {}, 'name': 'getrandommovie'}], # [{'arguments': {}, 'name': 'getrandommovie'}]], # 'raw_input_api_gen_generator_0': [{'role': 'system', # 'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.

Construct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.

Ensure the query: - Is clear and concise - Demonstrates typical use cases - Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words - Across a variety level of difficulties, ranging from beginner and advanced use cases - The corresponding result's parameter types and ranges match with the function's descriptions

Ensure the answer: - Is a list of function calls in JSON format - The length of the answer list should be equal to the number of requests in the query - Can solve all the requests in the query effectively"}, # {'role': 'user', # 'content': 'Here are examples of queries and the corresponding answers for similar functions: QUERY: What is the binary sum of 10010 and 11101? ANSWER: [{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]

Note that the query could be interpreted as a combination of several independent requests. Based on these examples, generate 2 diverse query and answer pairs for the function getrandommovie The detailed function description is the following: Returns a list of random movies from a database by calling an external API.

Now please generate 2 diverse query and answer pairs following the above format.'}]}, # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] ```

Source code in src/distilabel/steps/tasks/apigen/generator.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
class APIGenGenerator(Task):
    """Generate queries and answers for the given functions in JSON format.

    The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
    verifiable and diverse function-calling datasets. The task generates a set of diverse queries
    and corresponding answers for the given functions in JSON format.

    Attributes:
        system_prompt: The system prompt to guide the user in the generation of queries and answers.
        use_tools: Whether to use the tools available in the prompt to generate the queries and answers.
            In case the tools are given in the input, they will be added to the prompt.
        number: The number of queries to generate. It can be a list, where each number will be
            chosen randomly, or a dictionary with the number of queries and the probability of each.
            I.e: `number=1`, `number=[1, 2, 3]`, `number={1: 0.5, 2: 0.3, 3: 0.2}` are all valid inputs.
            It corresponds to the number of parallel queries to generate.
        use_default_structured_output: Whether to use the default structured output or not.

    Input columns:
        - examples (`str`): Examples used as few shots to guide the model.
        - func_name (`str`): Name for the function to generate.
        - func_desc (`str`): Description of what the function should do.
        - tools (`str`): JSON formatted string containing the tool representation of the function.

    Output columns:
        - query (`str`): The list of queries.
        - answers (`str`): JSON formatted string with the list of answers, containing the info as
            a dictionary to be passed to the functions.

    Categories:
        - text-generation

    References:
        - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
        - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)

    Examples:
        Generate without structured output (original implementation):

        ```python
        from distilabel.steps.tasks import ApiGenGenerator
        from distilabel.llms import InferenceEndpointsLLM

        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            generation_kwargs={
                "temperature": 0.7,
                "max_new_tokens": 1024,
            },
        )
        apigen = ApiGenGenerator(
            use_default_structured_output=False,
            llm=llm
        )
        apigen.load()

        res = next(
            apigen.process(
                [
                    {
                        "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
                        "func_name": "getrandommovie",
                        "func_desc": "Returns a list of random movies from a database by calling an external API."
                    }
                ]
            )
        )
        res
        # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
        # 'number': 1,
        # 'func_name': 'getrandommovie',
        # 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
        # 'queries': ['I want to watch a movie tonight, can you recommend a random one from your database?',
        # 'Give me 5 random movie suggestions from your database to plan my weekend.'],
        # 'answers': [[{'name': 'getrandommovie', 'arguments': {}}],
        # [{'name': 'getrandommovie', 'arguments': {}},
        #     {'name': 'getrandommovie', 'arguments': {}},
        #     {'name': 'getrandommovie', 'arguments': {}},
        #     {'name': 'getrandommovie', 'arguments': {}},
        #     {'name': 'getrandommovie', 'arguments': {}}]],
        # 'raw_input_api_gen_generator_0': [{'role': 'system',
        #     'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
        #     {'role': 'user',
        #     'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n```json\n[\n   {\n       "query": "The generated query.",\n       "answers": [\n           {\n               "name": "api_name",\n               "arguments": {\n                   "arg_name": "value"\n                   ... (more arguments as required)\n               }\n           },\n           ... (more API calls as required)\n       ]\n   }\n]\n```\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
        # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
        ```

        Generate with structured output:

        ```python
        from distilabel.steps.tasks import ApiGenGenerator
        from distilabel.llms import InferenceEndpointsLLM

        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
            generation_kwargs={
                "temperature": 0.7,
                "max_new_tokens": 1024,
            },
        )
        apigen = ApiGenGenerator(
            use_default_structured_output=True,
            llm=llm
        )
        apigen.load()

        res_struct = next(
            apigen.process(
                [
                    {
                        "examples": 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
                        "func_name": "getrandommovie",
                        "func_desc": "Returns a list of random movies from a database by calling an external API."
                    }
                ]
            )
        )
        res_struct
        # [{'examples': 'QUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]',
        # 'number': 1,
        # 'func_name': 'getrandommovie',
        # 'func_desc': 'Returns a list of random movies from a database by calling an external API.',
        # 'queries': ["I'm bored and want to watch a movie. Can you suggest some movies?",
        # "My family and I are planning a movie night. We can't decide on what to watch. Can you suggest some random movie titles?"],
        # 'answers': [[{'arguments': {}, 'name': 'getrandommovie'}],
        # [{'arguments': {}, 'name': 'getrandommovie'}]],
        # 'raw_input_api_gen_generator_0': [{'role': 'system',
        #     'content': "You are a data labeler. Your responsibility is to generate a set of diverse queries and corresponding answers for the given functions in JSON format.\n\nConstruct queries and answers that exemplify how to use these functions in a practical scenario. Include in each query specific, plausible values for each parameter. For instance, if the function requires a date, use a typical and reasonable date.\n\nEnsure the query:\n- Is clear and concise\n- Demonstrates typical use cases\n- Includes all necessary parameters in a meaningful way. For numerical parameters, it could be either numbers or words\n- Across a variety level of difficulties, ranging from beginner and advanced use cases\n- The corresponding result's parameter types and ranges match with the function's descriptions\n\nEnsure the answer:\n- Is a list of function calls in JSON format\n- The length of the answer list should be equal to the number of requests in the query\n- Can solve all the requests in the query effectively"},
        #     {'role': 'user',
        #     'content': 'Here are examples of queries and the corresponding answers for similar functions:\nQUERY:\nWhat is the binary sum of 10010 and 11101?\nANSWER:\n[{"name": "binary_addition", "arguments": {"a": "10010", "b": "11101"}}]\n\nNote that the query could be interpreted as a combination of several independent requests.\nBased on these examples, generate 2 diverse query and answer pairs for the function `getrandommovie`\nThe detailed function description is the following:\nReturns a list of random movies from a database by calling an external API.\n\nNow please generate 2 diverse query and answer pairs following the above format.'}]},
        # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
        ```
    """

    system_prompt: str = SYSTEM_PROMPT_API_GEN
    use_default_structured_output: bool = False
    number: Union[int, List[int], Dict[int, float]] = 1
    use_tools: bool = True

    _number: Union[int, None] = PrivateAttr(None)
    _fn_parallel_queries: Union[Callable[[], str], None] = PrivateAttr(None)
    _format_inst: Union[str, None] = PrivateAttr(None)

    def load(self) -> None:
        """Loads the template for the generator prompt."""
        super().load()
        _path = str(
            importlib_resources.files("distilabel")
            / "steps"
            / "tasks"
            / "templates"
            / "apigen"
            / "generator.jinja2"
        )
        self._template = Template(open(_path).read())
        self._format_inst = self._set_format_inst()

    def _parallel_queries(self, number: int) -> Callable[[int], str]:
        """Prepares the function to update the parallel queries guide in the prompt.

        Raises:
            ValueError: if `is_parallel` is not a boolean or a list of floats.

        Returns:
            The function to generate the parallel queries guide.
        """
        if number > 1:
            return (
                "It can contain multiple parallel queries in natural language for the given functions. "
                "They could use either the same function with different arguments or different functions.\n"
            )
        return ""

    def _get_number(self) -> int:
        """Generates the number of queries to generate in a single call.
        The number must be set to `_number` to avoid changing the original value
        when calling `_default_error`.
        """
        if isinstance(self.number, list):
            self._number = random.choice(self.number)
        elif isinstance(self.number, dict):
            self._number = random.choices(
                list(self.number.keys()), list(self.number.values())
            )[0]
        else:
            self._number = self.number
        return self._number

    def _set_format_inst(self) -> str:
        """Prepares the function to generate the formatted instructions for the prompt.

        If the default structured output is used, returns an empty string because nothing
        else is needed, otherwise, returns the original addition to the prompt to guide the model
        to generate a formatted JSON.
        """
        return (
            "\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n"
            "```\n"
            "[\n"
            "   {\n"
            '       "query": "The generated query.",\n'
            '       "answers": [\n'
            "           {\n"
            '               "name": "api_name",\n'
            '               "arguments": {\n'
            '                   "arg_name": "value"\n'
            "                   ... (more arguments as required)\n"
            "               }\n"
            "           },\n"
            "           ... (more API calls as required)\n"
            "       ]\n"
            "   }\n"
            "]\n"
            "```\n"
        )

    def _get_func_desc(self, input: Dict[str, Any]) -> str:
        """If available and required, will use the info from the tools in the
        prompt for extra information. Otherwise will use jut the function description.
        """
        if not self.use_tools:
            return input["func_desc"]
        extra = ""  # Extra information from the tools (if available will be added)
        if "tools" in input:
            extra = f"\n\nThis is the available tool to guide you (respect the order of the parameters):\n{input['tools']}"
        return input["func_desc"] + extra

    @property
    def inputs(self) -> "StepColumns":
        """The inputs for the task."""
        return {
            "examples": True,
            "func_name": True,
            "func_desc": True,
            "tools": False,
        }

    def format_input(self, input: Dict[str, Any]) -> "ChatType":
        """The input is formatted as a `ChatType`."""
        number = self._get_number()
        parallel_queries = self._parallel_queries(number)
        return [
            {"role": "system", "content": self.system_prompt},
            {
                "role": "user",
                "content": self._template.render(
                    examples=input["examples"],
                    parallel_queries=parallel_queries,
                    number=number,
                    func_name=input["func_name"],
                    func_desc=self._get_func_desc(input),
                    format_inst=self._format_inst,
                ),
            },
        ]

    @property
    def outputs(self) -> "StepColumns":
        """The output for the task are the queries and corresponding answers."""
        return ["query", "answers", "model_name"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """The output is formatted as a list with the score of each instruction.

        Args:
            output: the raw output of the LLM.
            input: the input to the task. Used for obtaining the number of responses.

        Returns:
            A dict with the queries and answers pairs.
            The answers are an array of answers corresponding to the query.
            Each answer is represented as an object with the following properties:
                - name (string): The name of the tool used to generate the answer.
                - arguments (object): An object representing the arguments passed to the tool to generate the answer.
            Each argument is represented as a key-value pair, where the key is the parameter name and the
            value is the corresponding value.
        """
        if output is None:
            return self._default_error(input)

        if not self.use_default_structured_output:
            output = remove_fences(output)

        try:
            pairs = orjson.loads(output)
        except orjson.JSONDecodeError:
            return self._default_error(input)

        pairs = pairs["pairs"] if self.use_default_structured_output else pairs

        return self._format_output(pairs, input)

    def _format_output(
        self, pairs: Dict[str, Any], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Parses the response, returning a dictionary with queries and answers.

        Args:
            pairs: The parsed dictionary from the LLM's output.
            input: The input from the `LLM`.

        Returns:
            Formatted output, where the `queries` are a list of strings, and the `answers`
            are a list of objects.
        """
        try:
            input.update(
                **{
                    "query": pairs[0]["query"],
                    "answers": json.dumps(pairs[0]["answers"]),
                }
            )
            return input
        except Exception as e:
            self._logger.error(f"Error formatting output: {e}, pairs: '{pairs}'")
            return self._default_error(input)

    def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
        """Returns a default error output, to fill the responses in case of failure."""
        input.update(
            **{
                "query": None,
                "answers": json.dumps([None] * self._number),
            }
        )
        return input

    @override
    def get_structured_output(self) -> Dict[str, Any]:
        """Creates the json schema to be passed to the LLM, to enforce generating
        a dictionary with the output which can be directly parsed as a python dictionary.

        The schema corresponds to the following:

        ```python
        from typing import Dict, List
        from pydantic import BaseModel


        class Answer(BaseModel):
            name: str
            arguments: Dict[str, str]

        class QueryAnswer(BaseModel):
            query: str
            answers: List[Answer]

        class QueryAnswerPairs(BaseModel):
            pairs: List[QueryAnswer]

        json.dumps(QueryAnswerPairs.model_json_schema(), indent=4)
        ```

        Returns:
            JSON Schema of the response to enforce.
        """
        return {
            "$defs": {
                "Answer": {
                    "properties": {
                        "name": {"title": "Name", "type": "string"},
                        "arguments": {
                            "additionalProperties": {"type": "string"},
                            "title": "Arguments",
                            "type": "object",
                        },
                    },
                    "required": ["name", "arguments"],
                    "title": "Answer",
                    "type": "object",
                },
                "QueryAnswer": {
                    "properties": {
                        "query": {"title": "Query", "type": "string"},
                        "answers": {
                            "items": {"$ref": "#/$defs/Answer"},
                            "title": "Answers",
                            "type": "array",
                        },
                    },
                    "required": ["query", "answers"],
                    "title": "QueryAnswer",
                    "type": "object",
                },
            },
            "properties": {
                "pairs": {
                    "items": {"$ref": "#/$defs/QueryAnswer"},
                    "title": "Pairs",
                    "type": "array",
                }
            },
            "required": ["pairs"],
            "title": "QueryAnswerPairs",
            "type": "object",
        }
inputs: StepColumns property

The inputs for the task.

outputs: StepColumns property

The output for the task are the queries and corresponding answers.

load()

Loads the template for the generator prompt.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def load(self) -> None:
    """Loads the template for the generator prompt."""
    super().load()
    _path = str(
        importlib_resources.files("distilabel")
        / "steps"
        / "tasks"
        / "templates"
        / "apigen"
        / "generator.jinja2"
    )
    self._template = Template(open(_path).read())
    self._format_inst = self._set_format_inst()
_parallel_queries(number)

Prepares the function to update the parallel queries guide in the prompt.

Raises:

Type Description
ValueError

if is_parallel is not a boolean or a list of floats.

Returns:

Type Description
Callable[[int], str]

The function to generate the parallel queries guide.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def _parallel_queries(self, number: int) -> Callable[[int], str]:
    """Prepares the function to update the parallel queries guide in the prompt.

    Raises:
        ValueError: if `is_parallel` is not a boolean or a list of floats.

    Returns:
        The function to generate the parallel queries guide.
    """
    if number > 1:
        return (
            "It can contain multiple parallel queries in natural language for the given functions. "
            "They could use either the same function with different arguments or different functions.\n"
        )
    return ""
_get_number()

Generates the number of queries to generate in a single call. The number must be set to _number to avoid changing the original value when calling _default_error.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def _get_number(self) -> int:
    """Generates the number of queries to generate in a single call.
    The number must be set to `_number` to avoid changing the original value
    when calling `_default_error`.
    """
    if isinstance(self.number, list):
        self._number = random.choice(self.number)
    elif isinstance(self.number, dict):
        self._number = random.choices(
            list(self.number.keys()), list(self.number.values())
        )[0]
    else:
        self._number = self.number
    return self._number
_set_format_inst()

Prepares the function to generate the formatted instructions for the prompt.

If the default structured output is used, returns an empty string because nothing else is needed, otherwise, returns the original addition to the prompt to guide the model to generate a formatted JSON.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def _set_format_inst(self) -> str:
    """Prepares the function to generate the formatted instructions for the prompt.

    If the default structured output is used, returns an empty string because nothing
    else is needed, otherwise, returns the original addition to the prompt to guide the model
    to generate a formatted JSON.
    """
    return (
        "\nThe output MUST strictly adhere to the following JSON format, and NO other text MUST be included:\n"
        "```\n"
        "[\n"
        "   {\n"
        '       "query": "The generated query.",\n'
        '       "answers": [\n'
        "           {\n"
        '               "name": "api_name",\n'
        '               "arguments": {\n'
        '                   "arg_name": "value"\n'
        "                   ... (more arguments as required)\n"
        "               }\n"
        "           },\n"
        "           ... (more API calls as required)\n"
        "       ]\n"
        "   }\n"
        "]\n"
        "```\n"
    )
_get_func_desc(input)

If available and required, will use the info from the tools in the prompt for extra information. Otherwise will use jut the function description.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def _get_func_desc(self, input: Dict[str, Any]) -> str:
    """If available and required, will use the info from the tools in the
    prompt for extra information. Otherwise will use jut the function description.
    """
    if not self.use_tools:
        return input["func_desc"]
    extra = ""  # Extra information from the tools (if available will be added)
    if "tools" in input:
        extra = f"\n\nThis is the available tool to guide you (respect the order of the parameters):\n{input['tools']}"
    return input["func_desc"] + extra
format_input(input)

The input is formatted as a ChatType.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def format_input(self, input: Dict[str, Any]) -> "ChatType":
    """The input is formatted as a `ChatType`."""
    number = self._get_number()
    parallel_queries = self._parallel_queries(number)
    return [
        {"role": "system", "content": self.system_prompt},
        {
            "role": "user",
            "content": self._template.render(
                examples=input["examples"],
                parallel_queries=parallel_queries,
                number=number,
                func_name=input["func_name"],
                func_desc=self._get_func_desc(input),
                format_inst=self._format_inst,
            ),
        },
    ]
format_output(output, input)

The output is formatted as a list with the score of each instruction.

Parameters:

Name Type Description Default
output Union[str, None]

the raw output of the LLM.

required
input Dict[str, Any]

the input to the task. Used for obtaining the number of responses.

required

Returns:

Type Description
Dict[str, Any]

A dict with the queries and answers pairs.

Dict[str, Any]

The answers are an array of answers corresponding to the query.

Dict[str, Any]

Each answer is represented as an object with the following properties: - name (string): The name of the tool used to generate the answer. - arguments (object): An object representing the arguments passed to the tool to generate the answer.

Dict[str, Any]

Each argument is represented as a key-value pair, where the key is the parameter name and the

Dict[str, Any]

value is the corresponding value.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def format_output(
    self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
    """The output is formatted as a list with the score of each instruction.

    Args:
        output: the raw output of the LLM.
        input: the input to the task. Used for obtaining the number of responses.

    Returns:
        A dict with the queries and answers pairs.
        The answers are an array of answers corresponding to the query.
        Each answer is represented as an object with the following properties:
            - name (string): The name of the tool used to generate the answer.
            - arguments (object): An object representing the arguments passed to the tool to generate the answer.
        Each argument is represented as a key-value pair, where the key is the parameter name and the
        value is the corresponding value.
    """
    if output is None:
        return self._default_error(input)

    if not self.use_default_structured_output:
        output = remove_fences(output)

    try:
        pairs = orjson.loads(output)
    except orjson.JSONDecodeError:
        return self._default_error(input)

    pairs = pairs["pairs"] if self.use_default_structured_output else pairs

    return self._format_output(pairs, input)
_format_output(pairs, input)

Parses the response, returning a dictionary with queries and answers.

Parameters:

Name Type Description Default
pairs Dict[str, Any]

The parsed dictionary from the LLM's output.

required
input Dict[str, Any]

The input from the LLM.

required

Returns:

Type Description
Dict[str, Any]

Formatted output, where the queries are a list of strings, and the answers

Dict[str, Any]

are a list of objects.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def _format_output(
    self, pairs: Dict[str, Any], input: Dict[str, Any]
) -> Dict[str, Any]:
    """Parses the response, returning a dictionary with queries and answers.

    Args:
        pairs: The parsed dictionary from the LLM's output.
        input: The input from the `LLM`.

    Returns:
        Formatted output, where the `queries` are a list of strings, and the `answers`
        are a list of objects.
    """
    try:
        input.update(
            **{
                "query": pairs[0]["query"],
                "answers": json.dumps(pairs[0]["answers"]),
            }
        )
        return input
    except Exception as e:
        self._logger.error(f"Error formatting output: {e}, pairs: '{pairs}'")
        return self._default_error(input)
_default_error(input)

Returns a default error output, to fill the responses in case of failure.

Source code in src/distilabel/steps/tasks/apigen/generator.py
def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
    """Returns a default error output, to fill the responses in case of failure."""
    input.update(
        **{
            "query": None,
            "answers": json.dumps([None] * self._number),
        }
    )
    return input
get_structured_output()

Creates the json schema to be passed to the LLM, to enforce generating a dictionary with the output which can be directly parsed as a python dictionary.

The schema corresponds to the following:

from typing import Dict, List
from pydantic import BaseModel


class Answer(BaseModel):
    name: str
    arguments: Dict[str, str]

class QueryAnswer(BaseModel):
    query: str
    answers: List[Answer]

class QueryAnswerPairs(BaseModel):
    pairs: List[QueryAnswer]

json.dumps(QueryAnswerPairs.model_json_schema(), indent=4)

Returns:

Type Description
Dict[str, Any]

JSON Schema of the response to enforce.

Source code in src/distilabel/steps/tasks/apigen/generator.py
@override
def get_structured_output(self) -> Dict[str, Any]:
    """Creates the json schema to be passed to the LLM, to enforce generating
    a dictionary with the output which can be directly parsed as a python dictionary.

    The schema corresponds to the following:

    ```python
    from typing import Dict, List
    from pydantic import BaseModel


    class Answer(BaseModel):
        name: str
        arguments: Dict[str, str]

    class QueryAnswer(BaseModel):
        query: str
        answers: List[Answer]

    class QueryAnswerPairs(BaseModel):
        pairs: List[QueryAnswer]

    json.dumps(QueryAnswerPairs.model_json_schema(), indent=4)
    ```

    Returns:
        JSON Schema of the response to enforce.
    """
    return {
        "$defs": {
            "Answer": {
                "properties": {
                    "name": {"title": "Name", "type": "string"},
                    "arguments": {
                        "additionalProperties": {"type": "string"},
                        "title": "Arguments",
                        "type": "object",
                    },
                },
                "required": ["name", "arguments"],
                "title": "Answer",
                "type": "object",
            },
            "QueryAnswer": {
                "properties": {
                    "query": {"title": "Query", "type": "string"},
                    "answers": {
                        "items": {"$ref": "#/$defs/Answer"},
                        "title": "Answers",
                        "type": "array",
                    },
                },
                "required": ["query", "answers"],
                "title": "QueryAnswer",
                "type": "object",
            },
        },
        "properties": {
            "pairs": {
                "items": {"$ref": "#/$defs/QueryAnswer"},
                "title": "Pairs",
                "type": "array",
            }
        },
        "required": ["pairs"],
        "title": "QueryAnswerPairs",
        "type": "object",
    }

APIGenSemanticChecker

Bases: Task

Generate queries and answers for the given functions in JSON format.

The APIGenGenerator is inspired by the APIGen pipeline, which was designed to generate verifiable and diverse function-calling datasets. The task generates a set of diverse queries and corresponding answers for the given functions in JSON format.

Attributes:

Name Type Description
system_prompt str

System prompt for the task. Has a default one.

exclude_failed_execution str

Whether to exclude failed executions (won't run on those rows that have a False in keep_row_after_execution_check column, which comes from running APIGenExecutionChecker). Defaults to True.

Input columns
  • func_desc (str): Description of what the function should do.
  • query (str): Instruction from the user.
  • answers (str): JSON encoded list with arguments to be passed to the function/API. Should be loaded using json.loads.
  • execution_result (str): Result of the function/API executed.
Output columns
  • thought (str): Reasoning for the output on whether to keep this output or not.
  • keep_row_after_semantic_check (bool): True or False, can be used to filter afterwards.
Categories
  • filtering
  • text-generation
References

Examples:

Semantic checker for generated function calls (original implementation):

```python
from distilabel.steps.tasks import APIGenSemanticChecker
from distilabel.llms import InferenceEndpointsLLM

llm=InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
    generation_kwargs={
        "temperature": 0.7,
        "max_new_tokens": 1024,
    },
)
semantic_checker = APIGenSemanticChecker(
    use_default_structured_output=False,
    llm=llm
)
semantic_checker.load()

res = next(
    semantic_checker.process(
        [
            {
                "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
                "query": "What information can be obtained about the Maine Coon cat breed?",
                "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
                "execution_result": "The Maine Coon is a big and hairy breed of cat",
            }
        ]
    )
)
res
# [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
# 'query': 'What information can be obtained about the Maine Coon cat breed?',
# 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
# 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
# 'thought': '',
# 'keep_row_after_semantic_check': True,
# 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
#     'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
#     {'role': 'user',
#     'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n   "thought": "Concisely describe your reasoning here",\n   "pass": "yes" or "no"\n}\n```\n'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```

Semantic checker for generated function calls (structured output):

```python
from distilabel.steps.tasks import APIGenSemanticChecker
from distilabel.llms import InferenceEndpointsLLM

llm=InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
    generation_kwargs={
        "temperature": 0.7,
        "max_new_tokens": 1024,
    },
)
semantic_checker = APIGenSemanticChecker(
    use_default_structured_output=True,
    llm=llm
)
semantic_checker.load()

res = next(
    semantic_checker.process(
        [
            {
                "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
                "query": "What information can be obtained about the Maine Coon cat breed?",
                "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
                "execution_result": "The Maine Coon is a big and hairy breed of cat",
            }
        ]
    )
)
res
# [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
# 'query': 'What information can be obtained about the Maine Coon cat breed?',
# 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
# 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
# 'keep_row_after_semantic_check': True,
# 'thought': '',
# 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
#     'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
#     {'role': 'user',
#     'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
```
Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
class APIGenSemanticChecker(Task):
    r"""Generate queries and answers for the given functions in JSON format.

    The `APIGenGenerator` is inspired by the APIGen pipeline, which was designed to generate
    verifiable and diverse function-calling datasets. The task generates a set of diverse queries
    and corresponding answers for the given functions in JSON format.

    Attributes:
        system_prompt: System prompt for the task. Has a default one.
        exclude_failed_execution: Whether to exclude failed executions (won't run on those
            rows that have a False in `keep_row_after_execution_check` column, which
            comes from running `APIGenExecutionChecker`). Defaults to True.

    Input columns:
        - func_desc (`str`): Description of what the function should do.
        - query (`str`): Instruction from the user.
        - answers (`str`): JSON encoded list with arguments to be passed to the function/API.
            Should be loaded using `json.loads`.
        - execution_result (`str`): Result of the function/API executed.

    Output columns:
        - thought (`str`): Reasoning for the output on whether to keep this output or not.
        - keep_row_after_semantic_check (`bool`): True or False, can be used to filter
            afterwards.

    Categories:
        - filtering
        - text-generation

    References:
        - [APIGen: Automated Pipeline for Generating Verifiable and Diverse Function-Calling Datasets](https://arxiv.org/abs/2406.18518)
        - [Salesforce/xlam-function-calling-60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)

    Examples:

        Semantic checker for generated function calls (original implementation):

        ```python
        from distilabel.steps.tasks import APIGenSemanticChecker
        from distilabel.llms import InferenceEndpointsLLM

        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            generation_kwargs={
                "temperature": 0.7,
                "max_new_tokens": 1024,
            },
        )
        semantic_checker = APIGenSemanticChecker(
            use_default_structured_output=False,
            llm=llm
        )
        semantic_checker.load()

        res = next(
            semantic_checker.process(
                [
                    {
                        "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
                        "query": "What information can be obtained about the Maine Coon cat breed?",
                        "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
                        "execution_result": "The Maine Coon is a big and hairy breed of cat",
                    }
                ]
            )
        )
        res
        # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
        # 'query': 'What information can be obtained about the Maine Coon cat breed?',
        # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
        # 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
        # 'thought': '',
        # 'keep_row_after_semantic_check': True,
        # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
        #     'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
        #     {'role': 'user',
        #     'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n```\n{\n   "thought": "Concisely describe your reasoning here",\n   "pass": "yes" or "no"\n}\n```\n'}]},
        # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
        ```

        Semantic checker for generated function calls (structured output):

        ```python
        from distilabel.steps.tasks import APIGenSemanticChecker
        from distilabel.llms import InferenceEndpointsLLM

        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            generation_kwargs={
                "temperature": 0.7,
                "max_new_tokens": 1024,
            },
        )
        semantic_checker = APIGenSemanticChecker(
            use_default_structured_output=True,
            llm=llm
        )
        semantic_checker.load()

        res = next(
            semantic_checker.process(
                [
                    {
                        "func_desc": "Fetch information about a specific cat breed from the Cat Breeds API.",
                        "query": "What information can be obtained about the Maine Coon cat breed?",
                        "answers": json.dumps([{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]),
                        "execution_result": "The Maine Coon is a big and hairy breed of cat",
                    }
                ]
            )
        )
        res
        # [{'func_desc': 'Fetch information about a specific cat breed from the Cat Breeds API.',
        # 'query': 'What information can be obtained about the Maine Coon cat breed?',
        # 'answers': [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}],
        # 'execution_result': 'The Maine Coon is a big and hairy breed of cat',
        # 'keep_row_after_semantic_check': True,
        # 'thought': '',
        # 'raw_input_a_p_i_gen_semantic_checker_0': [{'role': 'system',
        #     'content': 'As a data quality evaluator, you must assess the alignment between a user query, corresponding function calls, and their execution results.\nThese function calls and results are generated by other models, and your task is to ensure these results accurately reflect the user’s intentions.\n\nDo not pass if:\n1. The function call does not align with the query’s objective, or the input arguments appear incorrect.\n2. The function call and arguments are not properly chosen from the available functions.\n3. The number of function calls does not correspond to the user’s intentions.\n4. The execution results are irrelevant and do not match the function’s purpose.\n5. The execution results contain errors or reflect that the function calls were not executed successfully.\n'},
        #     {'role': 'user',
        #     'content': 'Given Information:\n- All Available Functions:\nFetch information about a specific cat breed from the Cat Breeds API.\n- User Query: What information can be obtained about the Maine Coon cat breed?\n- Generated Function Calls: [{"name": "get_breed_information", "arguments": {"breed": "Maine Coon"}}]\n- Execution Results: The Maine Coon is a big and hairy breed of cat\n\nNote: The query may have multiple intentions. Functions may be placeholders, and execution results may be truncated due to length, which is acceptable and should not cause a failure.\n\nThe main decision factor is wheather the function calls accurately reflect the query\'s intentions and the function descriptions.\nProvide your reasoning in the thought section and decide if the data passes (answer yes or no).\nIf not passing, concisely explain your reasons in the thought section; otherwise, leave this section blank.\n'}]},
        # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
        ```
    """

    system_prompt: str = SYSTEM_PROMPT_SEMANTIC_CHECKER
    use_default_structured_output: bool = False

    _format_inst: Union[str, None] = PrivateAttr(None)

    def load(self) -> None:
        """Loads the template for the generator prompt."""
        super().load()
        _path = str(
            importlib_resources.files("distilabel")
            / "steps"
            / "tasks"
            / "templates"
            / "apigen"
            / "semantic_checker.jinja2"
        )

        self._template = Template(open(_path).read())
        self._format_inst = self._set_format_inst()

    def _set_format_inst(self) -> str:
        """Prepares the function to generate the formatted instructions for the prompt.

        If the default structured output is used, returns an empty string because nothing
        else is needed, otherwise, returns the original addition to the prompt to guide the model
        to generate a formatted JSON.
        """
        return (
            "\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n"
            "```\n"
            "{\n"
            '   "thought": "Concisely describe your reasoning here",\n'
            '   "passes": "yes" or "no"\n'
            "}\n"
            "```\n"
        )

    @property
    def inputs(self) -> "StepColumns":
        """The inputs for the task."""
        return {
            "func_desc": True,
            "query": True,
            "answers": True,
            "execution_result": True,
            "keep_row_after_execution_check": True,
        }

    def format_input(self, input: Dict[str, Any]) -> "ChatType":
        """The input is formatted as a `ChatType`."""
        return [
            {"role": "system", "content": self.system_prompt},
            {
                "role": "user",
                "content": self._template.render(
                    func_desc=input["func_desc"],
                    query=input["query"] or "",
                    func_call=input["answers"] or "",
                    execution_result=input["execution_result"],
                    format_inst=self._format_inst,
                ),
            },
        ]

    @property
    def outputs(self) -> "StepColumns":
        """The output for the task are the queries and corresponding answers."""
        return ["keep_row_after_semantic_check", "thought"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """The output is formatted as a list with the score of each instruction.

        Args:
            output: the raw output of the LLM.
            input: the input to the task. Used for obtaining the number of responses.

        Returns:
            A dict with the queries and answers pairs.
            The answers are an array of answers corresponding to the query.
            Each answer is represented as an object with the following properties:
                - name (string): The name of the tool used to generate the answer.
                - arguments (object): An object representing the arguments passed to the tool to generate the answer.
            Each argument is represented as a key-value pair, where the key is the parameter name and the
            value is the corresponding value.
        """
        if output is None:
            return self._default_error(input)

        output = remove_fences(output)

        try:
            result = orjson.loads(output)
            # Update the column name and change to bool
            result["keep_row_after_semantic_check"] = (
                result.pop("passes").lower() == "yes"
            )
            input.update(**result)
            return input
        except orjson.JSONDecodeError:
            return self._default_error(input)

    def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
        """Default error message for the task."""
        input.update({"thought": None, "keep_row_after_semantic_check": None})
        return input

    @override
    def get_structured_output(self) -> Dict[str, Any]:
        """Creates the json schema to be passed to the LLM, to enforce generating
        a dictionary with the output which can be directly parsed as a python dictionary.

        The schema corresponds to the following:

        ```python
        from typing import Literal
        from pydantic import BaseModel
        import json

        class Checker(BaseModel):
            thought: str
            passes: Literal["yes", "no"]

        json.dumps(Checker.model_json_schema(), indent=4)
        ```

        Returns:
            JSON Schema of the response to enforce.
        """
        return {
            "properties": {
                "thought": {"title": "Thought", "type": "string"},
                "passes": {"enum": ["yes", "no"], "title": "Passes", "type": "string"},
            },
            "required": ["thought", "passes"],
            "title": "Checker",
            "type": "object",
        }
inputs: StepColumns property

The inputs for the task.

outputs: StepColumns property

The output for the task are the queries and corresponding answers.

load()

Loads the template for the generator prompt.

Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
def load(self) -> None:
    """Loads the template for the generator prompt."""
    super().load()
    _path = str(
        importlib_resources.files("distilabel")
        / "steps"
        / "tasks"
        / "templates"
        / "apigen"
        / "semantic_checker.jinja2"
    )

    self._template = Template(open(_path).read())
    self._format_inst = self._set_format_inst()
_set_format_inst()

Prepares the function to generate the formatted instructions for the prompt.

If the default structured output is used, returns an empty string because nothing else is needed, otherwise, returns the original addition to the prompt to guide the model to generate a formatted JSON.

Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
def _set_format_inst(self) -> str:
    """Prepares the function to generate the formatted instructions for the prompt.

    If the default structured output is used, returns an empty string because nothing
    else is needed, otherwise, returns the original addition to the prompt to guide the model
    to generate a formatted JSON.
    """
    return (
        "\nYour response MUST strictly adhere to the following JSON format, and NO other text MUST be included.\n"
        "```\n"
        "{\n"
        '   "thought": "Concisely describe your reasoning here",\n'
        '   "passes": "yes" or "no"\n'
        "}\n"
        "```\n"
    )
format_input(input)

The input is formatted as a ChatType.

Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
def format_input(self, input: Dict[str, Any]) -> "ChatType":
    """The input is formatted as a `ChatType`."""
    return [
        {"role": "system", "content": self.system_prompt},
        {
            "role": "user",
            "content": self._template.render(
                func_desc=input["func_desc"],
                query=input["query"] or "",
                func_call=input["answers"] or "",
                execution_result=input["execution_result"],
                format_inst=self._format_inst,
            ),
        },
    ]
format_output(output, input)

The output is formatted as a list with the score of each instruction.

Parameters:

Name Type Description Default
output Union[str, None]

the raw output of the LLM.

required
input Dict[str, Any]

the input to the task. Used for obtaining the number of responses.

required

Returns:

Type Description
Dict[str, Any]

A dict with the queries and answers pairs.

Dict[str, Any]

The answers are an array of answers corresponding to the query.

Dict[str, Any]

Each answer is represented as an object with the following properties: - name (string): The name of the tool used to generate the answer. - arguments (object): An object representing the arguments passed to the tool to generate the answer.

Dict[str, Any]

Each argument is represented as a key-value pair, where the key is the parameter name and the

Dict[str, Any]

value is the corresponding value.

Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
def format_output(
    self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
    """The output is formatted as a list with the score of each instruction.

    Args:
        output: the raw output of the LLM.
        input: the input to the task. Used for obtaining the number of responses.

    Returns:
        A dict with the queries and answers pairs.
        The answers are an array of answers corresponding to the query.
        Each answer is represented as an object with the following properties:
            - name (string): The name of the tool used to generate the answer.
            - arguments (object): An object representing the arguments passed to the tool to generate the answer.
        Each argument is represented as a key-value pair, where the key is the parameter name and the
        value is the corresponding value.
    """
    if output is None:
        return self._default_error(input)

    output = remove_fences(output)

    try:
        result = orjson.loads(output)
        # Update the column name and change to bool
        result["keep_row_after_semantic_check"] = (
            result.pop("passes").lower() == "yes"
        )
        input.update(**result)
        return input
    except orjson.JSONDecodeError:
        return self._default_error(input)
_default_error(input)

Default error message for the task.

Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
def _default_error(self, input: Dict[str, Any]) -> Dict[str, Any]:
    """Default error message for the task."""
    input.update({"thought": None, "keep_row_after_semantic_check": None})
    return input
get_structured_output()

Creates the json schema to be passed to the LLM, to enforce generating a dictionary with the output which can be directly parsed as a python dictionary.

The schema corresponds to the following:

from typing import Literal
from pydantic import BaseModel
import json

class Checker(BaseModel):
    thought: str
    passes: Literal["yes", "no"]

json.dumps(Checker.model_json_schema(), indent=4)

Returns:

Type Description
Dict[str, Any]

JSON Schema of the response to enforce.

Source code in src/distilabel/steps/tasks/apigen/semantic_checker.py
@override
def get_structured_output(self) -> Dict[str, Any]:
    """Creates the json schema to be passed to the LLM, to enforce generating
    a dictionary with the output which can be directly parsed as a python dictionary.

    The schema corresponds to the following:

    ```python
    from typing import Literal
    from pydantic import BaseModel
    import json

    class Checker(BaseModel):
        thought: str
        passes: Literal["yes", "no"]

    json.dumps(Checker.model_json_schema(), indent=4)
    ```

    Returns:
        JSON Schema of the response to enforce.
    """
    return {
        "properties": {
            "thought": {"title": "Thought", "type": "string"},
            "passes": {"enum": ["yes", "no"], "title": "Passes", "type": "string"},
        },
        "required": ["thought", "passes"],
        "title": "Checker",
        "type": "object",
    }

ArgillaLabeller

Bases: Task

Annotate Argilla records based on input fields, example records and question settings.

This task is designed to facilitate the annotation of Argilla records by leveraging a pre-trained LLM. It uses a system prompt that guides the LLM to understand the input fields, the question type, and the question settings. The task then formats the input data and generates a response based on the question. The response is validated against the question's value model, and the final suggestion is prepared for annotation.

Attributes:

Name Type Description
_template Union[Template, None]

a Jinja2 template used to format the input for the LLM.

Input columns
  • record (argilla.Record): The record to be annotated.
  • fields (Optional[List[Dict[str, Any]]]): The list of field settings for the input fields.
  • question (Optional[Dict[str, Any]]): The question settings for the question to be answered.
  • example_records (Optional[List[Dict[str, Any]]]): The few shot example records with responses to be used to answer the question.
  • guidelines (Optional[str]): The guidelines for the annotation task.
Output columns
  • suggestion (Dict[str, Any]): The final suggestion for annotation.
Categories
  • text-classification
  • scorer
  • text-generation
References

Examples:

Annotate a record with the same dataset and question:

import argilla as rg
from argilla import Suggestion
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM

# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
pending_records_filter = rg.Filter(("status", "==", "pending"))
completed_records_filter = rg.Filter(("status", "==", "completed"))
pending_records = list(
    dataset.records(
        query=rg.Query(filter=pending_records_filter),
        limit=5,
    )
)
example_records = list(
    dataset.records(
        query=rg.Query(filter=completed_records_filter),
        limit=5,
    )
)
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]

# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="mistralai/Mistral-7B-Instruct-v0.2",
    ),
    fields=[field],
    question=question,
    example_records=example_records,
    guidelines=dataset.guidelines
)
labeller.load()

# Process the pending records
result = next(
    labeller.process(
        [
            {
                "record": record
            } for record in pending_records
        ]
    )
)

# Add the suggestions to the records
for record, suggestion in zip(pending_records, result):
    record.suggestions.add(Suggestion(**suggestion["suggestion"]))

# Log the updated records
dataset.records.log(pending_records)

Annotate a record with alternating datasets and questions:

import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM

# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
field = dataset.settings.fields["text"]
question = dataset.settings.questions["label"]
question2 = dataset.settings.questions["label2"]

# Initialize the labeller with the model and fields
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="mistralai/Mistral-7B-Instruct-v0.2",
    )
)
labeller.load()

# Process the record
record = next(dataset.records())
result = next(
    labeller.process(
        [
            {
                "record": record,
                "fields": [field],
                "question": question,
            },
            {
                "record": record,
                "fields": [field],
                "question": question2,
            }
        ]
    )
)

# Add the suggestions to the record
for suggestion in result:
    record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))

# Log the updated record
dataset.records.log([record])

Overwrite default prompts and instructions:

import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM

# Overwrite default prompts and instructions
labeller = ArgillaLabeller(
    llm=InferenceEndpointsLLM(
        model_id="mistralai/Mistral-7B-Instruct-v0.2",
    ),
    system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
    question_to_label_instruction={
        "label_selection": "Select the appropriate label from the list of provided labels.",
        "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
        "text": "Provide a text response to the question.",
        "rating": "Provide a rating for the question.",
    },
)
labeller.load()
Source code in src/distilabel/steps/tasks/argilla_labeller.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
class ArgillaLabeller(Task):
    """
    Annotate Argilla records based on input fields, example records and question settings.

    This task is designed to facilitate the annotation of Argilla records by leveraging a pre-trained LLM.
    It uses a system prompt that guides the LLM to understand the input fields, the question type,
    and the question settings. The task then formats the input data and generates a response based on the question.
    The response is validated against the question's value model, and the final suggestion is prepared for annotation.

    Attributes:
        _template: a Jinja2 template used to format the input for the LLM.

    Input columns:
        - record (`argilla.Record`): The record to be annotated.
        - fields (`Optional[List[Dict[str, Any]]]`): The list of field settings for the input fields.
        - question (`Optional[Dict[str, Any]]`): The question settings for the question to be answered.
        - example_records (`Optional[List[Dict[str, Any]]]`): The few shot example records with responses to be used to answer the question.
        - guidelines (`Optional[str]`): The guidelines for the annotation task.

    Output columns:
        - suggestion (`Dict[str, Any]`): The final suggestion for annotation.

    Categories:
        - text-classification
        - scorer
        - text-generation

    References:
        - [`Argilla: Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets`](https://github.com/argilla-io/argilla/)

    Examples:
        Annotate a record with the same dataset and question:

        ```python
        import argilla as rg
        from argilla import Suggestion
        from distilabel.steps.tasks import ArgillaLabeller
        from distilabel.llms.huggingface import InferenceEndpointsLLM

        # Get information from Argilla dataset definition
        dataset = rg.Dataset("my_dataset")
        pending_records_filter = rg.Filter(("status", "==", "pending"))
        completed_records_filter = rg.Filter(("status", "==", "completed"))
        pending_records = list(
            dataset.records(
                query=rg.Query(filter=pending_records_filter),
                limit=5,
            )
        )
        example_records = list(
            dataset.records(
                query=rg.Query(filter=completed_records_filter),
                limit=5,
            )
        )
        field = dataset.settings.fields["text"]
        question = dataset.settings.questions["label"]

        # Initialize the labeller with the model and fields
        labeller = ArgillaLabeller(
            llm=InferenceEndpointsLLM(
                model_id="mistralai/Mistral-7B-Instruct-v0.2",
            ),
            fields=[field],
            question=question,
            example_records=example_records,
            guidelines=dataset.guidelines
        )
        labeller.load()

        # Process the pending records
        result = next(
            labeller.process(
                [
                    {
                        "record": record
                    } for record in pending_records
                ]
            )
        )

        # Add the suggestions to the records
        for record, suggestion in zip(pending_records, result):
            record.suggestions.add(Suggestion(**suggestion["suggestion"]))

        # Log the updated records
        dataset.records.log(pending_records)
        ```

        Annotate a record with alternating datasets and questions:

        ```python
        import argilla as rg
        from distilabel.steps.tasks import ArgillaLabeller
        from distilabel.llms.huggingface import InferenceEndpointsLLM

        # Get information from Argilla dataset definition
        dataset = rg.Dataset("my_dataset")
        field = dataset.settings.fields["text"]
        question = dataset.settings.questions["label"]
        question2 = dataset.settings.questions["label2"]

        # Initialize the labeller with the model and fields
        labeller = ArgillaLabeller(
            llm=InferenceEndpointsLLM(
                model_id="mistralai/Mistral-7B-Instruct-v0.2",
            )
        )
        labeller.load()

        # Process the record
        record = next(dataset.records())
        result = next(
            labeller.process(
                [
                    {
                        "record": record,
                        "fields": [field],
                        "question": question,
                    },
                    {
                        "record": record,
                        "fields": [field],
                        "question": question2,
                    }
                ]
            )
        )

        # Add the suggestions to the record
        for suggestion in result:
            record.suggestions.add(rg.Suggestion(**suggestion["suggestion"]))

        # Log the updated record
        dataset.records.log([record])
        ```

        Overwrite default prompts and instructions:

        ```python
        import argilla as rg
        from distilabel.steps.tasks import ArgillaLabeller
        from distilabel.llms.huggingface import InferenceEndpointsLLM

        # Overwrite default prompts and instructions
        labeller = ArgillaLabeller(
            llm=InferenceEndpointsLLM(
                model_id="mistralai/Mistral-7B-Instruct-v0.2",
            ),
            system_prompt="You are an expert annotator and labelling assistant that understands complex domains and natural language processing.",
            question_to_label_instruction={
                "label_selection": "Select the appropriate label from the list of provided labels.",
                "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
                "text": "Provide a text response to the question.",
                "rating": "Provide a rating for the question.",
            },
        )
        labeller.load()
        ```
    """

    system_prompt: str = (
        "You are an expert annotator and labelling assistant that understands complex domains and natural language processing. "
        "You are given input fields and a question. "
        "You should create a valid JSON object as an answer to the question based on the input fields. "
        "1. Understand the input fields and optional guidelines. "
        "2. Understand the question type and the question settings. "
        "3. Reason through your response step-by-step. "
        "4. Provide a valid JSON object as an answer to the question."
    )
    question_to_label_instruction: Dict[str, str] = {
        "label_selection": "Select the appropriate label from the list of provided labels.",
        "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
        "text": "Provide a text response to the question.",
        "rating": "Provide a rating for the question.",
    }
    example_records: Optional[
        RuntimeParameter[Union[List[Union[Dict[str, Any], BaseModel]], None]]
    ] = Field(
        default=None,
        description="The few shot serialized example records or `BaseModel`s with responses to be used to answer the question.",
    )
    fields: Optional[
        RuntimeParameter[Union[List[Union[BaseModel, Dict[str, Any]]], None]]
    ] = Field(
        default=None,
        description="The field serialized field settings or `BaseModel` for the fields to be used to answer the question.",
    )
    question: Optional[
        RuntimeParameter[
            Union[
                Dict[str, Any],
                BaseModel,
                None,
            ]
        ]
    ] = Field(
        default=None,
        description="The question serialized question settings or `BaseModel` for the question to be answered.",
    )
    guidelines: Optional[RuntimeParameter[str]] = Field(
        default=None,
        description="The guidelines for the annotation task.",
    )

    _template: Union[Template, None] = PrivateAttr(...)
    _client: Optional[Any] = PrivateAttr(None)

    def load(self) -> None:
        """Loads the Jinja2 template."""
        super().load()

        _path = str(
            importlib_resources.files("distilabel")
            / "steps"
            / "tasks"
            / "templates"
            / "argillalabeller.jinja2"
        )

        self._template = Template(open(_path).read())

    @property
    def inputs(self) -> Dict[str, bool]:
        return {
            "record": True,
            "fields": False,
            "question": False,
            "example_records": False,
            "guidelines": False,
        }

    def _format_record(
        self, record: Dict[str, Any], fields: List[Dict[str, Any]]
    ) -> str:
        """Format the record fields into a string.

        Args:
            record (Dict[str, Any]): The record to format.
            fields (List[Dict[str, Any]]): The fields to format.

        Returns:
            str: The formatted record fields.
        """
        output = []
        for field in fields:
            if title := field.get("title"):
                output.append(f"title: {title}")
            if description := field.get("description"):
                output.append(f"description: {description}")
            output.append(record.get("fields", {}).get(field.get("name", "")))
        return "\n".join(output)

    def _get_label_instruction(self, question: Dict[str, Any]) -> str:
        """Get the label instruction for the question.

        Args:
            question (Dict[str, Any]): The question to get the label instruction for.

        Returns:
            str: The label instruction for the question.
        """
        question_type = question["settings"]["type"]
        return self.question_to_label_instruction[question_type]

    def _format_question(self, question: Dict[str, Any]) -> str:
        """Format the question settings into a string.

        Args:
            question (Dict[str, Any]): The question to format.

        Returns:
            str: The formatted question.
        """
        output = [
            f"title: {question.get('title', '')}",
            f"description: {question.get('description', '')}",
            f"label_instruction: {self._get_label_instruction(question)}",
        ]
        settings = question.get("settings", {})
        if "options" in settings:
            output.append(
                f"labels: {[option['value'] for option in settings.get('options', [])]}"
            )
        return "\n".join(output)

    def _format_example_records(
        self,
        records: List[Dict[str, Any]],
        fields: List[Dict[str, Any]],
        question: Dict[str, Any],
    ) -> str:
        """Format the example records into a string.

        Args:
            records (List[Dict[str, Any]]): The records to format.
            fields (List[Dict[str, Any]]): The fields to format.
            question (Dict[str, Any]): The question to format.

        Returns:
            str: The formatted example records.
        """
        base = []
        for record in records:
            responses = record.get("responses", {})
            if responses.get(question["name"]):
                base.append(self._format_record(record, fields))
                value = responses[question["name"]][0]["value"]
                formatted_value = self._assign_value_to_question_value_model(
                    value, question
                )
                base.append(f"Response: {formatted_value}")
                base.append("")
            else:
                warnings.warn(
                    f"Record {record} has no response for question {question['name']}. Skipping example record.",
                    stacklevel=2,
                )
        return "\n".join(base)

    def format_input(
        self,
        input: Dict[
            str,
            Union[
                Dict[str, Any],
                "Record",
                "TextField",
                "MultiLabelQuestion",
                "LabelQuestion",
                "RatingQuestion",
                "TextQuestion",
            ],
        ],
    ) -> "ChatType":
        """Format the input into a chat message.

        Args:
            input: The input to format.

        Returns:
            The formatted chat message.

        Raises:
            ValueError: If question or fields are not provided.
        """
        input_keys = list(self.inputs.keys())
        record = input[input_keys[0]]
        fields = input.get(input_keys[1], self.fields)
        question = input.get(input_keys[2], self.question)
        examples = input.get(input_keys[3], self.example_records)
        guidelines = input.get(input_keys[4], self.guidelines)

        if question is None:
            raise ValueError("Question must be provided.")
        if fields is None or any(field is None for field in fields):
            raise ValueError("Fields must be provided.")

        record = record.to_dict() if not isinstance(record, dict) else record
        question = question.serialize() if not isinstance(question, dict) else question
        fields = [
            field.serialize() if not isinstance(field, dict) else field
            for field in fields
        ]
        examples = (
            [
                example.to_dict() if not isinstance(example, dict) else example
                for example in examples
            ]
            if examples
            else None
        )

        formatted_fields = self._format_record(record, fields)
        formatted_question = self._format_question(question)
        formatted_examples = (
            self._format_example_records(examples, fields, question)
            if examples
            else False
        )

        prompt = self._template.render(
            fields=formatted_fields,
            question=formatted_question,
            examples=formatted_examples,
            guidelines=guidelines,
        )

        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": prompt})
        return messages

    @property
    def outputs(self) -> List[str]:
        return ["suggestion"]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Format the output into a dictionary.

        Args:
            output (Union[str, None]): The output to format.
            input (Dict[str, Any]): The input to format.

        Returns:
            Dict[str, Any]: The formatted output.
        """
        from argilla import Suggestion

        question: Union[
            Any,
            Dict[str, Any],
            LabelQuestion,
            MultiLabelQuestion,
            RatingQuestion,
            TextQuestion,
            None,
        ] = input.get(list(self.inputs.keys())[2], self.question) or self.question
        question = question.serialize() if not isinstance(question, dict) else question
        model = self._get_pydantic_model_of_structured_output(question)
        validated_output = model(**json.loads(output))
        value = self._get_value_from_question_value_model(validated_output)
        suggestion = Suggestion(
            value=value,
            question_name=question["name"],
            type="model",
            agent=self.llm.model_name,
        ).serialize()
        return {
            self.outputs[0]: {
                k: v
                for k, v in suggestion.items()
                if k in ["value", "question_name", "type", "agent"]
            }
        }

    def _set_llm_structured_output_for_question(self, question: Dict[str, Any]) -> None:
        runtime_parameters = self.llm._runtime_parameters
        runtime_parameters.update(
            {
                "structured_output": {
                    "format": "json",
                    "schema": self._get_pydantic_model_of_structured_output(question),
                },
            }
        )
        self.llm.set_runtime_parameters(runtime_parameters)

    @override
    def process(self, inputs: StepInput) -> "StepOutput":
        """Process the input through the task.

        Args:
            inputs (StepInput): The input to process.

        Returns:
            StepOutput: The output of the task.
        """

        question_list = [input.get("question", self.question) for input in inputs]
        fields_list = [input.get("fields", self.fields) for input in inputs]
        # check if any field for the field in fields is None
        for fields in fields_list:
            if any(field is None for field in fields):
                raise ValueError(
                    "Fields must be provided during init or through `process` method."
                )
        # check if any question is None
        if any(question is None for question in question_list):
            raise ValueError(
                "Question must be provided during init or through `process` method."
            )
        question_list = [
            question.serialize() if not isinstance(question, dict) else question
            for question in question_list
        ]
        if not all(question == question_list[0] for question in question_list):
            warnings.warn(
                "Not all questions are the same. Processing each question separately by setting the structured output for each question. This may impact performance.",
                stacklevel=2,
            )
            for input, question in zip(inputs, question_list):
                self._set_llm_structured_output_for_question(question)
                yield from super().process([input])
        else:
            question = question_list[0]
            self._set_llm_structured_output_for_question(question)
            yield from super().process(inputs)

    def _get_value_from_question_value_model(
        self, question_value_model: BaseModel
    ) -> Any:
        """Get the value from the question value model.

        Args:
            question_value_model (BaseModel): The question value model to get the value from.

        Returns:
            Any: The value from the question value model.
        """
        for attr in ["label", "labels", "rating", "text"]:
            if hasattr(question_value_model, attr):
                return getattr(question_value_model, attr)
        raise ValueError(f"Unsupported question type: {question_value_model}")

    def _assign_value_to_question_value_model(
        self, value: Any, question: Dict[str, Any]
    ) -> BaseModel:
        """Assign the value to the question value model.

        Args:
            value (Any): The value to assign.
            question (Dict[str, Any]): The question to assign the value to.

        Returns:
            BaseModel: The question value model with the assigned value.
        """
        question_value_model = self._get_pydantic_model_of_structured_output(question)
        for attr in ["label", "labels", "rating", "text"]:
            try:
                model_dict = {attr: value}
                question_value_model = question_value_model(**model_dict)
                return question_value_model.model_dump_json()
            except AttributeError:
                pass
        return value

    def _get_pydantic_model_of_structured_output(
        self,
        question: Dict[str, Any],
    ) -> BaseModel:
        """Get the Pydantic model of the structured output.

        Args:
            question (Dict[str, Any]): The question to get the Pydantic model of the structured output for.

        Returns:
            BaseModel: The Pydantic model of the structured output.
        """

        question_type = question["settings"]["type"]

        if question_type == "multi_label_selection":

            class QuestionValueModel(BaseModel):
                labels: Optional[List[str]] = Field(default_factory=list)

        elif question_type == "label_selection":

            class QuestionValueModel(BaseModel):
                label: str

        elif question_type == "text":

            class QuestionValueModel(BaseModel):
                text: str

        elif question_type == "rating":

            class QuestionValueModel(BaseModel):
                rating: int
        else:
            raise ValueError(f"Unsupported question type: {question}")

        return QuestionValueModel
load()

Loads the Jinja2 template.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def load(self) -> None:
    """Loads the Jinja2 template."""
    super().load()

    _path = str(
        importlib_resources.files("distilabel")
        / "steps"
        / "tasks"
        / "templates"
        / "argillalabeller.jinja2"
    )

    self._template = Template(open(_path).read())
_format_record(record, fields)

Format the record fields into a string.

Parameters:

Name Type Description Default
record Dict[str, Any]

The record to format.

required
fields List[Dict[str, Any]]

The fields to format.

required

Returns:

Name Type Description
str str

The formatted record fields.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _format_record(
    self, record: Dict[str, Any], fields: List[Dict[str, Any]]
) -> str:
    """Format the record fields into a string.

    Args:
        record (Dict[str, Any]): The record to format.
        fields (List[Dict[str, Any]]): The fields to format.

    Returns:
        str: The formatted record fields.
    """
    output = []
    for field in fields:
        if title := field.get("title"):
            output.append(f"title: {title}")
        if description := field.get("description"):
            output.append(f"description: {description}")
        output.append(record.get("fields", {}).get(field.get("name", "")))
    return "\n".join(output)
_get_label_instruction(question)

Get the label instruction for the question.

Parameters:

Name Type Description Default
question Dict[str, Any]

The question to get the label instruction for.

required

Returns:

Name Type Description
str str

The label instruction for the question.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _get_label_instruction(self, question: Dict[str, Any]) -> str:
    """Get the label instruction for the question.

    Args:
        question (Dict[str, Any]): The question to get the label instruction for.

    Returns:
        str: The label instruction for the question.
    """
    question_type = question["settings"]["type"]
    return self.question_to_label_instruction[question_type]
_format_question(question)

Format the question settings into a string.

Parameters:

Name Type Description Default
question Dict[str, Any]

The question to format.

required

Returns:

Name Type Description
str str

The formatted question.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _format_question(self, question: Dict[str, Any]) -> str:
    """Format the question settings into a string.

    Args:
        question (Dict[str, Any]): The question to format.

    Returns:
        str: The formatted question.
    """
    output = [
        f"title: {question.get('title', '')}",
        f"description: {question.get('description', '')}",
        f"label_instruction: {self._get_label_instruction(question)}",
    ]
    settings = question.get("settings", {})
    if "options" in settings:
        output.append(
            f"labels: {[option['value'] for option in settings.get('options', [])]}"
        )
    return "\n".join(output)
_format_example_records(records, fields, question)

Format the example records into a string.

Parameters:

Name Type Description Default
records List[Dict[str, Any]]

The records to format.

required
fields List[Dict[str, Any]]

The fields to format.

required
question Dict[str, Any]

The question to format.

required

Returns:

Name Type Description
str str

The formatted example records.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _format_example_records(
    self,
    records: List[Dict[str, Any]],
    fields: List[Dict[str, Any]],
    question: Dict[str, Any],
) -> str:
    """Format the example records into a string.

    Args:
        records (List[Dict[str, Any]]): The records to format.
        fields (List[Dict[str, Any]]): The fields to format.
        question (Dict[str, Any]): The question to format.

    Returns:
        str: The formatted example records.
    """
    base = []
    for record in records:
        responses = record.get("responses", {})
        if responses.get(question["name"]):
            base.append(self._format_record(record, fields))
            value = responses[question["name"]][0]["value"]
            formatted_value = self._assign_value_to_question_value_model(
                value, question
            )
            base.append(f"Response: {formatted_value}")
            base.append("")
        else:
            warnings.warn(
                f"Record {record} has no response for question {question['name']}. Skipping example record.",
                stacklevel=2,
            )
    return "\n".join(base)
format_input(input)

Format the input into a chat message.

Parameters:

Name Type Description Default
input Dict[str, Union[Dict[str, Any], Record, TextField, MultiLabelQuestion, LabelQuestion, RatingQuestion, TextQuestion]]

The input to format.

required

Returns:

Type Description
ChatType

The formatted chat message.

Raises:

Type Description
ValueError

If question or fields are not provided.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def format_input(
    self,
    input: Dict[
        str,
        Union[
            Dict[str, Any],
            "Record",
            "TextField",
            "MultiLabelQuestion",
            "LabelQuestion",
            "RatingQuestion",
            "TextQuestion",
        ],
    ],
) -> "ChatType":
    """Format the input into a chat message.

    Args:
        input: The input to format.

    Returns:
        The formatted chat message.

    Raises:
        ValueError: If question or fields are not provided.
    """
    input_keys = list(self.inputs.keys())
    record = input[input_keys[0]]
    fields = input.get(input_keys[1], self.fields)
    question = input.get(input_keys[2], self.question)
    examples = input.get(input_keys[3], self.example_records)
    guidelines = input.get(input_keys[4], self.guidelines)

    if question is None:
        raise ValueError("Question must be provided.")
    if fields is None or any(field is None for field in fields):
        raise ValueError("Fields must be provided.")

    record = record.to_dict() if not isinstance(record, dict) else record
    question = question.serialize() if not isinstance(question, dict) else question
    fields = [
        field.serialize() if not isinstance(field, dict) else field
        for field in fields
    ]
    examples = (
        [
            example.to_dict() if not isinstance(example, dict) else example
            for example in examples
        ]
        if examples
        else None
    )

    formatted_fields = self._format_record(record, fields)
    formatted_question = self._format_question(question)
    formatted_examples = (
        self._format_example_records(examples, fields, question)
        if examples
        else False
    )

    prompt = self._template.render(
        fields=formatted_fields,
        question=formatted_question,
        examples=formatted_examples,
        guidelines=guidelines,
    )

    messages = []
    if self.system_prompt:
        messages.append({"role": "system", "content": self.system_prompt})
    messages.append({"role": "user", "content": prompt})
    return messages
format_output(output, input)

Format the output into a dictionary.

Parameters:

Name Type Description Default
output Union[str, None]

The output to format.

required
input Dict[str, Any]

The input to format.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The formatted output.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def format_output(
    self, output: Union[str, None], input: Dict[str, Any]
) -> Dict[str, Any]:
    """Format the output into a dictionary.

    Args:
        output (Union[str, None]): The output to format.
        input (Dict[str, Any]): The input to format.

    Returns:
        Dict[str, Any]: The formatted output.
    """
    from argilla import Suggestion

    question: Union[
        Any,
        Dict[str, Any],
        LabelQuestion,
        MultiLabelQuestion,
        RatingQuestion,
        TextQuestion,
        None,
    ] = input.get(list(self.inputs.keys())[2], self.question) or self.question
    question = question.serialize() if not isinstance(question, dict) else question
    model = self._get_pydantic_model_of_structured_output(question)
    validated_output = model(**json.loads(output))
    value = self._get_value_from_question_value_model(validated_output)
    suggestion = Suggestion(
        value=value,
        question_name=question["name"],
        type="model",
        agent=self.llm.model_name,
    ).serialize()
    return {
        self.outputs[0]: {
            k: v
            for k, v in suggestion.items()
            if k in ["value", "question_name", "type", "agent"]
        }
    }
process(inputs)

Process the input through the task.

Parameters:

Name Type Description Default
inputs StepInput

The input to process.

required

Returns:

Name Type Description
StepOutput StepOutput

The output of the task.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
@override
def process(self, inputs: StepInput) -> "StepOutput":
    """Process the input through the task.

    Args:
        inputs (StepInput): The input to process.

    Returns:
        StepOutput: The output of the task.
    """

    question_list = [input.get("question", self.question) for input in inputs]
    fields_list = [input.get("fields", self.fields) for input in inputs]
    # check if any field for the field in fields is None
    for fields in fields_list:
        if any(field is None for field in fields):
            raise ValueError(
                "Fields must be provided during init or through `process` method."
            )
    # check if any question is None
    if any(question is None for question in question_list):
        raise ValueError(
            "Question must be provided during init or through `process` method."
        )
    question_list = [
        question.serialize() if not isinstance(question, dict) else question
        for question in question_list
    ]
    if not all(question == question_list[0] for question in question_list):
        warnings.warn(
            "Not all questions are the same. Processing each question separately by setting the structured output for each question. This may impact performance.",
            stacklevel=2,
        )
        for input, question in zip(inputs, question_list):
            self._set_llm_structured_output_for_question(question)
            yield from super().process([input])
    else:
        question = question_list[0]
        self._set_llm_structured_output_for_question(question)
        yield from super().process(inputs)
_get_value_from_question_value_model(question_value_model)

Get the value from the question value model.

Parameters:

Name Type Description Default
question_value_model BaseModel

The question value model to get the value from.

required

Returns:

Name Type Description
Any Any

The value from the question value model.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _get_value_from_question_value_model(
    self, question_value_model: BaseModel
) -> Any:
    """Get the value from the question value model.

    Args:
        question_value_model (BaseModel): The question value model to get the value from.

    Returns:
        Any: The value from the question value model.
    """
    for attr in ["label", "labels", "rating", "text"]:
        if hasattr(question_value_model, attr):
            return getattr(question_value_model, attr)
    raise ValueError(f"Unsupported question type: {question_value_model}")
_assign_value_to_question_value_model(value, question)

Assign the value to the question value model.

Parameters:

Name Type Description Default
value Any

The value to assign.

required
question Dict[str, Any]

The question to assign the value to.

required

Returns:

Name Type Description
BaseModel BaseModel

The question value model with the assigned value.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _assign_value_to_question_value_model(
    self, value: Any, question: Dict[str, Any]
) -> BaseModel:
    """Assign the value to the question value model.

    Args:
        value (Any): The value to assign.
        question (Dict[str, Any]): The question to assign the value to.

    Returns:
        BaseModel: The question value model with the assigned value.
    """
    question_value_model = self._get_pydantic_model_of_structured_output(question)
    for attr in ["label", "labels", "rating", "text"]:
        try:
            model_dict = {attr: value}
            question_value_model = question_value_model(**model_dict)
            return question_value_model.model_dump_json()
        except AttributeError:
            pass
    return value
_get_pydantic_model_of_structured_output(question)

Get the Pydantic model of the structured output.

Parameters:

Name Type Description Default
question Dict[str, Any]

The question to get the Pydantic model of the structured output for.

required

Returns:

Name Type Description
BaseModel BaseModel

The Pydantic model of the structured output.

Source code in src/distilabel/steps/tasks/argilla_labeller.py
def _get_pydantic_model_of_structured_output(
    self,
    question: Dict[str, Any],
) -> BaseModel:
    """Get the Pydantic model of the structured output.

    Args:
        question (Dict[str, Any]): The question to get the Pydantic model of the structured output for.

    Returns:
        BaseModel: The Pydantic model of the structured output.
    """

    question_type = question["settings"]["type"]

    if question_type == "multi_label_selection":

        class QuestionValueModel(BaseModel):
            labels: Optional[List[str]] = Field(default_factory=list)

    elif question_type == "label_selection":

        class QuestionValueModel(BaseModel):
            label: str

    elif question_type == "text":

        class QuestionValueModel(BaseModel):
            text: str

    elif question_type == "rating":

        class QuestionValueModel(BaseModel):
            rating: int
    else:
        raise ValueError(f"Unsupported question type: {question}")

    return QuestionValueModel

CLAIR

Bases: Task

Contrastive Learning from AI Revisions (CLAIR).

CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting preference A preferred A’ is much more contrastive and precise.

Input columns
  • task (str): The task or instruction.
  • student_solution (str): An answer to the task that is to be revised.
Output columns
  • revision (str): The revised text.
  • rational (str): The rational for the provided revision.
  • model_name (str): The name of the model used to generate the revision and rational.
Categories
  • preference
  • text-generation
References

Examples:

Create contrastive preference pairs:

from distilabel.steps.tasks import CLAIR
from distilabel.llms.huggingface import InferenceEndpointsLLM

llm=InferenceEndpointsLLM(
    model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
    tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
    generation_kwargs={
        "temperature": 0.7,
        "max_new_tokens": 4096,
    },
)
clair_task = CLAIR(llm=llm)

clair_task.load()

result = next(
    clair_task.process(
        [
            {
                "task": "How many gaps are there between the earth and the moon?",
                "student_solution": 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.'
            }
        ]
    )
)
# result
# [{'task': 'How many gaps are there between the earth and the moon?',
# 'student_solution': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.',
# 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
# 'rational': 'The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.',
# 'distilabel_metadata': {'raw_output_c_l_a_i_r_0': '{teacher_reasoning}: The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.\n\n{corrected_student_solution}: There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
# 'raw_input_c_l_a_i_r_0': [{'role': 'system',
#     'content': "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."},
#     {'role': 'user',
#     'content': '{task}: How many gaps are there between the earth and the moon?\n\n{student_solution}: There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.\n\n-----------------\n\nLet\'s first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.'}]},
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]

Citations:

```
@misc{doosterlinck2024anchoredpreferenceoptimizationcontrastive,
    title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment},
    author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri},
    year={2024},
    eprint={2408.06266},
    archivePrefix={arXiv},
    primaryClass={cs.LG},
    url={https://arxiv.org/abs/2408.06266},
}
```
Source code in src/distilabel/steps/tasks/clair.py
class CLAIR(Task):
    r"""Contrastive Learning from AI Revisions (CLAIR).

    CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting
    preference A `preferred` A’ is much more contrastive and precise.

    Input columns:
        - task (`str`): The task or instruction.
        - student_solution (`str`): An answer to the task that is to be revised.

    Output columns:
        - revision (`str`): The revised text.
        - rational (`str`): The rational for the provided revision.
        - model_name (`str`): The name of the model used to generate the revision and rational.

    Categories:
        - preference
        - text-generation

    References:
        - [`Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment`](https://arxiv.org/abs/2408.06266v1)
        - [`APO and CLAIR - GitHub Repository`](https://github.com/ContextualAI/CLAIR_and_APO)

    Examples:
        Create contrastive preference pairs:

        ```python
        from distilabel.steps.tasks import CLAIR
        from distilabel.llms.huggingface import InferenceEndpointsLLM

        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
            generation_kwargs={
                "temperature": 0.7,
                "max_new_tokens": 4096,
            },
        )
        clair_task = CLAIR(llm=llm)

        clair_task.load()

        result = next(
            clair_task.process(
                [
                    {
                        "task": "How many gaps are there between the earth and the moon?",
                        "student_solution": 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.'
                    }
                ]
            )
        )
        # result
        # [{'task': 'How many gaps are there between the earth and the moon?',
        # 'student_solution': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.',
        # 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
        # 'rational': 'The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.',
        # 'distilabel_metadata': {'raw_output_c_l_a_i_r_0': '{teacher_reasoning}: The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.\n\n{corrected_student_solution}: There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.',
        # 'raw_input_c_l_a_i_r_0': [{'role': 'system',
        #     'content': "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."},
        #     {'role': 'user',
        #     'content': '{task}: How many gaps are there between the earth and the moon?\n\n{student_solution}: There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.\n\n-----------------\n\nLet\'s first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.'}]},
        # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}]
        ```

    Citations:

        ```
        @misc{doosterlinck2024anchoredpreferenceoptimizationcontrastive,
            title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment},
            author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri},
            year={2024},
            eprint={2408.06266},
            archivePrefix={arXiv},
            primaryClass={cs.LG},
            url={https://arxiv.org/abs/2408.06266},
        }
        ```
    """

    system_prompt: str = SYSTEM_PROMPT
    _template: Union[Template, None] = PrivateAttr(...)

    def load(self) -> None:
        super().load()
        _path = str(
            importlib_resources.files("distilabel")
            / "steps"
            / "tasks"
            / "templates"
            / "clair.jinja2"
        )
        with open(_path, "r") as f:
            self._template = Template(f.read())

    @property
    def inputs(self) -> "StepColumns":
        return ["task", "student_solution"]

    @property
    def outputs(self) -> "StepColumns":
        return ["revision", "rational", "model_name"]

    def format_input(self, input: Dict[str, Any]) -> "ChatType":
        """The input is formatted as a `ChatType` assuming that the instruction
        is the first interaction from the user within a conversation."""
        return [
            {"role": "system", "content": self.system_prompt},
            {
                "role": "user",
                "content": self._template.render(
                    task=input["task"], student_solution=input["student_solution"]
                ),
            },
        ]

    def format_output(
        self, output: Union[str, None], input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """The output is formatted as a list with the score of each instruction-response pair.

        Args:
            output: the raw output of the LLM.
            input: the input to the task. Used for obtaining the number of responses.

        Returns:
            A dict with the key `scores` containing the scores for each instruction-response pair.
        """
        if output is None:
            return self._default_error()

        return self._format_output(output)

    def _format_output(self, output: Union[str, None]) -> Dict[str, Any]:
        if "**Corrected Student Solution:**" in output:
            splits = output.split("**Corrected Student Solution:**")
        elif "{corrected_student_solution}:" in output:
            splits = output.split("{corrected_student_solution}:")
        elif "{corrected_student_solution}" in output:
            splits = output.split("{corrected_student_solution}")
        elif "**Worsened Student Solution:**" in output:
            splits = output.split("**Worsened Student Solution:**")
        elif "{worsened_student_solution}:" in output:
            splits = output.split("{worsened_student_solution}:")
        elif "{worsened_student_solution}" in output:
            splits = output.split("{worsened_student_solution}")
        else:
            splits = None

        # Safety check when the output doesn't follow the expected format
        if not splits:
            return self._default_error()

        if len(splits) >= 2:
            revision = splits[1]
            revision = revision.strip("\n\n").strip()  # noqa: B005

            rational = splits[0]
            if "{teacher_reasoning}" in rational:
                rational = rational.split("{teacher_reasoning}")[1].strip(":").strip()
            rational = rational.strip("\n\n").strip()  # noqa: B005
        else:
            return self._default_error()
        return {"revision": revision, "rational": rational}

    def _default_error(self) -> Dict[str, None]:
        return {"revision": None, "rational": None}
format_input(input)

The input is formatted as a ChatType assuming that the instruction is the first interaction from the user within a conversation.

Source code in src/distilabel/steps/tasks/clair.py
def format_input(self, input: Dict[str, Any]) -> "ChatType":
    """The input is formatted as a `ChatType` assuming that the instruction
    is the first interaction from the user within a conversation."""
    return [
        {"role": "system", "content": self.system_prompt},
        {
            "role": "user",
            "content": self._template.render(
                task=input["task"], student_solution=input["student_solution"]
            ),
        },
    ]
format_output(output, input)

The output is formatted as a list with the score of each instruction-response pair.

Parameters:

Name Type Description Default
output Union[str, None]

the raw output of the LLM.

required
input Dict[str, Any]

the input to the task. Used for obtaining the number of responses.

required

Returns:

Type Description
Dict[str, Any]

A dict with the key scores containing the scores for each instruction-response pair.

Source code in src/distilabel/steps/tasks/clair.py
def format_output(
    self, output: Union[str, None], input: Dict[