mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com>
This commit is contained in:
143
applications/ColossalChat/coati/dataset/conversation.py
Executable file
143
applications/ColossalChat/coati/dataset/conversation.py
Executable file
@@ -0,0 +1,143 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
tokenizer: PreTrainedTokenizer
|
||||
system_message: str
|
||||
chat_template: str
|
||||
stop_ids: List[int]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
|
||||
"""
|
||||
Setup the conversation template from config
|
||||
"""
|
||||
tokenizer.chat_template = config["chat_template"]
|
||||
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
|
||||
conv.clear()
|
||||
return conv
|
||||
|
||||
def clear(self):
|
||||
self.messages = []
|
||||
|
||||
@classmethod
|
||||
def get_conversation_template_keys(cls):
|
||||
return ["system_message", "chat_template"]
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(
|
||||
{k: self.__dict__[k] for k in self.__dict__ if k not in ["tokenizer", "messages"]},
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any:
|
||||
"""
|
||||
Retrieves the prompt for the conversation.
|
||||
|
||||
Args:
|
||||
length (int, optional): The number of messages to include in the prompt. Defaults to None.
|
||||
get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False.
|
||||
add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False.
|
||||
|
||||
Returns:
|
||||
str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information.
|
||||
"""
|
||||
|
||||
if length is None:
|
||||
length = len(self.messages)
|
||||
|
||||
assert length <= len(self.messages)
|
||||
if self.system_message is not None:
|
||||
messages = [{"role": "system", "content": self.system_message}] + self.messages[:length]
|
||||
else:
|
||||
messages = self.messages[:length]
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return prompt
|
||||
|
||||
def save_prompt(self):
|
||||
return self.get_prompt()
|
||||
|
||||
def append_message(self, role: str, message: str):
|
||||
"""
|
||||
Append a message to the conversation.
|
||||
|
||||
Args:
|
||||
role (str): The role of the message sender. Must be either 'user' or 'assistant'.
|
||||
message (str): The content of the message.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the role is not 'user' or 'assistant'.
|
||||
"""
|
||||
assert role in ["user", "assistant"]
|
||||
self.messages.append({"role": role, "content": message})
|
||||
|
||||
def copy(self):
|
||||
return Conversation(tokenizer=self.tokenizer, chat_template=self.chat_template)
|
||||
|
||||
|
||||
def setup_conversation_template(
|
||||
tokenizer: PreTrainedTokenizer, chat_template_config: Dict = None, save_path: str = None
|
||||
) -> Conversation:
|
||||
"""
|
||||
Setup the conversation template, if chat_template is given, will replace the default chat_template of the tokenizer
|
||||
with it. Otherwise, the default chat_template will be used. If the tokenizer doesn't have a default chat_template,
|
||||
raise error to remind the user to set it manually.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer to use
|
||||
chat_template_config:
|
||||
{
|
||||
"system_message": str The system message to use
|
||||
"chat_template": str The chat_template to use, if can be a chat_template, a huggingface model path or a local model.
|
||||
if a huggeface model path or a local model, the chat_template will be loaded from the model's tokenizer's default chat template.
|
||||
"stop_ids": List[int], the token ids used to terminate generation. You need to provide this for ppo training and generation.
|
||||
}
|
||||
"""
|
||||
if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
|
||||
# Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
|
||||
if "system_message" not in chat_template_config:
|
||||
logger.warning("No system message is provided, will not use system message.")
|
||||
if "chat_template" not in chat_template_config:
|
||||
logger.warning("No chat_template is provided, will try to load it from the tokenizer.")
|
||||
if tokenizer.chat_template != None:
|
||||
chat_template_config["chat_template"] = tokenizer.chat_template
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(chat_template_config["chat_template"])
|
||||
if tokenizer.chat_template != None:
|
||||
chat_template_config["chat_template"] = tokenizer.chat_template
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
|
||||
)
|
||||
logger.warning(
|
||||
f"chat_template is provided as a local model path or huggingface model path, loaded chat_template from \"{chat_template_config['chat_template']}\"."
|
||||
)
|
||||
except OSError:
|
||||
pass
|
||||
except ValueError as e:
|
||||
raise ValueError(e)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, "w", encoding="utf8") as f:
|
||||
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
|
||||
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)
|
||||
return Conversation.from_config(tokenizer, chat_template_config)
|
Reference in New Issue
Block a user