[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,3 +1,3 @@
from .pipeline_wrapper import PipelineSharedModuleWrapper
__all__ = ['PipelineSharedModuleWrapper']
__all__ = ["PipelineSharedModuleWrapper"]

View File

@@ -8,9 +8,8 @@ from colossalai.legacy.core import global_context as gpc
class PipelineSharedModuleWrapper:
def __init__(self, pipeline_ranks: Union[List[int], Tuple[int]]) -> None:
assert len(pipeline_ranks) > 1, f'Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}'
assert len(pipeline_ranks) > 1, f"Expect len(pipeline_ranks) > 1, got {len(pipeline_ranks)}"
self.pipeline_ranks = pipeline_ranks
self.group = None
self.ranks_in_group = None
@@ -33,16 +32,18 @@ class PipelineSharedModuleWrapper:
self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert (
self.ranks_in_group is not None
), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}"
src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group)
setattr(p, "pipeline_shared_module_pg", self.group)
dist.broadcast(p, src, group=self.group)
def register_parameter(self, param: nn.Parameter):
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert (
self.ranks_in_group is not None
), f"Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}"
src = self.ranks_in_group[self.pipeline_ranks[0]]
setattr(param, 'pipeline_shared_module_pg', self.group)
setattr(param, "pipeline_shared_module_pg", self.group)
dist.broadcast(param, src, group=self.group)