modify data loader

This commit is contained in:
Tong Li 2025-03-06 10:49:44 +08:00
parent 070907dd7f
commit c15225bc52

View File

@ -387,11 +387,6 @@ def apply_chat_template_and_mask(
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
# Left padding for generation.
# if tokenizer.padding_side == "right":
# tokens.extend([tokenizer.pad_token_id] * to_pad)
# assistant_mask.extend([False] * to_pad)
# attention_mask.extend([0] * to_pad)
# else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
@ -405,7 +400,9 @@ def apply_chat_template_and_mask(
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
if gt_answer is not None:
gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt")
gt_answer = tokenizer.encode(
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}