[autoparallel] fixed wrong sharding strategy in conv handler (#1747)

* [autoparallel] fixed wrong sharding strategy in conv handler

* polish code
This commit is contained in:
Frank Lee
2022-10-20 16:12:39 +08:00
committed by GitHub
parent 8b8937d901
commit 474111ecb5
6 changed files with 75 additions and 60 deletions

View File

@@ -68,4 +68,5 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
# make sure the entire shape matches the physical tensor shape
assert sharding_spec.entire_shape == tensor.shape
assert sharding_spec.entire_shape == tensor.shape, \
f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}'