mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[feat] Sync shard model (#6289)
* [feat] support hybrid parallel model sync * update consumer and producer * update files * update producer * remove print * update --------- Co-authored-by: duanjunwen <935724073@qq.com> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
@@ -59,10 +59,6 @@ class BaseConsumer:
|
||||
self.lr_scheduler = None
|
||||
|
||||
def setup(self) -> None:
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||
if self.rank == 0:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
|
||||
|
||||
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
|
||||
@@ -77,8 +73,24 @@ class BaseConsumer:
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.tp_rank = dist.get_rank(self.plugin.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.plugin.pp_group)
|
||||
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
self.tp_size = dist.get_world_size(self.plugin.tp_group)
|
||||
self.pp_size = dist.get_world_size(self.plugin.pp_group)
|
||||
|
||||
# Init Hybrid ray process group
|
||||
for i in range(self.num_producers):
|
||||
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
|
||||
if self.pp_size > 1:
|
||||
# use hybrid tp + pp
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
cc.init_collective_group(
|
||||
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
|
||||
|
||||
self.buffer = []
|
||||
|
||||
@@ -140,13 +152,27 @@ class BaseConsumer:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
||||
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
if self.pp_size > 1:
|
||||
print(
|
||||
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
|
||||
)
|
||||
else:
|
||||
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = self.state_dict()
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
if self.pp_size > 1:
|
||||
if self.tp_rank == 0 and self.dp_rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict,
|
||||
src=self.num_producers,
|
||||
device=self.device,
|
||||
group_name=f"sync_model_{self.pp_rank}",
|
||||
)
|
||||
else:
|
||||
if self.rank == 0:
|
||||
ray_broadcast_tensor_dict(
|
||||
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
|
||||
)
|
||||
del state_dict
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user