mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-03 04:39:43 +00:00
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
199 lines
8.6 KiB
Python
199 lines
8.6 KiB
Python
from .estimate_memory import EstimateMemory
|
|
from .reorder_graph import ReorderGraph
|
|
from .trace_indice import TraceIndice
|
|
from .utils import NodeMgr, is_non_compute_node
|
|
|
|
|
|
class SelectChunk(object):
|
|
|
|
def __init__(
|
|
self,
|
|
trace_indice: TraceIndice,
|
|
estimate_memory: EstimateMemory,
|
|
reorder_graph: ReorderGraph,
|
|
node_mgr: NodeMgr,
|
|
max_memory=None,
|
|
):
|
|
self.trace_indice = trace_indice
|
|
self.estimate_memory = estimate_memory
|
|
self.reorder_graph = reorder_graph
|
|
self.node_mgr = node_mgr
|
|
if max_memory is not None:
|
|
self.stratge = "fit_memory"
|
|
self.max_memory = max_memory # MB
|
|
else:
|
|
self.stratge = "min_memory"
|
|
|
|
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
|
|
if self.stratge == "min_memory":
|
|
best_region = self._select_min_memory_chunk_region(
|
|
possible_chunk_regions,
|
|
chunk_infos,
|
|
peak_node,
|
|
max_chunk_region,
|
|
mem_peak,
|
|
)
|
|
elif self.stratge == "fit_memory":
|
|
best_region = self._select_fit_memory_chunk_region(
|
|
possible_chunk_regions,
|
|
chunk_infos,
|
|
peak_node,
|
|
max_chunk_region,
|
|
mem_peak,
|
|
)
|
|
else:
|
|
raise RuntimeError()
|
|
return best_region
|
|
|
|
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
|
mem_peak):
|
|
# stop chunk if max memory satisfy memory limit
|
|
if max(mem_peak) < self.max_memory:
|
|
return None
|
|
|
|
# remove illegal regions
|
|
illegal_regions = []
|
|
for i in possible_chunk_regions:
|
|
if not self._is_legal_region(i, chunk_infos):
|
|
illegal_regions.append(i)
|
|
for i in illegal_regions:
|
|
if i in possible_chunk_regions:
|
|
possible_chunk_regions.remove(i)
|
|
|
|
if len(possible_chunk_regions) == 0:
|
|
return None
|
|
|
|
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
|
max([i["region"][1] for i in possible_chunk_regions]))
|
|
|
|
# get mem for chunk region
|
|
regions_dict = []
|
|
for region in possible_chunk_regions:
|
|
cur_region = region.copy()
|
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
|
cur_chunk_infos = chunk_infos + [cur_region]
|
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
|
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
|
if cur_chunk_region_max_peak < self.max_memory:
|
|
regions_dict.append({
|
|
"chunk_info": region,
|
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
|
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
|
"reorder_chunk_info": cur_region,
|
|
"reorder_node_list": cur_node_list,
|
|
})
|
|
# no region found
|
|
if len(regions_dict) == 0:
|
|
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
|
|
|
# select the min chunk len
|
|
chunk_len = [i["chunk_len"] for i in regions_dict]
|
|
best_region_idx = chunk_len.index(min(chunk_len))
|
|
best_region = regions_dict[best_region_idx]
|
|
|
|
# get max chunk size
|
|
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
|
return best_region
|
|
|
|
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
|
chunk_size = 1
|
|
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
|
|
reorder_chunk_info["chunk_size"] = chunk_size
|
|
cur_chunk_max_mem = 0
|
|
# search a region
|
|
while cur_chunk_max_mem < self.max_memory:
|
|
chunk_size *= 2
|
|
reorder_chunk_info["chunk_size"] = chunk_size
|
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
|
cur_chunk_infos)[0]
|
|
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
|
|
# search exact size
|
|
chunk_info = chunk_region_dict["chunk_info"]
|
|
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
|
|
chunk_infos)
|
|
return chunk_info
|
|
|
|
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
|
if left >= 16:
|
|
gap = 4
|
|
else:
|
|
gap = 1
|
|
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
|
while right >= left + gap:
|
|
mid = int((left + right) / 2 + 0.5)
|
|
chunk_info["chunk_size"] = mid
|
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
|
cur_chunk_infos)[0]
|
|
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
|
|
if cur_chunk_max_mem >= self.max_memory:
|
|
right = mid - gap
|
|
else:
|
|
left = mid + gap
|
|
return left
|
|
|
|
def _get_compute_node_num(self, start, end):
|
|
count = 0
|
|
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
|
|
if not is_non_compute_node(i):
|
|
count += 1
|
|
return count
|
|
|
|
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
|
mem_peak):
|
|
# remove illegal regions
|
|
illegal_regions = []
|
|
for i in possible_chunk_regions:
|
|
if not self._is_legal_region(i, chunk_infos):
|
|
illegal_regions.append(i)
|
|
for i in illegal_regions:
|
|
if i in possible_chunk_regions:
|
|
possible_chunk_regions.remove(i)
|
|
|
|
if len(possible_chunk_regions) == 0:
|
|
return None
|
|
|
|
# get max possible chunk region
|
|
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
|
max([i["region"][1] for i in possible_chunk_regions]))
|
|
|
|
# get mem for chunk region
|
|
regions_dict_list = []
|
|
for region in possible_chunk_regions:
|
|
cur_region = region.copy()
|
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
|
|
cur_chunk_infos = chunk_infos + [cur_region]
|
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
|
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
|
regions_dict_list.append({
|
|
"chunk_info": region,
|
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
|
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
|
"reorder_chunk_info": cur_region,
|
|
"reorder_node_list": cur_node_list,
|
|
})
|
|
|
|
# select the min mem
|
|
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
|
|
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
|
best_region = regions_dict_list[best_region_idx]["chunk_info"]
|
|
if best_region is not None:
|
|
best_region["chunk_size"] = 1
|
|
return best_region
|
|
|
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
|
if cur_chunk_info in chunk_infos:
|
|
return False
|
|
if chunk_region_end < chunk_region_start:
|
|
return False
|
|
for i in chunk_infos:
|
|
region = i["region"]
|
|
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
|
|
(chunk_region_start < region[0] and chunk_region_end < region[0])):
|
|
return False
|
|
return True
|