mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-19 09:51:18 +00:00
flake8 style (#352)
This commit is contained in:
parent
54ee8d1254
commit
7eb87f516d
@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
|
|||||||
|
|
||||||
def divide(numerator, denominator):
|
def divide(numerator, denominator):
|
||||||
"""Only allow exact division
|
"""Only allow exact division
|
||||||
|
|
||||||
:param numerator: Numerator of the division
|
:param numerator: Numerator of the division
|
||||||
:param denominator: Denominator of the division
|
:param denominator: Denominator of the division
|
||||||
"""
|
"""
|
||||||
|
@ -101,7 +101,7 @@ class WrappedDropPath(nn.Module):
|
|||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
class VanillaPatchEmbedding(nn.Module):
|
class VanillaPatchEmbedding(nn.Module):
|
||||||
"""
|
"""
|
||||||
2D Image to Patch Embedding
|
2D Image to Patch Embedding
|
||||||
|
|
||||||
:param img_size: image size
|
:param img_size: image size
|
||||||
|
@ -33,14 +33,16 @@ class PipelineSharedModuleWrapper:
|
|||||||
self.ranks_in_group = sub_ranks
|
self.ranks_in_group = sub_ranks
|
||||||
|
|
||||||
def register_module(self, module: nn.Module):
|
def register_module(self, module: nn.Module):
|
||||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
assert self.ranks_in_group is not None,\
|
||||||
|
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||||
dist.broadcast(p, src, group=self.group)
|
dist.broadcast(p, src, group=self.group)
|
||||||
|
|
||||||
def register_parameter(self, param: nn.Parameter):
|
def register_parameter(self, param: nn.Parameter):
|
||||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
assert self.ranks_in_group is not None,\
|
||||||
|
f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||||
dist.broadcast(param, src, group=self.group)
|
dist.broadcast(param, src, group=self.group)
|
||||||
|
Loading…
Reference in New Issue
Block a user