mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[autoparallel] fixed broken node handler tests (#1708)
This commit is contained in:
@@ -22,7 +22,6 @@ class BMMTorchFunctionModule(nn.Module):
|
||||
return torch.bmm(x1, x2)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_2d_device_mesh(module):
|
||||
|
||||
@@ -93,7 +92,6 @@ def test_2d_device_mesh(module):
|
||||
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
def test_1d_device_mesh(module):
|
||||
model = module()
|
||||
|
@@ -11,7 +11,6 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_norm_pool_handler():
|
||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
@@ -50,7 +49,7 @@ def test_norm_pool_handler():
|
||||
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
strategies_vector = handler.register_strategy()
|
||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||
strategy_name_list = [val.name for val in strategies_vector]
|
||||
assert len(strategy_name_list) == 9
|
||||
|
||||
|
Reference in New Issue
Block a user