upgrade colossal-chat support tp_group>1, add sp for sft

This commit is contained in:
YeAnbang
2024-05-27 05:55:57 +00:00
parent 73e88a5553
commit 7a7e86987d
33 changed files with 7574 additions and 105 deletions

View File

@@ -55,6 +55,8 @@ def supervised_tokenize_sft(
for mess in messages:
from_str = mess["from"]
if from_str is None:
print(mess)
if from_str.lower() == "human":
from_str = "user"
elif from_str.lower() == "assistant":
@@ -95,17 +97,26 @@ def supervised_tokenize_sft(
target_turn = turns[target_turn_index - 1]
prompt = template.get_prompt(2 * target_turn)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
conversation_template.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
labels = [ignore_index] * len(tokenized)
label_decode = []
for start, end in zip(starts, ends):
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
labels = labels + [ignore_index]
labels[start : end + 1] = tokenized[start : end + 1]
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
labels[start : end] = tokenized[start : end]
# truncate the sequence at the last token that requires loss calculation
to_truncate_len = 0
for i in range(len(tokenized) - 1, -1, -1):
if labels[i] == ignore_index:
to_truncate_len += 1
else:
break
tokenized = tokenized[: len(tokenized) - to_truncate_len]
labels = labels[: len(labels) - to_truncate_len]
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
@@ -123,6 +134,20 @@ def supervised_tokenize_sft(
# For some model without bos/eos may raise the following errors
try:
inputs_decode = tokenizer.decode(tokenized)
start = 0
end = 0
label_decode = []
for i in range(len(labels)):
if labels[i] == ignore_index:
if start!=end:
label_decode.append(tokenizer.decode(labels[start+1:i], skip_special_tokens=False))
start = i
end = i
else:
end = i
if i == len(labels) - 1:
label_decode.append(tokenizer.decode(labels[start+1:], skip_special_tokens=False))
except TypeError as e:
raise TypeError(str(e) + f"\nUnable to decode input_ids: {tokenized}")
@@ -191,7 +216,9 @@ def tokenize_prompt_dataset(
# Prepare data
prompt = template.get_prompt(target_turn, add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: target_turn], prompt,
conversation_template.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized
@@ -219,7 +246,8 @@ def apply_rlhf_data_format(
):
target_turn = int(len(template.messages) / 2)
prompt = template.get_prompt(target_turn * 2)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt)
chunks, require_loss = split_templated_prompt_into_chunks(template.messages[: 2 * target_turn], prompt,
tempalte.end_of_assistant)
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
loss_mask = [0] * len(tokenized)
mask_token = tokenizer.eos_token_id or tokenizer.pad_token_id
@@ -232,8 +260,8 @@ def apply_rlhf_data_format(
if end == len(tokenized):
tokenized = tokenized + [tokenizer.eos_token_id]
loss_mask = loss_mask + [1]
loss_mask[start : end + 1] = [1] * len(loss_mask[start : end + 1])
label_decode.append(tokenizer.decode(tokenized[start : end + 1], skip_special_tokens=False))
loss_mask[start : end] = [1] * len(loss_mask[start : end])
label_decode.append(tokenizer.decode(tokenized[start : end], skip_special_tokens=False))
if tokenizer.bos_token_id is not None:
if tokenized[0] != tokenizer.bos_token_id:
tokenized = [tokenizer.bos_token_id] + tokenized