mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-14 06:05:26 +00:00
update min memory stratege, reduce mem usage by 30%
This commit is contained in:
parent
9c5e028a62
commit
55cb713f36
@ -1433,7 +1433,11 @@ class ChunkSelector(object):
|
|||||||
):
|
):
|
||||||
if self.stratge == "min_memory":
|
if self.stratge == "min_memory":
|
||||||
best_region = self._select_min_memory_chunk_region(
|
best_region = self._select_min_memory_chunk_region(
|
||||||
possible_chunk_regions, chunk_infos
|
possible_chunk_regions,
|
||||||
|
chunk_infos,
|
||||||
|
peak_node,
|
||||||
|
max_chunk_region,
|
||||||
|
mem_peak,
|
||||||
)
|
)
|
||||||
elif self.stratge == "fit_memory":
|
elif self.stratge == "fit_memory":
|
||||||
best_region = self._select_fit_memory_chunk_region(
|
best_region = self._select_fit_memory_chunk_region(
|
||||||
@ -1561,19 +1565,52 @@ class ChunkSelector(object):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
def _select_min_memory_chunk_region(
|
||||||
max_region_range = 0
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
best_region = None
|
):
|
||||||
while len(possible_chunk_regions) > 0:
|
# remove illegal regions
|
||||||
for i in possible_chunk_regions:
|
illegal_regions = []
|
||||||
if i["region"][1] - i["region"][0] > max_region_range:
|
for i in possible_chunk_regions:
|
||||||
best_region = i
|
if not self._is_legal_region(i, chunk_infos):
|
||||||
max_region_range = i["region"][1] - i["region"][0]
|
illegal_regions.append(i)
|
||||||
if self._is_legal_region(best_region, chunk_infos):
|
for i in illegal_regions:
|
||||||
break
|
if i in possible_chunk_regions:
|
||||||
possible_chunk_regions.remove(i)
|
possible_chunk_regions.remove(i)
|
||||||
max_region_range = 0
|
|
||||||
best_region = None
|
if len(possible_chunk_regions) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get mem for chunk region
|
||||||
|
regions_dict = []
|
||||||
|
for region in possible_chunk_regions:
|
||||||
|
cur_region = region.copy()
|
||||||
|
cur_node_list, cur_region = self.index_tracer.tmp_reorder(
|
||||||
|
self.index_tracer.node_list, cur_region
|
||||||
|
)
|
||||||
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
|
cur_node_list, cur_chunk_infos
|
||||||
|
)[0]
|
||||||
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||||
|
]
|
||||||
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||||
|
regions_dict.append(
|
||||||
|
{
|
||||||
|
"chunk_info": region,
|
||||||
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||||
|
"chunk_len": self._get_compute_node_num(
|
||||||
|
region["region"][0], region["region"][1]
|
||||||
|
),
|
||||||
|
"reorder_chunk_info": cur_region,
|
||||||
|
"reorder_node_list": cur_node_list,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# select the min mem
|
||||||
|
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
|
||||||
|
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
||||||
|
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||||
if best_region is not None:
|
if best_region is not None:
|
||||||
best_region["chunk_size"] = 1
|
best_region["chunk_size"] = 1
|
||||||
return best_region
|
return best_region
|
||||||
|
Loading…
Reference in New Issue
Block a user