add simple grpo

This commit is contained in:
Tong Li
2025-02-23 22:54:26 +08:00
parent 8e6c9a4ab3
commit ffd3878a1e
8 changed files with 253 additions and 21 deletions

View File

@@ -356,6 +356,12 @@ def apply_chat_template_and_mask(
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
# Format for RL.
gt_answer = None
if "messages" in chat and "gt_answer" in chat:
gt_answer = chat["gt_answer"]
chat = [chat["messages"]]
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
@@ -389,6 +395,11 @@ def apply_chat_template_and_mask(
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
if gt_answer is not None:
gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt")
gt_answer = gt_answer.squeeze(1)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
return {
"input_ids": input_ids,
"attention_mask": attention_mask,