mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -3,7 +3,6 @@ import copy
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
|
||||
import colossalai
|
||||
@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
# from colossalai.fx.passes.algorithms import solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import Sequence
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -26,8 +25,8 @@ except:
|
||||
withcodegen = False
|
||||
|
||||
|
||||
def _run_C_solver_consistency_test(rank=0):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_C_solver_consistency_test(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
|
||||
model = M()
|
||||
@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
|
||||
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_C_solver_consistency():
|
||||
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
|
||||
spawn(_run_C_solver_consistency_test, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -4,7 +4,6 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
|
||||
@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
|
||||
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
||||
|
||||
|
||||
def _run_ckpt_solver(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_ckpt_solver(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
|
||||
|
||||
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
spawn(_run_ckpt_solver, 1)
|
||||
|
||||
|
||||
def _run_ckpt_solver_torch11(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_ckpt_solver_torch11(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ckpt_solver_torch11():
|
||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||
spawn(_run_ckpt_solver_torch11, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
|
||||
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -24,6 +25,7 @@ except:
|
||||
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
@@ -84,6 +86,7 @@ def test_linearize():
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize_torch11():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
|
Reference in New Issue
Block a user