[Tensor] distributed view supports inter-process hybrid parallel (#1169)

This commit is contained in:
Jiarui Fang
2022-06-27 09:45:26 +08:00
committed by GitHub
parent 9e1daa63d2
commit aa7bef73d4
13 changed files with 101 additions and 19 deletions

View File

@@ -114,7 +114,7 @@ class Chunk:
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self.is_src_rank:
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
tensor_state = TensorState.HOLD
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
else:

View File

@@ -101,3 +101,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
def view(self, *args) -> 'ColoTensor':
return super().view_base(*args)
def size(self, *args, **kwargs) -> torch.Size:
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
return super().size_base(*args, **kwargs)

View File

@@ -8,6 +8,7 @@ from colossalai.tensor import TensorSpec
from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec
from typing import Optional
def _convert_output(output):
@@ -60,6 +61,12 @@ class ColoTensor(torch.Tensor):
def tensor_spec(self) -> TensorSpec:
return self._tensor_spec
@tensor_spec.setter
def tensor_spec(self, tenseor_spec: TensorSpec):
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def set_tensor_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
@@ -136,4 +143,52 @@ class ColoTensor(torch.Tensor):
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
memo[id(self)] = tensor
return tensor
return tensor
##### override builtin functions which must use tensor in replicate placement ####
def view_base(self, *args) -> 'ColoTensor':
return super().view(*args)
def size_base(self, *args, **kwargs) -> torch.Size:
return super().size(*args, **kwargs)
def view(self, *args) -> 'ColoTensor':
"""override the torch buildin view()
the args passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if self.tensor_spec.is_replicate():
return super().view(*args)
# TODO(jiaruifang) check why this not work
# self.data = self.to_replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
return super().view(*args)
def size(self, args: Optional[int] = None):
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if self.tensor_spec.is_replicate():
if args is not None:
return super().size(args)
else:
return super().size()
spec = self.tensor_spec.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list = list(super().size())
for dim, num_partition in zip(dims, num_partitions):
size_list[dim] *= num_partition
if args is not None:
return size_list[args]
else:
return torch.Size(size_list)

View File

@@ -68,6 +68,7 @@ class DistSpecManager:
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
idx %= num_parts

View File

@@ -26,7 +26,7 @@ class TensorSpec(object):
def get_placement(self):
return self.dist_spec.placement
def is_gathered(self):
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \