mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[autochunk] support diffusion for autochunk (#2621)
* add alphafold benchmark * renae alphafold test * rename tests * rename diffuser * renme * rename * update transformer * update benchmark * update benchmark * update bench memory * update transformer benchmark * rename * support diffuser * support unet metainfo prop * fix bug and simplify code * update linear and support some op * optimize max region search, support conv * update unet test * support some op * support groupnorm and interpolate * update flow search * add fix dim in node flow * fix utils * rename * support diffusion * update diffuser * update chunk search * optimize imports * import * finish autochunk
This commit is contained in:
@@ -8,14 +8,7 @@ from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
NodeMgr,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
get_logger,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
@@ -75,8 +68,8 @@ class SearchChunk(object):
|
||||
max_chunk_region_list = []
|
||||
while True:
|
||||
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
|
||||
cur_node_idx = max_chunk_region[1]
|
||||
if cur_node_idx == len(active_nodes) - 1:
|
||||
cur_node_idx = max_chunk_region[1] + 1
|
||||
if cur_node_idx >= len(active_nodes) - 1:
|
||||
break
|
||||
max_chunk_region_list.append(max_chunk_region)
|
||||
|
||||
@@ -135,6 +128,7 @@ class SearchChunk(object):
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
# normal search
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
@@ -144,7 +138,6 @@ class SearchChunk(object):
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
@@ -155,6 +148,22 @@ class SearchChunk(object):
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
# if normal search fails, use approximate search
|
||||
if (chunk_region_end - chunk_region_start) > 250:
|
||||
window_size = 100
|
||||
# search min for start
|
||||
min_num = 1e3
|
||||
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_start = i
|
||||
# search min for end
|
||||
min_num = 1e3
|
||||
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
|
||||
if active_node_num[i] < min_num:
|
||||
min_num = active_node_num[i]
|
||||
chunk_region_end = i
|
||||
|
||||
# avoid chunk regions overlap
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
@@ -271,12 +280,6 @@ class SearchChunk(object):
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
sorted_init_mem_peak = sorted(init_mem_peak)
|
||||
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def search_region(self) -> Dict:
|
||||
"""
|
||||
Search all chunk regions:
|
||||
@@ -291,11 +294,7 @@ class SearchChunk(object):
|
||||
get_logger().info("AutoChunk start searching chunk regions")
|
||||
|
||||
chunk_infos = []
|
||||
(
|
||||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
||||
init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
@@ -304,18 +303,13 @@ class SearchChunk(object):
|
||||
break
|
||||
chunk_infos.append(chunk_info)
|
||||
|
||||
(
|
||||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
|
||||
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.node_mgr.get_node_list(), chunk_infos)
|
||||
|
||||
if self.print_progress:
|
||||
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
|
||||
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
|
||||
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
|
||||
|
Reference in New Issue
Block a user