mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 09:01:06 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -147,7 +147,6 @@ def tokenize_prompt(
|
||||
ignore_index: the ignore index when calculate loss during training
|
||||
max_length: the maximum context length
|
||||
"""
|
||||
|
||||
messages = data_point["messages"]
|
||||
template = deepcopy(conversation_template)
|
||||
template.messages = []
|
||||
@@ -167,7 +166,6 @@ def tokenize_prompt(
|
||||
if len(template.messages) % 2 != 1:
|
||||
# exclude the answer if provided. keep only the prompt
|
||||
template.messages = template.messages[:-1]
|
||||
|
||||
# Prepare data
|
||||
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
|
||||
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
|
||||
@@ -185,12 +183,21 @@ def tokenize_prompt(
|
||||
)
|
||||
|
||||
# `inputs_decode` can be used to check whether the tokenization method is true.
|
||||
return dict(
|
||||
input_ids=tokenized,
|
||||
inputs_decode=prompt,
|
||||
seq_length=len(tokenized),
|
||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||
)
|
||||
if "gt_answer" in data_point:
|
||||
return dict(
|
||||
input_ids=tokenized,
|
||||
inputs_decode=prompt,
|
||||
seq_length=len(tokenized),
|
||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||
gt_answer=data_point["gt_answer"],
|
||||
)
|
||||
else:
|
||||
return dict(
|
||||
input_ids=tokenized,
|
||||
inputs_decode=prompt,
|
||||
seq_length=len(tokenized),
|
||||
seq_category=data_point["category"] if "category" in data_point else "None",
|
||||
)
|
||||
|
||||
|
||||
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
|
||||
|
Reference in New Issue
Block a user