mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update test
This commit is contained in:
@@ -11,8 +11,8 @@ logger = get_dist_logger()
|
||||
|
||||
class NodeMgr(object):
|
||||
|
||||
def __init__(self, gm) -> None:
|
||||
self._node_list = list(gm.graph.nodes)
|
||||
def __init__(self, nodes_list: List[Node]) -> None:
|
||||
self._node_list = nodes_list
|
||||
self._node_dict = {}
|
||||
self._set_node_dict()
|
||||
|
||||
@@ -76,6 +76,8 @@ def flat_list(inputs: Any) -> List:
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
res.extend(flat_list(i))
|
||||
elif isinstance(i, dict):
|
||||
res.extend(flat_list(list(i.keys())))
|
||||
else:
|
||||
res.append(i)
|
||||
return res
|
||||
@@ -135,13 +137,6 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_node_idx(name: str, nodes_list: List) -> int:
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
|
Reference in New Issue
Block a user