[zero] ZeRO supports pipeline parallel (#477)

This commit is contained in:
ver217
2022-03-21 16:55:37 +08:00
committed by GitHub
parent 7f5e4592eb
commit 8d3250d74b
3 changed files with 113 additions and 95 deletions

View File

@@ -262,3 +262,15 @@ class ShardedModelV2(nn.Module):
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
raise NotImplementedError
def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList)
return self.module[idx]
def __len__(self):
assert isinstance(self.module, nn.ModuleList)
return len(self.module)
def __iter__(self):
assert isinstance(self.module, nn.ModuleList)
return iter(self.module)