[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
@@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
if len(self.device_mesh.shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh