mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
support code generation tasks
This commit is contained in:
@@ -367,9 +367,9 @@ def apply_chat_template_and_mask(
|
||||
}
|
||||
|
||||
# Format for RL.
|
||||
gt_answer = None
|
||||
if "messages" in chat and "gt_answer" in chat:
|
||||
gt_answer = chat["gt_answer"]
|
||||
if "messages" in chat:
|
||||
gt_answer = chat.get("gt_answer", None)
|
||||
test_cases = chat.get("test_cases", None)
|
||||
chat = [chat["messages"]]
|
||||
|
||||
tokens = []
|
||||
@@ -402,12 +402,14 @@ def apply_chat_template_and_mask(
|
||||
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx
|
||||
|
||||
if gt_answer is not None:
|
||||
gt_answer = tokenizer.encode(
|
||||
gt_answer, padding="max_length", truncation=True, max_length=128, return_tensors="pt"
|
||||
)
|
||||
gt_answer = gt_answer.squeeze(1)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "gt_answer": gt_answer}
|
||||
|
||||
elif test_cases is not None:
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
"test_cases": test_cases,
|
||||
}
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
@@ -440,3 +442,20 @@ class RawConversationDataset(Dataset):
|
||||
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length, self.system_prompt)
|
||||
self.tokenized_texts[index] = dict(tokens)
|
||||
return self.tokenized_texts[index]
|
||||
|
||||
|
||||
def collate_fn_grpo(batch):
|
||||
input_ids = [item["input_ids"] for item in batch]
|
||||
attention_mask = [item["attention_mask"] for item in batch]
|
||||
labels = [item["labels"] for item in batch]
|
||||
# Assume input_ids, attention_mask, labels are already of the same length,
|
||||
# otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
|
||||
input_ids = torch.stack(input_ids)
|
||||
attention_mask = torch.stack(attention_mask)
|
||||
labels = torch.stack(labels)
|
||||
ret = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
if "test_cases" in batch[0]:
|
||||
ret["test_cases"] = [item["test_cases"] for item in batch]
|
||||
if "gt_answer" in batch[0]:
|
||||
ret["gt_answer"] = [item["gt_answer"] for item in batch]
|
||||
return ret
|
||||
|
Reference in New Issue
Block a user