[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)

This commit is contained in:
HELSON
2022-07-19 14:15:28 +08:00
committed by GitHub
parent f3ce7b8336
commit f92c100ddd
6 changed files with 209 additions and 91 deletions

View File

@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
def size_global(self, args: Optional[int] = None) -> torch.Size:
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns: