close mem and code print

This commit is contained in:
oahzxl 2023-01-06 14:19:45 +08:00
parent 1a6d2a740b
commit 8a634af2f5
3 changed files with 10 additions and 7 deletions

View File

@ -214,13 +214,13 @@ def emit_code_with_chunk(
if CODEGEN_AVAILABLE: if CODEGEN_AVAILABLE:
class AutoChunkCodeGen(CodeGen): class AutoChunkCodeGen(CodeGen):
def __init__(self, meta_graph, max_memory=None): def __init__(self, meta_graph, max_memory=None, print_mem=False):
super().__init__() super().__init__()
self.meta_graph = meta_graph self.meta_graph = meta_graph
self.max_memory = max_memory self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes) self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions # find the chunk regions
self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory) self.chunk_region_search = ChunkRegionSearch(meta_graph, max_memory, print_mem)
self.chunk_infos = self.chunk_region_search.search_region() self.chunk_infos = self.chunk_region_search.search_region()
def _gen_python_code( def _gen_python_code(

View File

@ -6,8 +6,9 @@ from .utils import is_non_compute_node, is_non_compute_node_except_placeholder,
class ChunkRegionSearch(object): class ChunkRegionSearch(object):
def __init__(self, gm, max_memory=None) -> None: def __init__(self, gm, max_memory=None, print_mem=False) -> None:
self.gm = gm self.gm = gm
self.print_mem = print_mem
self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index() self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer) self.memory_estimator = MemoryEstimator(self.index_tracer)
@ -204,6 +205,8 @@ class ChunkRegionSearch(object):
) )
if self._stop_search(init_mem_peak, mem_peak): if self._stop_search(init_mem_peak, mem_peak):
break break
if self.print_mem:
self.print_mem = False
self.memory_estimator.estimate_chunk_inference_mem( self.memory_estimator.estimate_chunk_inference_mem(
self.index_tracer.node_list, chunk_infos, print_mem=True self.index_tracer.node_list, chunk_infos, print_mem=True
) )

View File

@ -64,7 +64,7 @@ def _build_autochunk(model, max_memory, node, pair):
) )
# set code_gen # set code_gen
codegen = AutoChunkCodeGen(gm_prop, max_memory) codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
graph.set_codegen(codegen) graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()