mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-22 15:26:57 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			415 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			415 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| """
 | |
| tokenization utils for constructing dataset for ppo, dpo, sft, rm
 | |
| """
 | |
| 
 | |
| import warnings
 | |
| from copy import deepcopy
 | |
| from typing import Any, Dict, List, Union
 | |
| 
 | |
| from coati.dataset.conversation import Conversation
 | |
| from coati.dataset.utils import split_templated_prompt_into_chunks, tokenize_and_concatenate
 | |
| from datasets import dataset_dict
 | |
| from torch.utils.data import ConcatDataset, Dataset
 | |
| from transformers import PreTrainedTokenizer
 | |
| 
 | |
| from colossalai.logging import get_dist_logger
 | |
| 
 | |
| logger = get_dist_logger()
 | |
| 
 | |
| IGNORE_INDEX = -100
 | |
| 
 | |
| DSType = Union[Dataset, ConcatDataset, dataset_dict.Dataset]
 | |
| 
 | |
| 
 | |
| def supervised_tokenize_sft(
 | |
|     data_point: Dict[str, str],
 | |
|     tokenizer: PreTrainedTokenizer,
 | |
|     conversation_template: Conversation = None,
 | |
|     ignore_index: int = None,
 | |
|     max_length: int = 4096,
 | |
| ) -> Dict[str, Union[int, str, List[int]]]:
 | |
|     """
 | |
|     A tokenization function to tokenize an original pretraining data point as following
 | |
|          and calculate corresponding labels for sft training:
 | |
|         "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line end]Something here"
 | |
|                                             ^
 | |
|                                 end_of_system_line_position
 | |
| 
 | |
|     Args:
 | |
|         data_point: the data point of the following format
 | |
|             {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
 | |
|         tokenizer: the tokenizer whose
 | |
|         conversation_template: the conversation template to apply
 | |
|         ignore_index: the ignore index when calculate loss during training
 | |
|         max_length: the maximum context length
 | |
|     """
 | |
| 
 | |
|     if ignore_index is None:
 | |
|         ignore_index = IGNORE_INDEX
 | |
| 
 | |
|     messages = data_point["messages"]
 | |
|     template = deepcopy(conversation_template)
 | |
|     template.messages = []
 | |
| 
 | |
|     for mess in messages:
 | |
|         from_str = mess["from"]
 | |
|         if from_str is None:
 | |
|             print(mess)
 | |
|         if from_str.lower() == "human":
 | |
|             from_str = "user"
 | |
|         elif from_str.lower() == "assistant":
 | |
|             from_str = "assistant"
 | |
|         else:
 | |
|             raise ValueError(f"Unsupported role {from_str.lower()}")
 | |
| 
 | |
|         template.append_message(from_str, mess["content"])
 | |
| 
 | |
|     if len(template.messages) % 2 != 0:
 | |
|         template.messages = template.messages[0:-1]
 | |
| 
 | |
|     # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
 | |
|     turns = [i for i in range(1, len(messages) // 2 + 1)]
 | |
| 
 | |
|     lo, hi = 0, len(turns)
 | |
|     while lo < hi:
 | |
|         mid = (lo + hi) // 2
 | |
|         if max_length - 1 < len(
 | |
|             tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
 | |
|         ):
 | |
|             hi = mid
 | |
|         else:
 | |
|             lo = mid + 1
 | |
|     target_turn_index = lo
 | |
| 
 | |
|     # The tokenized length for first turn already exceeds `max_length - 1`.
 | |
|     if target_turn_index - 1 < 0:
 | |
|         warnings.warn("The tokenized length for first turn already exceeds `max_length - 1`.")
 | |
|         return dict(
 | |
|             input_ids=None,
 | |
|             labels=None,
 | |
|             inputs_decode=None,
 | |
|             labels_decode=None,
 | |
|             seq_length=None,
 | |
|             seq_category=None,
 | |
|         )
 | |
| 
 | |
|     target_turn = turns[target_turn_index - 1]
 | |
|     prompt = template.get_prompt(2 * target_turn)
 | |
|     chunks, require_loss = split_templated_prompt_into_chunks(
 | |
|         template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
 | |
|     )
 | |
|     tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
 | |
| 
 | |
|     labels = [ignore_index] * len(tokenized)
 | |
|     for start, end in zip(starts, ends):
 | |
|         if end == len(tokenized):
 | |
|             tokenized = tokenized + [tokenizer.eos_token_id]
 | |
|             labels = labels + [ignore_index]
 | |
|         labels[start:end] = tokenized[start:end]
 | |
| 
 | |
|     # truncate the sequence at the last token that requires loss calculation
 | |
|     to_truncate_len = 0
 | |
|     for i in range(len(tokenized) - 1, -1, -1):
 | |
|         if labels[i] == ignore_index:
 | |
|             to_truncate_len += 1
 | |
|         else:
 | |
|             break
 | |
|     tokenized = tokenized[: len(tokenized) - to_truncate_len]
 | |
|     labels = labels[: len(labels) - to_truncate_len]
 | |
| 
 | |
|     if tokenizer.bos_token_id is not None:
 | |
|         if tokenized[0] != tokenizer.bos_token_id:
 | |
|             tokenized = [tokenizer.bos_token_id] + tokenized
 | |
|             labels = [ignore_index] + labels
 | |
| 
 | |
|     if tokenizer.eos_token_id is not None:
 | |
|         # Force to add eos token at the end of the tokenized sequence
 | |
|         if tokenized[-1] != tokenizer.eos_token_id:
 | |
|             tokenized = tokenized + [tokenizer.eos_token_id]
 | |
|             labels = labels + [tokenizer.eos_token_id]
 | |
|         else:
 | |
|             labels[-1] = tokenizer.eos_token_id
 | |
| 
 | |
|     # For some model without bos/eos may raise the following errors
 | |
|     try:
 | |
|         inputs_decode = tokenizer.decode(tokenized)
 | |
|         start = 0
 | |
|         end = 0
 | |
|         label_decode = []
 | |
|         for i in range(len(labels)):
 | |
|             if labels[i] == ignore_index:
 | |
|                 if start != end:
 | |
|                     label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
 | |
|                 start = i
 | |
|                 end = i
 | |
|             else:
 | |
|                 end = i
 | |
|                 if i == len(labels) - 1:
 | |
|                     label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
 | |
| 
 | |
|     except TypeError as e:
 | |
|         raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
 | |
| 
 | |
|     # Check if all labels are ignored, this may happen when the tokenized length is too long
 | |
|     if labels.count(ignore_index) == len(labels):
 | |
|         return dict(
 | |
|             input_ids=None,
 | |
|             labels=None,
 | |
|             inputs_decode=None,
 | |
|             labels_decode=None,
 | |
|             seq_length=None,
 | |
|             seq_category=None,
 | |
|         )
 | |
| 
 | |
|     return dict(
 | |
|         input_ids=tokenized,
 | |
|         labels=labels,
 | |
|         inputs_decode=inputs_decode,
 | |
|         labels_decode=label_decode,
 | |
|         seq_length=len(tokenized),
 | |
|         seq_category=data_point["category"] if "category" in data_point else "None",
 | |
|     )
 | |
| 
 | |
| 
 | |
| def tokenize_prompt_dataset(
 | |
|     data_point: Dict[str, str],
 | |
|     tokenizer: PreTrainedTokenizer,
 | |
|     conversation_template: Conversation = None,
 | |
|     ignore_index: int = None,
 | |
|     max_length: int = 4096,
 | |
| ) -> Dict[str, Union[int, str, List[int]]]:
 | |
|     """
 | |
|     A tokenization function to tokenize an original pretraining data point as following for ppo training:
 | |
|         "Something here can be system message[user_line_start]User line[User line end][Assistant line start]Assistant line[Assistant line end]...[Assistant line start]"
 | |
|     Args:
 | |
|         data_point: the data point of the following format
 | |
|             {"messages": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}]}
 | |
|         tokenizer: the tokenizer whose
 | |
|         conversation_template: the conversation template to apply
 | |
|         ignore_index: the ignore index when calculate loss during training
 | |
|         max_length: the maximum context length
 | |
|     """
 | |
|     if ignore_index is None:
 | |
|         ignore_index = IGNORE_INDEX
 | |
| 
 | |
|     messages = data_point["messages"]
 | |
|     template = deepcopy(conversation_template)
 | |
|     template.messages = []
 | |
| 
 | |
|     for mess in messages:
 | |
|         from_str = mess["from"]
 | |
|         if from_str.lower() == "human":
 | |
|             from_str = "user"
 | |
|         elif from_str.lower() == "assistant":
 | |
|             from_str = "assistant"
 | |
|         else:
 | |
|             raise ValueError(f"Unsupported role {from_str.lower()}")
 | |
| 
 | |
|         template.append_message(from_str, mess["content"])
 | |
| 
 | |
|     # `target_turn_index` is the number of turns which exceeds `max_length - 1` for the first time.
 | |
|     target_turn = len(template.messages)
 | |
|     if target_turn % 2 != 1:
 | |
|         # exclude the answer if provided. keep only the prompt
 | |
|         target_turn = target_turn - 1
 | |
| 
 | |
|     # Prepare data
 | |
|     prompt = template.get_prompt(target_turn, add_generation_prompt=True)
 | |
|     chunks, require_loss = split_templated_prompt_into_chunks(
 | |
|         template.messages[:target_turn], prompt, conversation_template.end_of_assistant
 | |
|     )
 | |
|     tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
 | |
|     if tokenizer.bos_token_id is not None:
 | |
|         if tokenized[0] != tokenizer.bos_token_id:
 | |
|             tokenized = [tokenizer.bos_token_id] + tokenized
 | |
| 
 | |
|     # Skip overlength data
 | |
|     if max_length - 1 < len(tokenized):
 | |
|         return dict(
 | |
|             input_ids=None,
 | |
|             inputs_decode=None,
 | |
|             seq_length=None,
 | |
|             seq_category=None,
 | |
|         )
 | |
| 
 | |
|     # `inputs_decode` can be used to check whether the tokenization method is true.
 | |
|     return dict(
 | |
|         input_ids=tokenized,
 | |
|         inputs_decode=tokenizer.decode(tokenized),
 | |
|         seq_length=len(tokenized),
 | |
|         seq_category=data_point["category"] if "category" in data_point else "None",
 | |
|     )
 | |
| 
 | |
| 
 | |
| def apply_rlhf_data_format(
 | |
|     template: Conversation, tokenizer: Any, context_len: int, mask_out_target_assistant_line_end=False
 | |
| ):
 | |
|     target_turn = int(len(template.messages) / 2)
 | |
|     prompt = template.get_prompt(target_turn * 2)
 | |
|     chunks, require_loss = split_templated_prompt_into_chunks(
 | |
|         template.messages[: 2 * target_turn], prompt, template.end_of_assistant
 | |
|     )
 | |
|     tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
 | |
|     loss_mask = [0] * len(tokenized)
 | |
|     mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
 | |
|     if mask_token is None:
 | |
|         mask_token = 1  # If the tokenizer doesn't have eos_token or pad_token: Qwen
 | |
| 
 | |
|     label_decode = []
 | |
|     for start, end in zip(starts[-1:], ends[-1:]):
 | |
|         # only the last round (chosen/rejected) counts
 | |
|         if end == len(tokenized):
 | |
|             tokenized = tokenized + [tokenizer.eos_token_id]
 | |
|             loss_mask = loss_mask + [1]
 | |
|         loss_mask[start:end] = [1] * len(loss_mask[start:end])
 | |
|         label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
 | |
|     if tokenizer.bos_token_id is not None:
 | |
|         if tokenized[0] != tokenizer.bos_token_id:
 | |
|             tokenized = [tokenizer.bos_token_id] + tokenized
 | |
|             loss_mask = [0] + loss_mask
 | |
| 
 | |
|     if tokenizer.eos_token_id is not None:
 | |
|         # Force to add eos token at the end of the tokenized sequence
 | |
|         if tokenized[-1] != tokenizer.eos_token_id:
 | |
|             tokenized = tokenized + [tokenizer.eos_token_id]
 | |
|             loss_mask = loss_mask + [1]
 | |
|         else:
 | |
|             loss_mask[-1] = 1
 | |
| 
 | |
|     return {"input_ids": tokenized, "loss_mask": loss_mask, "label_decode": label_decode}
 | |
| 
 | |
| 
 | |
| def tokenize_rlhf(
 | |
|     data_point: Dict[str, str],
 | |
|     tokenizer: PreTrainedTokenizer,
 | |
|     conversation_template: Conversation = None,
 | |
|     ignore_index: int = None,
 | |
|     max_length: int = 4096,
 | |
| ) -> Dict[str, Union[int, str, List[int]]]:
 | |
|     """
 | |
|     A tokenization function to tokenize an original pretraining data point as following:
 | |
|         {"context": [{"from": "human", "content": "xxx"}, {"from": "assistant", "content": "xxx"}],
 | |
|         "chosen": {"from": "assistant", "content": "xxx"}, "rejected": {"from": "assistant", "content": "xxx"}}
 | |
|     """
 | |
|     if ignore_index is None:
 | |
|         ignore_index = IGNORE_INDEX
 | |
| 
 | |
|     context = data_point["context"]
 | |
|     template = deepcopy(conversation_template)
 | |
|     template.clear()
 | |
| 
 | |
|     for mess in context:
 | |
|         from_str = mess["from"]
 | |
|         if from_str.lower() == "human":
 | |
|             from_str = "user"
 | |
|         elif from_str.lower() == "assistant":
 | |
|             from_str = "assistant"
 | |
|         else:
 | |
|             raise ValueError(f"Unsupported role {from_str.lower()}")
 | |
| 
 | |
|         if len(template.messages) > 0 and from_str == template.messages[-1]["role"]:
 | |
|             # Concate adjacent message from the same role
 | |
|             template.messages[-1]["content"] = str(template.messages[-1]["content"] + " " + mess["content"])
 | |
|         else:
 | |
|             template.append_message(from_str, mess["content"])
 | |
| 
 | |
|     if len(template.messages) % 2 != 1:
 | |
|         warnings.warn(
 | |
|             "Please make sure leading context starts and ends with a line from human\nLeading context: "
 | |
|             + str(template.messages)
 | |
|         )
 | |
|         return dict(
 | |
|             chosen_input_ids=None,
 | |
|             chosen_loss_mask=None,
 | |
|             chosen_label_decode=None,
 | |
|             rejected_input_ids=None,
 | |
|             rejected_loss_mask=None,
 | |
|             rejected_label_decode=None,
 | |
|         )
 | |
|     round_of_context = int((len(template.messages) - 1) / 2)
 | |
| 
 | |
|     assert context[-1]["from"].lower() == "human", "The last message in context should be from human."
 | |
|     chosen = deepcopy(template)
 | |
|     rejected = deepcopy(template)
 | |
| 
 | |
|     for round in range(len(data_point["chosen"])):
 | |
|         from_str = data_point["chosen"][round]["from"]
 | |
|         if from_str.lower() == "human":
 | |
|             from_str = "user"
 | |
|         elif from_str.lower() == "assistant":
 | |
|             from_str = "assistant"
 | |
|         else:
 | |
|             raise ValueError(f"Unsupported role {from_str.lower()}")
 | |
|         chosen.append_message(from_str, data_point["chosen"][round]["content"])
 | |
| 
 | |
|     for round in range(len(data_point["rejected"])):
 | |
|         from_str = data_point["rejected"][round]["from"]
 | |
|         if from_str.lower() == "human":
 | |
|             from_str = "user"
 | |
|         elif from_str.lower() == "assistant":
 | |
|             from_str = "assistant"
 | |
|         else:
 | |
|             raise ValueError(f"Unsupported role {from_str.lower()}")
 | |
|         rejected.append_message(from_str, data_point["rejected"][round]["content"])
 | |
| 
 | |
|     (
 | |
|         chosen_input_ids,
 | |
|         chosen_loss_mask,
 | |
|         chosen_label_decode,
 | |
|         rejected_input_ids,
 | |
|         rejected_loss_mask,
 | |
|         rejected_label_decode,
 | |
|     ) = (None, None, None, None, None, None)
 | |
|     if (
 | |
|         len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
 | |
|         <= max_length - 1
 | |
|         and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
 | |
|         <= max_length - 1
 | |
|     ):
 | |
|         chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
 | |
|         (chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
 | |
|             chosen_data_packed["input_ids"],
 | |
|             chosen_data_packed["loss_mask"],
 | |
|             chosen_data_packed["label_decode"],
 | |
|         )
 | |
| 
 | |
|         rejected_data_packed = apply_rlhf_data_format(
 | |
|             rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
 | |
|         )
 | |
|         (rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
 | |
|             rejected_data_packed["input_ids"],
 | |
|             rejected_data_packed["loss_mask"],
 | |
|             rejected_data_packed["label_decode"],
 | |
|         )
 | |
| 
 | |
|         # Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
 | |
|         if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
 | |
|             return dict(
 | |
|                 chosen_input_ids=None,
 | |
|                 chosen_loss_mask=None,
 | |
|                 chosen_label_decode=None,
 | |
|                 rejected_input_ids=None,
 | |
|                 rejected_loss_mask=None,
 | |
|                 rejected_label_decode=None,
 | |
|             )
 | |
| 
 | |
|         return {
 | |
|             "chosen_input_ids": chosen_input_ids,
 | |
|             "chosen_loss_mask": chosen_loss_mask,
 | |
|             "chosen_label_decode": chosen_label_decode,
 | |
|             "rejected_input_ids": rejected_input_ids,
 | |
|             "rejected_loss_mask": rejected_loss_mask,
 | |
|             "rejected_label_decode": rejected_label_decode,
 | |
|         }
 | |
|     else:
 | |
|         return dict(
 | |
|             chosen_input_ids=None,
 | |
|             chosen_loss_mask=None,
 | |
|             chosen_label_decode=None,
 | |
|             rejected_input_ids=None,
 | |
|             rejected_loss_mask=None,
 | |
|             rejected_label_decode=None,
 | |
|         )
 |