[autochunk] refactor chunk memory estimation (#2762)

* refact memory code

* dont log free var memory

* add memory align

* update chunk target

* update setting for new memory

* finish test

* update tracer

* update typo

* update test
This commit is contained in:
Xuanlei Zhao
2023-03-08 16:22:30 +08:00
committed by GitHub
parent b51bfec357
commit 2ca9728cbb
12 changed files with 294 additions and 422 deletions

View File

@@ -24,29 +24,16 @@ class SelectChunk(object):
else:
self.stratge = "min_memory"
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region(
possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
elif self.stratge == "fit_memory":
best_region = self._select_fit_memory_chunk_region(
possible_chunk_regions,
chunk_infos,
peak_node,
max_chunk_region,
mem_peak,
)
best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
else:
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
@@ -63,17 +50,14 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0:
return None
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append({
@@ -141,8 +125,7 @@ class SelectChunk(object):
count += 1
return count
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions: