mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 21:22:49 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
|
||||
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -24,6 +25,7 @@ except:
|
||||
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
@@ -84,6 +86,7 @@ def test_linearize():
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize_torch11():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
|
Reference in New Issue
Block a user