mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -188,7 +188,7 @@ class NodeHandler(ABC):
|
||||
remove_strategy_list = []
|
||||
for strategy in self.strategies_vector:
|
||||
shard_axis_list = []
|
||||
last_axis = len(self.device_mesh.mesh_shape) - 1
|
||||
last_axis = len(self.device_mesh.shape) - 1
|
||||
for op_data, sharding_spec in strategy.sharding_specs.items():
|
||||
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
|
||||
for dim, shard_axes in sharding_spec.dim_partition_dict.items():
|
||||
|
@@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
device_mesh_is_1d = True
|
||||
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
|
||||
if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
|
||||
device_mesh_is_1d = False
|
||||
|
||||
if device_mesh_is_1d:
|
||||
@@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
# only for 1D device mesh
|
||||
if len(self.device_mesh.mesh_shape) == 1:
|
||||
if len(self.device_mesh.shape) == 1:
|
||||
mesh_dim = 0
|
||||
else:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
mesh_dim = self.device_mesh.shape.index(1)
|
||||
strategy_list.append(self.split_one_batch_dim(mesh_dim))
|
||||
else:
|
||||
# for 2D device mesh
|
||||
|
@@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
|
||||
# make sure all dims are covered in sharding spec
|
||||
sharding_len = len(sharding_spec.sharding_sequence)
|
||||
tensor_num_dim = tensor.dim()
|
||||
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
|
||||
num_devices_in_col = sharding_spec.device_mesh.shape[0]
|
||||
num_devices_in_row = sharding_spec.device_mesh.shape[1]
|
||||
assert sharding_len == tensor_num_dim, \
|
||||
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
|
||||
|
||||
|
Reference in New Issue
Block a user