From 0dd4e2bbfbef3e704fbfc8e4c9bc957864ceccd6 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 27 Jun 2022 15:56:11 +0800 Subject: [PATCH] [Tensor] rename some APIs in TensorSpec and Polish view unittest (#1176) --- colossalai/nn/_ops/addmm.py | 6 +++--- colossalai/nn/_ops/embedding.py | 5 +++-- colossalai/nn/_ops/embedding_bag.py | 2 +- colossalai/nn/_ops/linear.py | 6 +++--- colossalai/nn/_ops/loss.py | 2 +- colossalai/tensor/chunk.py | 4 +++- colossalai/tensor/distspec.py | 2 +- colossalai/tensor/tensor_spec.py | 4 ++-- tests/test_tensor/test_tensor.py | 13 +++++++++---- 9 files changed, 26 insertions(+), 18 deletions(-) diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index 2091f3247..0acd7486c 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -72,10 +72,10 @@ def colo_addmm(input_tensor: GeneralTensor, assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op' ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)) elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_replicate(): + if mat2.tensor_spec.is_shard_1drow() and input_tensor.tensor_spec.is_replicate(): mode = 'row' - elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col() - or input_tensor.tensor_spec.is_1D_row()): + elif mat2.tensor_spec.is_shard_1dcol() and (input_tensor.tensor_spec.is_shard_1dcol() + or input_tensor.tensor_spec.is_shard_1drow()): mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 5f41b0c6e..1e392c04d 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -32,6 +32,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) compute_spec = weight.tensor_spec.compute_spec + if compute_spec.output_replicate: return output.to_replicate() else: @@ -125,9 +126,9 @@ def colo_embedding(input_tensor: GeneralTensor, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse)) elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.tensor_spec.is_1D_row(): + if weight.tensor_spec.is_shard_1drow(): mode = 'row' - elif weight.tensor_spec.is_1D_col(): + elif weight.tensor_spec.is_shard_1dcol(): mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index 825dd8d92..e8c21c047 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -104,7 +104,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor, include_last_offset=include_last_offset, padding_idx=padding_idx)) elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.tensor_spec.is_1D_col(): + if weight.tensor_spec.is_shard_1dcol(): tp_mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 01dcef6a6..45637c5a1 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -71,10 +71,10 @@ def colo_linear_imp(input_tensor: GeneralTensor, assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op' ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias)) elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_replicate()): + if weight.tensor_spec.is_shard_1dcol() and (bias is None or bias.tensor_spec.is_replicate()): mode = 'row' - elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row() - or bias.tensor_spec.is_1D_col()): + elif weight.tensor_spec.is_shard_1drow() and (bias is None or bias.tensor_spec.is_shard_1drow() + or bias.tensor_spec.is_shard_1dcol()): mode = 'col' else: raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}") diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index 1fc814937..0550193bc 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -29,7 +29,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor, label_smoothing=label_smoothing) return ColoTensor.from_torch_tensor(output) elif input_tensor.has_compute_spec(): # Single Model Parallel Applied - if input_tensor.tensor_spec.is_1D_col(): + if input_tensor.tensor_spec.is_shard_1dcol(): output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) return ColoTensor.from_torch_tensor(output) else: diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index bb7a17ae5..b66612088 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -116,6 +116,7 @@ class Chunk: if self.is_src_rank: self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) tensor_state = TensorState.HOLD + assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) else: tensor.storage().resize_(0) @@ -131,6 +132,7 @@ class Chunk: self._update_tensors_state(TensorState.FREE) def _update_tensors_ptr(self) -> None: + assert type(self._payload) == torch.Tensor for tensor, tensor_info in self.tensors_info.items(): tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) @@ -228,7 +230,7 @@ class Chunk: data_slice (torch.Tensor): the tensor to be copied to the chunk """ tensor_info = self.tensors_info[tensor] - self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1)) + self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) @property diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py index 99f041fd3..714a135bf 100644 --- a/colossalai/tensor/distspec.py +++ b/colossalai/tensor/distspec.py @@ -54,5 +54,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int assert process_group is not None assert isinstance(dims, list) and isinstance(num_partitions, list) assert len(dims) == len(num_partitions) - assert prod(num_partitions) == process_group.size() + assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}" return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions)) diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py index f847ad62f..d38ff2aac 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/tensor/tensor_spec.py @@ -32,11 +32,11 @@ class TensorSpec(object): and self.dist_spec.num_partitions[0] == 1) \ or (self.dist_spec.process_group.size() == 1) - def is_1D_col(self): + def is_shard_1dcol(self): return self.dist_spec.placement == DistPlacementPattern.SHARD \ and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 - def is_1D_row(self): + def is_shard_1drow(self): return self.dist_spec.placement == DistPlacementPattern.SHARD \ and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index e28a94bd8..836d7c16e 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -63,13 +63,19 @@ def test_operand(): def _run_view(world_size): t_ref = torch.randn(4, 5) t = ColoTensor.from_torch_tensor( - t_ref, TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[2]))) + t_ref, + TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], + num_partitions=[world_size]))) assert t.size()[0] == 4 * world_size assert t.size(1) == 5 assert t.size() == torch.Size([4 * world_size, 5]) + t.view_base(4 * 5) + assert t.tensor_spec.dist_spec.placement.value == 's' + t = t.view(4 * 5 * world_size) + assert t.tensor_spec.dist_spec.placement.value == 'r' assert t.shape == torch.Size([4 * 5 * world_size]) @@ -100,11 +106,10 @@ def run_dist_tests(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() -def _test_dist_cases(world_size): +def test_dist_cases(world_size): run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - # _test_dist_init(4) - _test_dist_cases(2) + test_dist_cases(2)