[autochunk] refactor chunk memory estimation (#2762)

* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
This commit is contained in:
Xuanlei Zhao
2023-03-08 16:22:30 +08:00
committed by GitHub
parent b51bfec357
commit 2ca9728cbb
12 changed files with 294 additions and 422 deletions

View File

@@ -11,8 +11,8 @@ logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, gm) -> None:
self._node_list = list(gm.graph.nodes)
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
self._set_node_dict()
@@ -76,6 +76,8 @@ def flat_list(inputs: Any) -> List:
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(flat_list(i))
elif isinstance(i, dict):
res.extend(flat_list(list(i.keys())))
else:
res.append(i)
return res
@@ -135,13 +137,6 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
return is_non_compute_node_except_placeholder(node)
def find_node_idx(name: str, nodes_list: List) -> int:
for idx, node in enumerate(nodes_list):
if node.name == name:
return idx
raise RuntimeError("name %s not found in node list" % name)
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
for key, value in user_to_last_uses.items():
for n in value: