mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
fix dataloader
This commit is contained in:
@@ -73,9 +73,12 @@ def supervised_tokenize_sft(
|
||||
lo, hi = 0, len(turns)
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
if max_length - 1 < len(
|
||||
tokenizer([template.get_prompt(2 * turns[mid] - 1)], add_special_tokens=False)["input_ids"][0]
|
||||
):
|
||||
prompt = template.get_prompt(2 * turns[mid] - 1)
|
||||
chunks, require_loss = split_templated_prompt_into_chunks(
|
||||
template.messages[: 2 * turns[mid] - 1], prompt, conversation_template.end_of_assistant
|
||||
)
|
||||
tokenized, starts, ends = tokenize_and_concatenate(tokenizer, chunks, require_loss)
|
||||
if max_length - 1 < len(tokenized):
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 1
|
||||
@@ -114,6 +117,7 @@ def supervised_tokenize_sft(
|
||||
to_truncate_len += 1
|
||||
else:
|
||||
break
|
||||
to_truncate_len = max(len(tokenized) - max_length, to_truncate_len)
|
||||
tokenized = tokenized[: len(tokenized) - to_truncate_len]
|
||||
labels = labels[: len(labels) - to_truncate_len]
|
||||
|
||||
@@ -356,48 +360,24 @@ def tokenize_rlhf(
|
||||
rejected_loss_mask,
|
||||
rejected_label_decode,
|
||||
) = (None, None, None, None, None, None)
|
||||
if (
|
||||
len(tokenizer([chosen.get_prompt(len(chosen.messages))], add_special_tokens=False)["input_ids"][0])
|
||||
<= max_length - 1
|
||||
and len(tokenizer([rejected.get_prompt(len(rejected.messages))], add_special_tokens=False)["input_ids"][0])
|
||||
<= max_length - 1
|
||||
):
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
rejected_data_packed = apply_rlhf_data_format(
|
||||
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
||||
)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
rejected_data_packed["label_decode"],
|
||||
)
|
||||
chosen_data_packed = apply_rlhf_data_format(chosen, tokenizer, round_of_context)
|
||||
(chosen_input_ids, chosen_loss_mask, chosen_label_decode) = (
|
||||
chosen_data_packed["input_ids"],
|
||||
chosen_data_packed["loss_mask"],
|
||||
chosen_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask.count(0) == len(chosen_loss_mask) or rejected_loss_mask.count(0) == len(rejected_loss_mask):
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
chosen_label_decode=None,
|
||||
rejected_input_ids=None,
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
rejected_data_packed = apply_rlhf_data_format(
|
||||
rejected, tokenizer, round_of_context, mask_out_target_assistant_line_end=True
|
||||
)
|
||||
(rejected_input_ids, rejected_loss_mask, rejected_label_decode) = (
|
||||
rejected_data_packed["input_ids"],
|
||||
rejected_data_packed["loss_mask"],
|
||||
rejected_data_packed["label_decode"],
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_input_ids,
|
||||
"chosen_loss_mask": chosen_loss_mask,
|
||||
"chosen_label_decode": chosen_label_decode,
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
else:
|
||||
if len(chosen_input_ids) > max_length or len(rejected_input_ids) > max_length:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
@@ -406,3 +386,22 @@ def tokenize_rlhf(
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
# Check if loss mask is all 0s (no loss), this may happen when the tokenized length is too long
|
||||
if chosen_loss_mask[1:].count(1) == 0 or rejected_loss_mask[1:].count(1) == 0:
|
||||
return dict(
|
||||
chosen_input_ids=None,
|
||||
chosen_loss_mask=None,
|
||||
chosen_label_decode=None,
|
||||
rejected_input_ids=None,
|
||||
rejected_loss_mask=None,
|
||||
rejected_label_decode=None,
|
||||
)
|
||||
|
||||
return {
|
||||
"chosen_input_ids": chosen_input_ids,
|
||||
"chosen_loss_mask": chosen_loss_mask,
|
||||
"chosen_label_decode": chosen_label_decode,
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_loss_mask": rejected_loss_mask,
|
||||
"rejected_label_decode": rejected_label_decode,
|
||||
}
|
||||
|
Reference in New Issue
Block a user