mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
||||
# make sure all dims are covered in sharding spec
|
||||
sharding_len = len(sharding_spec.sharding_sequence)
|
||||
tensor_num_dim = tensor.dim()
|
||||
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
||||
num_devices_in_col = sharding_spec.device_mesh.shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.shape[1]
|
||||
assert sharding_len == tensor_num_dim, \
|
||||
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
||||
|
||||
|
Reference in New Issue
Block a user