mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
code style
This commit is contained in:
@@ -103,7 +103,7 @@ def emit_code_with_chunk(
|
||||
nodes,
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
chunk_region_search: SearchChunk,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos,
|
||||
):
|
||||
"""Emit code with nested activation checkpoint
|
||||
@@ -133,7 +133,7 @@ def emit_code_with_chunk(
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||
|
||||
node_list = chunk_region_search.reorder_graph.reorder_node_list(node_list)
|
||||
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
@@ -167,7 +167,7 @@ def emit_code_with_chunk(
|
||||
)
|
||||
# ones like
|
||||
if "ones_like" in node.name:
|
||||
meta_node = chunk_region_search.trace_index.node_list[node_idx]
|
||||
meta_node = search_chunk.trace_index.node_list[node_idx]
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
|
||||
"chunk_dim"
|
||||
]
|
||||
@@ -220,10 +220,8 @@ if CODEGEN_AVAILABLE:
|
||||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.chunk_region_search = SearchChunk(
|
||||
meta_graph, max_memory, print_mem
|
||||
)
|
||||
self.chunk_infos = self.chunk_region_search.search_region()
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
|
||||
def _gen_python_code(
|
||||
self, nodes, root_module: str, namespace: _Namespace
|
||||
@@ -458,7 +456,7 @@ if CODEGEN_AVAILABLE:
|
||||
nodes,
|
||||
emit_node,
|
||||
delete_unused_values,
|
||||
self.chunk_region_search,
|
||||
self.search_chunk,
|
||||
self.chunk_infos,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user