[auto-chunk] support extramsa (#3) (#2504)

This commit is contained in:
oahzxl
2023-01-20 10:13:03 +08:00
committed by GitHub
parent 0f02b8c6e6
commit 72341e65f4
8 changed files with 283 additions and 54 deletions

View File

@@ -6,12 +6,7 @@ from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import (
delete_free_var_from_last_use,
find_idx_by_name,
get_node_shape,
is_non_compute_node_except_placeholder,
)
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
class EstimateMemory(object):
@@ -240,7 +235,7 @@ class EstimateMemory(object):
elif node.op == "output":
continue
# no change for non compute node
elif is_non_compute_node_except_placeholder(node):
elif is_non_memory_node(node):
act_memory_peak_log.append(act_memory)
# node is a compute op
# calculate tmp, output node and delete node memory