mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update test
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -216,14 +216,13 @@ def _add_node_slice(
|
||||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(
|
||||
body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
):
|
||||
def emit_code_with_chunk(body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func: Callable,
|
||||
delete_unused_value_func: Callable,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
eval_mem: bool = False):
|
||||
"""
|
||||
Emit code with chunk according to chunk_infos.
|
||||
|
||||
@@ -260,6 +259,9 @@ def emit_code_with_chunk(
|
||||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
|
||||
if eval_mem:
|
||||
body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n")
|
||||
|
||||
while node_idx < len(node_list):
|
||||
node = node_list[node_idx]
|
||||
|
||||
@@ -289,10 +291,18 @@ def emit_code_with_chunk(
|
||||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
if eval_mem:
|
||||
body.append(
|
||||
" if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
if eval_mem:
|
||||
body.append(
|
||||
"print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
|
||||
% (node.name))
|
||||
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
@@ -312,8 +322,10 @@ if AUTOCHUNK_AVAILABLE:
|
||||
meta_graph,
|
||||
max_memory: int = None,
|
||||
print_mem: bool = False,
|
||||
print_progress: bool = False) -> None:
|
||||
print_progress: bool = False,
|
||||
eval_mem: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.eval_mem = eval_mem
|
||||
# find the chunk regions
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
@@ -511,14 +523,8 @@ if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(
|
||||
body,
|
||||
nodes,
|
||||
emit_node,
|
||||
delete_unused_values,
|
||||
self.search_chunk,
|
||||
self.chunk_infos,
|
||||
)
|
||||
emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
|
||||
self.eval_mem)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
|
Reference in New Issue
Block a user