Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-02-18 09:43:36 +08:00
committed by GitHub
parent ce0ec40811
commit d20c8ffd97
39 changed files with 1995 additions and 277 deletions

View File

@@ -147,7 +147,6 @@ def tokenize_prompt(
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""
messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
@@ -167,7 +166,6 @@ def tokenize_prompt(
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
template.messages = template.messages[:-1]
# Prepare data
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
@@ -185,12 +183,21 @@ def tokenize_prompt(
)
# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
if "gt_answer" in data_point:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
gt_answer=data_point["gt_answer"],
)
else:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
def apply_rlhf_data_format(template: Conversation, tokenizer: Any):