mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .pipeline_wrapper import PipelineSharedModuleWrapper
|
||||
|
||||
__all__ = ['PipelineSharedModuleWrapper']
|
||||
__all__ = ["PipelineSharedModuleWrapper"]
|
||||
|
@@ -8,9 +8,8 @@ from colossalai.legacy.core import global_context as gpc
|
||||
|
||||
|
||||
class PipelineSharedModuleWrapper:
|
||||
|
||||
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
|
||||
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
|
||||
assert len(pipeline_ranks) > 1, f"Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}"
|
||||
self.pipeline_ranks = pipeline_ranks
|
||||
self.group = None
|
||||
self.ranks_in_group = None
|
||||
@@ -33,16 +32,18 @@ class PipelineSharedModuleWrapper:
|
||||
self.ranks_in_group = sub_ranks
|
||||
|
||||
def register_module(self, module: nn.Module):
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
assert (
|
||||
self.ranks_in_group is not None
|
||||
), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}"
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
for p in module.parameters():
|
||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||
setattr(p, "pipeline_shared_module_pg", self.group)
|
||||
dist.broadcast(p, src, group=self.group)
|
||||
|
||||
def register_parameter(self, param: nn.Parameter):
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
assert (
|
||||
self.ranks_in_group is not None
|
||||
), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}"
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||
setattr(param, "pipeline_shared_module_pg", self.group)
|
||||
dist.broadcast(param, src, group=self.group)
|
||||
|
Reference in New Issue
Block a user