mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)
This commit is contained in:
@@ -262,7 +262,7 @@ class ColoTensor(torch.Tensor):
|
||||
replicated_t = self.redistribute(dist_spec=ReplicaSpec())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
def size_global(self, args: Optional[int] = None) -> torch.Size:
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
|
@@ -141,9 +141,18 @@ class ProcessGroup:
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
def ranks_in_group(self):
|
||||
return self._rank_list
|
||||
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def tp_rank_list(self):
|
||||
return self._tp_rank_list
|
||||
|
||||
def dp_rank_list(self):
|
||||
return self._dp_rank_list
|
||||
|
||||
def tp_local_rank(self):
|
||||
return self._rank % self._tp_degree
|
||||
|
||||
|
Reference in New Issue
Block a user