[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)

This commit is contained in:
HELSON
2022-07-19 14:15:28 +08:00
committed by GitHub
parent f3ce7b8336
commit f92c100ddd
6 changed files with 209 additions and 91 deletions

View File

@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
def size_global(self, args: Optional[int] = None) -> torch.Size:
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:

View File

@@ -141,9 +141,18 @@ class ProcessGroup:
def rank(self):
return self._rank
def ranks_in_group(self):
return self._rank_list
def world_size(self):
return self._world_size
def tp_rank_list(self):
return self._tp_rank_list
def dp_rank_list(self):
return self._dp_rank_list
def tp_local_rank(self):
return self._rank % self._tp_degree