mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 19:58:17 +00:00
parent
0f02b8c6e6
commit
72341e65f4
@ -6,12 +6,7 @@ from torch.fx.node import Node, map_arg
|
|||||||
|
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
|
|
||||||
from .utils import (
|
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
|
||||||
delete_free_var_from_last_use,
|
|
||||||
find_idx_by_name,
|
|
||||||
get_node_shape,
|
|
||||||
is_non_compute_node_except_placeholder,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EstimateMemory(object):
|
class EstimateMemory(object):
|
||||||
@ -240,7 +235,7 @@ class EstimateMemory(object):
|
|||||||
elif node.op == "output":
|
elif node.op == "output":
|
||||||
continue
|
continue
|
||||||
# no change for non compute node
|
# no change for non compute node
|
||||||
elif is_non_compute_node_except_placeholder(node):
|
elif is_non_memory_node(node):
|
||||||
act_memory_peak_log.append(act_memory)
|
act_memory_peak_log.append(act_memory)
|
||||||
# node is a compute op
|
# node is a compute op
|
||||||
# calculate tmp, output node and delete node memory
|
# calculate tmp, output node and delete node memory
|
||||||
|
@ -118,16 +118,34 @@ class TraceFlow(object):
|
|||||||
|
|
||||||
def _assgin_single_node_flow(
|
def _assgin_single_node_flow(
|
||||||
self,
|
self,
|
||||||
arg_node,
|
arg_node: Node,
|
||||||
start_idx,
|
start_idx: int,
|
||||||
end_idx,
|
end_idx: int,
|
||||||
cur_node_dim,
|
cur_node_dim: int,
|
||||||
cur_node_compute,
|
cur_node_compute: Dict,
|
||||||
cur_node_source,
|
cur_node_source: Dict,
|
||||||
cur_node_fix_dim,
|
cur_node_fix_dim: List,
|
||||||
all_node_info,
|
all_node_info: Dict,
|
||||||
next_node_list,
|
next_node_list: List,
|
||||||
):
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Given the current node and one of its arg node,
|
||||||
|
this function finds out arg node's chunk dim and fix dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arg_node (Node): input node
|
||||||
|
start_idx (int): chunk region start
|
||||||
|
end_idx (int): chunk region end
|
||||||
|
cur_node_dim (int): current node chunk dim
|
||||||
|
cur_node_compute (Dict): current node compute dict
|
||||||
|
cur_node_source (Dict): current node source dict
|
||||||
|
cur_node_fix_dim (List): current node fix dim
|
||||||
|
all_node_info (Dict): all node chunk info in the chunk region
|
||||||
|
next_node_list (List)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if this node can be added to the flow, vice versa.
|
||||||
|
"""
|
||||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||||
# arg in chunk range or be inputs
|
# arg in chunk range or be inputs
|
||||||
if not (start_idx <= arg_idx < end_idx):
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
@ -142,6 +160,9 @@ class TraceFlow(object):
|
|||||||
arg_dim = None
|
arg_dim = None
|
||||||
else:
|
else:
|
||||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||||
|
# chunk dim should be None if shape size is 1
|
||||||
|
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||||
|
arg_dim = None
|
||||||
else:
|
else:
|
||||||
arg_dim = None
|
arg_dim = None
|
||||||
|
|
||||||
@ -184,7 +205,7 @@ class TraceFlow(object):
|
|||||||
|
|
||||||
# get all valid args
|
# get all valid args
|
||||||
arg_list = []
|
arg_list = []
|
||||||
for arg in cur_node.args:
|
for arg in cur_node.all_input_nodes:
|
||||||
if type(arg) != type(cur_node):
|
if type(arg) != type(cur_node):
|
||||||
continue
|
continue
|
||||||
if is_non_compute_node(arg):
|
if is_non_compute_node(arg):
|
||||||
|
@ -432,6 +432,38 @@ class TraceIndice(object):
|
|||||||
"""
|
"""
|
||||||
self._assign_all_indice(node, node_idx)
|
self._assign_all_indice(node, node_idx)
|
||||||
|
|
||||||
|
def _assign_cat_indice(self, node: Node, node_idx: int):
|
||||||
|
"""
|
||||||
|
Assign indice for cat op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
nodes_in = flat_list(node.args[0])
|
||||||
|
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||||
|
for n in nodes_in[1:]:
|
||||||
|
self._mark_computation_from_node(n, node)
|
||||||
|
cat_dim = node.kwargs["dim"]
|
||||||
|
self._del_dim(node_idx, cat_dim)
|
||||||
|
self._add_dim(node_idx, cat_dim)
|
||||||
|
|
||||||
|
def _assign_sum_indice(self, node: Node, node_idx: int):
|
||||||
|
"""
|
||||||
|
Assign indice for sum op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
nodes_in = flat_list(node.args[0])
|
||||||
|
self._add_dim(node_idx, 0)
|
||||||
|
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||||
|
for n in nodes_in[1:]:
|
||||||
|
self._mark_computation_from_node(n, node)
|
||||||
|
cat_dim = node.kwargs["dim"]
|
||||||
|
self._del_dim(node_idx, cat_dim)
|
||||||
|
|
||||||
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
||||||
"""
|
"""
|
||||||
Assign indice for getitem.
|
Assign indice for getitem.
|
||||||
@ -442,7 +474,16 @@ class TraceIndice(object):
|
|||||||
node_idx (int)
|
node_idx (int)
|
||||||
"""
|
"""
|
||||||
node_args = flat_list(node.args[1:])
|
node_args = flat_list(node.args[1:])
|
||||||
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
|
flag = False
|
||||||
|
for node_arg in node_args:
|
||||||
|
node_arg_str = str(node_arg)
|
||||||
|
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
if "slice" in node_arg_str:
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
if flag == False:
|
||||||
return
|
return
|
||||||
|
|
||||||
# node args should be like [Ellipsis, slice(start, step, end), None]
|
# node args should be like [Ellipsis, slice(start, step, end), None]
|
||||||
@ -461,8 +502,11 @@ class TraceIndice(object):
|
|||||||
shape_gap = len(node_shape) - len(node_args) + 1
|
shape_gap = len(node_shape) - len(node_args) + 1
|
||||||
origin_idx_count += shape_gap
|
origin_idx_count += shape_gap
|
||||||
new_idx_count += shape_gap
|
new_idx_count += shape_gap
|
||||||
# slice(None, None, None) means all indexes, doesn't support other slice
|
# slice(None, None, None) means all indexes
|
||||||
elif "slice(None, None, None)" == node_arg_str:
|
elif "slice" in node_arg_str:
|
||||||
|
if "slice(None, None, None)" != node_arg_str:
|
||||||
|
self._del_dim(node_idx, new_idx_count)
|
||||||
|
self._add_dim(node_idx, new_idx_count)
|
||||||
origin_idx_count += 1
|
origin_idx_count += 1
|
||||||
new_idx_count += 1
|
new_idx_count += 1
|
||||||
# None means a new dim
|
# None means a new dim
|
||||||
@ -565,7 +609,7 @@ class TraceIndice(object):
|
|||||||
self._assign_view_reshape_indice(node, idx)
|
self._assign_view_reshape_indice(node, idx)
|
||||||
elif "unsqueeze" in node.name:
|
elif "unsqueeze" in node.name:
|
||||||
self._assign_unsqueeze_indice(node, idx)
|
self._assign_unsqueeze_indice(node, idx)
|
||||||
elif any(i in node.name for i in ["to", "contiguous"]):
|
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
|
||||||
self._assgin_no_change_indice(node, idx)
|
self._assgin_no_change_indice(node, idx)
|
||||||
elif "new_ones" in node.name:
|
elif "new_ones" in node.name:
|
||||||
self._assign_ones_like_indice(node, idx)
|
self._assign_ones_like_indice(node, idx)
|
||||||
@ -574,6 +618,8 @@ class TraceIndice(object):
|
|||||||
elif node.op == "call_function":
|
elif node.op == "call_function":
|
||||||
if "linear" in node.name:
|
if "linear" in node.name:
|
||||||
self._assign_linear_indice(node, idx)
|
self._assign_linear_indice(node, idx)
|
||||||
|
elif "cat" in node.name:
|
||||||
|
self._assign_cat_indice(node, idx)
|
||||||
elif "matmul" in node.name:
|
elif "matmul" in node.name:
|
||||||
self._assign_matmul_indice(node, idx)
|
self._assign_matmul_indice(node, idx)
|
||||||
elif "softmax" in node.name:
|
elif "softmax" in node.name:
|
||||||
@ -586,6 +632,8 @@ class TraceIndice(object):
|
|||||||
self._assign_dropout_indice(node, idx)
|
self._assign_dropout_indice(node, idx)
|
||||||
elif "einsum" in node.name:
|
elif "einsum" in node.name:
|
||||||
self._assign_einsum_indice(node, idx)
|
self._assign_einsum_indice(node, idx)
|
||||||
|
elif "sum" in node.name:
|
||||||
|
self._assign_sum_indice(node, idx)
|
||||||
elif "layer_norm" in node.name:
|
elif "layer_norm" in node.name:
|
||||||
self._assign_layernorm_indice(node, idx)
|
self._assign_layernorm_indice(node, idx)
|
||||||
elif "getitem" in node.name:
|
elif "getitem" in node.name:
|
||||||
|
@ -3,10 +3,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
|
||||||
def flat_list(inputs):
|
def flat_list(inputs: Any) -> List:
|
||||||
"""
|
"""
|
||||||
flat a list by recursion
|
flat a list by recursion
|
||||||
"""
|
"""
|
||||||
|
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
|
||||||
|
return [inputs]
|
||||||
res = []
|
res = []
|
||||||
for i in inputs:
|
for i in inputs:
|
||||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||||
@ -16,7 +18,7 @@ def flat_list(inputs):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def find_first_tensor_arg(node):
|
def find_first_tensor_arg(node: Node) -> Node:
|
||||||
"""
|
"""
|
||||||
Find the first input tensor arg for a node
|
Find the first input tensor arg for a node
|
||||||
"""
|
"""
|
||||||
@ -26,7 +28,7 @@ def find_first_tensor_arg(node):
|
|||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
|
|
||||||
|
|
||||||
def is_non_compute_node(node):
|
def is_non_compute_node(node: Node) -> bool:
|
||||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||||
return True
|
return True
|
||||||
if "getitem" in node.name:
|
if "getitem" in node.name:
|
||||||
@ -34,16 +36,26 @@ def is_non_compute_node(node):
|
|||||||
for node_arg in node_args:
|
for node_arg in node_args:
|
||||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||||
return False
|
return False
|
||||||
|
if "slice" in str(node_arg):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_node_shape(node):
|
def get_node_shape(node: Node) -> List:
|
||||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||||
return node.meta["tensor_meta"].shape
|
return node.meta["tensor_meta"].shape
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_memory_node(node: Node) -> bool:
|
||||||
|
if "getitem" in node.name:
|
||||||
|
return True
|
||||||
|
if "output" in node.op:
|
||||||
|
return True
|
||||||
|
return is_non_compute_node(node)
|
||||||
|
|
||||||
|
|
||||||
def is_non_compute_node_except_placeholder(node):
|
def is_non_compute_node_except_placeholder(node):
|
||||||
if "placeholder" in node.op:
|
if "placeholder" in node.op:
|
||||||
return False
|
return False
|
||||||
|
@ -130,7 +130,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
graph.set_codegen(codegen)
|
graph.set_codegen(codegen)
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# assert we have inserted chunk
|
# assert we have inserted chunk
|
||||||
|
164
tests/test_autochunk/test_extramsa_codegen.py
Normal file
164
tests/test_autochunk/test_extramsa_codegen.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastfold.model.nn.evoformer import ExtraMSABlock
|
||||||
|
HAS_REPO = True
|
||||||
|
except:
|
||||||
|
HAS_REPO = False
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||||
|
|
||||||
|
|
||||||
|
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||||
|
# for memory test
|
||||||
|
# model = model.cuda()
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
# with torch.no_grad():
|
||||||
|
# node1 = node.clone()
|
||||||
|
# pair1 = pair.clone()
|
||||||
|
# node_mask1 = node_mask.clone()
|
||||||
|
# pair_mask1 = pair_mask.clone()
|
||||||
|
# gm(node1, pair1, node_mask1, pair_mask1)
|
||||||
|
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
# print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
|
||||||
|
|
||||||
|
# test forward
|
||||||
|
model = model.cuda()
|
||||||
|
with torch.no_grad():
|
||||||
|
non_fx_out = model(node, pair, node_mask, pair_mask)
|
||||||
|
fx_out = gm(node, pair, node_mask, pair_mask)
|
||||||
|
|
||||||
|
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||||
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||||
|
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||||
|
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||||
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||||
|
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openfold():
|
||||||
|
model = ExtraMSABlock(
|
||||||
|
c_m=256,
|
||||||
|
c_z=128,
|
||||||
|
c_hidden_msa_att=32,
|
||||||
|
c_hidden_opm=32,
|
||||||
|
c_hidden_mul=128,
|
||||||
|
c_hidden_pair_att=32,
|
||||||
|
no_heads_msa=8,
|
||||||
|
no_heads_pair=4,
|
||||||
|
transition_n=4,
|
||||||
|
msa_dropout=0.15,
|
||||||
|
pair_dropout=0.15,
|
||||||
|
inf=1e4,
|
||||||
|
eps=1e-4,
|
||||||
|
ckpt=False,
|
||||||
|
is_multimer=False,
|
||||||
|
).eval().cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _test_extramsa_codegen(rank, msa_len, pair_len, max_memory):
|
||||||
|
# launch colossalai
|
||||||
|
colossalai.launch(
|
||||||
|
config={},
|
||||||
|
rank=rank,
|
||||||
|
world_size=1,
|
||||||
|
host="localhost",
|
||||||
|
port=free_port(),
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# build model and input
|
||||||
|
model = _build_openfold()
|
||||||
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
|
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||||
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||||
|
|
||||||
|
# trace the meta graph and setup codegen
|
||||||
|
meta_graph = symbolic_trace(
|
||||||
|
model,
|
||||||
|
meta_args={
|
||||||
|
"m": node.to(torch.device("meta")),
|
||||||
|
"z": pair.to(torch.device("meta")),
|
||||||
|
"msa_mask": node_mask.to(torch.device("meta")),
|
||||||
|
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||||
|
},
|
||||||
|
concrete_args={
|
||||||
|
"chunk_size": None,
|
||||||
|
"_chunk_logits": 1024,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
interp = MetaInfoProp(meta_graph)
|
||||||
|
interp.propagate(
|
||||||
|
MetaTensor(node, fake_device="cuda:0"),
|
||||||
|
MetaTensor(pair, fake_device="cuda:0"),
|
||||||
|
MetaTensor(node_mask, fake_device="cuda:0"),
|
||||||
|
MetaTensor(pair_mask, fake_device="cuda:0"),
|
||||||
|
)
|
||||||
|
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
|
||||||
|
|
||||||
|
# trace and recompile
|
||||||
|
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||||
|
graph = ColoTracer().trace(
|
||||||
|
model,
|
||||||
|
meta_args={
|
||||||
|
"m": node.to(torch.device("meta")),
|
||||||
|
"z": pair.to(torch.device("meta")),
|
||||||
|
"msa_mask": node_mask.to(torch.device("meta")),
|
||||||
|
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||||
|
},
|
||||||
|
concrete_args={
|
||||||
|
"chunk_size": None,
|
||||||
|
"_chunk_logits": 1024,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
graph.set_codegen(codegen)
|
||||||
|
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# assert we have inserted chunk
|
||||||
|
code = graph.python_code("self").src
|
||||||
|
# print(code)
|
||||||
|
assert "chunk_result = None; chunk_size = None;" in code
|
||||||
|
|
||||||
|
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||||
|
reason="torch version is lower than 1.12.0",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
|
||||||
|
@pytest.mark.parametrize("msa_len", [32])
|
||||||
|
@pytest.mark.parametrize("pair_len", [64])
|
||||||
|
def test_extramsa_codegen(msa_len, pair_len, max_memory):
|
||||||
|
run_func = partial(
|
||||||
|
_test_extramsa_codegen,
|
||||||
|
msa_len=msa_len,
|
||||||
|
pair_len=pair_len,
|
||||||
|
max_memory=max_memory,
|
||||||
|
)
|
||||||
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_extramsa_codegen(0, 32, 64, None)
|
@ -73,7 +73,7 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
graph.set_codegen(codegen)
|
graph.set_codegen(codegen)
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph, ckpt_codegen=False)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# assert we have inserted chunk
|
# assert we have inserted chunk
|
||||||
|
@ -13,6 +13,7 @@ except:
|
|||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx import symbolic_trace
|
||||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
@ -28,10 +29,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||||||
|
|
||||||
if msa_len == 32 and pair_len == 64:
|
if msa_len == 32 and pair_len == 64:
|
||||||
if max_memory is None:
|
if max_memory is None:
|
||||||
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191),
|
target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191),
|
||||||
(161, 166), (198, 203), (6, 69)]
|
(161, 166), (198, 203), (7, 57)]
|
||||||
elif max_memory == 20:
|
elif max_memory == 20:
|
||||||
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)]
|
target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)]
|
||||||
elif max_memory == 25:
|
elif max_memory == 25:
|
||||||
target_regions = [(144, 154), (369, 370)]
|
target_regions = [(144, 154), (369, 370)]
|
||||||
elif max_memory == 30:
|
elif max_memory == 30:
|
||||||
@ -41,25 +42,10 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
assert len(found_regions) == len(
|
assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % (
|
||||||
target_regions), "len of found regions %s doesn't equal len of target regions %s" % (
|
str(found_regions),
|
||||||
str(found_regions),
|
str(target_regions),
|
||||||
str(target_regions),
|
)
|
||||||
)
|
|
||||||
for region in target_regions:
|
|
||||||
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % (
|
|
||||||
str(region),
|
|
||||||
msa_len,
|
|
||||||
pair_len,
|
|
||||||
str(max_memory),
|
|
||||||
)
|
|
||||||
for region in found_regions:
|
|
||||||
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
|
||||||
str(region),
|
|
||||||
msa_len,
|
|
||||||
pair_len,
|
|
||||||
str(max_memory),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
|
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
|
||||||
@ -78,11 +64,14 @@ def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
|
|||||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
|
||||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
meta_graph = symbolic_trace(model,
|
||||||
interp = MetaInfoProp(gm_prop)
|
meta_args={
|
||||||
|
"node": node.to(torch.device("meta")),
|
||||||
|
"pair": pair.to(torch.device("meta")),
|
||||||
|
}) # must use symbolic_trace
|
||||||
|
interp = MetaInfoProp(meta_graph)
|
||||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||||
|
codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
|
||||||
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
|
||||||
chunk_infos = codegen.chunk_infos
|
chunk_infos = codegen.chunk_infos
|
||||||
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
|
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user