[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
This commit is contained in:
Frank Lee
2022-10-19 12:53:06 +08:00
committed by GitHub
parent cbe9a4cb45
commit eee84908d4
36 changed files with 459 additions and 303 deletions

View File

@@ -1,6 +1,7 @@
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec():
@@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
entire_shape = torch.Size((16, 8, 6))
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R