[autoparallel] fixed wrong generated strategy for dot op (#1746)

* [autoparallel] fixed wrong generated strategy for dot op

* polish code
This commit is contained in:
Frank Lee
2022-10-20 15:18:16 +08:00
committed by GitHub
parent 993b8875b6
commit 8b8937d901
13 changed files with 187 additions and 116 deletions

View File

@@ -36,9 +36,10 @@ def ignore_sharding_exception(func):
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
This check includes 3 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
3. the sharding spec's entire shape must match the tensor shape
#
"""
# make sure all dims are covered in sharding spec
@@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape