mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[autochunk] refactor chunk memory estimation (#2762)
* refact memory code * dont log free var memory * add memory align * update chunk target * update setting for new memory * finish test * update tracer * update typo * update test
This commit is contained in:
@@ -24,29 +24,16 @@ class SelectChunk(object):
|
||||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
|
||||
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
|
||||
if self.stratge == "min_memory":
|
||||
best_region = self._select_min_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
|
||||
elif self.stratge == "fit_memory":
|
||||
best_region = self._select_fit_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
|
||||
else:
|
||||
raise RuntimeError()
|
||||
return best_region
|
||||
|
||||
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
||||
mem_peak):
|
||||
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
|
||||
# stop chunk if max memory satisfy memory limit
|
||||
if max(mem_peak) < self.max_memory:
|
||||
return None
|
||||
@@ -63,17 +50,14 @@ class SelectChunk(object):
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]))
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append({
|
||||
@@ -141,8 +125,7 @@ class SelectChunk(object):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
||||
mem_peak):
|
||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
|
Reference in New Issue
Block a user