Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-02-18 09:43:36 +08:00
committed by GitHub
parent ce0ec40811
commit d20c8ffd97
39 changed files with 1995 additions and 277 deletions

View File

@@ -1,3 +1,4 @@
import copy
from typing import Any, Callable, List, Optional
import torch
@@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
return model_kwargs
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict:
model_kwargs["input_ids"] = input_ids
return model_kwargs
def _sample(
model: Any,
tokenizer: Any,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = True,
@@ -137,8 +139,8 @@ def _sample(
if max_new_tokens is None:
max_new_tokens = max_length - context_length
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
print("Exeeded length limitation")
return input_ids
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
past = None
@@ -183,18 +185,14 @@ def _sample(
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
for stop_token_id in stop_token_ids:
tokens_to_check = input_ids[:, -len(stop_token_id) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
if i == context_length + max_new_tokens - 1:
# Force to end with stop token ids
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
)
return input_ids
@@ -237,8 +235,10 @@ def generate(
raise NotImplementedError
elif is_sample_gen_mode:
# Run sample
generation_kwargs = copy.deepcopy(model_kwargs)
res = _sample(
model,
tokenizer,
input_ids,
max_length,
early_stopping=early_stopping,
@@ -249,8 +249,9 @@ def generate(
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
**generation_kwargs,
)
del generation_kwargs
return res
elif is_beam_gen_mode:
raise NotImplementedError
@@ -350,11 +351,17 @@ def _sample_streaming(
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
if isinstance(stop_token_ids[0], int):
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
else:
for stop_token_id in stop_token_ids:
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (