FormatTextGenerationDPO¶
Format the output of your LLMs for Direct Preference Optimization (DPO).
FormatTextGenerationDPO
is a Step
that formats the output of the combination of a TextGeneration
task with a preference Task
i.e. a task generating ratings
, so that those are used to rank the
existing generations and provide the chosen
and rejected
generations based on the ratings
.
Use this step to transform the output of a combination of a TextGeneration
+ a preference task such as
UltraFeedback
following the standard formatting from frameworks such as axolotl
or alignment-handbook
.
Note¶
The generations
column should contain at least two generations, the ratings
column should
contain the same number of ratings as generations.
Input & Output Columns¶
Inputs¶
-
system_prompt (
str
, optional): The system prompt used within theLLM
to generate thegenerations
, if available. -
instruction (
str
): The instruction used to generate thegenerations
with theLLM
. -
generations (
List[str]
): The generations produced by theLLM
. -
generation_models (
List[str]
, optional): The model names used to generate thegenerations
, only available if themodel_name
from theTextGeneration
task/s is combined into a single column named this way, otherwise, it will be ignored. -
ratings (
List[float]
): The ratings for each of thegenerations
, produced by a preference task such asUltraFeedback
.
Outputs¶
-
prompt (
str
): The instruction used to generate thegenerations
with theLLM
. -
prompt_id (
str
): TheSHA256
hash of theprompt
. -
chosen (
List[Dict[str, str]]
): Thechosen
generation based on theratings
. -
chosen_model (
str
, optional): The model name used to generate thechosen
generation, if thegeneration_models
are available. -
chosen_rating (
float
): The rating of thechosen
generation. -
rejected (
List[Dict[str, str]]
): Therejected
generation based on theratings
. -
rejected_model (
str
, optional): The model name used to generate therejected
generation, if thegeneration_models
are available. -
rejected_rating (
float
): The rating of therejected
generation.
Examples¶
Format your dataset for DPO fine tuning¶
from distilabel.steps import FormatTextGenerationDPO
format_dpo = FormatTextGenerationDPO()
format_dpo.load()
# NOTE: Both "system_prompt" and "generation_models" can be added optionally.
result = next(
format_dpo.process(
[
{
"instruction": "What's 2+2?",
"generations": ["4", "5", "6"],
"ratings": [1, 0, -1],
}
]
)
)
# >>> result
# [
# { 'instruction': "What's 2+2?",
# 'generations': ['4', '5', '6'],
# 'ratings': [1, 0, -1],
# 'prompt': "What's 2+2?",
# 'prompt_id': '7762ecf17ad41479767061a8f4a7bfa3b63d371672af5180872f9b82b4cd4e29',
# 'chosen': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '4'}],
# 'chosen_rating': 1,
# 'rejected': [{'role': 'user', 'content': "What's 2+2?"}, {'role': 'assistant', 'content': '6'}],
# 'rejected_rating': -1
# }
# ]