mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
upgrade colossal-chat support tp_group>1, add sp for sft
This commit is contained in:
@@ -17,6 +17,7 @@ class Conversation:
|
||||
system_message: str
|
||||
chat_template: str
|
||||
stop_ids: List[int]
|
||||
end_of_assistant: str
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer: PreTrainedTokenizer, config: Dict):
|
||||
@@ -24,7 +25,7 @@ class Conversation:
|
||||
Setup the conversation template from config
|
||||
"""
|
||||
tokenizer.chat_template = config["chat_template"]
|
||||
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"])
|
||||
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"])
|
||||
conv.clear()
|
||||
return conv
|
||||
|
||||
@@ -109,6 +110,8 @@ def setup_conversation_template(
|
||||
"""
|
||||
if any([s not in chat_template_config.keys() for s in Conversation.get_conversation_template_keys()]):
|
||||
# Try to automatically set up conversation template, if fail, it throws an error that you need to do it manually
|
||||
if "end_of_assistant" not in chat_template_config:
|
||||
raise ValueError("Please set the end of assistant token.")
|
||||
if "system_message" not in chat_template_config:
|
||||
logger.warning("No system message is provided, will not use system message.")
|
||||
if "chat_template" not in chat_template_config:
|
||||
|
Reference in New Issue
Block a user