mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import is_non_compute_node
|
||||
from .utils import NodeMgr, is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
@@ -11,11 +11,13 @@ class SelectChunk(object):
|
||||
trace_indice: TraceIndice,
|
||||
estimate_memory: EstimateMemory,
|
||||
reorder_graph: ReorderGraph,
|
||||
node_mgr: NodeMgr,
|
||||
max_memory=None,
|
||||
):
|
||||
self.trace_indice = trace_indice
|
||||
self.estimate_memory = estimate_memory
|
||||
self.reorder_graph = reorder_graph
|
||||
self.node_mgr = node_mgr
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
@@ -68,7 +70,7 @@ class SelectChunk(object):
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
@@ -134,7 +136,7 @@ class SelectChunk(object):
|
||||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.trace_indice.node_list[start:end + 1]:
|
||||
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
|
||||
if not is_non_compute_node(i):
|
||||
count += 1
|
||||
return count
|
||||
@@ -161,7 +163,7 @@ class SelectChunk(object):
|
||||
regions_dict_list = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
|
Reference in New Issue
Block a user