mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[tensor] revert local view back (#1178)
This commit is contained in:
parent
0dd4e2bbfb
commit
1b657f9ce1
@ -52,7 +52,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||||
|
|
||||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||||
num_embeddings_per_partition = weight.size_base(0)
|
num_embeddings_per_partition = weight.size_local(0)
|
||||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||||
|
|
||||||
|
@ -101,13 +101,3 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||||||
# TODO(jzy) we don't support object reflection now.
|
# TODO(jzy) we don't support object reflection now.
|
||||||
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
|
|
||||||
|
|
||||||
def view(self, *args) -> 'ColoTensor':
|
|
||||||
return super().view_base(*args)
|
|
||||||
|
|
||||||
def size(self, *args, **kwargs) -> torch.Size:
|
|
||||||
# import inspect
|
|
||||||
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
|
|
||||||
return super().size_base(*args, **kwargs)
|
|
||||||
|
@ -147,13 +147,13 @@ class ColoTensor(torch.Tensor):
|
|||||||
|
|
||||||
##### override builtin functions which must use tensor in replicate placement ####
|
##### override builtin functions which must use tensor in replicate placement ####
|
||||||
|
|
||||||
def view_base(self, *args) -> 'ColoTensor':
|
def view_local(self, *args) -> 'ColoTensor':
|
||||||
return super().view(*args)
|
return super().view(*args)
|
||||||
|
|
||||||
def size_base(self, *args, **kwargs) -> torch.Size:
|
def size_local(self, *args, **kwargs) -> torch.Size:
|
||||||
return super().size(*args, **kwargs)
|
return super().size(*args, **kwargs)
|
||||||
|
|
||||||
def view(self, *args) -> 'ColoTensor':
|
def view_global(self, *args) -> 'ColoTensor':
|
||||||
"""override the torch buildin view()
|
"""override the torch buildin view()
|
||||||
the args passed in must be in a replicate placement.
|
the args passed in must be in a replicate placement.
|
||||||
Returns:
|
Returns:
|
||||||
@ -167,7 +167,7 @@ class ColoTensor(torch.Tensor):
|
|||||||
self._tensor_spec.dist_spec = distspec.replicate()
|
self._tensor_spec.dist_spec = distspec.replicate()
|
||||||
return super().view(*args)
|
return super().view(*args)
|
||||||
|
|
||||||
def size(self, args: Optional[int] = None):
|
def size_global(self, args: Optional[int] = None):
|
||||||
"""override the torch buildin size()
|
"""override the torch buildin size()
|
||||||
the shape passed in must be in a replicate placement.
|
the shape passed in must be in a replicate placement.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -67,14 +67,14 @@ def _run_view(world_size):
|
|||||||
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
|
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0],
|
||||||
num_partitions=[world_size])))
|
num_partitions=[world_size])))
|
||||||
|
|
||||||
assert t.size()[0] == 4 * world_size
|
assert t.size_global()[0] == 4 * world_size
|
||||||
assert t.size(1) == 5
|
assert t.size_global(1) == 5
|
||||||
assert t.size() == torch.Size([4 * world_size, 5])
|
assert t.size_global() == torch.Size([4 * world_size, 5])
|
||||||
|
|
||||||
t.view_base(4 * 5)
|
t.view_local(4 * 5)
|
||||||
assert t.tensor_spec.dist_spec.placement.value == 's'
|
assert t.tensor_spec.dist_spec.placement.value == 's'
|
||||||
|
|
||||||
t = t.view(4 * 5 * world_size)
|
t = t.view_global(4 * 5 * world_size)
|
||||||
assert t.tensor_spec.dist_spec.placement.value == 'r'
|
assert t.tensor_spec.dist_spec.placement.value == 'r'
|
||||||
assert t.shape == torch.Size([4 * 5 * world_size])
|
assert t.shape == torch.Size([4 * 5 * world_size])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user