[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -3,7 +3,6 @@ import copy
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import colossalai
@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -26,8 +25,8 @@ except:
withcodegen = False
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
spawn(_run_C_solver_consistency_test, 1)
if __name__ == '__main__':

View File

@@ -4,7 +4,6 @@ from typing import Callable
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@rerun_if_address_is_in_use()
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
spawn(_run_ckpt_solver, 1)
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
@rerun_if_address_is_in_use()
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
spawn(_run_ckpt_solver_torch11, 1)
if __name__ == '__main__':

View File

@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.testing import clear_cache_before_run
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -24,6 +25,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
@@ -84,6 +86,7 @@ def test_linearize():
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
@clear_cache_before_run()
def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()