[autoparallel] fixed broken node handler tests (#1708)

This commit is contained in:
Frank Lee
2022-10-14 18:25:59 +08:00
committed by GitHub
parent 1468e4bcfc
commit 22a115406b
8 changed files with 53 additions and 49 deletions

View File

@@ -22,7 +22,6 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):
@@ -93,7 +92,6 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()

View File

@@ -11,7 +11,6 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
@@ -50,7 +49,7 @@ def test_norm_pool_handler():
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 9