mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
fix pp+tp, fix dataloader (#6280)
This commit is contained in:
@@ -49,6 +49,12 @@ class GRPOConsumer(BaseConsumer):
|
||||
UserWarning,
|
||||
)
|
||||
minibatch_size = batch_size
|
||||
if (
|
||||
plugin_config.get("pp_size", 1) > 1
|
||||
and "num_microbatches" not in plugin_config
|
||||
and "microbatch_size" not in plugin_config
|
||||
):
|
||||
plugin_config["microbatch_size"] = max(1, grpo_config.get("train_microbatch_size") // 2)
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
@@ -373,7 +379,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss_mask=inputs["loss_mask"],
|
||||
total_effective_tokens_in_batch=total_effective_tokens_count,
|
||||
)
|
||||
return loss, num_excessive_samples // self.num_generations
|
||||
return loss
|
||||
|
||||
policy_model_outputs = self.booster.execute_pipeline(
|
||||
iter([data_policy_forward]),
|
||||
@@ -468,10 +474,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
sample_utilization = self.effective_sample_count / self.total_sample_count
|
||||
self.effective_sample_count = 0
|
||||
self.total_sample_count = 0
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
@@ -507,7 +513,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_advantages.zero_()
|
||||
self.accum_response_length.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
return loss_scalar, num_excessive_samples // self.num_generations
|
||||
else:
|
||||
return None, num_excessive_samples // self.num_generations
|
||||
|
||||
|
Reference in New Issue
Block a user