mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg
|
||||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
|
||||
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
|
||||
|
||||
|
||||
class EstimateMemory(object):
|
||||
@@ -14,8 +14,8 @@ class EstimateMemory(object):
|
||||
Estimate memory with chunk
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||
self.node_mgr = node_mgr
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
x = x.meta["tensor_meta"]
|
||||
@@ -78,7 +78,7 @@ class EstimateMemory(object):
|
||||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
|
||||
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
@@ -212,7 +212,7 @@ class EstimateMemory(object):
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
|
||||
] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
|
||||
|
||||
@@ -221,7 +221,7 @@ class EstimateMemory(object):
|
||||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
|
Reference in New Issue
Block a user