mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[autochunk] support transformer (#2526)
This commit is contained in:
@@ -1,94 +0,0 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from simple_evoformer import base_evoformer, openfold_evoformer
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
loop = 3
|
||||
with torch.no_grad():
|
||||
for _ in range(loop // 2 + 1):
|
||||
if chunk_size:
|
||||
model(node, pair, chunk_size)
|
||||
else:
|
||||
model(node, pair)
|
||||
torch.cuda.synchronize()
|
||||
time1 = time.time()
|
||||
for _ in range(loop):
|
||||
if chunk_size:
|
||||
model(node, pair, chunk_size)
|
||||
else:
|
||||
model(node, pair)
|
||||
torch.cuda.synchronize()
|
||||
time2 = time.time()
|
||||
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem))
|
||||
|
||||
|
||||
def _build_autochunk(model, max_memory, node, pair):
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
},
|
||||
)
|
||||
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# set code_gen
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# print
|
||||
# code = graph.python_code("self").src
|
||||
# print(code)
|
||||
return gm
|
||||
|
||||
|
||||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 128
|
||||
pair_len = 256
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = base_evoformer().cuda()
|
||||
|
||||
# build autochunk model
|
||||
# max_memory = 1000 # MB, fit memory mode
|
||||
max_memory = None # min memory mode
|
||||
autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair)
|
||||
|
||||
# build openfold
|
||||
chunk_size = 64
|
||||
openfold = openfold_evoformer().cuda()
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "base")
|
||||
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_evoformer()
|
122
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
Normal file
122
tests/test_autochunk/test_alphafold/test_alphafold_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
import colossalai
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.autochunk.utils import flat_list
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def assert_codegen_run(
|
||||
model: Any,
|
||||
meta_args: List,
|
||||
concrete_args: List = None,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
) -> List[Dict]:
|
||||
if concrete_args is None:
|
||||
concrete_args = []
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||
concrete_args={k: v for k, v in concrete_args},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||
concrete_args={k: v for k, v in concrete_args},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert chunk in code
|
||||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
# assert result
|
||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
model.cuda()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
out_model = model(*inputs)
|
||||
out_gm = flat_list(out_gm)
|
||||
out_model = flat_list(out_model)
|
||||
for out_gm_i, out_model_i in zip(out_gm, out_model):
|
||||
assert torch.allclose(out_gm_i, out_model_i,
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(out_gm_i - out_model_i))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def run_test(
|
||||
rank: int,
|
||||
data_args: tuple,
|
||||
max_memory: int,
|
||||
get_model: Any,
|
||||
get_data: Any,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = get_model()
|
||||
meta_args, concrete_args = get_data(*data_args)
|
||||
chunks = assert_codegen_run(
|
||||
model,
|
||||
meta_args=meta_args,
|
||||
concrete_args=concrete_args,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
if get_chunk_target is not None:
|
||||
chunk_found = [i["region"] for i in chunks]
|
||||
chunk_target = get_chunk_target()[max_memory]
|
||||
assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % (
|
||||
str(chunk_found),
|
||||
str(chunk_target),
|
||||
)
|
95
tests/test_autochunk/test_alphafold/test_evoformer_block.py
Normal file
95
tests/test_autochunk/test_alphafold/test_evoformer_block.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerBlock
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
from test_alphafold_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
|
||||
def get_model():
|
||||
model = EvoformerBlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
meta_args = [
|
||||
("m", node),
|
||||
("z", pair),
|
||||
("msa_mask", node_mask),
|
||||
("pair_mask", pair_mask),
|
||||
]
|
||||
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(118, 123), (219, 237), (264, 289), (302, 309), (97, 104), (144, 152), (185, 193), (241, 242), (21, 46)],
|
||||
20: [(118, 123), (230, 237), (275, 282), (305, 306), (100, 101), (32, 39), (73, 79)],
|
||||
24: [(118, 123)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
||||
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
def test_evoformer_block(data_args, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
data_args=data_args,
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
90
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
Normal file
90
tests/test_autochunk/test_alphafold/test_evoformer_stack.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerStack
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
from test_alphafold_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
|
||||
def get_model():
|
||||
model = EvoformerStack(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
c_s=384,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
no_blocks=2, # 48
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.25,
|
||||
blocks_per_ckpt=None,
|
||||
inf=1000000000.0,
|
||||
eps=1e-08,
|
||||
clear_cache_between_blocks=False,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
meta_args = [
|
||||
("m", node),
|
||||
("z", pair),
|
||||
("msa_mask", node_mask),
|
||||
("pair_mask", pair_mask),
|
||||
]
|
||||
concrete_args = [("chunk_size", None), ("_mask_trans", True)]
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
||||
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
def test_evoformer_stack(data_args, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
data_args=data_args,
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
96
tests/test_autochunk/test_alphafold/test_extramsa_block.py
Normal file
96
tests/test_autochunk/test_alphafold/test_extramsa_block.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import ExtraMSABlock
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
from test_alphafold_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
|
||||
def get_model():
|
||||
model = ExtraMSABlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
ckpt=False,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
meta_args = [
|
||||
("m", node),
|
||||
("z", pair),
|
||||
("msa_mask", node_mask),
|
||||
("pair_mask", pair_mask),
|
||||
]
|
||||
concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)]
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
def get_chunk_target() -> Dict:
|
||||
return {
|
||||
None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
|
||||
(33, 46)],
|
||||
20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
|
||||
24: [(126, 131)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 24])
|
||||
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
|
||||
def test_extramsa_block(data_args, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
data_args=data_args,
|
||||
max_memory=max_memory,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data_args=(32, 64),
|
||||
max_memory=20,
|
||||
get_model=get_model,
|
||||
get_data=get_data,
|
||||
get_chunk_target=get_chunk_target,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
120
tests/test_autochunk/test_diffuser/test_diffuser_utils.py
Normal file
120
tests/test_autochunk/test_diffuser/test_diffuser_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
import colossalai
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def assert_codegen_run(
|
||||
model: Any,
|
||||
meta_args: List,
|
||||
concrete_args: List = None,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
) -> List[Dict]:
|
||||
if concrete_args is None:
|
||||
concrete_args = []
|
||||
model = model()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||
concrete_args={k: v for k, v in concrete_args},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model.cuda(),
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
|
||||
concrete_args={k: v for k, v in concrete_args},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert chunk in code
|
||||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
# assert result
|
||||
inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
|
||||
model.cuda().eval()
|
||||
gm.eval()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
out_model = model(*inputs)
|
||||
assert torch.allclose(out_gm["sample"], out_model["sample"],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(out_gm["sample"] - out_model["sample"]))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def run_test(
|
||||
rank: int,
|
||||
model: Any,
|
||||
data: tuple,
|
||||
max_memory: int,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
meta_args, concrete_args = data
|
||||
chunks = assert_codegen_run(
|
||||
model,
|
||||
meta_args=meta_args,
|
||||
concrete_args=concrete_args,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
if get_chunk_target is not None:
|
||||
chunk_found = [i["region"] for i in chunks]
|
||||
chunk_target = get_chunk_target()[max_memory]
|
||||
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
|
||||
str(chunk_found),
|
||||
str(chunk_target),
|
||||
)
|
||||
|
||||
gpc.destroy()
|
70
tests/test_autochunk/test_diffuser/test_unet.py
Normal file
70
tests/test_autochunk/test_diffuser/test_unet.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from diffusers import UNet2DModel
|
||||
MODELS = [UNet2DModel]
|
||||
HAS_REPO = True
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
|
||||
from test_diffuser_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 5
|
||||
HEIGHT = 224
|
||||
WIDTH = 224
|
||||
IN_CHANNELS = 3
|
||||
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
|
||||
|
||||
|
||||
def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
sample = torch.randn(shape)
|
||||
meta_args = [
|
||||
("sample", sample),
|
||||
]
|
||||
concrete_args = [("timestep", 50)]
|
||||
return meta_args, concrete_args
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason="not implemented",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
|
||||
@pytest.mark.parametrize("max_memory", [64])
|
||||
def test_evoformer_block(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
data=get_data(shape),
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data(LATENTS_SHAPE),
|
||||
max_memory=64,
|
||||
model=UNet2DModel,
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
@@ -1,163 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerBlock
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||
# for memory test
|
||||
# model = model.cuda()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# node_mask1 = node_mask.clone()
|
||||
# pair_mask1 = pair_mask.clone()
|
||||
# gm(node1, pair1, node_mask1, pair_mask1)
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
||||
|
||||
# test forward
|
||||
model = model.cuda()
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair, node_mask, pair_mask)
|
||||
fx_out = gm(node, pair, node_mask, pair_mask)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = EvoformerBlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = _build_openfold()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"),
|
||||
MetaTensor(pair, fake_device="cuda:0"),
|
||||
MetaTensor(node_mask, fake_device="cuda:0"),
|
||||
MetaTensor(pair_mask, fake_device="cuda:0"),
|
||||
)
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_evoformer_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_evoformer_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_evoformer_codegen(0, 32, 64, 24)
|
@@ -1,163 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerStack
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||
# for memory test
|
||||
# model = model.cuda()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# node_mask1 = node_mask.clone()
|
||||
# pair_mask1 = pair_mask.clone()
|
||||
# gm(node1, pair1, node_mask1, pair_mask1, None)
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
||||
|
||||
# test forward
|
||||
model = model.cuda()
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair, node_mask, pair_mask, None)
|
||||
fx_out = gm(node, pair, node_mask, pair_mask, None)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = EvoformerStack(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
c_s=384,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
no_blocks=2, # 48
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.25,
|
||||
blocks_per_ckpt=None,
|
||||
inf=1000000000.0,
|
||||
eps=1e-08,
|
||||
clear_cache_between_blocks=False,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def _test_evoformer_stack_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = _build_openfold()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"),
|
||||
MetaTensor(node_mask, fake_device="cuda:0"), MetaTensor(pair_mask, fake_device="cuda:0"), None)
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False, print_progress=False)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_evoformer_stack_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_evoformer_stack_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_evoformer_stack_codegen(0, 32, 64, None)
|
@@ -1,164 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import ExtraMSABlock
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||
# for memory test
|
||||
# model = model.cuda()
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# node_mask1 = node_mask.clone()
|
||||
# pair_mask1 = pair_mask.clone()
|
||||
# gm(node1, pair1, node_mask1, pair_mask1)
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
||||
|
||||
# test forward
|
||||
model = model.cuda()
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair, node_mask, pair_mask)
|
||||
fx_out = gm(node, pair, node_mask, pair_mask)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = ExtraMSABlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
ckpt=False,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = _build_openfold()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_chunk_logits": 1024,
|
||||
},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"),
|
||||
MetaTensor(pair, fake_device="cuda:0"),
|
||||
MetaTensor(node_mask, fake_device="cuda:0"),
|
||||
MetaTensor(pair_mask, fake_device="cuda:0"),
|
||||
)
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_chunk_logits": 1024,
|
||||
},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_extramsa_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_extramsa_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_extramsa_codegen(0, 32, 64, None)
|
@@ -1,104 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from simple_evoformer import base_evoformer
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer, symbolic_trace
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = base_evoformer().cuda()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
# meta info prop
|
||||
meta_graph = symbolic_trace(model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
}) # must use symbolic_trace
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
|
||||
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
# print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
_test_fwd(model, gm, node, pair)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_simple_evoformer_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_simple_evoformer_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_simple_evoformer_codegen(0, 32, 64, 25)
|
@@ -1,97 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from simple_evoformer import base_evoformer
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||
found_regions = [i["region"] for i in chunk_infos]
|
||||
|
||||
if msa_len == 32 and pair_len == 64:
|
||||
if max_memory is None:
|
||||
target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191),
|
||||
(161, 166), (198, 203), (7, 57)]
|
||||
elif max_memory == 20:
|
||||
target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)]
|
||||
elif max_memory == 25:
|
||||
target_regions = [(144, 154), (369, 370)]
|
||||
elif max_memory == 30:
|
||||
target_regions = [(144, 154)]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % (
|
||||
str(found_regions),
|
||||
str(target_regions),
|
||||
)
|
||||
|
||||
|
||||
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = base_evoformer().cuda()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
meta_graph = symbolic_trace(model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
}) # must use symbolic_trace
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
|
||||
chunk_infos = codegen.chunk_infos
|
||||
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0")
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_simple_evoformer_search(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_simple_evoformer_search,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_simple_evoformer_search(0, 32, 64, 20)
|
65
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
Normal file
65
tests/test_autochunk/test_transformer/test_autochunk_gpt.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
MODELS = [GPT2Model]
|
||||
HAS_REPO = True
|
||||
except:
|
||||
MODELS = []
|
||||
HAS_REPO = False
|
||||
|
||||
from test_transformer_utils import run_test
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 256
|
||||
|
||||
|
||||
def get_data(shape: tuple) -> Tuple[List, List]:
|
||||
input_ids = torch.zeros(shape, dtype=torch.int64)
|
||||
token_type_ids = torch.zeros(shape, dtype=torch.int64)
|
||||
attention_mask = torch.ones(shape, dtype=torch.int64)
|
||||
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
concrete_args = {"past_key_values": None}
|
||||
sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"]
|
||||
return meta_args, concrete_args, sequence
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
|
||||
@pytest.mark.parametrize("max_memory", [None, 4.5, 5])
|
||||
def test_gpt(model, shape, max_memory):
|
||||
run_func = partial(
|
||||
run_test,
|
||||
data=get_data(shape),
|
||||
max_memory=max_memory,
|
||||
model=model,
|
||||
config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4),
|
||||
print_code=False,
|
||||
print_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data((BATCH_SIZE, SEQ_LENGTH)),
|
||||
max_memory=None,
|
||||
model=GPT2Model,
|
||||
config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
|
||||
print_code=True,
|
||||
print_mem=True,
|
||||
print_progress=False,
|
||||
)
|
123
tests/test_autochunk/test_transformer/test_transformer_utils.py
Normal file
123
tests/test_autochunk/test_transformer/test_transformer_utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
import colossalai
|
||||
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def assert_codegen_run(
|
||||
model: Any,
|
||||
data: tuple,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False,
|
||||
print_code: bool = False,
|
||||
) -> List[Dict]:
|
||||
meta_args, concrete_args, sequence = data
|
||||
if concrete_args is None:
|
||||
concrete_args = {}
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
|
||||
concrete_args={k: v for k, v in concrete_args.items()},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||
meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
|
||||
interp.propagate(*meta_tensors)
|
||||
codegen = AutoChunkCodeGen(
|
||||
meta_graph,
|
||||
max_memory=max_memory,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
chunks = codegen.chunk_infos
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model.cuda(),
|
||||
meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
|
||||
concrete_args={k: v for k, v in concrete_args.items()},
|
||||
)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||
gm.recompile()
|
||||
|
||||
# assert chunk in code
|
||||
code = graph.python_code("self").src
|
||||
if print_code:
|
||||
print(code)
|
||||
assert "chunk_result = None; chunk_size = None;" in code
|
||||
|
||||
# assert result
|
||||
inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
|
||||
inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
|
||||
model.cuda().eval()
|
||||
gm.eval()
|
||||
with torch.no_grad():
|
||||
out_gm = gm(*inputs)
|
||||
out_model = model(*inputs)
|
||||
for k in out_model.keys():
|
||||
if torch.is_tensor(out_gm[k]):
|
||||
assert torch.equal(
|
||||
out_model[k], out_gm[k]
|
||||
), f'{model.__class__.__name__} has incorrect output {k}, expect {out_model[k]}, but got {out_gm[k]}'
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def run_test(
|
||||
rank: int,
|
||||
model: Any,
|
||||
config: Any,
|
||||
data: tuple,
|
||||
max_memory: int,
|
||||
print_code: bool,
|
||||
print_mem: bool,
|
||||
print_progress: bool,
|
||||
get_chunk_target: Any = None,
|
||||
) -> None:
|
||||
model = model(config=config)
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
chunks = assert_codegen_run(
|
||||
model,
|
||||
data=data,
|
||||
max_memory=max_memory,
|
||||
print_code=print_code,
|
||||
print_mem=print_mem,
|
||||
print_progress=print_progress,
|
||||
)
|
||||
|
||||
if get_chunk_target is not None:
|
||||
chunk_found = [i["region"] for i in chunks]
|
||||
chunk_target = get_chunk_target()[max_memory]
|
||||
assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
|
||||
str(chunk_found),
|
||||
str(chunk_target),
|
||||
)
|
||||
|
||||
gpc.destroy()
|
Reference in New Issue
Block a user