[DTensor] refactor dtensor with new components (#3089)

* [DTensor] refactor dtensor with new components

* polish
This commit is contained in:
YuliangLiu0306
2023-03-14 16:25:47 +08:00
committed by GitHub
parent ed8f60b93b
commit 2eca4cd376
3 changed files with 20 additions and 41 deletions

View File

@@ -4,12 +4,11 @@ import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.utils import free_port
@@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
compare_output = test_model(original_tensor)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0]})
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
@@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
else:
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=original_tensor.shape,
dim_partition_dict={0: [0, 1]})
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,