[feat] Support prompt level dynamic (#6300)

* adjust to dynamic prompt bs

* remove debug

* update pad seq (#6303)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* adjust to dynamic prompt bs

* remove debug

* fix dp issue

* fix

* fix default settings

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li
2025-05-14 16:40:35 +08:00
committed by GitHub
parent b920af427b
commit aca547623f
4 changed files with 123 additions and 93 deletions

View File

@@ -1,4 +1,3 @@
import warnings
from contextlib import nullcontext
from typing import Any, Optional
@@ -10,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@@ -42,13 +41,6 @@ class GRPOConsumer(BaseConsumer):
save_dir="./model",
):
print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
if batch_size != minibatch_size:
warnings.warn(
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
UserWarning,
)
minibatch_size = batch_size
if (
plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in plugin_config
@@ -90,6 +82,7 @@ class GRPOConsumer(BaseConsumer):
self.grpo_config = grpo_config
self.project_name = project_name
self.effective_sample_count = 0
self.effective_prompt_count = 0
self.total_sample_count = 0
self.policy_loss_fn = PolicyLoss(
@@ -213,70 +206,66 @@ class GRPOConsumer(BaseConsumer):
group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
)
# [minibatch_size x num_of_generation]
loss_mask = (
torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
)
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
loss_mask = torch.logical_and(
loss_mask,
action_mask[:, -1] == False,
)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item()
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
num_excessive_samples = (
int(
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
/ self.num_generations
/ self.dp_size
)
* self.num_generations
)
if num_excessive_samples > 0:
data = {
k: (
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
if k
in [
"input_ids",
"attention_mask",
"action_log_probs",
"action_mask",
"response_idx",
"gt_answer",
]
else v
)
for k, v in data.items()
}
action_mask = action_mask[
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
]
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
else:
num_excessive_samples = 0
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
num_excessive_samples = 0
pbar.set_postfix(
{
"Step": self.global_step + 1,
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
"Global Step": self.global_step,
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
}
)
@@ -375,7 +364,7 @@ class GRPOConsumer(BaseConsumer):
kl.append(appox_kl.mean())
else:
per_token_kl = 0.0
kl.append(0.0)
kl.append(torch.tensor(0.0))
loss, _ = self.policy_loss_fn(
action_log_probs,
@@ -479,6 +468,7 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.zero_grad()
self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_prompt_count = 0
self.effective_sample_count = 0
self.total_sample_count = 0
loss_scalar = self.accum_loss.item()
@@ -495,6 +485,7 @@ class GRPOConsumer(BaseConsumer):
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg))
metrics = {
@@ -520,9 +511,15 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar, num_excessive_samples // self.num_generations
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
else:
return None, num_excessive_samples // self.num_generations
return None, excessive_prompts_idx
def state_dict(self):
self.policy_model._force_wait_all_gather()