refactor structure

This commit is contained in:
oahzxl 2023-01-06 11:07:57 +08:00
parent 71e72c4890
commit 27ab524096
19 changed files with 29 additions and 34 deletions

View File

@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
def emit_code_with_chunk( def emit_code_with_chunk(
body, body,
ckpt_func,
nodes, nodes,
emit_node_func, emit_node_func,
delete_unused_value_func, delete_unused_value_func,
meta_nodes, chunk_region_search,
meta_graph, chunk_infos
max_memory=None,
): ):
"""Emit code with nested activation checkpoint """Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use When we detect some of the node.activation_checkpoint is a List, we will use
@ -1988,23 +1986,19 @@ def emit_code_with_chunk(
""" """
node_list = list(nodes) node_list = list(nodes)
# find the chunk regions chunk_regions = [i["region"] for i in chunk_infos]
chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
chunk_search = chunk_region_search.search_region()
chunk_regions = [i["region"] for i in chunk_search]
chunk_starts = [i[0] for i in chunk_regions] chunk_starts = [i[0] for i in chunk_regions]
chunk_ends = [i[1] for i in chunk_regions] chunk_ends = [i[1] for i in chunk_regions]
chunk_inputs = [i["inputs"] for i in chunk_search] chunk_inputs = [i["inputs"] for i in chunk_infos]
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search] chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
j.name for i in chunk_inputs_non_chunk for j in i j.name for i in chunk_inputs_non_chunk for j in i
] ]
chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
node_idx = 0 node_idx = 0
@ -2022,7 +2016,7 @@ def emit_code_with_chunk(
chunk_inputs[region_idx], chunk_inputs[region_idx],
chunk_outputs[region_idx], chunk_outputs[region_idx],
chunk_outputs_dim[region_idx], chunk_outputs_dim[region_idx],
chunk_search[region_idx]["chunk_size"], chunk_infos[region_idx]["chunk_size"],
) )
) )
@ -2041,14 +2035,14 @@ def emit_code_with_chunk(
# ones like # ones like
if "ones_like" in node.name: if "ones_like" in node.name:
meta_node = chunk_region_search.index_tracer.node_list[node_idx] meta_node = chunk_region_search.index_tracer.node_list[node_idx]
chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][ chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
"chunk_dim" "chunk_dim"
] ]
if _get_node_shape(meta_node)[chunk_dim] != 1: if _get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0] source_node = meta_node.args[0].args[0]
if ( if (
source_node not in chunk_search[region_idx]["node_chunk_dim"] source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_search[region_idx]["node_chunk_dim"][source_node][ or chunk_infos[region_idx]["node_chunk_dim"][source_node][
"chunk_dim" "chunk_dim"
] ]
is None is None
@ -2060,7 +2054,7 @@ def emit_code_with_chunk(
body[-1], node.args[0].name, node.args[0].name + chunk_slice body[-1], node.args[0].name, node.args[0].name + chunk_slice
) )
body[-1] = _replace_reshape_size( body[-1] = _replace_reshape_size(
body[-1], node.name, chunk_search[region_idx]["reshape_size"] body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
) )
body[-1] = " " + body[-1] body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names) delete_unused_value_func(node, body, chunk_inputs_names)
@ -2092,6 +2086,9 @@ if CODEGEN_AVAILABLE:
self.meta_graph = meta_graph self.meta_graph = meta_graph
self.max_memory = max_memory self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes) self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code( def _gen_python_code(
self, nodes, root_module: str, namespace: _Namespace self, nodes, root_module: str, namespace: _Namespace
@ -2323,13 +2320,11 @@ if CODEGEN_AVAILABLE:
# will use nested type of activation checkpoint codegen # will use nested type of activation checkpoint codegen
emit_code_with_chunk( emit_code_with_chunk(
body, body,
ckpt_func,
nodes, nodes,
emit_node, emit_node,
delete_unused_values, delete_unused_values,
self.meta_node, self.chunk_region_search,
self.meta_graph, self.chunk_infos
self.max_memory,
) )
if len(body) == 0: if len(body) == 0:

View File

@ -3,13 +3,13 @@ import time
import torch import torch
import torch.fx import torch.fx
from chunk_codegen import ChunkCodeGen from autochunk.chunk_codegen import ChunkCodeGen
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from evoformer.evoformer import evoformer_base from autochunk.evoformer.evoformer import evoformer_base
from openfold.evoformer import EvoformerBlock from autochunk.openfold.evoformer import EvoformerBlock
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None): def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
@ -94,23 +94,23 @@ def _build_openfold():
def benchmark_evoformer(): def benchmark_evoformer():
# init data and model # init data and model
msa_len = 256 msa_len = 256
pair_len = 2048 pair_len = 1024
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
model = evoformer_base().cuda() model = evoformer_base().cuda()
# build autochunk model # build autochunk model
max_memory = 10000 # MB fit memory mode # max_memory = 10000 # MB fit memory mode
# max_memory = None # min memory mode max_memory = None # min memory mode
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair) autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
# build openfold # build openfold
chunk_size = 64 chunk_size = 64
openfold = _build_openfold() # openfold = _build_openfold()
# benchmark # benchmark
_benchmark_evoformer(model, node, pair, "base") # _benchmark_evoformer(model, node, pair, "base")
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size) # _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
_benchmark_evoformer(autochunk, node, pair, "autochunk") _benchmark_evoformer(autochunk, node, pair, "autochunk")

View File

@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
from evoformer.evoformer import evoformer_base from autochunk.evoformer.evoformer import evoformer_base
from chunk_codegen import ChunkCodeGen from autochunk.chunk_codegen import ChunkCodeGen
with_codegen = True with_codegen = True