[tensor] redistribute among different process groups (#1247)

* make it faster

* [tensor] rename convert_to_dist -> redistribute

* [tensor] ShardSpec and ReplicaSpec

* [tensor] redistribute among diff pgs

* polish code
This commit is contained in:
Jiarui Fang
2022-07-12 10:24:05 +08:00
committed by GitHub
parent 9bcd2fd4af
commit 1aad903c15
8 changed files with 48 additions and 17 deletions

View File

@@ -186,7 +186,28 @@ class ColoTensor(torch.Tensor):
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistributed tensor payload among TP process group. Keep the
DP process group still as replicated.
2. If the pg is not not None and not equal to the cureent process group.
First, convert the tensor as replicated among TP process group.
Second, reset the process group.
Third, conver the tensor (new replicated both among tp and dp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
Returns:
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
print('here _redistribute')
# if the pg is not equal, convert the current tensor to replicated
self._redistribute(ReplicaSpec())
self.process_group = pg
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
@@ -202,7 +223,6 @@ class ColoTensor(torch.Tensor):
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)