mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 17:10:03 +00:00
add kto
This commit is contained in:
@@ -405,3 +405,66 @@ def tokenize_rlhf(
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
|
||||
|
||||
def tokenize_kto(
|
||||
data_point: Dict[str, str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
conversation_template: Conversation = None,
|
||||
ignore_index: int = None,
|
||||
max_length: int = 4096,
|
||||
) -> Dict[str, Union[int, str, List[int]]]:
|
||||
"""
|
||||
Tokenize a dataset for KTO training
|
||||
The raw input data is conversation that have the following format
|
||||
{
|
||||
"prompt": [{"from": "human", "content": "xxx"}...],
|
||||
"completion": {"from": "assistant", "content": "xxx"},
|
||||
"label": true/false
|
||||
}
|
||||
It returns three fields
|
||||
The context, which contain the query and the assistant start,
|
||||
the completion, which only contains the assistance's answer,
|
||||
and a binary label, which indicates if the sample is prefered or not
|
||||
"""
|
||||
if ignore_index is None:
|
||||
ignore_index = IGNORE_INDEX
|
||||
|
||||
prompt = data_point["prompt"]
|
||||
completion = data_point["completion"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0].get("from", None) != "human":
|
||||
raise ValueError("conversation should start with human")
|
||||
if completion.get("from", None) != "assistant":
|
||||
raise ValueError("conversation should end with assistant")
|
||||
|
||||
for mess in prompt:
|
||||
if mess.get("from", None) == "human":
|
||||
template.append_message("user", mess["content"])
|
||||
elif mess.get("from", None) == "assistant":
|
||||
template.append_message("assistant", mess["content"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported role {mess.get('from', None)}")
|
||||
generation_prompt = template.get_prompt(len(prompt), add_generation_prompt=True)
|
||||
template.append_message("assistant", completion["content"])
|
||||
full_prompt = template.get_prompt(len(prompt) + 1, add_generation_prompt=False)
|
||||
tokenized_full_prompt = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
|
||||
if len(tokenized_full_prompt) + 1 > max_length:
|
||||
return dict(prompt=None, completion=None, label=None, input_id_decode=None, completion_decode=None)
|
||||
tokenized_generation_prompt = tokenizer(generation_prompt, add_special_tokens=False)["input_ids"]
|
||||
tokenized_completion = tokenized_full_prompt[len(tokenized_generation_prompt) :]
|
||||
tokenized_completion = deepcopy(tokenized_completion)
|
||||
if tokenizer.bos_token_id is not None and tokenized_generation_prompt[0] != tokenizer.bos_token_id:
|
||||
tokenized_generation_prompt = [tokenizer.bos_token_id] + tokenized_generation_prompt
|
||||
decoded_full_prompt = tokenizer.decode(tokenized_full_prompt, skip_special_tokens=False)
|
||||
decoded_completion = tokenizer.decode(tokenized_completion, skip_special_tokens=False)
|
||||
|
||||
return {
|
||||
"prompt": tokenized_generation_prompt,
|
||||
"completion": tokenized_completion,
|
||||
"label": data_point["label"],
|
||||
"input_id_decode": decoded_full_prompt,
|
||||
"completion_decode": decoded_completion,
|
||||
}
|
||||
|
Reference in New Issue
Block a user