[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -1,6 +1,8 @@
import pytest
import torch
from colossalai.testing import clear_cache_before_run, parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
except:
@@ -62,9 +64,10 @@ class AModel(torch.nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@clear_cache_before_run()
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
def test_mod_dir(bias, bias_addition_split, shape):
model = AModel(bias=bias)
x = torch.rand(shape)
@@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape):
if __name__ == '__main__':
test_mod_dir(True, True, (3, 3, 3))
test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3))