ColossalAI/applications/ColossalChat/coati/dataset/conversation.py
YeAnbang df5e9c53cf
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>
2024-03-29 14:12:29 +08:00

144 lines
6.1 KiB
Python
Executable File

import dataclasses
import json
import os
from typing import Any, Dict, List
import torch.distributed as dist
from transformers import AutoTokenizer, PreTrainedTokenizer
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
@dataclasses.dataclass
class Conversation:
tokenizer: PreTrainedTokenizer
system_message: str
chat_template: str
stop_ids: List[int]
@classmethod
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
"""
Setup the conversation template from config
"""
tokenizer.chat_template = config["chat_template"]
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
conv.clear()
return conv
def clear(self):
self.messages = []
@classmethod
def get_conversation_template_keys(cls):
return ["system_message", "chat_template"]
def __str__(self):
return json.dumps(
{k: self.__dict__[k] for k in self.__dict__ if k not in ["tokenizer", "messages"]},
ensure_ascii=False,
indent=4,
)
def get_prompt(self, length: int = None, add_generation_prompt=False) -> Any:
"""
Retrieves the prompt for the conversation.
Args:
length (int, optional): The number of messages to include in the prompt. Defaults to None.
get_seps_info (bool, optional): Whether to include separator information in the output. Defaults to False.
add_generation_prompt (bool, optional): Whether to add the assistant line start token in generation (for generation only). Defaults to False.
Returns:
str or tuple: The prompt string if get_seps_info is False, otherwise a tuple containing the prompt string and separator information.
"""
if length is None:
length = len(self.messages)
assert length <= len(self.messages)
if self.system_message is not None:
messages = [{"role": "system", "content": self.system_message}] + self.messages[:length]
else:
messages = self.messages[:length]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt
)
return prompt
def save_prompt(self):
return self.get_prompt()
def append_message(self, role: str, message: str):
"""
Append a message to the conversation.
Args:
role (str): The role of the message sender. Must be either 'user' or 'assistant'.
message (str): The content of the message.
Raises:
AssertionError: If the role is not 'user' or 'assistant'.
"""
assert role in ["user", "assistant"]
self.messages.append({"role": role, "content": message})
def copy(self):
return Conversation(tokenizer=self.tokenizer, chat_template=self.chat_template)
def setup_conversation_template(
tokenizer: PreTrainedTokenizer, chat_template_config: Dict = None, save_path: str = None
) -> Conversation:
"""
Setup the conversation template, if chat_template is given, will replace the default chat_template of the tokenizer
with it. Otherwise, the default chat_template will be used. If the tokenizer doesn't have a default chat_template,
raise error to remind the user to set it manually.
Args:
tokenizer: The tokenizer to use
chat_template_config:
{
"system_message": str The system message to use
"chat_template": str The chat_template to use, if can be a chat_template, a huggingface model path or a local model.
if a huggeface model path or a local model, the chat_template will be loaded from the model's tokenizer's default chat template.
"stop_ids": List[int], the token ids used to terminate generation. You need to provide this for ppo training and generation.
}
"""
if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
# Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
if "system_message" not in chat_template_config:
logger.warning("No system message is provided, will not use system message.")
if "chat_template" not in chat_template_config:
logger.warning("No chat_template is provided, will try to load it from the tokenizer.")
if tokenizer.chat_template != None:
chat_template_config["chat_template"] = tokenizer.chat_template
else:
raise ValueError(
f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(chat_template_config["chat_template"])
if tokenizer.chat_template != None:
chat_template_config["chat_template"] = tokenizer.chat_template
else:
raise ValueError(
f"Load a tokenizer from {chat_template_config['chat_template']}, which doesn't have a default chat template, please set it manually."
)
logger.warning(
f"chat_template is provided as a local model path or huggingface model path, loaded chat_template from \"{chat_template_config['chat_template']}\"."
)
except OSError:
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
json.dump(chat_template_config, f, indent=4, ensure_ascii=False)
return Conversation.from_config(tokenizer, chat_template_config)