[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -1,14 +1,11 @@
import pytest
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn

View File

@@ -3,9 +3,7 @@ import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global
from colossalai.testing import rerun_if_address_is_in_use, spawn
@@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port):
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor = DTensor(original_tensor, layout)
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
assert d_tensor.entire_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype
assert get_global_shape(d_tensor) == original_tensor.shape
assert d_tensor.dtype == original_tensor.dtype
if rank in (0, 1):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
assert d_tensor.equal(original_tensor.narrow(0, 0, 2))
elif rank in (2, 3):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
assert d_tensor.to_global().equal(original_tensor)
assert to_global(d_tensor).equal(original_tensor)
output = test_model(d_tensor)
if rank in (0, 1):
@@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
assert d_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
assert d_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
assert d_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')

View File

@@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16))