mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[rename] convert_to_dist -> redistribute (#1243)
This commit is contained in:
@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
||||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class ColoTensor(torch.Tensor):
|
||||
"""
|
||||
assert isinstance(dist_spec, _DistSpec)
|
||||
assert self.process_group is not None
|
||||
self._convert_to_dist_spec(dist_spec)
|
||||
self._redistribute(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
if dist_spec:
|
||||
@@ -174,8 +174,8 @@ class ColoTensor(torch.Tensor):
|
||||
def __repr__(self):
|
||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
|
||||
|
||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
"""_convert_to_dist_spec
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
Args:
|
||||
@@ -186,7 +186,7 @@ class ColoTensor(torch.Tensor):
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||
|
||||
@@ -194,13 +194,13 @@ class ColoTensor(torch.Tensor):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
self._redistribute(dist_spec=distspec.replicate())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.convert_to_dist_spec(distspec.replicate())
|
||||
return self.redistribute(distspec.replicate())
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
@@ -234,7 +234,7 @@ class ColoTensor(torch.Tensor):
|
||||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
replicated_t = self.redistribute(dist_spec=distspec.replicate())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
@@ -280,4 +280,4 @@ class ColoTensor(torch.Tensor):
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def is_sharded(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
|
Reference in New Issue
Block a user