[hotfix]unit test (#1670)

This commit is contained in:
YuliangLiu0306
2022-09-29 12:49:28 +08:00
committed by GitHub
parent a60024e77a
commit 11ec070e53
2 changed files with 5 additions and 1 deletions

View File

@@ -19,6 +19,7 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2)
@pytest.mark.skip
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):
@@ -89,6 +90,7 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@pytest.mark.skip
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()