From 5a52e21fe351910214c9641baeb748bcaadce260 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 12 Aug 2022 14:52:31 +0800 Subject: [PATCH] [test] fixed the activation codegen test (#1447) * [test] fixed the activation codegen test * polish code --- colossalai/utils/common.py | 1 - .../test_activation_checkpoint_codegen.py | 21 +++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index ccc136858..a52c25530 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- import os -from pprint import pp import random import socket from pathlib import Path diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 411ec0083..fe5c638b2 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,6 +1,7 @@ from operator import mod import torch import pytest +import torch.multiprocessing as mp from torch.utils.checkpoint import checkpoint from torch.fx import GraphModule from colossalai.fx import ColoTracer @@ -42,10 +43,9 @@ class MyModule(torch.nn.Module): return y1 + y2 + y3 + y4 -@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') -def test_act_ckpt_codegen(): +def _run_act_ckpt_codegen(rank): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') # build model and run forward model = MyModule() @@ -90,10 +90,14 @@ def test_act_ckpt_codegen(): gpc.destroy() -@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') -def test_act_ckpt_python_code_torch11(): +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_act_ckpt_codegen, nprocs=1) + + +def _run_act_ckpt_python_code_torch11(rank): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') # build model and run forward model = MyModule() @@ -138,6 +142,11 @@ def test_act_ckpt_python_code_torch11(): gpc.destroy() +@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + + if __name__ == '__main__': test_act_ckpt_codegen()