[tensor] add ColoTensor 1Dcol (#888)

This commit is contained in:
Ziyue Jiang
2022-04-27 14:13:55 +08:00
committed by GitHub
parent a0e5971692
commit 1d0aba4153
4 changed files with 166 additions and 28 deletions

View File

@@ -121,18 +121,25 @@ class ColoTensor(object):
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
if self._shard_spec.num_action == 1:
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
dim = -1
chunk_size = divide(self._size[dim], num_partition)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DRow)
self._shard_1d(parallel_action=parallel_action, dim=-1)
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(
ComputePattern.TP1DCol)
self._shard_1d(parallel_action=parallel_action, dim=0)
def _shard_1d(self, parallel_action, dim=-1):
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
chunk_size = divide(self._size[dim], num_partition)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):