mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[tensor] add ColoTensor 1Dcol (#888)
This commit is contained in:
@@ -121,18 +121,25 @@ class ColoTensor(object):
|
||||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||
if self._shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(
|
||||
ComputePattern.TP1DRow)
|
||||
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
||||
elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(
|
||||
ComputePattern.TP1DCol)
|
||||
self._shard_1d(parallel_action=parallel_action, dim=0)
|
||||
|
||||
def _shard_1d(self, parallel_action, dim=-1):
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
Reference in New Issue
Block a user