From 2eca4cd376918e6cd7b085b87483af92acf067bf Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 14 Mar 2023 16:25:47 +0800 Subject: [PATCH] [DTensor] refactor dtensor with new components (#3089) * [DTensor] refactor dtensor with new components * polish --- colossalai/tensor/d_tensor/d_tensor.py | 44 ++++++------------- .../tensor/d_tensor/layout_converter.py | 6 +-- .../test_tensor/test_dtensor/test_dtensor.py | 11 ++--- 3 files changed, 20 insertions(+), 41 deletions(-) diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py index e311eb3ba..c1fe9d50a 100644 --- a/colossalai/tensor/d_tensor/d_tensor.py +++ b/colossalai/tensor/d_tensor/d_tensor.py @@ -3,12 +3,11 @@ from typing import Optional import torch from torch.utils._pytree import tree_map -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global -from colossalai.tensor.sharding_spec import ShardingSpec +from .layout import Layout +from .layout_converter import LayoutConverter, to_global +from .sharding_spec import ShardingSpec -shape_consistency_manager = ShapeConsistencyManager() +layout_converter = LayoutConverter() class DTensor(torch.Tensor): @@ -17,8 +16,6 @@ class DTensor(torch.Tensor): self.local_tensor = local_tensor self.data_type = local_tensor.dtype self.entire_shape = local_tensor.shape - if dist_layout.entire_shape is None: - dist_layout.entire_shape = self.entire_shape self.dist_layout = dist_layout self._apply_layout() @@ -36,20 +33,19 @@ class DTensor(torch.Tensor): ''' Convert the layout of the tensor from source_spec to target_spec. ''' - source_spec = convert_layout_to_sharding_spec(self.dist_layout) - target_spec = convert_layout_to_sharding_spec(target_layout) - self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( - self.local_tensor, source_spec, target_spec) + self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout) self.dist_layout = target_layout def _apply_layout(self): ''' Apply the layout to the local tensor during initializing process. ''' - source_spec = construct_default_sharding_spec(self.local_tensor, self.device_mesh) - target_spec = convert_layout_to_sharding_spec(self.dist_layout) - self.local_tensor = shape_consistency_manager.apply_for_autoparallel_runtime( - self.local_tensor, source_spec, target_spec) + source_spec = construct_default_sharding_spec(self.local_tensor) + source_layout = Layout(device_mesh=self.dist_layout.device_mesh, + device_type=self.dist_layout.device_type, + sharding_spec=source_spec, + entire_shape=self.entire_shape) + self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -108,7 +104,7 @@ class DTensor(torch.Tensor): will not change the layout of the DTensor. This function is mainly used for debugging or check the correctness of the distributed tensor. ''' - return to_global(self.local_tensor, convert_layout_to_sharding_spec(self.dist_layout)) + return to_global(self.local_tensor, self.dist_layout) def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: @@ -139,20 +135,8 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] return module -def convert_layout_to_sharding_spec(layout: Layout) -> ShardingSpec: - ''' - Convert the layout from Layout class to ShardingSpec class. - ''' - return ShardingSpec(device_mesh=layout.device_mesh, - entire_shape=layout.entire_shape, - dim_partition_dict=layout.sharding_spec.dim_partition_dict) - - -def construct_default_sharding_spec( - tensor: torch.Tensor, - device_mesh: DeviceMesh, -) -> ShardingSpec: +def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: ''' Construct the default sharding specification for the tensor. ''' - return ShardingSpec(device_mesh=device_mesh, entire_shape=tensor.shape, dim_partition_dict={}) + return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 22bbb1d2f..a4f4c9c2d 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -22,21 +22,21 @@ __all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_o @dataclass class LayoutConverterOptions: """ - LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency. + LayoutConverterOptions is a dataclass which specifies the preferences for layout converting. """ # TODO: layout converter option is not implemented yet pass def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: - shape_consistency_manager = LayoutConverter() + layout_converter = LayoutConverter() global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) global_layout = Layout(device_mesh=layout.device_mesh, device_type=layout.device_type, sharding_spec=global_sharding_spec, entire_shape=layout.entire_shape) with torch.no_grad(): - global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout) + global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) return global_tensor diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 80e275d97..a99ac6e41 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -4,12 +4,11 @@ import torch import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.utils import free_port @@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port): compare_output = test_model(original_tensor) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) - target_sharding_spec = ShardingSpec(device_mesh=device_mesh, - entire_shape=original_tensor.shape, - dim_partition_dict={0: [0]}) + target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec, @@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - new_sharding_spec = ShardingSpec(device_mesh=device_mesh, - entire_shape=original_tensor.shape, - dim_partition_dict={0: [0, 1]}) + new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) new_layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=new_sharding_spec,