Merge pull request #6312 from hpcaitech/grpo-latest-dev

[feat] Move prompt-level-filtering to buffer side
This commit is contained in:
YeAnbang 2025-06-05 15:51:38 +08:00 committed by GitHub
commit ceb7065d6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 252 additions and 144 deletions

View File

@ -117,26 +117,102 @@ class BaseConsumer:
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}") print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend( raw_batch = ray_broadcast_tensor_dict(
unbind_batch(
ray_broadcast_tensor_dict(
None, src=0, device=self.device, group_name=f"sync_data_{r}" None, src=0, device=self.device, group_name=f"sync_data_{r}"
) )
# calculate group reward et al. filtering. As only the filtered group will be used for training (which is incomplete),
# we need to calculate the metrics before filtering here for logging
# [batch_size, num_generations, ...] -> [batch_size * num_generations, ...]
raw_batch_with_reward = self.calculate_reward(
{k: v.view(-1, v.size(-1)) if k != "temperature" else v for k, v in raw_batch.items()}
) )
raw_batch_with_reward = {
k: v.view(-1, self.num_generations, v.size(-1)) if k != "temperature" else v
for k, v in raw_batch_with_reward.items()
}
# [batch_size, num_generations] -> [batch_size]
reward = raw_batch_with_reward["reward"][:, :, 0]
format_acc = raw_batch_with_reward["format_acc"][:, :, 0]
ans_acc = raw_batch_with_reward["ans_acc"][:, :, 0]
response_len = (
raw_batch_with_reward["response_idx"][:, :, 1]
- raw_batch_with_reward["response_idx"][:, :, 0]
+ 1
).type(torch.float32)
effective_group_mask = None
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", True):
# filter the group based on the reward and accuracy
group_ans_acc_mean = ans_acc.mean(dim=1)
effective_group_mask = torch.logical_and(
group_ans_acc_mean > self.filter_range[0], group_ans_acc_mean < self.filter_range[1]
) )
while len(self.buffer) >= self.dp_size * self.minibatch_size: raw_batch_with_reward = unbind_batch(raw_batch_with_reward) # List[Dict[str, torch.Tensor]]
batches = self.buffer[ for group_idx, group_with_reward in enumerate(raw_batch_with_reward):
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size self.buffer.append(
[
(
group_with_reward
if effective_group_mask is None or effective_group_mask[group_idx]
else None
),
reward[group_idx],
format_acc[group_idx],
ans_acc[group_idx],
response_len[group_idx],
] ]
batch = bind_batch(batches) )
batch = post_recv(batch) if effective_group_mask is not None:
loss, excessive_prompts_idx = self.step(i, pbar, **batch) print(
f"[T{dist.get_rank()}] Filter recv data: {len(raw_batch_with_reward)} -> {torch.sum(effective_group_mask).cpu().item()} effective groups"
)
# mapping the effective group to the raw group for indexing
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
print(
f"[T{dist.get_rank()}] Collect Effective Prompt: {len(effective_group_to_raw_group_mapping)}/{self.dp_size * self.minibatch_size}"
)
if excessive_prompts_idx is not None: while len(effective_group_to_raw_group_mapping) >= self.dp_size * self.minibatch_size:
excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx] # on each dp_rank, we use minibatch_size effective samples to form a batch
self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :] batches = [
else: self.buffer[effective_group_to_raw_group_mapping[i]]
self.buffer = self.buffer[self.dp_size * self.minibatch_size :] for i in range(
self.dp_rank * self.minibatch_size, (self.dp_rank + 1) * self.minibatch_size
)
]
# every dp_rank will receive a complete mini-batch, no need to sync within step() later
# each mini-batch use the first self.dp_size * minibatch_size effective samples
raw_mini_batches = self.buffer[
: effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1
] # include the last effective sample
raw_mini_batches_metric_dict = {
"raw_train_mini_batch_reward": [t[1] for t in raw_mini_batches],
"raw_train_mini_batch_format_acc": [t[2] for t in raw_mini_batches],
"raw_train_mini_batch_ans_acc": [t[3] for t in raw_mini_batches],
"raw_train_mini_batch_response_len": [t[4] for t in raw_mini_batches],
}
batch = bind_batch([t[0] for t in batches])
batch = post_recv(batch)
loss = self.step(i, pbar, **batch, **raw_mini_batches_metric_dict)
self.buffer = self.buffer[
effective_group_to_raw_group_mapping[self.dp_size * self.minibatch_size - 1] + 1 :
]
# recalculate the effective group to raw group mapping
effective_group_to_raw_group_mapping_size_before = len(effective_group_to_raw_group_mapping)
effective_group_to_raw_group_mapping = {}
for buffer_idx in range(len(self.buffer)):
if self.buffer[buffer_idx][0] is not None:
effective_group_to_raw_group_mapping[len(effective_group_to_raw_group_mapping)] = (
buffer_idx
)
assert (
len(effective_group_to_raw_group_mapping)
== effective_group_to_raw_group_mapping_size_before - self.dp_size * self.minibatch_size
)
if loss is not None: if loss is not None:
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1

View File

@ -1,5 +1,5 @@
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Optional from typing import Any, Dict, Optional
import ray import ray
import torch import torch
@ -9,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -72,21 +72,18 @@ class GRPOConsumer(BaseConsumer):
self.policy_model.gradient_checkpointing_enable() self.policy_model.gradient_checkpointing_enable()
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6)) self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
self.accum_loss = torch.zeros(1, device=self.device) self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_acc = torch.zeros(1, device=self.device)
self.accum_ans_acc = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_response_length = torch.zeros(1, device=self.device) self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
self.accum_count = 0 self.accum_count = 0
self.generate_config = generate_config self.generate_config = generate_config
self.grpo_config = grpo_config self.grpo_config = grpo_config
self.project_name = project_name self.project_name = project_name
self.effective_sample_count = 0 self.effective_sample_count = 0
self.effective_prompt_count = 0 self.effective_prompt_count = 0
self.total_sample_count = 0
self.overlength_samples = 0
self.total_overlength_samples = 0
self.project_name = project_name self.project_name = project_name
self.run_name = run_name self.run_name = run_name
self.wandb_group_name = wandb_group_name self.wandb_group_name = wandb_group_name
@ -122,16 +119,7 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
) )
# Initialize verifiable reward. # Initialize verifiable reward.
response_format_tags = ( response_format_tags = grpo_config.get("response_format_tags", None)
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if grpo_config.get("reward_fn_type") == "think_answer_tags"
else None
)
reward_model_kwargs = { reward_model_kwargs = {
k: v k: v
for k, v in grpo_config.items() for k, v in grpo_config.items()
@ -187,24 +175,21 @@ class GRPOConsumer(BaseConsumer):
Format: Format:
[minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>. [minibatch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
""" """
# Reshape to [minibatch_size x num_of_generation, prompt_length + response_length] # Reshape to [minibatch_size x num_of_generation, prompt_length + response_length]
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()} data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items() if "raw_train_mini_batch_" not in k}
self.raw_train_batch_reward.extend(kwargs["raw_train_mini_batch_reward"])
self.raw_train_batch_format_acc.extend(kwargs["raw_train_mini_batch_format_acc"])
self.raw_train_batch_ans_acc.extend(kwargs["raw_train_mini_batch_ans_acc"])
self.raw_train_batch_response_len.extend(kwargs["raw_train_mini_batch_response_len"])
action_mask = data["action_mask"] action_mask = data["action_mask"]
num_action = action_mask.shape[1] num_action = action_mask.shape[1]
old_action_log_probs = data["action_log_probs"] old_action_log_probs = data["action_log_probs"]
response_length = torch.sum(action_mask, dim=1).to(torch.float32) response_length = torch.sum(action_mask, dim=1).to(torch.float32)
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0)) train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
reward_group = self.reward_model( reward = data["reward"].view((-1))
data["input_ids"], format_acc = data["format_acc"].view((-1))
gt_answer=data["gt_answer"], ans_acc = data["ans_acc"].view((-1))
response_idx=data["response_idx"],
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [minibatch_size, num_generations] # [minibatch_size, num_generations]
@ -216,67 +201,35 @@ class GRPOConsumer(BaseConsumer):
reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0) reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
# [minibatch_size x num_generations] # [minibatch_size x num_generations]
advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1) advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
# [minibatch_size x num_of_generation] # [minibatch_size x num_of_generation]
loss_mask = ( loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
)
# filter out overlength samples # filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length: if self.filter_truncated_response and action_mask.size(1) == self.max_length:
old_loss_mask = loss_mask.clone()
loss_mask = torch.logical_and( loss_mask = torch.logical_and(
loss_mask, loss_mask,
action_mask[:, -1] == False, action_mask[:, -1] == False,
) )
if self.filter_range is not None and self.grpo_config.get("dynamic_batching", False) == False:
self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item() # filter out samples with reward outside the range
self.overlength_samples = all_reduce_sum( # if dynamic batching is enabled, we filter out out of range groups before training
torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin group_ans_acc_mean = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=-1)
) )
self.total_overlength_samples += self.overlength_samples.item() loss_mask = torch.logical_and(
loss_mask,
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) torch.logical_and(
group_ans_acc_mean > self.filter_range[0],
# [minibatch_size] -> calculate the number of effective prompts group_ans_acc_mean < self.filter_range[1],
effective_prompts_mask = prompt_level_mask.any(dim=1) ),
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin) )
self.effective_prompt_count += effective_prompts.item() self.effective_prompt_count += group_reward.size(0) * self.dp_size
excessive_prompts_idx = None
mean_kl, mean_loss = [], [] mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True): if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask)
# Make sure the indices are not empty.
if true_indices.numel() > 0:
true_indices = true_indices.squeeze(-1)
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
excessive_prompts_idx = torch.empty([0])
else: else:
# If dynamic batching is disabled, we need to use all samples for training. # If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0 need_update = (step_idx + 1) % self.num_microbatches == 0
@ -286,12 +239,10 @@ class GRPOConsumer(BaseConsumer):
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item() self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
pbar.set_postfix( pbar.set_postfix(
{ {
"Global Step": self.global_step, "Global Step": self.global_step,
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}", "Gradient Accumulation on": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size} effective prompts, {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations} effective samples",
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
} }
) )
@ -483,22 +434,16 @@ class GRPOConsumer(BaseConsumer):
self.accum_loss.add_(sum(mean_loss) / len(mean_loss)) self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
self.accum_kl.add_(sum(mean_kl) / len(mean_kl)) self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
self.accum_reward.add_(reward.data)
self.accum_format_acc.add_(format_acc.data)
self.accum_ans_acc.add_(ans_acc.data)
self.accum_advantages.add_(advantages.data) self.accum_advantages.add_(advantages.data)
self.accum_response_length.add_(response_length.data)
self.accum_count += 1 self.accum_count += 1
if need_update: if need_update:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1 self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count # no need to run all reduce as raw_train_batch_* are not splited across dp rank
overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count sample_utilization = self.effective_sample_count / len(self.raw_train_batch_reward) / self.num_generations
self.effective_prompt_count = 0 self.effective_prompt_count = 0
self.effective_sample_count = 0 self.effective_sample_count = 0
self.total_sample_count = 0
self.total_overlength_samples = 0
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or ( if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
@ -506,27 +451,39 @@ class GRPOConsumer(BaseConsumer):
if (not self.plugin.pp_size > 1 and self.rank == 0) or ( if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
raw_batch_reward_mean = torch.cat(self.raw_train_batch_reward, dim=0).mean().cpu().item()
raw_batch_format_acc_mean = torch.cat(self.raw_train_batch_format_acc, dim=0).mean().cpu().item()
raw_batch_ans_acc_mean = torch.cat(self.raw_train_batch_ans_acc, dim=0).mean().cpu().item()
raw_batch_response_len = torch.cat(self.raw_train_batch_response_len, dim=0)
raw_batch_response_len_mean = raw_batch_response_len.mean().cpu().item()
overlength_samples_ratio = (
(raw_batch_response_len >= action_mask.size(-1)).to(float).mean().cpu().item()
) # not an exact figure, but a close estimate
self.raw_train_batch_reward = []
self.raw_train_batch_format_acc = []
self.raw_train_batch_ans_acc = []
self.raw_train_batch_response_len = []
to_log_msg = [ to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", f"Reward: {raw_batch_reward_mean:.4f}",
f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", f"format Reward: {raw_batch_format_acc_mean:.4f}",
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", f"Acc Reward: {raw_batch_ans_acc_mean:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", f"Response Length: {raw_batch_response_len_mean:.4f}",
f"Sample_utilization: {sample_utilization:.4f}", f"Sample_utilization: {sample_utilization:.4f}",
f"Percentage of overlength samples: {overlength_samples_percentage:.4f}", f"Overlength samples ratio: {overlength_samples_ratio:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg)) print("\n".join(to_log_msg))
metrics = { metrics = {
"metrics/reward": self.accum_reward.item() / self.accum_count, "metrics/reward": raw_batch_reward_mean,
"metrics/format_acc": self.accum_format_acc.item() / self.accum_count, "metrics/format_acc": raw_batch_format_acc_mean,
"metrics/ans_acc": self.accum_ans_acc.item() / self.accum_count, "metrics/ans_acc": raw_batch_ans_acc_mean,
"metrics/response_length": self.accum_response_length.item() / self.accum_count, "metrics/response_length": raw_batch_response_len_mean,
"train/loss": self.accum_loss.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count,
"train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
"train/sample_utilization": sample_utilization, "train/sample_utilization": sample_utilization,
"train/percentage_overlength_samples": overlength_samples_percentage, "train/overlength_samples_ratio": overlength_samples_ratio,
"rollout/temperature": data["temperature"].cpu().numpy()[0][0], "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
} }
if self.policy_loss_fn.beta > 0: if self.policy_loss_fn.beta > 0:
@ -534,21 +491,46 @@ class GRPOConsumer(BaseConsumer):
if self.wandb_run is not None: if self.wandb_run is not None:
self.wandb_run.log(metrics) self.wandb_run.log(metrics)
self.accum_loss.zero_() self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_ans_acc.zero_()
self.accum_format_acc.zero_()
self.accum_kl.zero_() self.accum_kl.zero_()
self.accum_advantages.zero_() self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0 self.accum_count = 0
return loss_scalar
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
else: else:
return None, excessive_prompts_idx return None
def calculate_reward(self, rollout: Dict[str, Any]) -> Dict[str, Any]:
"""
Calculate the group reward for the given rollout group.
Args:
rollout_group (Dict[str, Any]):
a group of samples generated by the model from the same prompt
contain the following keys:
"input_ids": torch.Tensor, [num_of_generation, prompt_length + response_length]
"attention_mask": torch.Tensor, [num_of_generation, prompt_length + response_length]
"action_mask": torch.Tensor, [num_of_generation, response_length]
"action_log_probs": torch.Tensor, [num_of_generation, response_length]
"response_idx": int, torch.Tensor, [num_of_generation, 2]
"gt_answer": torch.Tensor, [num_of_generation, 128]
"temperature": torch.Tensor, [] (scalar)
Returns:
Dict[str, Any]: The new group data with calculated reward.
"""
reward_model_output = self.reward_model(
rollout["input_ids"],
gt_answer=rollout["gt_answer"],
response_idx=rollout["response_idx"],
)
# [num_of_generation]
reward = torch.tensor([value[0] for value in reward_model_output]).to(rollout["input_ids"].device)
format_acc = torch.tensor([value[1] for value in reward_model_output]).to(rollout["input_ids"].device)
ans_acc = torch.tensor([value[2] for value in reward_model_output]).to(rollout["input_ids"].device)
rollout["reward"] = reward.view((-1, 1))
rollout["format_acc"] = format_acc.view((-1, 1))
rollout["ans_acc"] = ans_acc.view((-1, 1))
return rollout
def state_dict(self): def state_dict(self):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()

View File

@ -100,6 +100,7 @@ def launch_distributed(
eval_dataset_config=eval_dataset_config, eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval, eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"], evaluation_function_type=grpo_config["reward_fn_type"],
response_format_tags=grpo_config["response_format_tags"],
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config, eval_generation_config=eval_generation_config,
project_name=project_name, project_name=project_name,

View File

@ -46,6 +46,7 @@ class BaseProducer:
eval_dataset_config=None, eval_dataset_config=None,
eval_interval=-1, # disable evaluation eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags", evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval", eval_save_dir: str = "./eval",
project_name: str = None, project_name: str = None,
run_name: str = None, run_name: str = None,
@ -148,6 +149,7 @@ class BaseProducer:
self.evaluation_function = boxed_math_reward_fn self.evaluation_function = boxed_math_reward_fn
else: else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
else: else:
print("No eval dataset provided, skip eval") print("No eval dataset provided, skip eval")
self.device = get_current_device() self.device = get_current_device()
@ -217,6 +219,7 @@ class BaseProducer:
eval_outputs["response_idx"][m][n], eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
eval_mode=True, eval_mode=True,
tags=self.response_format_tags,
) )
for m in range(eval_outputs["input_ids"].size(0)) for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1)) for n in range(eval_outputs["input_ids"].size(1))
@ -245,11 +248,10 @@ class BaseProducer:
self.eval_mode = False self.eval_mode = False
self.latest_eval_step = self.consumer_global_step self.latest_eval_step = self.consumer_global_step
outputs = self.rollout(**batch) outputs = self.rollout(**batch)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs["temperature"] = torch.tensor( outputs["temperature"] = torch.tensor(
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0) [self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
).to(outputs["input_ids"].device) ).to(outputs["input_ids"].device)
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")
outputs = pre_send(outputs) outputs = pre_send(outputs)
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}" outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
@ -324,6 +326,7 @@ class SimpleProducer(BaseProducer):
eval_dataset_config=None, eval_dataset_config=None,
eval_interval=-1, # disable evaluation eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags", evaluation_function_type="think_answer_tags",
response_format_tags=None,
eval_save_dir: str = "./eval", eval_save_dir: str = "./eval",
eval_generation_config={}, eval_generation_config={},
project_name: str = None, project_name: str = None,
@ -349,6 +352,7 @@ class SimpleProducer(BaseProducer):
eval_dataset_config=eval_dataset_config, eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval, eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type, evaluation_function_type=evaluation_function_type,
response_format_tags=response_format_tags,
eval_save_dir=eval_save_dir, eval_save_dir=eval_save_dir,
project_name=project_name, project_name=project_name,
run_name=run_name, run_name=run_name,

View File

@ -121,6 +121,34 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings." "-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
) )
parser.add_argument(
"-tp",
"--tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-pp",
"--pipeline-parallel-size",
type=int,
default=1,
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-zero",
"--zero-stage",
type=int,
default=0,
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
)
parser.add_argument(
"-ptp",
"--producer-tensor-parallel-size",
type=int,
default=1,
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
)
args = parser.parse_args() args = parser.parse_args()
if args.train_minibatch_size is None: if args.train_minibatch_size is None:
@ -134,8 +162,8 @@ if __name__ == "__main__":
and args.train_microbatch_size > 0 and args.train_microbatch_size > 0
), "Train micro batch size must be greater than 0 less than train mini batch size * num generations" ), "Train micro batch size must be greater than 0 less than train mini batch size * num generations"
assert ( assert (
args.train_minibatch_size <= args.train_batch_size args.train_minibatch_size <= args.train_batch_size and args.train_batch_size % args.train_minibatch_size == 0
), "Train mini batch size must be less than or equals to train batch size" ), "Train mini batch size must be less than or equals to train batch size and train batch size must be divisible by train mini batch size"
if args.master_address is None: if args.master_address is None:
# Default settings: Using single machine # Default settings: Using single machine
@ -178,7 +206,7 @@ if __name__ == "__main__":
enforce_eager=True, enforce_eager=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_model_len=args.max_new_tokens + args.max_prompt_tokens, max_model_len=args.max_new_tokens + args.max_prompt_tokens,
tensor_parallel_size=1, tensor_parallel_size=args.producer_tensor_parallel_size,
) )
) )
generate_config.update( generate_config.update(
@ -203,6 +231,16 @@ if __name__ == "__main__":
"reward_fn_type": args.reward_type, "reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens, "max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
} }
elif args.algo == "DAPO": elif args.algo == "DAPO":
# DAPO variant settings # DAPO variant settings
@ -222,13 +260,23 @@ if __name__ == "__main__":
"cache_length": min(1024, int(args.max_new_tokens / 4)), "cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True, "filter_truncated_response": True,
"reward_fn_type": args.reward_type, "reward_fn_type": args.reward_type,
"response_format_tags": (
{
"think_start": {"text": "<think>", "num_occur": 1},
"think_end": {"text": "</think>", "num_occur": 1},
"answer_start": {"text": "<answer>", "num_occur": 1},
"answer_end": {"text": "</answer>", "num_occur": 1},
}
if args.reward_type == "think_answer_tags"
else None
),
} }
else: else:
raise ValueError(f"Unsupported algorithm: {args.algo}") raise ValueError(f"Unsupported algorithm: {args.algo}")
launch_distributed( launch_distributed(
num_producers=args.num_inferencer, num_producers=args.num_inferencer,
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
num_consumer_procs=args.num_trainers, num_consumer_procs=args.num_trainers,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
inference_batch_size=args.inference_batch_size, inference_batch_size=args.inference_batch_size,
@ -247,17 +295,14 @@ if __name__ == "__main__":
train_model_config=train_model_config, train_model_config=train_model_config,
grpo_config=grpo_config, grpo_config=grpo_config,
plugin_config={ plugin_config={
"zero_stage": 2, "tp_size": args.tensor_parallel_size,
}, # for zero "pp_size": args.pipeline_parallel_size,
# plugin_config={ "microbatch_size": max(
# "tp_size": 2, 1, args.train_microbatch_size // args.pipeline_parallel_size
# "pp_size": 2, ), # microbatch size should be set to train_microbatch_size // pp_size
# "microbatch_size": max( "zero_stage": args.zero_stage,
# 1, args.train_microbatch_size // 2 "max_norm": 1.0,
# ), # microbatch size should be set to train_microbatch_size // pp_size }, # for pp, tp
# "zero_stage": 0,
# "max_norm": 1.0,
# }, # for pp, tp
inference_backend=args.backend, inference_backend=args.backend,
master_addr="localhost", master_addr="localhost",
master_port=args.master_port, master_port=args.master_port,