fix pp+tp, fix dataloader (#6280)

This commit is contained in:
YeAnbang
2025-04-28 17:10:00 +08:00
committed by GitHub
parent 28795f560c
commit 2ca1e3c630
5 changed files with 17 additions and 8 deletions

View File

@@ -49,6 +49,12 @@ class GRPOConsumer(BaseConsumer):
UserWarning,
)
minibatch_size = batch_size
if (
plugin_config.get("pp_size", 1) > 1
and "num_microbatches" not in plugin_config
and "microbatch_size" not in plugin_config
):
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
super().__init__(
num_producers,
num_episodes,
@@ -373,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
loss_mask=inputs["loss_mask"],
total_effective_tokens_in_batch=total_effective_tokens_count,
)
return loss, num_excessive_samples // self.num_generations
return loss
policy_model_outputs = self.booster.execute_pipeline(
iter([data_policy_forward]),
@@ -468,10 +474,10 @@ class GRPOConsumer(BaseConsumer):
sample_utilization = self.effective_sample_count / self.total_sample_count
self.effective_sample_count = 0
self.total_sample_count = 0
loss_scalar = self.accum_loss.item()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
loss_scalar = self.accum_loss.item()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
@@ -507,7 +513,7 @@ class GRPOConsumer(BaseConsumer):
self.accum_advantages.zero_()
self.accum_response_length.zero_()
self.accum_count = 0
return loss_scalar, num_excessive_samples // self.num_generations
return loss_scalar, num_excessive_samples // self.num_generations
else:
return None, num_excessive_samples // self.num_generations