mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
device_mesh_is_1d = True
|
||||
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
|
||||
if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
|
||||
device_mesh_is_1d = False
|
||||
|
||||
if device_mesh_is_1d:
|
||||
@@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
|
||||
# Sb = Sb x Sb
|
||||
# can be None as it is only for 1D device mesh
|
||||
# only for 1D device mesh
|
||||
if len(self.device_mesh.mesh_shape) == 1:
|
||||
if len(self.device_mesh.shape) == 1:
|
||||
mesh_dim = 0
|
||||
else:
|
||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||
mesh_dim = self.device_mesh.shape.index(1)
|
||||
strategy_list.append(self.split_one_batch_dim(mesh_dim))
|
||||
else:
|
||||
# for 2D device mesh
|
||||
|
Reference in New Issue
Block a user