mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
48
colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
Normal file
48
colossalai/legacy/nn/layer/wrapper/pipeline_wrapper.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class PipelineSharedModuleWrapper:
|
||||
|
||||
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
|
||||
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
|
||||
self.pipeline_ranks = pipeline_ranks
|
||||
self.group = None
|
||||
self.ranks_in_group = None
|
||||
self._init_group()
|
||||
|
||||
def _init_group(self):
|
||||
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
rank = gpc.get_global_rank()
|
||||
num_dp_groups = world_size // dp_size
|
||||
num_pp_stages = num_dp_groups // pp_size
|
||||
for i in range(dp_size):
|
||||
for j in range(num_pp_stages):
|
||||
pipeline_ranks = list(range(i * num_dp_groups + j, (i + 1) * num_dp_groups, num_pp_stages))
|
||||
sub_ranks = [pipeline_ranks[idx] for idx in self.pipeline_ranks]
|
||||
group = dist.new_group(sub_ranks)
|
||||
if rank in sub_ranks:
|
||||
self.group = group
|
||||
self.ranks_in_group = sub_ranks
|
||||
|
||||
def register_module(self, module: nn.Module):
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
for p in module.parameters():
|
||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(p, src, group=self.group)
|
||||
|
||||
def register_parameter(self, param: nn.Parameter):
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(param, src, group=self.group)
|
Reference in New Issue
Block a user