mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -2,15 +2,13 @@ import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -66,9 +64,9 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T
|
||||
assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
|
||||
|
||||
|
||||
def _run_offload_codegen(rank):
|
||||
def _run_offload_codegen(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model and input
|
||||
model = MyNet().cuda()
|
||||
@@ -116,13 +114,14 @@ def _run_offload_codegen(rank):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_act_ckpt_codegen():
|
||||
mp.spawn(_run_offload_codegen, nprocs=1)
|
||||
spawn(_run_offload_codegen, 1)
|
||||
|
||||
|
||||
def _run_offload_codegen_torch11(rank):
|
||||
def _run_offload_codegen_torch11(rank, world_size, port):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model and input
|
||||
model = MyNet().cuda()
|
||||
@@ -171,8 +170,9 @@ def _run_offload_codegen_torch11(rank):
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_act_ckpt_python_code_torch11():
|
||||
mp.spawn(_run_offload_codegen_torch11, nprocs=1)
|
||||
spawn(_run_offload_codegen_torch11, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user