mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[DTensor] implementation of dtensor (#2946)
* [DTensor] implementation of dtensor * test layout convert * polish
This commit is contained in:
104
tests/test_tensor/test_dtensor.py
Normal file
104
tests/test_tensor/test_dtensor.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(in_features, out_features)
|
||||
self.linear_2 = torch.nn.Linear(out_features, in_features)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_dtensor(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_model = TestModel(8, 8).to('cuda')
|
||||
original_tensor = torch.rand(4, 8).to('cuda')
|
||||
compare_output = test_model(original_tensor)
|
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=original_tensor.shape,
|
||||
dim_partition_dict={0: [0]})
|
||||
layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec)
|
||||
d_tensor = DTensor(original_tensor, layout)
|
||||
|
||||
assert d_tensor.entire_shape == original_tensor.shape
|
||||
assert d_tensor.data_type == original_tensor.dtype
|
||||
|
||||
if rank in (0, 1):
|
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
|
||||
elif rank in (2, 3):
|
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
assert d_tensor.to_global().equal(original_tensor)
|
||||
output = test_model(d_tensor)
|
||||
|
||||
if rank in (0, 1):
|
||||
assert output.equal(compare_output.narrow(0, 0, 2))
|
||||
elif rank in (2, 3):
|
||||
assert output.equal(compare_output.narrow(0, 2, 2))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
new_sharding_spec = ShardingSpec(device_mesh=device_mesh,
|
||||
entire_shape=original_tensor.shape,
|
||||
dim_partition_dict={0: [0, 1]})
|
||||
new_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=new_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
|
||||
d_tensor.layout_convert(new_layout)
|
||||
|
||||
if rank == 0:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
elif rank == 1:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
elif rank == 2:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
elif rank == 3:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||
|
||||
if rank == 0:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
elif rank == 1:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
elif rank == 2:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
elif rank == 3:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
|
||||
def test_dtensor():
|
||||
world_size = 4
|
||||
run_func = partial(check_dtensor, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dtensor()
|
Reference in New Issue
Block a user