[rename] convert_to_dist -> redistribute (#1243)

This commit is contained in:
Jiarui Fang
2022-07-11 13:05:44 +08:00
committed by GitHub
parent f6add9b720
commit 2699dfbbfd
7 changed files with 18 additions and 18 deletions

View File

@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}
@@ -140,7 +140,7 @@ class ColoTensor(torch.Tensor):
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None
self._convert_to_dist_spec(dist_spec)
self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec:
@@ -174,8 +174,8 @@ class ColoTensor(torch.Tensor):
def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
@@ -186,7 +186,7 @@ class ColoTensor(torch.Tensor):
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
@@ -194,13 +194,13 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._convert_to_dist_spec(dist_spec=distspec.replicate())
self._redistribute(dist_spec=distspec.replicate())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate())
return self.redistribute(distspec.replicate())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
@@ -234,7 +234,7 @@ class ColoTensor(torch.Tensor):
"""
if self.is_replicate():
return super().view(*args)
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
replicated_t = self.redistribute(dist_spec=distspec.replicate())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
@@ -280,4 +280,4 @@ class ColoTensor(torch.Tensor):
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD
return self.dist_spec.placement == DistPlacementPattern.SHARD