[autochunk] support autochunk on evoformer (#2497)

This commit is contained in:
oahzxl
2023-01-19 11:41:00 +08:00
committed by GitHub
parent 304f1ba124
commit ecccc91f21
9 changed files with 200 additions and 188 deletions

View File

@@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
# for memory test
# model = model.cuda()
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# node_mask1 = node_mask.clone()
# pair_mask1 = pair_mask.clone()
# gm(node1, pair1, node_mask1, pair_mask1)
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
# test forward
model = model.cuda()
@@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
MetaTensor(node_mask, fake_device="cuda:0"),
MetaTensor(pair_mask, fake_device="cuda:0"),
)
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
# trace and recompile
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
@@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"_mask_trans": True,
},
)
# graph.set_codegen(codegen)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
assert "chunk_size" in code
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
gpc.destroy()
@@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_evoformer_codegen(msa_len, pair_len, max_memory):
@@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
if __name__ == "__main__":
_test_evoformer_codegen(0, 32, 64, 25)
_test_evoformer_codegen(0, 32, 64, 24)

View File

@@ -13,7 +13,7 @@ except:
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx import ColoTracer, symbolic_trace
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
@@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
# for memory test
# torch.cuda.reset_peak_memory_stats()
# now_mem = torch.cuda.memory_allocated() / 1024**2
# with torch.no_grad():
# node1 = node.clone()
# pair1 = pair.clone()
# gm(node1, pair1)
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print(
# "autochunk now mem:%.2f max mem:%.2f"
# % (new_now_mem - now_mem, new_max_mem - now_mem)
# )
# test forward
with torch.no_grad():
non_fx_out = model(node, pair)
fx_out = gm(node, pair)
@@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
# meta info prop
meta_graph = symbolic_trace(model,
meta_args={
"node": node.to(torch.device("meta")),
"pair": pair.to(torch.device("meta")),
}) # must use symbolic_trace
interp = MetaInfoProp(meta_graph)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
# trace the module and replace codegen
graph = ColoTracer().trace(
model,
@@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
"pair": pair.to(torch.device("meta")),
},
)
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph)
gm.recompile()
# assert we have inserted chunk
code = graph.python_code("self").src
assert "chunk_size" in code
# print(code)
assert "chunk_result = None; chunk_size = None;" in code
_test_fwd(model, gm, node, pair)
gpc.destroy()

View File

@@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
str(target_regions),
)
for region in target_regions:
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % (
str(region),
msa_len,
pair_len,
max_memory,
str(max_memory),
)
for region in found_regions:
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
str(region),
msa_len,
pair_len,
max_memory,
str(max_memory),
)