mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-31 15:25:21 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
b1031f7244
commit
1b880ce095
@ -25,7 +25,9 @@ class Conversation:
|
|||||||
Setup the conversation template from config
|
Setup the conversation template from config
|
||||||
"""
|
"""
|
||||||
tokenizer.chat_template = config["chat_template"]
|
tokenizer.chat_template = config["chat_template"]
|
||||||
conv = cls(tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"])
|
conv = cls(
|
||||||
|
tokenizer, config["system_message"], config["chat_template"], config["stop_ids"], config["end_of_assistant"]
|
||||||
|
)
|
||||||
conv.clear()
|
conv.clear()
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
|
@ -97,8 +97,9 @@ def supervised_tokenize_sft(
|
|||||||
|
|
||||||
target_turn = turns[target_turn_index - 1]
|
target_turn = turns[target_turn_index - 1]
|
||||||
prompt = template.get_prompt(2 * target_turn)
|
prompt = template.get_prompt(2 * target_turn)
|
||||||
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
|
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||||
conversation_template.end_of_assistant)
|
template.messages[: 2 * target_turn], prompt, conversation_template.end_of_assistant
|
||||||
|
)
|
||||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||||
|
|
||||||
labels = [ignore_index] * len(tokenized)
|
labels = [ignore_index] * len(tokenized)
|
||||||
@ -106,7 +107,7 @@ def supervised_tokenize_sft(
|
|||||||
if end == len(tokenized):
|
if end == len(tokenized):
|
||||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||||
labels = labels + [ignore_index]
|
labels = labels + [ignore_index]
|
||||||
labels[start : end] = tokenized[start : end]
|
labels[start:end] = tokenized[start:end]
|
||||||
|
|
||||||
# truncate the sequence at the last token that requires loss calculation
|
# truncate the sequence at the last token that requires loss calculation
|
||||||
to_truncate_len = 0
|
to_truncate_len = 0
|
||||||
@ -139,14 +140,14 @@ def supervised_tokenize_sft(
|
|||||||
label_decode = []
|
label_decode = []
|
||||||
for i in range(len(labels)):
|
for i in range(len(labels)):
|
||||||
if labels[i] == ignore_index:
|
if labels[i] == ignore_index:
|
||||||
if start!=end:
|
if start != end:
|
||||||
label_decode.append(tokenizer.decode(labels[start+1:i], skip_special_tokens=False))
|
label_decode.append(tokenizer.decode(labels[start + 1 : i], skip_special_tokens=False))
|
||||||
start = i
|
start = i
|
||||||
end = i
|
end = i
|
||||||
else:
|
else:
|
||||||
end = i
|
end = i
|
||||||
if i == len(labels) - 1:
|
if i == len(labels) - 1:
|
||||||
label_decode.append(tokenizer.decode(labels[start+1:], skip_special_tokens=False))
|
label_decode.append(tokenizer.decode(labels[start + 1 :], skip_special_tokens=False))
|
||||||
|
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
|
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
|
||||||
@ -216,8 +217,9 @@ def tokenize_prompt_dataset(
|
|||||||
|
|
||||||
# Prepare data
|
# Prepare data
|
||||||
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
|
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
|
||||||
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: target_turn], prompt,
|
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||||
conversation_template.end_of_assistant)
|
template.messages[:target_turn], prompt, conversation_template.end_of_assistant
|
||||||
|
)
|
||||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||||
if tokenizer.bos_token_id is not None:
|
if tokenizer.bos_token_id is not None:
|
||||||
if tokenized[0] != tokenizer.bos_token_id:
|
if tokenized[0] != tokenizer.bos_token_id:
|
||||||
@ -246,8 +248,9 @@ def apply_rlhf_data_format(
|
|||||||
):
|
):
|
||||||
target_turn = int(len(template.messages) / 2)
|
target_turn = int(len(template.messages) / 2)
|
||||||
prompt = template.get_prompt(target_turn * 2)
|
prompt = template.get_prompt(target_turn * 2)
|
||||||
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
|
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||||
template.end_of_assistant)
|
template.messages[: 2 * target_turn], prompt, template.end_of_assistant
|
||||||
|
)
|
||||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||||
loss_mask = [0] * len(tokenized)
|
loss_mask = [0] * len(tokenized)
|
||||||
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
|
||||||
@ -260,8 +263,8 @@ def apply_rlhf_data_format(
|
|||||||
if end == len(tokenized):
|
if end == len(tokenized):
|
||||||
tokenized = tokenized + [tokenizer.eos_token_id]
|
tokenized = tokenized + [tokenizer.eos_token_id]
|
||||||
loss_mask = loss_mask + [1]
|
loss_mask = loss_mask + [1]
|
||||||
loss_mask[start : end] = [1] * len(loss_mask[start : end])
|
loss_mask[start:end] = [1] * len(loss_mask[start:end])
|
||||||
label_decode.append(tokenizer.decode(tokenized[start : end], skip_special_tokens=False))
|
label_decode.append(tokenizer.decode(tokenized[start:end], skip_special_tokens=False))
|
||||||
if tokenizer.bos_token_id is not None:
|
if tokenizer.bos_token_id is not None:
|
||||||
if tokenized[0] != tokenizer.bos_token_id:
|
if tokenized[0] != tokenizer.bos_token_id:
|
||||||
tokenized = [tokenizer.bos_token_id] + tokenized
|
tokenized = [tokenizer.bos_token_id] + tokenized
|
||||||
|
@ -121,8 +121,10 @@ def split_templated_prompt_into_chunks(messages: List[Dict[str, str]], prompt: s
|
|||||||
for line in messages:
|
for line in messages:
|
||||||
content_length = len(line["content"])
|
content_length = len(line["content"])
|
||||||
first_occur = prompt.find(line["content"], start_idx)
|
first_occur = prompt.find(line["content"], start_idx)
|
||||||
if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length:]:
|
if line["role"].lower() == "assistant" and end_of_assistant in prompt[first_occur + content_length :]:
|
||||||
content_length = prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
|
content_length = (
|
||||||
|
prompt.find(end_of_assistant, first_occur + content_length) + len(end_of_assistant) - first_occur
|
||||||
|
)
|
||||||
if prompt[first_occur - 1] != " ":
|
if prompt[first_occur - 1] != " ":
|
||||||
chunks.append(prompt[start_idx:first_occur])
|
chunks.append(prompt[start_idx:first_occur])
|
||||||
chunks.append(prompt[first_occur : first_occur + content_length])
|
chunks.append(prompt[first_occur : first_occur + content_length])
|
||||||
|
@ -37,4 +37,4 @@ class Critic(BaseModel):
|
|||||||
return self.model.get_input_embeddings()
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.model.get_output_embeddings()
|
return self.model.get_output_embeddings()
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
7
|
7
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|im_end|>"
|
"end_of_assistant": "<|im_end|>"
|
||||||
}
|
}
|
||||||
|
@ -6,4 +6,4 @@
|
|||||||
151643
|
151643
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|im_end|>"
|
"end_of_assistant": "<|im_end|>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|im_end|>"
|
"end_of_assistant": "<|im_end|>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|user|>"
|
"end_of_assistant": "<|user|>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|im_end|>"
|
"end_of_assistant": "<|im_end|>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "</s>"
|
"end_of_assistant": "</s>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
100001
|
100001
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|end▁of▁sentence|>"
|
"end_of_assistant": "<|end▁of▁sentence|>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "</s>"
|
"end_of_assistant": "</s>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
50256
|
50256
|
||||||
],
|
],
|
||||||
"end_of_assistant": "<|im_end|>"
|
"end_of_assistant": "<|im_end|>"
|
||||||
}
|
}
|
||||||
|
@ -5,4 +5,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "</s>"
|
"end_of_assistant": "</s>"
|
||||||
}
|
}
|
||||||
|
@ -226,7 +226,7 @@ def main():
|
|||||||
"max_length": args.max_length,
|
"max_length": args.max_length,
|
||||||
},
|
},
|
||||||
keep_in_memory=False,
|
keep_in_memory=False,
|
||||||
num_proc= min(len(dataset), cpu_count()),
|
num_proc=min(len(dataset), cpu_count()),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
|
@ -6,4 +6,4 @@
|
|||||||
2
|
2
|
||||||
],
|
],
|
||||||
"end_of_assistant": "</s>"
|
"end_of_assistant": "</s>"
|
||||||
}
|
}
|
||||||
|
@ -1,36 +1,41 @@
|
|||||||
from coati.dataset import setup_conversation_template
|
|
||||||
from coati.dataset.conversation import Conversation
|
|
||||||
from coati.dataset.tokenization_utils import supervised_tokenize_sft
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from coati.dataset import setup_conversation_template
|
||||||
|
from coati.dataset.tokenization_utils import supervised_tokenize_sft
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
model_data_mapping = {
|
model_data_mapping = {
|
||||||
'THUDM/chatglm2-6b': 'THUDM_chatglm2-6b.json',
|
"THUDM/chatglm2-6b": "THUDM_chatglm2-6b.json",
|
||||||
'THUDM/chatglm3-6b': 'THUDM_chatglm3-6b.json',
|
"THUDM/chatglm3-6b": "THUDM_chatglm3-6b.json",
|
||||||
'baichuan-inc/Baichuan2-13B-Chat': 'baichuan-inc_Baichuan2-13B-Chat.json',
|
"baichuan-inc/Baichuan2-13B-Chat": "baichuan-inc_Baichuan2-13B-Chat.json",
|
||||||
'01-ai/Yi-1.5-9B-Chat': '01-ai_Yi-1.5-9B-Chat.json',
|
"01-ai/Yi-1.5-9B-Chat": "01-ai_Yi-1.5-9B-Chat.json",
|
||||||
'01-ai/Yi-34B': '01-ai_Yi-34B.json',
|
"01-ai/Yi-34B": "01-ai_Yi-34B.json",
|
||||||
'deepseek-ai/DeepSeek-V2-Lite': 'deepseek-ai_DeepSeek-V2-Lite.json',
|
"deepseek-ai/DeepSeek-V2-Lite": "deepseek-ai_DeepSeek-V2-Lite.json",
|
||||||
'microsoft/phi-2': 'microsoft_phi-2.json',
|
"microsoft/phi-2": "microsoft_phi-2.json",
|
||||||
'mistralai/Mixtral-8x7B-Instruct-v0.1': 'mistralai_Mixtral-8x7B-Instruct-v0.1.json'
|
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai_Mixtral-8x7B-Instruct-v0.1.json",
|
||||||
}
|
}
|
||||||
chat_template_config_path = './config/conversation_template'
|
chat_template_config_path = "./config/conversation_template"
|
||||||
|
|
||||||
|
|
||||||
def test_tokenization_sft():
|
def test_tokenization_sft():
|
||||||
for model in model_data_mapping:
|
for model in model_data_mapping:
|
||||||
print(f"#############{model}#############")
|
print(f"#############{model}#############")
|
||||||
conversation_template_config = os.path.join(chat_template_config_path, model_data_mapping[model])
|
conversation_template_config = os.path.join(chat_template_config_path, model_data_mapping[model])
|
||||||
messages = [{"from": "human", "content": "What are the three primary colors?"},
|
messages = [
|
||||||
|
{"from": "human", "content": "What are the three primary colors?"},
|
||||||
{"from": "assistant", "content": "The three primary colors are red, blue, and yellow."},
|
{"from": "assistant", "content": "The three primary colors are red, blue, and yellow."},
|
||||||
{"from": "human", "content": "解释个人电脑和服务器之间的区别。"},
|
{"from": "human", "content": "解释个人电脑和服务器之间的区别。"},
|
||||||
{"from": "assistant", "content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。"}]
|
{
|
||||||
|
"from": "assistant",
|
||||||
|
"content": "个人电脑和服务器是两种不同类型的计算机系统,它们的主要区别在于用途、硬件配置和性能。 个人电脑,顾名思义,是为个人使用而设计的计算机。它们通常用于日常的工作、娱乐和学习,可以运行各种各样的应用程序和游戏。个人电脑的硬件配置一般是按照标准配置来设计的,不过也可以根据个人需求进行定制。 而服务器是为了满足大量用户的需求而设计的计算机系统,它们通常用于为用户提供各种网络服务,如网站、电子邮件和文件传输等。服务器通常需要高性能的硬件配置,并且可以承受高负载和长时间的运行。由于服务器需要支持大量用户的访问,它们通常配备多核处理器、大容量内存和大容量硬盘驱动器,以提高系统的运行速度和稳定性。 总之,个人电脑和服务器之间的主要区别在于它们的用途、硬件配置和性能。个人电脑用于个人使用,而服务器用于支持大量用户的访问。服务器的硬件配置通常比个人电脑更高,以保证系统的性能和稳定性。",
|
||||||
|
},
|
||||||
|
]
|
||||||
chat_template_config = json.load(open(conversation_template_config, "r", encoding="utf8"))
|
chat_template_config = json.load(open(conversation_template_config, "r", encoding="utf8"))
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True)
|
||||||
conversation_template = setup_conversation_template(
|
conversation_template = setup_conversation_template(
|
||||||
tokenizer, chat_template_config=chat_template_config, save_path=conversation_template_config
|
tokenizer, chat_template_config=chat_template_config, save_path=conversation_template_config
|
||||||
)
|
)
|
||||||
|
|
||||||
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
|
output = supervised_tokenize_sft({"messages": messages}, tokenizer, conversation_template)
|
||||||
with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
|
with open(f"./tests/test_data/chat_template/{model_data_mapping[model]}", "r", encoding="utf8") as f:
|
||||||
|
@ -582,4 +582,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 286,
|
"seq_length": 286,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -604,4 +604,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 297,
|
"seq_length": 297,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -600,4 +600,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 295,
|
"seq_length": 295,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -712,4 +712,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 351,
|
"seq_length": 351,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -582,4 +582,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 286,
|
"seq_length": 286,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -694,4 +694,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 342,
|
"seq_length": 342,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -578,4 +578,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 284,
|
"seq_length": 284,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -2006,4 +2006,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 998,
|
"seq_length": 998,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
@ -916,4 +916,4 @@
|
|||||||
],
|
],
|
||||||
"seq_length": 453,
|
"seq_length": 453,
|
||||||
"seq_category": "None"
|
"seq_category": "None"
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user