[feat] Support prompt level dynamic (#6300)

* adjust to dynamic prompt bs

* remove debug

* update pad seq (#6303)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>

* adjust to dynamic prompt bs

* remove debug

* fix dp issue

* fix

* fix default settings

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-05-14 16:40:35 +08:00 committed by GitHub
parent b920af427b
commit aca547623f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 123 additions and 93 deletions

View File

@ -107,9 +107,14 @@ class BaseConsumer:
f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}" f"Consumer{self.rank} num_update: {self.num_update_per_episode}, num_recv: {self.num_recv_per_update}, nmb: {self.num_microbatches}"
) )
for episode in range(self.num_episodes): for episode in range(self.num_episodes):
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar: with tqdm(
range(self.num_update_per_episode),
desc=f"Episode {episode} with rollout step(s)",
disable=self.rank != 0,
) as pbar:
for step in pbar: for step in pbar:
i = 0 i = 0
allow_sync_model = False
for _ in range(self.num_recv_per_update): for _ in range(self.num_recv_per_update):
# receive data from producers # receive data from producers
for r in range(self.num_producers): for r in range(self.num_producers):
@ -127,15 +132,15 @@ class BaseConsumer:
] ]
batch = bind_batch(batches) batch = bind_batch(batches)
batch = post_recv(batch) batch = post_recv(batch)
loss, num_excessive_prompts = self.step(i, pbar, **batch) loss, excessive_prompts_idx = self.step(i, pbar, **batch)
self.buffer = (
self.buffer[ if excessive_prompts_idx is not None:
(self.dp_rank + 1) * self.minibatch_size excessive_prompts = [self.buffer[idx] for idx in excessive_prompts_idx]
- num_excessive_prompts : (self.dp_rank + 1) * self.minibatch_size self.buffer = excessive_prompts + self.buffer[self.dp_size * self.minibatch_size :]
] else:
+ self.buffer[self.dp_size * self.minibatch_size :] self.buffer = self.buffer[self.dp_size * self.minibatch_size :]
)
if loss is not None: if loss is not None:
allow_sync_model = True
pbar.set_postfix({"loss": loss}) pbar.set_postfix({"loss": loss})
i += 1 i += 1
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
@ -149,6 +154,7 @@ class BaseConsumer:
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1: if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
if allow_sync_model:
if self.pp_size > 1: if self.pp_size > 1:
print( print(
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}" f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
@ -172,6 +178,7 @@ class BaseConsumer:
) )
del state_dict del state_dict
torch.cuda.empty_cache() torch.cuda.empty_cache()
allow_sync_model = False
@ray.remote @ray.remote

View File

@ -1,4 +1,3 @@
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Optional from typing import Any, Optional
@ -10,7 +9,7 @@ from coati.distributed.loss import PolicyLoss
from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn from coati.distributed.reward.reward_fn import boxed_math_reward_fn, math_reward_fn
from coati.distributed.reward.verifiable_reward import VerifiableReward from coati.distributed.reward.verifiable_reward import VerifiableReward
from coati.distributed.utils import calc_action_log_probs from coati.distributed.utils import calc_action_log_probs
from coati.trainer.utils import all_reduce_mean, all_reduce_sum from coati.trainer.utils import all_gather_tensors, all_reduce_mean, all_reduce_sum
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -42,13 +41,6 @@ class GRPOConsumer(BaseConsumer):
save_dir="./model", save_dir="./model",
): ):
print(f"Using GRPO config: {grpo_config}") print(f"Using GRPO config: {grpo_config}")
if grpo_config.get("loss_variation", "sample_level") == "token_level":
if batch_size != minibatch_size:
warnings.warn(
f"Applied token_level loss, force overwrite mini-batch-size with batch-size: mini-batch-size: {minibatch_size}->{batch_size}",
UserWarning,
)
minibatch_size = batch_size
if ( if (
plugin_config.get("pp_size", 1) > 1 plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in plugin_config and "num_microbatches" not in plugin_config
@ -90,6 +82,7 @@ class GRPOConsumer(BaseConsumer):
self.grpo_config = grpo_config self.grpo_config = grpo_config
self.project_name = project_name self.project_name = project_name
self.effective_sample_count = 0 self.effective_sample_count = 0
self.effective_prompt_count = 0
self.total_sample_count = 0 self.total_sample_count = 0
self.policy_loss_fn = PolicyLoss( self.policy_loss_fn = PolicyLoss(
@ -213,70 +206,66 @@ class GRPOConsumer(BaseConsumer):
group_ans_acc = ( group_ans_acc = (
ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0) ans_acc.view(-1, self.num_generations).mean(dim=1).repeat_interleave(self.num_generations, dim=0)
) )
# [minibatch_size x num_of_generation]
loss_mask = ( loss_mask = (
torch.ones(action_mask.size(0), device=action_mask.device).bool() torch.ones(action_mask.size(0), device=action_mask.device).bool()
if self.filter_range is None if self.filter_range is None
else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1]) else torch.logical_and(group_ans_acc > self.filter_range[0], group_ans_acc < self.filter_range[1])
) )
# filter out overlength samples # filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length: if self.filter_truncated_response and action_mask.size(1) == self.max_length:
loss_mask = torch.logical_and( loss_mask = torch.logical_and(
loss_mask, loss_mask,
action_mask[:, -1] == False, action_mask[:, -1] == False,
) )
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations)
# [minibatch_size] -> calculate the number of effective prompts
effective_prompts_mask = prompt_level_mask.any(dim=1)
effective_prompts = all_reduce_sum(torch.sum(effective_prompts_mask), self.plugin)
self.effective_prompt_count += effective_prompts.item()
excessive_prompts_idx = None
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_prompt_count >= self.batch_size * self.dp_size
excessive_prompts = self.effective_prompt_count - self.batch_size * self.dp_size
if excessive_prompts > 0:
excessive_prompts_per_rank = excessive_prompts // self.dp_size
# Only count excessive prompts if they are greater than 1 per rank.
# TODO: customize excessive prompts calculation.
if excessive_prompts_per_rank != 0:
# Mask excessive prompts to False
true_indices = torch.nonzero(effective_prompts_mask).squeeze()
if excessive_prompts_per_rank <= len(true_indices):
excessive_prompts_idx = true_indices[-excessive_prompts_per_rank:]
else:
excessive_prompts_idx = true_indices
effective_prompts_mask[excessive_prompts_idx] = False
for mask_idx in range(len(effective_prompts_mask)):
if effective_prompts_mask[mask_idx] == False:
# Update loss mask.
loss_mask[mask_idx] = False
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item() self.effective_sample_count += effective_samples.item()
self.total_sample_count += total_samples.item() self.total_sample_count += total_samples.item()
mean_kl, mean_loss = [], []
if self.grpo_config.get("dynamic_batching", True):
need_update = self.effective_sample_count >= self.batch_size * self.dp_size * self.num_generations
# to exclude the excessive samples from the last batch, the last num_excessive_samples samples are not used for training and will be kept in buffer for the next iteration.
num_excessive_samples = (
int(
(self.effective_sample_count - self.batch_size * self.dp_size * self.num_generations)
/ self.num_generations
/ self.dp_size
)
* self.num_generations
)
if num_excessive_samples > 0:
data = {
k: (
v[: -num_excessive_samples if num_excessive_samples != 0 else v.size(0)]
if k
in [
"input_ids",
"attention_mask",
"action_log_probs",
"action_mask",
"response_idx",
"gt_answer",
]
else v
)
for k, v in data.items()
}
action_mask = action_mask[
: -num_excessive_samples if num_excessive_samples != 0 else action_mask.size(0)
]
loss_mask = loss_mask[: -num_excessive_samples if num_excessive_samples != 0 else loss_mask.size(0)]
advantages = advantages[: -num_excessive_samples if num_excessive_samples != 0 else advantages.size(0)]
else:
num_excessive_samples = 0
else:
# If dynamic batching is disabled, we need to use all samples for training.
need_update = (step_idx + 1) % self.num_microbatches == 0
num_excessive_samples = 0
pbar.set_postfix( pbar.set_postfix(
{ {
"Step": self.global_step + 1, "Global Step": self.global_step,
"Status": f"Collecting: {self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}", "Effective prompts": f"{self.effective_prompt_count}/{self.batch_size * self.dp_size}",
"Effective samples": f"{self.effective_sample_count}/{self.batch_size * self.dp_size * self.num_generations}",
} }
) )
@ -375,7 +364,7 @@ class GRPOConsumer(BaseConsumer):
kl.append(appox_kl.mean()) kl.append(appox_kl.mean())
else: else:
per_token_kl = 0.0 per_token_kl = 0.0
kl.append(0.0) kl.append(torch.tensor(0.0))
loss, _ = self.policy_loss_fn( loss, _ = self.policy_loss_fn(
action_log_probs, action_log_probs,
@ -479,6 +468,7 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.global_step += 1 self.global_step += 1
sample_utilization = self.effective_sample_count / self.total_sample_count sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_prompt_count = 0
self.effective_sample_count = 0 self.effective_sample_count = 0
self.total_sample_count = 0 self.total_sample_count = 0
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
@ -495,6 +485,7 @@ class GRPOConsumer(BaseConsumer):
f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}",
f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}",
f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}",
f"Sample_utilization: {sample_utilization:.4f}",
] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else [])
print("\n".join(to_log_msg)) print("\n".join(to_log_msg))
metrics = { metrics = {
@ -520,9 +511,15 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_() self.accum_advantages.zero_()
self.accum_response_length.zero_() self.accum_response_length.zero_()
self.accum_count = 0 self.accum_count = 0
return loss_scalar, num_excessive_samples // self.num_generations
if excessive_prompts_idx is not None:
# All gather excessive prompts index across DP ranks.
excessive_prompts_idx = [idx + self.dp_rank * self.minibatch_size for idx in excessive_prompts_idx]
excessive_prompts_idx = all_gather_tensors(excessive_prompts_idx, self.plugin)
return loss_scalar, excessive_prompts_idx
else: else:
return None, num_excessive_samples // self.num_generations return None, excessive_prompts_idx
def state_dict(self): def state_dict(self):
self.policy_model._force_wait_all_gather() self.policy_model._force_wait_all_gather()

View File

@ -144,3 +144,29 @@ def all_reduce_sum(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
else: else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
return tensor return tensor
def all_gather_tensors(local_tensor_list: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
"""
Gathers tensors from all processes and concatenates them along the first dimension.
Args:
tensor (torch.Tensor): The input tensor to be gathered.
Returns:
torch.Tensor: The gathered tensor.
"""
# Gather tensors across DP group
if plugin is not None:
all_tensor_lists = [None] * plugin.dp_size
dist.all_gather_object(all_tensor_lists, local_tensor_list, group=plugin.dp_group)
gathered_tensor_list = []
for tensors in all_tensor_lists:
gathered_tensor_list.extend(tensors)
else:
all_tensor_lists = [None] * dist.get_world_size()
dist.all_gather_object(all_tensor_lists, local_tensor_list)
gathered_tensor_list = []
for tensors in all_tensor_lists:
gathered_tensor_list.extend(tensors)
return gathered_tensor_list

View File

@ -9,7 +9,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
# Distributed training parameters # Distributed training parameters
@ -20,7 +20,7 @@ if __name__ == "__main__":
"-ibs", "-ibs",
"--inference-batch-size", "--inference-batch-size",
type=int, type=int,
default=None, default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.", help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
) )
parser.add_argument( parser.add_argument(
@ -41,7 +41,7 @@ if __name__ == "__main__":
"-tMbs", "-tMbs",
"--train-minibatch-size", "--train-minibatch-size",
type=int, type=int,
default=None, default=8,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs", help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
) )
parser.add_argument( parser.add_argument(
@ -58,7 +58,7 @@ if __name__ == "__main__":
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional" "--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
) )
parser.add_argument( parser.add_argument(
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional" "--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
) )
# Sampling parameters # Sampling parameters
@ -223,7 +223,7 @@ if __name__ == "__main__":
"zero_stage": 2, "zero_stage": 2,
}, # for zero }, # for zero
# plugin_config={ # plugin_config={
# "tp_size": 1, # "tp_size": 2,
# "pp_size": 2, # "pp_size": 2,
# "microbatch_size": max( # "microbatch_size": max(
# 1, args.train_microbatch_size // 2 # 1, args.train_microbatch_size // 2