mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[DTensor] refactor dtensor with new components (#3089)
* [DTensor] refactor dtensor with new components * polish
This commit is contained in:
@@ -4,12 +4,11 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
@@ -34,9 +33,7 @@ def check_dtensor(rank, world_size, port):
|
||||
compare_output = test_model(original_tensor)
|
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=original_tensor.shape,
|
||||
dim_partition_dict={0: [0]})
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
@@ -62,9 +59,7 @@ def check_dtensor(rank, world_size, port):
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=original_tensor.shape,
|
||||
dim_partition_dict={0: [0, 1]})
|
||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||
new_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=new_sharding_spec,
|
||||
|
Reference in New Issue
Block a user