From 3abbaf8bc68c8a3366241a3dc2e97f6944605fb2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 14:53:04 +0800 Subject: [PATCH] update codegen test --- .../test_autochunk/test_autochunk_codegen.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/test_autochunk/test_autochunk_codegen.py b/tests/test_autochunk/test_autochunk_codegen.py index 8246275eb..c91148e11 100644 --- a/tests/test_autochunk/test_autochunk_codegen.py +++ b/tests/test_autochunk/test_autochunk_codegen.py @@ -1,3 +1,5 @@ +from functools import partial + import pytest import torch import torch.fx @@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): ) -def _test_autochunk_codegen(rank): +def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly colossalai.launch( config={}, @@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank): # build model and input model = evoformer_base().cuda() - msa_len = 32 - pair_len = 64 node = torch.randn(1, msa_len, pair_len, 256).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda() @@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank): MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") ) - codegen = AutoChunkCodeGen(gm_prop) + codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) graph.set_codegen(codegen) gm = ColoGraphModule(model, graph) gm.recompile() @@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank): gpc.destroy() -def test_autochunk_codegen(): - mp.spawn(_test_autochunk_codegen, nprocs=1) +@pytest.mark.parametrize("max_memory", [None, 20, 24, 28, 32]) +@pytest.mark.parametrize("msa_len", [32]) +@pytest.mark.parametrize("pair_len", [64]) +def test_autochunk_codegen(msa_len, pair_len, max_memory): + run_func = partial( + _test_autochunk_codegen, + msa_len=msa_len, + pair_len=pair_len, + max_memory=max_memory, + ) + mp.spawn(run_func, nprocs=1) if __name__ == "__main__": - _test_autochunk_codegen(0) + _test_autochunk_codegen(0, 32, 64, None)