mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
Support overall loss, update KTO logging
This commit is contained in:
@@ -49,6 +49,10 @@ def tokenize_sft(
|
||||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
|
||||
if messages[0]["from"] == "system":
|
||||
template.system_message = str(messages[0]["content"])
|
||||
messages.pop(0)
|
||||
template.messages = []
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
@@ -148,11 +152,14 @@ def tokenize_prompt(
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
|
||||
if messages[0]["from"] == "system":
|
||||
template.system_message = str(messages[0]["content"])
|
||||
messages.pop(0)
|
||||
|
||||
for idx, mess in enumerate(messages):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
f"Message should iterate between user and assistant and starts with a \
|
||||
line from the user. Got the following data:\n{messages}"
|
||||
f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
|
||||
)
|
||||
template.append_message(mess["from"], mess["content"])
|
||||
|
||||
@@ -225,6 +232,10 @@ def tokenize_rlhf(
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if context[0]["from"] == "system":
|
||||
template.system_message = str(context[0]["content"])
|
||||
context.pop(0)
|
||||
|
||||
for idx, mess in enumerate(context):
|
||||
if mess["from"] != template.roles[idx % 2]:
|
||||
raise ValueError(
|
||||
@@ -345,6 +356,10 @@ def tokenize_kto(
|
||||
template = deepcopy(conversation_template)
|
||||
template.clear()
|
||||
|
||||
if prompt[0]["from"] == "system":
|
||||
template.system_message = str(prompt[0]["content"])
|
||||
prompt.pop(0)
|
||||
|
||||
if prompt[0].get("from", None) != "user":
|
||||
raise ValueError("conversation should start with user")
|
||||
if completion.get("from", None) != "assistant":
|
||||
|
Reference in New Issue
Block a user