[tensor]build sharding spec to replace distspec in future. (#1405)

This commit is contained in:
YuliangLiu0306
2022-08-08 11:15:57 +08:00
committed by GitHub
parent 12b4887097
commit 7c96055c68
2 changed files with 116 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
def test_sharding_spec():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
assert str(sharding_spec.sharding_sequence) == "[S01, R, R]"
if __name__ == '__main__':
test_sharding_spec()