mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
fix memory leakage support tp+pp
This commit is contained in:
parent
a40d82f629
commit
1ea3b72c22
@ -72,6 +72,8 @@ class BaseConsumer:
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
|
||||
self.buffer = []
|
||||
@ -132,6 +134,8 @@ class BaseConsumer:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
@ -109,7 +109,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
super().setup()
|
||||
if self.use_wandb and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
@ -282,10 +282,10 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
if len(kl) > 0:
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
|
||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||
mean_kl.append(kl)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
mean_loss.append(loss.data)
|
||||
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
|
||||
policy_model_logits = self.policy_model(
|
||||
@ -336,7 +336,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
mean_kl.append(kl.data)
|
||||
mean_loss.append(loss.data)
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
@ -355,11 +355,11 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
print(
|
||||
"Loss:",
|
||||
|
@ -121,7 +121,7 @@ if __name__ == "__main__":
|
||||
# plugin_config={}, # for zero
|
||||
plugin_config={
|
||||
"pp_size": 2,
|
||||
"tp_size": 1,
|
||||
"tp_size": 2,
|
||||
"microbatch_size": args.train_microbatch_size // 2,
|
||||
"zero_stage": 0,
|
||||
"max_norm": 1.0,
|
||||
|
Loading…
Reference in New Issue
Block a user