mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,7 +5,6 @@ from .utils import NodeMgr, is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_indice: TraceIndice,
|
||||
@@ -20,7 +19,7 @@ class SelectChunk(object):
|
||||
self.node_mgr = node_mgr
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
self.max_memory = max_memory # MB
|
||||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
@@ -57,16 +56,18 @@ class SelectChunk(object):
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
|
||||
cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append({
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
})
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
# no region found
|
||||
if len(regions_dict) == 0:
|
||||
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||
@@ -90,13 +91,15 @@ class SelectChunk(object):
|
||||
chunk_size *= 2
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
||||
cur_chunk_infos)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1])
|
||||
# search exact size
|
||||
chunk_info = chunk_region_dict["chunk_info"]
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
|
||||
chunk_infos)
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||
)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||
@@ -109,9 +112,10 @@ class SelectChunk(object):
|
||||
mid = int((left + right) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
||||
cur_chunk_infos)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1])
|
||||
if cur_chunk_max_mem >= self.max_memory:
|
||||
right = mid - gap
|
||||
else:
|
||||
@@ -139,8 +143,10 @@ class SelectChunk(object):
|
||||
return None
|
||||
|
||||
# get max possible chunk region
|
||||
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]))
|
||||
max_possible_chunk_region = (
|
||||
min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]),
|
||||
)
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict_list = []
|
||||
@@ -149,15 +155,17 @@ class SelectChunk(object):
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
regions_dict_list.append({
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
})
|
||||
regions_dict_list.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
|
||||
# select the min mem
|
||||
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
|
||||
@@ -175,7 +183,9 @@ class SelectChunk(object):
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
|
||||
(chunk_region_start < region[0] and chunk_region_end < region[0])):
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
Reference in New Issue
Block a user