mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-13 11:34:37 +00:00
[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)
This commit is contained in:
@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
def size_global(self, args: Optional[int] = None) -> torch.Size:
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user