mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] fixed wrong generated strategy for dot op (#1746)
* [autoparallel] fixed wrong generated strategy for dot op * polish code
This commit is contained in:
@@ -36,9 +36,10 @@ def ignore_sharding_exception(func):
|
||||
def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tensor):
|
||||
"""
|
||||
This function checks whether the ShardingSpec is valid for the physical tensor.
|
||||
This check includes 2 items:
|
||||
This check includes 3 items:
|
||||
1. the sharding spec covers all dimensions of the physical tensor
|
||||
2. the sharding spec for each dimension is divisible by the number of devices.
|
||||
3. the sharding spec's entire shape must match the tensor shape
|
||||
#
|
||||
"""
|
||||
# make sure all dims are covered in sharding spec
|
||||
@@ -65,3 +66,6 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
||||
|
||||
assert dim_size >= num_devices and dim_size % num_devices == 0, \
|
||||
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
|
||||
|
||||
# make sure the entire shape matches the physical tensor shape
|
||||
assert sharding_spec.entire_shape == tensor.shape
|
||||
|
Reference in New Issue
Block a user