mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
fix memory leakage support tp+pp
This commit is contained in:
parent
a40d82f629
commit
1ea3b72c22
@ -72,6 +72,8 @@ class BaseConsumer:
|
|||||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||||
self.booster = Booster(plugin=self.plugin)
|
self.booster = Booster(plugin=self.plugin)
|
||||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||||
|
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||||
|
|
||||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
@ -132,6 +134,8 @@ class BaseConsumer:
|
|||||||
ray_broadcast_tensor_dict(
|
ray_broadcast_tensor_dict(
|
||||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||||
)
|
)
|
||||||
|
del state_dict
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
|
@ -109,7 +109,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
super().setup()
|
super().setup()
|
||||||
if self.use_wandb and (
|
if self.use_wandb and (
|
||||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
|
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||||
):
|
):
|
||||||
# Initialize wandb.
|
# Initialize wandb.
|
||||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||||
@ -282,10 +282,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
if len(kl) > 0:
|
if len(kl) > 0:
|
||||||
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
|
kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data
|
||||||
mean_kl.append(kl)
|
mean_kl.append(kl)
|
||||||
loss = all_reduce_mean(loss, self.plugin)
|
mean_loss.append(all_reduce_mean(loss, self.plugin).data)
|
||||||
mean_loss.append(loss.data)
|
torch.cuda.empty_cache()
|
||||||
else:
|
else:
|
||||||
|
|
||||||
policy_model_logits = self.policy_model(
|
policy_model_logits = self.policy_model(
|
||||||
@ -336,7 +336,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
mean_kl.append(kl.data)
|
mean_kl.append(kl.data)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||||
@ -355,11 +355,11 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
loss_scalar = self.accum_loss.item()
|
loss_scalar = self.accum_loss.item()
|
||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
|
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||||
):
|
):
|
||||||
print(
|
print(
|
||||||
"Loss:",
|
"Loss:",
|
||||||
|
@ -121,7 +121,7 @@ if __name__ == "__main__":
|
|||||||
# plugin_config={}, # for zero
|
# plugin_config={}, # for zero
|
||||||
plugin_config={
|
plugin_config={
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"tp_size": 1,
|
"tp_size": 2,
|
||||||
"microbatch_size": args.train_microbatch_size // 2,
|
"microbatch_size": args.train_microbatch_size // 2,
|
||||||
"zero_stage": 0,
|
"zero_stage": 0,
|
||||||
"max_norm": 1.0,
|
"max_norm": 1.0,
|
||||||
|
Loading…
Reference in New Issue
Block a user