This commit is contained in:
oahzxl
2023-01-06 17:09:37 +08:00
parent c3d72f7db9
commit da4076846d
6 changed files with 19 additions and 20 deletions

View File

@@ -17,7 +17,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
import colossalai
from .chunk_region_search import ChunkRegionSearch
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
CODEGEN_AVAILABLE = True
@@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes,
emit_node_func,
delete_unused_value_func,
chunk_region_search: ChunkRegionSearch,
chunk_region_search: SearchChunk,
chunk_infos,
):
"""Emit code with nested activation checkpoint
@@ -220,7 +220,7 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = ChunkRegionSearch(
self.chunk_region_search = SearchChunk(
meta_graph, max_memory, print_mem
)
self.chunk_infos = self.chunk_region_search.search_region()