fix memory leakage support tp+pp

This commit is contained in:
YeAnbang 2025-04-09 17:11:55 +08:00
parent a40d82f629
commit 1ea3b72c22
3 changed files with 12 additions and 8 deletions

View File

@ -72,6 +72,8 @@ class BaseConsumer:
self.plugin = HybridParallelPlugin(**plugin_config) self.plugin = HybridParallelPlugin(**plugin_config)
self.booster = Booster(plugin=self.plugin) self.booster = Booster(plugin=self.plugin)
self.dp_rank = dist.get_rank(self.plugin.dp_group) self.dp_rank = dist.get_rank(self.plugin.dp_group)
self.tp_rank = dist.get_rank(self.plugin.tp_group)
self.dp_size = dist.get_world_size(self.plugin.dp_group) self.dp_size = dist.get_world_size(self.plugin.dp_group)
self.buffer = [] self.buffer = []
@ -132,6 +134,8 @@ class BaseConsumer:
ray_broadcast_tensor_dict( ray_broadcast_tensor_dict(
state_dict, src=self.num_producers, device=self.device, group_name="sync_model" state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
) )
del state_dict
torch.cuda.empty_cache()
@ray.remote @ray.remote

View File

@ -109,7 +109,7 @@ class GRPOConsumer(BaseConsumer):
super().setup() super().setup()
if self.use_wandb and ( if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0) (not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()) or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
): ):
# Initialize wandb. # Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}" name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
@ -282,10 +282,10 @@ class GRPOConsumer(BaseConsumer):
if self.booster.plugin.stage_manager.is_last_stage(): if self.booster.plugin.stage_manager.is_last_stage():
if len(kl) > 0: if len(kl) > 0:
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin) kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
mean_kl.append(kl) mean_kl.append(kl)
loss = all_reduce_mean(loss, self.plugin) mean_loss.append(all_reduce_mean(loss, self.plugin).data)
mean_loss.append(loss.data) torch.cuda.empty_cache()
else: else:
policy_model_logits = self.policy_model( policy_model_logits = self.policy_model(
@ -336,7 +336,7 @@ class GRPOConsumer(BaseConsumer):
mean_kl.append(kl.data) mean_kl.append(kl.data)
mean_loss.append(loss.data) mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or ( if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
reward = all_reduce_mean(reward.mean(), self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin) format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
@ -355,11 +355,11 @@ class GRPOConsumer(BaseConsumer):
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
if not self.plugin.pp_size > 1 or ( if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
loss_scalar = self.accum_loss.item() loss_scalar = self.accum_loss.item()
if (not self.plugin.pp_size > 1 and self.rank == 0) or ( if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
): ):
print( print(
"Loss:", "Loss:",

View File

@ -121,7 +121,7 @@ if __name__ == "__main__":
# plugin_config={}, # for zero # plugin_config={}, # for zero
plugin_config={ plugin_config={
"pp_size": 2, "pp_size": 2,
"tp_size": 1, "tp_size": 2,
"microbatch_size": args.train_microbatch_size // 2, "microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0, "zero_stage": 0,
"max_norm": 1.0, "max_norm": 1.0,