mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
update codegen test
This commit is contained in:
parent
74b81395a2
commit
3abbaf8bc6
@ -1,3 +1,5 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _test_autochunk_codegen(rank):
|
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
||||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||||
colossalai.launch(
|
colossalai.launch(
|
||||||
config={},
|
config={},
|
||||||
@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank):
|
|||||||
|
|
||||||
# build model and input
|
# build model and input
|
||||||
model = evoformer_base().cuda()
|
model = evoformer_base().cuda()
|
||||||
msa_len = 32
|
|
||||||
pair_len = 64
|
|
||||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank):
|
|||||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||||
)
|
)
|
||||||
|
|
||||||
codegen = AutoChunkCodeGen(gm_prop)
|
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||||
graph.set_codegen(codegen)
|
graph.set_codegen(codegen)
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank):
|
|||||||
gpc.destroy()
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
def test_autochunk_codegen():
|
@pytest.mark.parametrize("max_memory", [None, 20, 24, 28, 32])
|
||||||
mp.spawn(_test_autochunk_codegen, nprocs=1)
|
@pytest.mark.parametrize("msa_len", [32])
|
||||||
|
@pytest.mark.parametrize("pair_len", [64])
|
||||||
|
def test_autochunk_codegen(msa_len, pair_len, max_memory):
|
||||||
|
run_func = partial(
|
||||||
|
_test_autochunk_codegen,
|
||||||
|
msa_len=msa_len,
|
||||||
|
pair_len=pair_len,
|
||||||
|
max_memory=max_memory,
|
||||||
|
)
|
||||||
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_autochunk_codegen(0)
|
_test_autochunk_codegen(0, 32, 64, None)
|
||||||
|
Loading…
Reference in New Issue
Block a user