flake8 style (#352)

This commit is contained in:
Liang Bowen 2022-03-09 17:34:43 +08:00 committed by Frank Lee
parent 54ee8d1254
commit 7eb87f516d
3 changed files with 6 additions and 4 deletions

View File

@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
self.ranks_in_group = sub_ranks
def register_module(self, module: nn.Module):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
for p in module.parameters():
setattr(p, 'pipeline_shared_module_pg', self.group)
dist.broadcast(p, src, group=self.group)
def register_parameter(self, param: nn.Parameter):
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
assert self.ranks_in_group is not None,\
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
src = self.ranks_in_group[self.pipeline_ranks[0]]
setattr(param, 'pipeline_shared_module_pg', self.group)
dist.broadcast(param, src, group=self.group)