[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -6,7 +6,7 @@ from torch.fx.node import Node, map_arg
from colossalai.fx.profiler import activation_size, parameter_size
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node
class EstimateMemory(object):
@@ -14,8 +14,8 @@ class EstimateMemory(object):
Estimate memory with chunk
"""
def __init__(self) -> None:
pass
def __init__(self, node_mgr: NodeMgr) -> None:
self.node_mgr = node_mgr
def _get_meta_node_size(self, x):
x = x.meta["tensor_meta"]
@@ -78,7 +78,7 @@ class EstimateMemory(object):
nodes_to_delete = []
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
chunk_input_users = chunk_input.users.keys()
chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users]
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
if chunk_input not in nodes_to_delete:
nodes_to_delete.append(chunk_input)
@@ -212,7 +212,7 @@ class EstimateMemory(object):
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
chunk_inputs_names = [j.name for i in chunk_inputs for j in i
] + [j.name for i in chunk_inputs_non_chunk for j in i]
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
@@ -221,7 +221,7 @@ class EstimateMemory(object):
if use_chunk and idx in chunk_starts:
chunk_within = True
chunk_region_idx = chunk_starts.index(idx)
act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2)
# determine chunk ratio for current node
if chunk_within: