[test] fixed the activation codegen test (#1447)

* [test] fixed the activation codegen test

* polish code
This commit is contained in:
Frank Lee 2022-08-12 14:52:31 +08:00 committed by GitHub
parent 0f3042363c
commit 5a52e21fe3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 7 deletions

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os import os
from pprint import pp
import random import random
import socket import socket
from pathlib import Path from pathlib import Path

View File

@ -1,6 +1,7 @@
from operator import mod from operator import mod
import torch import torch
import pytest import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
@ -42,10 +43,9 @@ class MyModule(torch.nn.Module):
return y1 + y2 + y3 + y4 return y1 + y2 + y3 + y4
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def _run_act_ckpt_codegen(rank):
def test_act_ckpt_codegen():
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
@ -90,10 +90,14 @@ def test_act_ckpt_codegen():
gpc.destroy() gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_python_code_torch11(): def test_act_ckpt_codegen():
mp.spawn(_run_act_ckpt_codegen, nprocs=1)
def _run_act_ckpt_python_code_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
@ -138,6 +142,11 @@ def test_act_ckpt_python_code_torch11():
gpc.destroy() gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__': if __name__ == '__main__':
test_act_ckpt_codegen() test_act_ckpt_codegen()