code style

This commit is contained in:
oahzxl
2023-01-06 17:31:59 +08:00
parent a6cdbf9161
commit c3a2bf48b4
5 changed files with 46 additions and 36 deletions

View File

@@ -103,7 +103,7 @@ def emit_code_with_chunk(
nodes,
emit_node_func,
delete_unused_value_func,
chunk_region_search: SearchChunk,
search_chunk: SearchChunk,
chunk_infos,
):
"""Emit code with nested activation checkpoint
@@ -133,7 +133,7 @@ def emit_code_with_chunk(
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list)
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
node_idx = 0
region_idx = 0
within_chunk_region = False
@@ -167,7 +167,7 @@ def emit_code_with_chunk(
)
# ones like
if "ones_like" in node.name:
meta_node = chunk_region_search.trace_index.node_list[node_idx]
meta_node = search_chunk.trace_index.node_list[node_idx]
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
"chunk_dim"
]
@@ -220,10 +220,8 @@ if CODEGEN_AVAILABLE:
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.chunk_region_search = SearchChunk(
meta_graph, max_memory, print_mem
)
self.chunk_infos = self.chunk_region_search.search_region()
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.chunk_infos = self.search_chunk.search_region()
def _gen_python_code(
self, nodes, root_module: str, namespace: _Namespace
@@ -458,7 +456,7 @@ if CODEGEN_AVAILABLE:
nodes,
emit_node,
delete_unused_values,
self.chunk_region_search,
self.search_chunk,
self.chunk_infos,
)