[autochunk] support transformer (#2526)

This commit is contained in:
oahzxl
2023-01-31 16:00:06 +08:00
committed by GitHub
parent 6e0faa70e0
commit 63199c6687
20 changed files with 1214 additions and 1084 deletions

View File

@@ -5,6 +5,7 @@ from .utils import is_non_compute_node
class SelectChunk(object):
def __init__(
self,
trace_indice: TraceIndice,
@@ -17,13 +18,11 @@ class SelectChunk(object):
self.reorder_graph = reorder_graph
if max_memory is not None:
self.stratge = "fit_memory"
self.max_memory = max_memory # MB
self.max_memory = max_memory # MB
else:
self.stratge = "min_memory"
def _select_best_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
if self.stratge == "min_memory":
best_region = self._select_min_memory_chunk_region(
possible_chunk_regions,
@@ -44,9 +43,8 @@ class SelectChunk(object):
raise RuntimeError()
return best_region
def _select_fit_memory_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
# stop chunk if max memory satisfy memory limit
if max(mem_peak) < self.max_memory:
return None
@@ -63,33 +61,26 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0:
return None
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.trace_indice.node_list, cur_region
)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
if cur_chunk_region_max_peak < self.max_memory:
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
regions_dict.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
# no region found
if len(regions_dict) == 0:
raise RuntimeError("Search failed. Try a larger memory threshold.")
@@ -113,20 +104,13 @@ class SelectChunk(object):
chunk_size *= 2
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
+ 1
]
)
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
chunk_infos)
return chunk_info
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
@@ -139,12 +123,9 @@ class SelectChunk(object):
mid = int((left + right) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
chunk_region_dict["reorder_node_list"], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
)
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
cur_chunk_infos)[0]
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
if cur_chunk_max_mem >= self.max_memory:
right = mid - gap
else:
@@ -153,14 +134,13 @@ class SelectChunk(object):
def _get_compute_node_num(self, start, end):
count = 0
for i in self.trace_indice.node_list[start : end + 1]:
for i in self.trace_indice.node_list[start:end + 1]:
if not is_non_compute_node(i):
count += 1
return count
def _select_min_memory_chunk_region(
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
):
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
mem_peak):
# remove illegal regions
illegal_regions = []
for i in possible_chunk_regions:
@@ -173,37 +153,31 @@ class SelectChunk(object):
if len(possible_chunk_regions) == 0:
return None
# get max possible chunk region
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
max([i["region"][1] for i in possible_chunk_regions]))
# get mem for chunk region
regions_dict = []
regions_dict_list = []
for region in possible_chunk_regions:
cur_region = region.copy()
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
self.trace_indice.node_list, cur_region
)
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
cur_chunk_infos = chunk_infos + [cur_region]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
)[0]
cur_chunk_region_peak = cur_mem_peak[
max_chunk_region[0] : max_chunk_region[1] + 1
]
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
regions_dict.append(
{
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
}
)
regions_dict_list.append({
"chunk_info": region,
"chunk_max_mem": cur_chunk_region_max_peak,
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list,
})
# select the min mem
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
best_region = regions_dict[best_region_idx]["chunk_info"]
best_region = regions_dict_list[best_region_idx]["chunk_info"]
if best_region is not None:
best_region["chunk_size"] = 1
return best_region
@@ -216,9 +190,7 @@ class SelectChunk(object):
return False
for i in chunk_infos:
region = i["region"]
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
(chunk_region_start < region[0] and chunk_region_end < region[0])):
return False
return True