mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@@ -3,9 +3,7 @@ import torch
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port):
|
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
d_tensor = DTensor(original_tensor, layout)
|
||||
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
|
||||
|
||||
assert d_tensor.entire_shape == original_tensor.shape
|
||||
assert d_tensor.data_type == original_tensor.dtype
|
||||
assert get_global_shape(d_tensor) == original_tensor.shape
|
||||
assert d_tensor.dtype == original_tensor.dtype
|
||||
|
||||
if rank in (0, 1):
|
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 0, 2))
|
||||
elif rank in (2, 3):
|
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
assert d_tensor.to_global().equal(original_tensor)
|
||||
assert to_global(d_tensor).equal(original_tensor)
|
||||
output = test_model(d_tensor)
|
||||
|
||||
if rank in (0, 1):
|
||||
@@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port):
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||
new_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=new_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
|
||||
d_tensor.layout_convert(new_layout)
|
||||
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
|
||||
|
||||
if rank == 0:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
elif rank == 1:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
elif rank == 2:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
elif rank == 3:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||
|
||||
if rank == 0:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
|
||||
elif rank == 1:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1))
|
||||
elif rank == 2:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1))
|
||||
elif rank == 3:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
|
@@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
|
Reference in New Issue
Block a user