Add GRPO and Support RLVR for PPO (#6186)

* add grpo, support rlvr

* add grpo, support rlvr

* tested deepseek r1 pipeline

* add ci

* verify grpo r1

* verify grpo r1

* update readme, remove unused code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove path

* clean code

* fix circular import

* fix ci OOM

* fix ci OOM

* skip kto tp, fix qwen generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
YeAnbang
2025-02-18 09:43:36 +08:00
committed by GitHub
parent ce0ec40811
commit d20c8ffd97
39 changed files with 1995 additions and 277 deletions

View File

@@ -4,12 +4,14 @@ from .generation import generate, generate_streaming, prepare_inputs_fn, update_
from .lora import LoraConfig, convert_to_lora_module, lora_manager
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .rlvr_reward_model import RLVRRewardModel
from .utils import disable_dropout
__all__ = [
"BaseModel",
"Critic",
"RewardModel",
"RLVRRewardModel",
"PolicyLoss",
"ValueLoss",
"LogSigLoss",

View File

@@ -1,3 +1,4 @@
import copy
from typing import Any, Callable, List, Optional
import torch
@@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
return model_kwargs
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict:
model_kwargs["input_ids"] = input_ids
return model_kwargs
def _sample(
model: Any,
tokenizer: Any,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = True,
@@ -137,8 +139,8 @@ def _sample(
if max_new_tokens is None:
max_new_tokens = max_length - context_length
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
print("Exeeded length limitation")
return input_ids
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
past = None
@@ -183,18 +185,14 @@ def _sample(
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
for stop_token_id in stop_token_ids:
tokens_to_check = input_ids[:, -len(stop_token_id) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
if i == context_length + max_new_tokens - 1:
# Force to end with stop token ids
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
)
return input_ids
@@ -237,8 +235,10 @@ def generate(
raise NotImplementedError
elif is_sample_gen_mode:
# Run sample
generation_kwargs = copy.deepcopy(model_kwargs)
res = _sample(
model,
tokenizer,
input_ids,
max_length,
early_stopping=early_stopping,
@@ -249,8 +249,9 @@ def generate(
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
**generation_kwargs,
)
del generation_kwargs
return res
elif is_beam_gen_mode:
raise NotImplementedError
@@ -350,11 +351,17 @@ def _sample_streaming(
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
if stop_token_ids is not None:
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
if isinstance(stop_token_ids[0], int):
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
)
else:
for stop_token_id in stop_token_ids:
unfinished_sequences = unfinished_sequences.mul(
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
)
# Stop when each sentence is finished if early_stopping=True
if (

View File

@@ -25,7 +25,9 @@ class RewardModel(BaseModel):
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
outputs = self.model(input_ids, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]

View File

@@ -0,0 +1,50 @@
"""
reward model
"""
from typing import Callable, List, Optional
import torch
class RLVRRewardModel:
"""
RLVRReward model class. Support varifiable reward.
Args:
reward_fn_list List: list of reward functions
**kwargs: all other kwargs as in reward functions
"""
def __init__(self, reward_fn_list: List[Callable], **kwargs) -> None:
self.reward_fn_list = reward_fn_list
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
response_start: List = None,
response_end: List = None,
gt_answer: List = None,
) -> torch.Tensor:
# apply varifiable reward
bs = input_ids.size(0)
rewards = torch.zeros(bs, device=input_ids.device)
for i in range(bs):
for reward_fn in self.reward_fn_list:
rewards[i] += reward_fn(
input_ids[i],
attention_mask[i],
response_start=response_start[i],
response_end=response_end[i],
gt_answer=gt_answer[i],
**self.kwargs,
)
return rewards
def to(self, device):
return self
def eval(self):
return self

View File

@@ -142,3 +142,17 @@ def disable_dropout(model: torch.nn.Module):
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0.0
def repad_to_left(tensor, tokenizer):
repadded_input_ids = []
max_non_padded_seq_len = 0
for i in range(tensor.size(0)):
non_pad_indices = (tensor[i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
start, end = non_pad_indices.min(), non_pad_indices.max()
repadded_input_ids.append(tensor[i][start : end + 1])
max_non_padded_seq_len = max(max_non_padded_seq_len, repadded_input_ids[-1].size(0))
repadded_input_ids = [
F.pad(t, (max_non_padded_seq_len - t.size(0), 0), value=tokenizer.pad_token_id) for t in repadded_input_ids
]
return torch.stack(repadded_input_ids)