mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
add simple grpo
This commit is contained in:
@@ -356,6 +356,12 @@ def apply_chat_template_and_mask(
|
||||
truncation: bool = True,
|
||||
ignore_idx: int = -100,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# Format for RL.
|
||||
gt_answer = None
|
||||
if "messages" in chat and "gt_answer" in chat:
|
||||
gt_answer = chat["gt_answer"]
|
||||
chat = [chat["messages"]]
|
||||
|
||||
tokens = []
|
||||
assistant_mask = []
|
||||
for i, msg in enumerate(chat):
|
||||
@@ -389,6 +395,11 @@ def apply_chat_template_and_mask(
|
||||
labels = input_ids.clone()
|
||||
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
|
||||
|
||||
if gt_answer is not None:
|
||||
gt_answer = tokenizer.encode(gt_answer, padding="max_length", max_length=64, return_tensors="pt")
|
||||
gt_answer = gt_answer.squeeze(1)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
|
Reference in New Issue
Block a user