[autochunk] support diffusion for autochunk (#2621)

* add alphafold benchmark

* renae alphafold test

* rename tests

* rename diffuser

* renme

* rename

* update transformer

* update benchmark

* update benchmark

* update bench memory

* update transformer benchmark

* rename

* support diffuser

* support unet metainfo prop

* fix bug and simplify code

* update linear and support some op

* optimize max region search, support conv

* update unet test

* support some op

* support groupnorm and interpolate

* update flow search

* add fix dim in node flow

* fix utils

* rename

* support diffusion

* update diffuser

* update chunk search

* optimize imports

* import

* finish autochunk
This commit is contained in:
oahzxl
2023-02-07 16:32:45 +08:00
committed by GitHub
parent 291b051171
commit 6ba8364881
6 changed files with 216 additions and 166 deletions

View File

@@ -8,14 +8,7 @@ from .reorder_graph import ReorderGraph
from .select_chunk import SelectChunk
from .trace_flow import TraceFlow
from .trace_indice import TraceIndice
from .utils import (
NodeMgr,
find_chunk_compute_input_and_output_nodes,
get_logger,
get_node_shape,
is_non_compute_node,
is_non_compute_node_except_placeholder,
)
from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
class SearchChunk(object):
@@ -75,8 +68,8 @@ class SearchChunk(object):
max_chunk_region_list = []
while True:
max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx)
cur_node_idx = max_chunk_region[1]
if cur_node_idx == len(active_nodes) - 1:
cur_node_idx = max_chunk_region[1] + 1
if cur_node_idx >= len(active_nodes) - 1:
break
max_chunk_region_list.append(max_chunk_region)
@@ -135,6 +128,7 @@ class SearchChunk(object):
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)
# normal search
# from peak_node to free_var
inside_flag = False
chunk_region_start = free_var_num
@@ -144,7 +138,6 @@ class SearchChunk(object):
if inside_flag and active_node_num[i] > threshold:
chunk_region_start = i + 1
break
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
@@ -155,6 +148,22 @@ class SearchChunk(object):
chunk_region_end = i
break
# if normal search fails, use approximate search
if (chunk_region_end - chunk_region_start) > 250:
window_size = 100
# search min for start
min_num = 1e3
for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_start = i
# search min for end
min_num = 1e3
for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1):
if active_node_num[i] < min_num:
min_num = active_node_num[i]
chunk_region_end = i
# avoid chunk regions overlap
if chunk_regions is not None:
for i in chunk_regions:
@@ -271,12 +280,6 @@ class SearchChunk(object):
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
return best_chunk_region
def _stop_search(self, init_mem_peak, mem_peak):
sorted_init_mem_peak = sorted(init_mem_peak)
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
return True
return False
def search_region(self) -> Dict:
"""
Search all chunk regions:
@@ -291,11 +294,7 @@ class SearchChunk(object):
get_logger().info("AutoChunk start searching chunk regions")
chunk_infos = []
(
init_mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
mem_peak = init_mem_peak
while True:
@@ -304,18 +303,13 @@ class SearchChunk(object):
break
chunk_infos.append(chunk_info)
(
mem_peak,
_,
active_node,
) = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(), chunk_infos)
mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
self.node_mgr.get_node_list(), chunk_infos)
if self.print_progress:
get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
(len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
if self._stop_search(init_mem_peak, mem_peak):
break
if self.print_mem:
self.print_mem = False
self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),