mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[Tensor] distributed view supports inter-process hybrid parallel (#1169)
This commit is contained in:
@@ -8,6 +8,7 @@ from colossalai.tensor import TensorSpec
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _convert_output(output):
|
||||
@@ -60,6 +61,12 @@ class ColoTensor(torch.Tensor):
|
||||
def tensor_spec(self) -> TensorSpec:
|
||||
return self._tensor_spec
|
||||
|
||||
@tensor_spec.setter
|
||||
def tensor_spec(self, tenseor_spec: TensorSpec):
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def set_tensor_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
@@ -136,4 +143,52 @@ class ColoTensor(torch.Tensor):
|
||||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
return tensor
|
||||
|
||||
##### override builtin functions which must use tensor in replicate placement ####
|
||||
|
||||
def view_base(self, *args) -> 'ColoTensor':
|
||||
return super().view(*args)
|
||||
|
||||
def size_base(self, *args, **kwargs) -> torch.Size:
|
||||
return super().size(*args, **kwargs)
|
||||
|
||||
def view(self, *args) -> 'ColoTensor':
|
||||
"""override the torch buildin view()
|
||||
the args passed in must be in a replicate placement.
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.tensor_spec.is_replicate():
|
||||
return super().view(*args)
|
||||
# TODO(jiaruifang) check why this not work
|
||||
# self.data = self.to_replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate())
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
return super().view(*args)
|
||||
|
||||
def size(self, args: Optional[int] = None):
|
||||
"""override the torch buildin size()
|
||||
the shape passed in must be in a replicate placement.
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.tensor_spec.is_replicate():
|
||||
if args is not None:
|
||||
return super().size(args)
|
||||
else:
|
||||
return super().size()
|
||||
|
||||
spec = self.tensor_spec.dist_spec
|
||||
dims = spec.dims
|
||||
num_partitions = spec.num_partitions
|
||||
# import inspect
|
||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
||||
|
||||
size_list = list(super().size())
|
||||
for dim, num_partition in zip(dims, num_partitions):
|
||||
size_list[dim] *= num_partition
|
||||
if args is not None:
|
||||
return size_list[args]
|
||||
else:
|
||||
return torch.Size(size_list)
|
||||
|
Reference in New Issue
Block a user