mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
Add GRPO and Support RLVR for PPO (#6186)
* add grpo, support rlvr * add grpo, support rlvr * tested deepseek r1 pipeline * add ci * verify grpo r1 * verify grpo r1 * update readme, remove unused code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove path * clean code * fix circular import * fix ci OOM * fix ci OOM * skip kto tp, fix qwen generation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,12 +4,14 @@ from .generation import generate, generate_streaming, prepare_inputs_fn, update_
|
||||
from .lora import LoraConfig, convert_to_lora_module, lora_manager
|
||||
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
from .reward_model import RewardModel
|
||||
from .rlvr_reward_model import RLVRRewardModel
|
||||
from .utils import disable_dropout
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"Critic",
|
||||
"RewardModel",
|
||||
"RLVRRewardModel",
|
||||
"PolicyLoss",
|
||||
"ValueLoss",
|
||||
"LogSigLoss",
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -88,13 +89,14 @@ def update_model_kwargs_fn(outputs: dict, new_mask, **model_kwargs) -> dict:
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def prepare_inputs_fn(input_ids: torch.Tensor, pad_token_id: int, **model_kwargs) -> dict:
|
||||
def prepare_inputs_fn(input_ids: torch.Tensor, **model_kwargs) -> dict:
|
||||
model_kwargs["input_ids"] = input_ids
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def _sample(
|
||||
model: Any,
|
||||
tokenizer: Any,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = True,
|
||||
@@ -137,8 +139,8 @@ def _sample(
|
||||
if max_new_tokens is None:
|
||||
max_new_tokens = max_length - context_length
|
||||
if context_length + max_new_tokens > max_length or max_new_tokens == 0:
|
||||
print("Exeeded length limitation")
|
||||
return input_ids
|
||||
|
||||
logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
past = None
|
||||
@@ -183,18 +185,14 @@ def _sample(
|
||||
|
||||
if stop_token_ids is not None:
|
||||
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
||||
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
||||
)
|
||||
for stop_token_id in stop_token_ids:
|
||||
tokens_to_check = input_ids[:, -len(stop_token_id) :]
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
|
||||
)
|
||||
|
||||
# Stop when each sentence is finished if early_stopping=True
|
||||
if (early_stopping and _is_sequence_finished(unfinished_sequences)) or i == context_length + max_new_tokens - 1:
|
||||
if i == context_length + max_new_tokens - 1:
|
||||
# Force to end with stop token ids
|
||||
input_ids[input_ids[:, -1] != pad_token_id, -len(stop_token_ids) :] = (
|
||||
torch.LongTensor(stop_token_ids).to(input_ids.device).long()
|
||||
)
|
||||
return input_ids
|
||||
|
||||
|
||||
@@ -237,8 +235,10 @@ def generate(
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# Run sample
|
||||
generation_kwargs = copy.deepcopy(model_kwargs)
|
||||
res = _sample(
|
||||
model,
|
||||
tokenizer,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
@@ -249,8 +249,9 @@ def generate(
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs,
|
||||
**generation_kwargs,
|
||||
)
|
||||
del generation_kwargs
|
||||
return res
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
@@ -350,11 +351,17 @@ def _sample_streaming(
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
|
||||
if stop_token_ids is not None:
|
||||
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
||||
tokens_to_check = input_ids[:, -len(stop_token_ids) :]
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
||||
)
|
||||
if isinstance(stop_token_ids[0], int):
|
||||
# If the last len(stop_token_ids) tokens of input_ids are equal to stop_token_ids, set sentence to finished.
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_ids).to(input_ids.device), dim=1).long()
|
||||
)
|
||||
else:
|
||||
for stop_token_id in stop_token_ids:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
torch.any(tokens_to_check != torch.LongTensor(stop_token_id).to(input_ids.device), dim=1).long()
|
||||
)
|
||||
|
||||
# Stop when each sentence is finished if early_stopping=True
|
||||
if (
|
||||
|
@@ -25,7 +25,9 @@ class RewardModel(BaseModel):
|
||||
self.value_head = nn.Linear(self.last_hidden_state_size, 1)
|
||||
self.value_head.weight.data.normal_(mean=0.0, std=1 / (self.last_hidden_state_size + 1))
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask)
|
||||
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
|
50
applications/ColossalChat/coati/models/rlvr_reward_model.py
Normal file
50
applications/ColossalChat/coati/models/rlvr_reward_model.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
reward model
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RLVRRewardModel:
|
||||
"""
|
||||
RLVRReward model class. Support varifiable reward.
|
||||
|
||||
Args:
|
||||
reward_fn_list List: list of reward functions
|
||||
**kwargs: all other kwargs as in reward functions
|
||||
"""
|
||||
|
||||
def __init__(self, reward_fn_list: List[Callable], **kwargs) -> None:
|
||||
self.reward_fn_list = reward_fn_list
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
response_start: List = None,
|
||||
response_end: List = None,
|
||||
gt_answer: List = None,
|
||||
) -> torch.Tensor:
|
||||
# apply varifiable reward
|
||||
bs = input_ids.size(0)
|
||||
rewards = torch.zeros(bs, device=input_ids.device)
|
||||
for i in range(bs):
|
||||
for reward_fn in self.reward_fn_list:
|
||||
rewards[i] += reward_fn(
|
||||
input_ids[i],
|
||||
attention_mask[i],
|
||||
response_start=response_start[i],
|
||||
response_end=response_end[i],
|
||||
gt_answer=gt_answer[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
return rewards
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
return self
|
@@ -142,3 +142,17 @@ def disable_dropout(model: torch.nn.Module):
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0.0
|
||||
|
||||
|
||||
def repad_to_left(tensor, tokenizer):
|
||||
repadded_input_ids = []
|
||||
max_non_padded_seq_len = 0
|
||||
for i in range(tensor.size(0)):
|
||||
non_pad_indices = (tensor[i] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
|
||||
start, end = non_pad_indices.min(), non_pad_indices.max()
|
||||
repadded_input_ids.append(tensor[i][start : end + 1])
|
||||
max_non_padded_seq_len = max(max_non_padded_seq_len, repadded_input_ids[-1].size(0))
|
||||
repadded_input_ids = [
|
||||
F.pad(t, (max_non_padded_seq_len - t.size(0), 0), value=tokenizer.pad_token_id) for t in repadded_input_ids
|
||||
]
|
||||
return torch.stack(repadded_input_ids)
|
||||
|
Reference in New Issue
Block a user