mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[feat] Support prompt level dynamic (#6300)
* adjust to dynamic prompt bs * remove debug * update pad seq (#6303) Co-authored-by: Tong Li <tong.li35271158@gmail.com> * adjust to dynamic prompt bs * remove debug * fix dp issue * fix * fix default settings --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -10,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
|
||||
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs
|
||||
from coati.trainer.utils import all_reduce_mean, all_reduce_sum
|
||||
from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
@@ -42,13 +41,6 @@ class GRPOConsumer(BaseConsumer):
|
||||
save_dir="./model",
|
||||
):
|
||||
print(f"Using GRPO config: {grpo_config}")
|
||||
if grpo_config.get("loss_variation", "sample_level") == "token_level":
|
||||
if batch_size != minibatch_size:
|
||||
warnings.warn(
|
||||
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
|
||||
UserWarning,
|
||||
)
|
||||
minibatch_size = batch_size
|
||||
if (
|
||||
plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in plugin_config
|
||||
@@ -90,6 +82,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.grpo_config = grpo_config
|
||||
self.project_name = project_name
|
||||
self.effective_sample_count = 0
|
||||
self.effective_prompt_count = 0
|
||||
self.total_sample_count = 0
|
||||
|
||||
self.policy_loss_fn = PolicyLoss(
|
||||
@@ -213,70 +206,66 @@ class GRPOConsumer(BaseConsumer):
|
||||
group_ans_acc = (
|
||||
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
|
||||
)
|
||||
# [minibatch_size x num_of_generation]
|
||||
loss_mask = (
|
||||
torch.ones(action_mask.size(0), device=action_mask.device).bool()
|
||||
if self.filter_range is None
|
||||
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
|
||||
)
|
||||
|
||||
# filter out overlength samples
|
||||
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
|
||||
loss_mask = torch.logical_and(
|
||||
loss_mask,
|
||||
action_mask[:, -1] == False,
|
||||
)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
|
||||
|
||||
# [minibatch_size] -> calculate the number of effective prompts
|
||||
effective_prompts_mask = prompt_level_mask.any(dim=1)
|
||||
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
|
||||
self.effective_prompt_count += effective_prompts.item()
|
||||
excessive_prompts_idx = None
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
|
||||
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
|
||||
|
||||
if excessive_prompts > 0:
|
||||
excessive_prompts_per_rank = excessive_prompts // self.dp_size
|
||||
# Only count excessive prompts if they are greater than 1 per rank.
|
||||
# TODO: customize excessive prompts calculation.
|
||||
if excessive_prompts_per_rank != 0:
|
||||
# Mask excessive prompts to False
|
||||
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
|
||||
if excessive_prompts_per_rank <= len(true_indices):
|
||||
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
|
||||
else:
|
||||
excessive_prompts_idx = true_indices
|
||||
effective_prompts_mask[excessive_prompts_idx] = False
|
||||
|
||||
for mask_idx in range(len(effective_prompts_mask)):
|
||||
if effective_prompts_mask[mask_idx] == False:
|
||||
# Update loss mask.
|
||||
loss_mask[mask_idx] = False
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
||||
self.effective_sample_count += effective_samples.item()
|
||||
self.total_sample_count += total_samples.item()
|
||||
|
||||
mean_kl, mean_loss = [], []
|
||||
|
||||
if self.grpo_config.get("dynamic_batching", True):
|
||||
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
|
||||
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
|
||||
num_excessive_samples = (
|
||||
int(
|
||||
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
|
||||
/ self.num_generations
|
||||
/ self.dp_size
|
||||
)
|
||||
* self.num_generations
|
||||
)
|
||||
if num_excessive_samples > 0:
|
||||
data = {
|
||||
k: (
|
||||
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
|
||||
if k
|
||||
in [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"action_log_probs",
|
||||
"action_mask",
|
||||
"response_idx",
|
||||
"gt_answer",
|
||||
]
|
||||
else v
|
||||
)
|
||||
for k, v in data.items()
|
||||
}
|
||||
action_mask = action_mask[
|
||||
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
|
||||
]
|
||||
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
|
||||
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
|
||||
else:
|
||||
num_excessive_samples = 0
|
||||
else:
|
||||
# If dynamic batching is disabled, we need to use all samples for training.
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
num_excessive_samples = 0
|
||||
|
||||
pbar.set_postfix(
|
||||
{
|
||||
"Step": self.global_step + 1,
|
||||
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
"Global Step": self.global_step,
|
||||
"Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
|
||||
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -375,7 +364,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
kl.append(appox_kl.mean())
|
||||
else:
|
||||
per_token_kl = 0.0
|
||||
kl.append(0.0)
|
||||
kl.append(torch.tensor(0.0))
|
||||
|
||||
loss, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
@@ -479,6 +468,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.zero_grad()
|
||||
self.global_step += 1
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_prompt_count = 0
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
@@ -495,6 +485,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
|
||||
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
|
||||
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
|
||||
f"Sample_utilization: {sample_utilization:.4f}",
|
||||
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
|
||||
print("\n".join(to_log_msg))
|
||||
metrics = {
|
||||
@@ -520,9 +511,15 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
|
||||
if excessive_prompts_idx is not None:
|
||||
# All gather excessive prompts index across DP ranks.
|
||||
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
|
||||
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
|
||||
|
||||
return loss_scalar, excessive_prompts_idx
|
||||
else:
|
||||
return None, num_excessive_samples // self.num_generations
|
||||
return None, excessive_prompts_idx
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
|
Reference in New Issue
Block a user