[hot-fix] Fix memory leakage bug, support TP+PP (#6258)

* update help information

* update style

* fix

* minor fix

* support PP training

* add pp support

* remove unused code

* address conversation

* fix memory leakage support tp+pp

* move empty cache

* move empty cache

---------

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
YeAnbang
2025-04-10 10:52:18 +08:00
committed by GitHub
parent ed43a4be04
commit 9467c10690
3 changed files with 12 additions and 8 deletions

View File

@@ -109,7 +109,7 @@ class GRPOConsumer(BaseConsumer):
super().setup()
if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
@@ -282,10 +282,9 @@ class GRPOConsumer(BaseConsumer):
if self.booster.plugin.stage_manager.is_last_stage():
if len(kl) > 0:
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl)
loss = all_reduce_mean(loss, self.plugin)
mean_loss.append(loss.data)
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
else:
policy_model_logits = self.policy_model(
@@ -336,7 +335,7 @@ class GRPOConsumer(BaseConsumer):
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
@@ -355,11 +354,11 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step()
self.optimizer.zero_grad()
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
loss_scalar = self.accum_loss.item()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
print(
"Loss:",