mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
refactor structure
This commit is contained in:
parent
71e72c4890
commit
27ab524096
@ -1967,13 +1967,11 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
|
|||||||
|
|
||||||
def emit_code_with_chunk(
|
def emit_code_with_chunk(
|
||||||
body,
|
body,
|
||||||
ckpt_func,
|
|
||||||
nodes,
|
nodes,
|
||||||
emit_node_func,
|
emit_node_func,
|
||||||
delete_unused_value_func,
|
delete_unused_value_func,
|
||||||
meta_nodes,
|
chunk_region_search,
|
||||||
meta_graph,
|
chunk_infos
|
||||||
max_memory=None,
|
|
||||||
):
|
):
|
||||||
"""Emit code with nested activation checkpoint
|
"""Emit code with nested activation checkpoint
|
||||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||||
@ -1988,23 +1986,19 @@ def emit_code_with_chunk(
|
|||||||
"""
|
"""
|
||||||
node_list = list(nodes)
|
node_list = list(nodes)
|
||||||
|
|
||||||
# find the chunk regions
|
chunk_regions = [i["region"] for i in chunk_infos]
|
||||||
chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
|
|
||||||
chunk_search = chunk_region_search.search_region()
|
|
||||||
|
|
||||||
chunk_regions = [i["region"] for i in chunk_search]
|
|
||||||
chunk_starts = [i[0] for i in chunk_regions]
|
chunk_starts = [i[0] for i in chunk_regions]
|
||||||
chunk_ends = [i[1] for i in chunk_regions]
|
chunk_ends = [i[1] for i in chunk_regions]
|
||||||
|
|
||||||
chunk_inputs = [i["inputs"] for i in chunk_search]
|
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_search]
|
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_search]
|
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
|
||||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||||
j.name for i in chunk_inputs_non_chunk for j in i
|
j.name for i in chunk_inputs_non_chunk for j in i
|
||||||
]
|
]
|
||||||
|
|
||||||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||||
|
|
||||||
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
|
node_list = chunk_region_search.index_tracer.reorder_node_list(node_list)
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
@ -2022,7 +2016,7 @@ def emit_code_with_chunk(
|
|||||||
chunk_inputs[region_idx],
|
chunk_inputs[region_idx],
|
||||||
chunk_outputs[region_idx],
|
chunk_outputs[region_idx],
|
||||||
chunk_outputs_dim[region_idx],
|
chunk_outputs_dim[region_idx],
|
||||||
chunk_search[region_idx]["chunk_size"],
|
chunk_infos[region_idx]["chunk_size"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2041,14 +2035,14 @@ def emit_code_with_chunk(
|
|||||||
# ones like
|
# ones like
|
||||||
if "ones_like" in node.name:
|
if "ones_like" in node.name:
|
||||||
meta_node = chunk_region_search.index_tracer.node_list[node_idx]
|
meta_node = chunk_region_search.index_tracer.node_list[node_idx]
|
||||||
chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
||||||
"chunk_dim"
|
"chunk_dim"
|
||||||
]
|
]
|
||||||
if _get_node_shape(meta_node)[chunk_dim] != 1:
|
if _get_node_shape(meta_node)[chunk_dim] != 1:
|
||||||
source_node = meta_node.args[0].args[0]
|
source_node = meta_node.args[0].args[0]
|
||||||
if (
|
if (
|
||||||
source_node not in chunk_search[region_idx]["node_chunk_dim"]
|
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||||
or chunk_search[region_idx]["node_chunk_dim"][source_node][
|
or chunk_infos[region_idx]["node_chunk_dim"][source_node][
|
||||||
"chunk_dim"
|
"chunk_dim"
|
||||||
]
|
]
|
||||||
is None
|
is None
|
||||||
@ -2060,7 +2054,7 @@ def emit_code_with_chunk(
|
|||||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||||
)
|
)
|
||||||
body[-1] = _replace_reshape_size(
|
body[-1] = _replace_reshape_size(
|
||||||
body[-1], node.name, chunk_search[region_idx]["reshape_size"]
|
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||||
)
|
)
|
||||||
body[-1] = " " + body[-1]
|
body[-1] = " " + body[-1]
|
||||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
@ -2092,6 +2086,9 @@ if CODEGEN_AVAILABLE:
|
|||||||
self.meta_graph = meta_graph
|
self.meta_graph = meta_graph
|
||||||
self.max_memory = max_memory
|
self.max_memory = max_memory
|
||||||
self.meta_node = list(meta_graph.graph.nodes)
|
self.meta_node = list(meta_graph.graph.nodes)
|
||||||
|
# find the chunk regions
|
||||||
|
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory)
|
||||||
|
self.chunk_infos = self.chunk_region_search.search_region()
|
||||||
|
|
||||||
def _gen_python_code(
|
def _gen_python_code(
|
||||||
self, nodes, root_module: str, namespace: _Namespace
|
self, nodes, root_module: str, namespace: _Namespace
|
||||||
@ -2323,13 +2320,11 @@ if CODEGEN_AVAILABLE:
|
|||||||
# will use nested type of activation checkpoint codegen
|
# will use nested type of activation checkpoint codegen
|
||||||
emit_code_with_chunk(
|
emit_code_with_chunk(
|
||||||
body,
|
body,
|
||||||
ckpt_func,
|
|
||||||
nodes,
|
nodes,
|
||||||
emit_node,
|
emit_node,
|
||||||
delete_unused_values,
|
delete_unused_values,
|
||||||
self.meta_node,
|
self.chunk_region_search,
|
||||||
self.meta_graph,
|
self.chunk_infos
|
||||||
self.max_memory,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(body) == 0:
|
if len(body) == 0:
|
@ -3,13 +3,13 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
|
|
||||||
from chunk_codegen import ChunkCodeGen
|
from autochunk.chunk_codegen import ChunkCodeGen
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
from evoformer.evoformer import evoformer_base
|
from autochunk.evoformer.evoformer import evoformer_base
|
||||||
from openfold.evoformer import EvoformerBlock
|
from autochunk.openfold.evoformer import EvoformerBlock
|
||||||
|
|
||||||
|
|
||||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||||
@ -94,23 +94,23 @@ def _build_openfold():
|
|||||||
def benchmark_evoformer():
|
def benchmark_evoformer():
|
||||||
# init data and model
|
# init data and model
|
||||||
msa_len = 256
|
msa_len = 256
|
||||||
pair_len = 2048
|
pair_len = 1024
|
||||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
model = evoformer_base().cuda()
|
model = evoformer_base().cuda()
|
||||||
|
|
||||||
# build autochunk model
|
# build autochunk model
|
||||||
max_memory = 10000 # MB fit memory mode
|
# max_memory = 10000 # MB fit memory mode
|
||||||
# max_memory = None # min memory mode
|
max_memory = None # min memory mode
|
||||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||||
|
|
||||||
# build openfold
|
# build openfold
|
||||||
chunk_size = 64
|
chunk_size = 64
|
||||||
openfold = _build_openfold()
|
# openfold = _build_openfold()
|
||||||
|
|
||||||
# benchmark
|
# benchmark
|
||||||
_benchmark_evoformer(model, node, pair, "base")
|
# _benchmark_evoformer(model, node, pair, "base")
|
||||||
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
# _benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,8 +12,8 @@ from colossalai.core import global_context as gpc
|
|||||||
from colossalai.fx.graph_module import ColoGraphModule
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
from evoformer.evoformer import evoformer_base
|
from autochunk.evoformer.evoformer import evoformer_base
|
||||||
from chunk_codegen import ChunkCodeGen
|
from autochunk.chunk_codegen import ChunkCodeGen
|
||||||
with_codegen = True
|
with_codegen = True
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user