mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 09:51:18 +00:00
flake8 style (#352)
This commit is contained in:
parent
54ee8d1254
commit
7eb87f516d
@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
|
||||
self.ranks_in_group = sub_ranks
|
||||
|
||||
def register_module(self, module: nn.Module):
|
||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
for p in module.parameters():
|
||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(p, src, group=self.group)
|
||||
|
||||
def register_parameter(self, param: nn.Parameter):
|
||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
assert self.ranks_in_group is not None,\
|
||||
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(param, src, group=self.group)
|
||||
|
Loading…
Reference in New Issue
Block a user