diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 90bde8730..15e15517b 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Callable, Dict, Iterable, List, Tuple import torch @@ -216,14 +216,13 @@ def _add_node_slice( return body -def emit_code_with_chunk( - body: List[str], - nodes: Iterable[Node], - emit_node_func, - delete_unused_value_func, - search_chunk: SearchChunk, - chunk_infos: List, -): +def emit_code_with_chunk(body: List[str], + nodes: Iterable[Node], + emit_node_func: Callable, + delete_unused_value_func: Callable, + search_chunk: SearchChunk, + chunk_infos: List, + eval_mem: bool = False): """ Emit code with chunk according to chunk_infos. @@ -260,6 +259,9 @@ def emit_code_with_chunk( region_idx = 0 within_chunk_region = False + if eval_mem: + body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n") + while node_idx < len(node_list): node = node_list[node_idx] @@ -289,10 +291,18 @@ def emit_code_with_chunk( body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"]) body[-1] = " " + body[-1] delete_unused_value_func(node, body, chunk_inputs_names) + if eval_mem: + body.append( + " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" + % (node.name)) else: emit_node_func(node, body) if node_idx not in chunk_inputs: delete_unused_value_func(node, body, chunk_inputs_names) + if eval_mem: + body.append( + "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n" + % (node.name)) # generate chunk region end if node_idx in chunk_ends: @@ -312,8 +322,10 @@ if AUTOCHUNK_AVAILABLE: meta_graph, max_memory: int = None, print_mem: bool = False, - print_progress: bool = False) -> None: + print_progress: bool = False, + eval_mem: bool = False) -> None: super().__init__() + self.eval_mem = eval_mem # find the chunk regions self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress) self.chunk_infos = self.search_chunk.search_region() @@ -511,14 +523,8 @@ if AUTOCHUNK_AVAILABLE: # if any node has a list of labels for activation_checkpoint, we # will use nested type of activation checkpoint codegen - emit_code_with_chunk( - body, - nodes, - emit_node, - delete_unused_values, - self.search_chunk, - self.chunk_infos, - ) + emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos, + self.eval_mem) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py index f457696e6..08a55f9aa 100644 --- a/colossalai/autochunk/estimate_memory.py +++ b/colossalai/autochunk/estimate_memory.py @@ -2,11 +2,11 @@ import copy from typing import Any, Callable, Dict, Iterable, List, Tuple import torch -from torch.fx.node import Node, map_arg +from torch.fx.node import Node from colossalai.fx.profiler import activation_size, parameter_size -from .utils import NodeMgr, delete_free_var_from_last_use, get_node_shape, is_non_memory_node +from .utils import NodeMgr, get_node_shape, is_non_memory_node class EstimateMemory(object): @@ -14,102 +14,85 @@ class EstimateMemory(object): Estimate memory with chunk """ - def __init__(self, node_mgr: NodeMgr) -> None: - self.node_mgr = node_mgr + def __init__(self) -> None: + pass - def _get_meta_node_size(self, x): + def _get_node_size(self, x: Node) -> float: + """ + return node size in MB + """ x = x.meta["tensor_meta"] - x = x.numel * torch.tensor([], dtype=x.dtype).element_size() - return x + if not hasattr(x, "numel"): + out = sum([i.numel * torch.tensor([], dtype=i.dtype).element_size() for i in x]) + else: + out = x.numel * torch.tensor([], dtype=x.dtype).element_size() + out = float(out) / 1024**2 + return out - def _get_output_node(self, n): - out_size = activation_size(n.meta["fwd_out"]) - out_node = [n.name] if out_size > 0 else [] - return out_size, out_node + def _add_active_node(self, n: Node, active_nodes: Dict, chunk_ratio: float) -> None: + """ + add an active node and its shape to active node dict + """ + if get_node_shape(n) is None: + return + if n.op == "placeholder": + return + if n not in active_nodes: + node_size = self._get_node_size(n) * chunk_ratio + active_nodes[n] = node_size - def _get_output_node_size(self, n): - return self._get_output_node(n)[0] + def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict: + """ + build delete node dict, means node should be deleted at what time + """ + delete_node_dict = {} + for idx, node in enumerate(node_mgr.get_node_list()): + # skip non shape node + if get_node_shape(node) is None: + continue + # dont remove free nodes + elif node.op == "placeholder": + delete_node_dict[node] = len(node_mgr.get_node_list()) + # node no user + elif len(node.users) == 0: + delete_node_dict[node] = idx + # log max use + else: + node_user_idx = [node_mgr.find_node_idx(i) for i in node.users.keys()] + delete_node_dict[node] = max(node_user_idx) + return delete_node_dict - def _add_active_node(self, n, active_list): - new_active = self._get_output_node(n)[1] - if n.op == "placeholder" and get_node_shape(n) is not None: - new_active.append(n.name) - for i in new_active: - if i not in active_list and get_node_shape(n) is not None: - active_list.append(i) + def _remove_deactive_node(self, + user_idx: int, + user: Node, + active_nodes: List, + delete_node_dict: List, + kept_nodes: List = None) -> None: + """ + remove deactivate nodes from active nodes + """ + if kept_nodes is None: + kept_nodes = [] + if user.op in ("output",): + return - def _get_delete_node(self, user, user_to_last_uses, to_keep=None): - delete_size = 0 - delete_node = [] - if user.op not in ("output",): - nodes_to_delete = user_to_last_uses.get(user, []) - if len(user.users) == 0: - nodes_to_delete.append(user) - if to_keep is not None: - keep_list = [] - for n in nodes_to_delete: - if n.name in to_keep: - keep_list.append(n) - for n in keep_list: - if n in nodes_to_delete: - nodes_to_delete.remove(n) - if len(nodes_to_delete): - out_node = [self._get_output_node(i) for i in nodes_to_delete] - delete_size = sum([i[0] for i in out_node]) - for i in range(len(out_node)): - if out_node[i][0] > 0: - delete_node.append(out_node[i][1][0]) - elif nodes_to_delete[i].op == "placeholder": - delete_node.append(nodes_to_delete[i].name) - # elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']): - # delete_node.append(nodes_to_delete[i].name) - return delete_size, delete_node + for node in list(active_nodes.keys()): + # dont delete kept nodes + if node in kept_nodes: + continue + # should be deleted + if delete_node_dict[node] <= user_idx: + active_nodes.pop(node) - def _get_delete_node_size(self, user, user_to_last_uses, to_keep): - return self._get_delete_node(user, user_to_last_uses, to_keep)[0] - - def _remove_deactive_node(self, user, user_to_last_uses, active_list): - delete_node = self._get_delete_node(user, user_to_last_uses)[1] - for i in delete_node: - if i in active_list: - active_list.remove(i) - - def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx): - nodes_to_delete = [] - for chunk_input in chunk_inputs + chunk_inputs_non_chunk: - chunk_input_users = chunk_input.users.keys() - chunk_input_users_idx = [self.node_mgr.find_node_idx(i) for i in chunk_input_users] - if all(i <= chunk_end_idx for i in chunk_input_users_idx): - if chunk_input not in nodes_to_delete: - nodes_to_delete.append(chunk_input) - out_node = [self._get_output_node(i) for i in nodes_to_delete] - delete_size = sum([i[0] for i in out_node]) - return delete_size - - def _get_last_usr(self, nodes): - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - return user_to_last_uses - - def _get_contiguous_memory(self, node, not_contiguous_list, delete=False): + def _get_tmp_memory(self, node, not_contiguous_list, delete=False): mem = 0 not_contiguous_ops = ["permute"] - inherit_contiguous_ops = ["transpose", "view"] if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]): for n in node.args: if n in not_contiguous_list: # matmul won't change origin tensor, but create a tmp copy - mem += self._get_output_node_size(n) + mem += self._get_node_size(n) elif node.op == "call_module": for n in node.args: if n in not_contiguous_list: @@ -129,31 +112,7 @@ class EstimateMemory(object): if chunk_dim is None: return 1.0 else: - return float(chunk_size) / node_shape[chunk_dim] - - def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names): - # if any(j in user.name for j in ['transpose', 'permute', 'view']): - # return 0 - if user.op in ("placeholder", "output"): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - if len(user.users) == 0: - nodes_to_delete.append(user) - delete_size = 0 - for n in nodes_to_delete: - if n.name in chunk_inputs_names: - continue - delete_size += self._get_output_node_size(n) * chunk_ratio - return delete_size - - def _print_mem_log(self, log, nodes, title=None): - if title: - print(title) - for idx, (l, n) in enumerate(zip(log, nodes)): - print("%s:%.2f \t" % (n.name, l), end="") - if (idx + 1) % 3 == 0: - print("") - print("\n") + return chunk_size / float(node_shape[chunk_dim]) def _print_compute_op_mem_log(self, log, nodes, title=None): if title: @@ -168,12 +127,22 @@ class EstimateMemory(object): print("") print("\n") - def estimate_chunk_inference_mem( - self, - node_list: List, - chunk_infos=None, - print_mem=False, - ): + def _add_active_nodes_from_list(self, active_nodes: List, nodes: List) -> List: + """ + add active nodes from nodes + """ + for n in nodes: + self._add_active_node(n, active_nodes, 1) + + def _get_memory_from_active_nodes(self, active_nodes: Dict) -> float: + """ + sum all memory of active nodes + """ + out = [i for i in active_nodes.values()] + out = sum(out) + return out + + def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None, print_mem: bool = False): """ Estimate inference memory with chunk @@ -191,18 +160,17 @@ class EstimateMemory(object): act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] - active_node_list = [] - active_node_list_log = [] + active_nodes = {} + active_nodes_log = [] not_contiguous_list = [] - user_to_last_uses = self._get_last_usr(node_list) - user_to_last_uses_no_free_var = self._get_last_usr(node_list) - delete_free_var_from_last_use(user_to_last_uses_no_free_var) + node_mgr = NodeMgr(node_list) + delete_node_dict = self._build_delete_node_dict(node_mgr) use_chunk = True if chunk_infos is not None else False chunk_within = False chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem - chunk_inputs_names = [] + chunk_inputs_all = [] if use_chunk: chunk_regions = [i["region"] for i in chunk_infos] @@ -210,30 +178,30 @@ class EstimateMemory(object): chunk_ends = [i[1] for i in chunk_regions] chunk_inputs = [i["inputs"] for i in chunk_infos] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_names = [j.name for i in chunk_inputs for j in i - ] + [j.name for i in chunk_inputs_non_chunk for j in i] + chunk_inputs_all = [j for i in chunk_inputs for j in i] + [j for i in chunk_inputs_non_chunk for j in i] chunk_outputs = [i["outputs"] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos] - for idx, node in enumerate(node_list): + for idx, node in enumerate(node_mgr.get_node_list()): + # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in chunk_starts: chunk_within = True chunk_region_idx = chunk_starts.index(idx) - act_memory += sum(self._get_output_node_size(i) for i in chunk_outputs[chunk_region_idx]) / (1024**2) + self._add_active_nodes_from_list(active_nodes, chunk_outputs[chunk_region_idx]) # determine chunk ratio for current node if chunk_within: - chunk_ratio = self._get_chunk_ratio( - node, - chunk_node_dim[chunk_region_idx], - chunk_sizes[chunk_region_idx], - ) + chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx], + chunk_sizes[chunk_region_idx]) + + # add current node as active node + self._add_active_node(node, active_nodes, chunk_ratio) + act_memory = self._get_memory_from_active_nodes(active_nodes) # if node is placeholder, just add the size of the node if node.op == "placeholder": - act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2) act_memory_peak_log.append(act_memory) # skip output elif node.op == "output": @@ -241,83 +209,32 @@ class EstimateMemory(object): # no change for non compute node elif is_non_memory_node(node): act_memory_peak_log.append(act_memory) - # node is a compute op - # calculate tmp, output node and delete node memory + # node is a compute op, calculate tmp else: # forward memory # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose - act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2)) - act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2)) + tmp_memory = self._get_tmp_memory(node, not_contiguous_list, delete=True) * chunk_ratio # record max act memory - act_memory_peak_log.append(act_memory) - # delete useless memory - act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / - (1024**2)) - # delete unused vars not in chunk_input_list - # we can't delete input nodes until chunk ends - if chunk_within: - act_memory -= self._get_chunk_delete_node_size( - node, - user_to_last_uses_no_free_var, - chunk_ratio, - chunk_inputs_names, - ) / (1024**2) - else: - act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var, - chunk_inputs_names) / (1024**2) + act_memory_peak_log.append(act_memory + tmp_memory) - # log active node, only effective without chunk - self._add_active_node(node, active_node_list) - self._remove_deactive_node(node, user_to_last_uses, active_node_list) + # remove_deactive_node + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict, kept_nodes=chunk_inputs_all) # if node in chunk end nodes, restore chunk settings if use_chunk and idx in chunk_ends: - act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2)) - act_memory -= self._get_chunk_inputs_size( - chunk_inputs[chunk_region_idx], - chunk_inputs_non_chunk[chunk_region_idx], - node_list, - chunk_regions[chunk_region_idx][1], - ) / (1024**2) + self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now chunk_within = False chunk_ratio = 1 chunk_region_idx = None + act_memory = self._get_memory_from_active_nodes(active_nodes) act_memory_after_node_log.append(act_memory) - active_node_list_log.append(copy.deepcopy(active_node_list)) + active_nodes_log.append(active_nodes.copy()) if print_mem: print("with chunk" if use_chunk else "without chunk") - # self._print_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_mem_log(act_memory_after_node_log, node_list, "after") - self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak") - # self._print_compute_op_mem_log( - # act_memory_after_node_log, node_list, "after" - # ) + self._print_compute_op_mem_log(act_memory_peak_log, node_mgr.get_node_list(), "peak") # param_memory = parameter_size(gm) # all_memory = act_memory + param_memory - return act_memory_peak_log, act_memory_after_node_log, active_node_list_log - - def get_active_nodes(self, node_list: List) -> List: - """ - Get active nodes for every node - - Args: - node_list (List): _description_ - - Returns: - active_node_list_log (List): active nodes of every node. active nodes refer to - nodes generated but not deleted. - """ - active_node_list = [] - active_node_list_log = [] - user_to_last_uses = self._get_last_usr(node_list) - user_to_last_uses_no_free_var = self._get_last_usr(node_list) - delete_free_var_from_last_use(user_to_last_uses_no_free_var) - for _, node in enumerate(node_list): - # log active node, only effective without chunk - self._add_active_node(node, active_node_list) - self._remove_deactive_node(node, user_to_last_uses, active_node_list) - active_node_list_log.append(copy.deepcopy(active_node_list)) - return active_node_list_log + return act_memory_peak_log, act_memory_after_node_log, active_nodes_log diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py index eb9949095..326445ee9 100644 --- a/colossalai/autochunk/search_chunk.py +++ b/colossalai/autochunk/search_chunk.py @@ -42,10 +42,11 @@ class SearchChunk(object): def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None: self.print_mem = print_mem + self.max_memory = max_memory self.print_progress = print_progress - self.node_mgr = NodeMgr(gm) + self.node_mgr = NodeMgr(list(gm.graph.nodes)) self.trace_indice = TraceIndice(self.node_mgr) - self.estimate_memory = EstimateMemory(self.node_mgr) + self.estimate_memory = EstimateMemory() self._init_trace() self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr) self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr) @@ -63,45 +64,46 @@ class SearchChunk(object): reduce the computation complexity of trace_indice """ # find all max ranges - active_nodes = self.estimate_memory.get_active_nodes(self.node_mgr.get_node_list()) - cur_node_idx = len(self._get_free_var_idx()) - max_chunk_region_list = [] - while True: - max_chunk_region = self._search_max_chunk_region(active_nodes, cur_node_idx) - cur_node_idx = max_chunk_region[1] + 1 - if cur_node_idx >= len(active_nodes) - 1: - break - max_chunk_region_list.append(max_chunk_region) - - # nothing to limit for the first range - max_chunk_region_list = max_chunk_region_list[1:] - max_chunk_region_list[0] = (0, max_chunk_region_list[0][1]) - + active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2] # set trace range and do the trace if self.print_progress: get_logger().info("AutoChunk start tracing indice") - self.trace_indice.set_trace_range(max_chunk_region_list, active_nodes) + self.trace_indice.set_active_nodes(active_nodes) self.trace_indice.trace_indice() - def _find_peak_node(self, mem_peak: List) -> int: + def _find_peak_region(self, mem_peak: List) -> int: + """ + find peak node, along with its neighbour nodes exceeds max mem + """ max_value = max(mem_peak) max_idx = mem_peak.index(max_value) - return max_idx + peak_region = [max_idx, max_idx] + if self.max_memory is None: + return peak_region - def _get_free_var_idx(self) -> List: - """ - Get free var index + # to left + count = 0 + for i in range(max_idx - 1, -1, -1): + if mem_peak[i] > self.max_memory: + peak_region[0] = i + else: + count += 1 + if count >= 3: + break + # to right + count = 0 + for i in range(max_idx + 1, len(mem_peak) - 1): + if mem_peak[i] > self.max_memory: + peak_region[1] = i + count = 0 + else: + count += 1 + if count >= 3: + break - Returns: - free_var_idx (List): all indexs of free vars - """ - free_var_idx = [] - for idx, n in enumerate(self.node_mgr.get_node_list()): - if n.op == "placeholder" and get_node_shape(n) is not None: - free_var_idx.append(idx) - return free_var_idx + return peak_region - def _search_max_chunk_region(self, active_node: List, peak_node_idx: int, chunk_regions: List = None) -> Tuple: + def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple: """ Search max chunk region according to peak memory node @@ -119,50 +121,24 @@ class SearchChunk(object): # check if peak node already in chunkinfo if chunk_regions is not None: for i in chunk_regions: - if i["region"][0] < peak_node_idx <= i["region"][1]: + if i["region"][0] < peak_region[0] <= i["region"][1] or \ + i["region"][0] < peak_region[1] <= i["region"][1]: return None - free_vars = self._get_free_var_idx() - free_var_num = len(free_vars) active_node_num = [len(i) for i in active_node] - min_active_node_num = min(active_node_num[free_var_num:]) - threshold = max(free_var_num, min_active_node_num) - - # normal search - # from peak_node to free_var - inside_flag = False - chunk_region_start = free_var_num - for i in range(peak_node_idx, -1, -1): - if active_node_num[i] <= threshold: - inside_flag = True - if inside_flag and active_node_num[i] > threshold: - chunk_region_start = i + 1 - break - # from peak_node to len-2 - inside_flag = False - chunk_region_end = len(active_node) - 1 - for i in range(peak_node_idx, len(active_node)): - if active_node_num[i] <= threshold: - inside_flag = True - if inside_flag and active_node_num[i] > threshold: + window_size = 100 + # search min for start + min_num = 1e4 + for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1): + if active_node_num[i] < min_num: + min_num = active_node_num[i] + chunk_region_start = i + # search min for end + min_num = 1e4 + for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))): + if active_node_num[i] < min_num: + min_num = active_node_num[i] chunk_region_end = i - break - - # if normal search fails, use approximate search - if (chunk_region_end - chunk_region_start) > 250: - window_size = 100 - # search min for start - min_num = 1e3 - for i in range(max(peak_node_idx - window_size, 0), peak_node_idx + 1): - if active_node_num[i] < min_num: - min_num = active_node_num[i] - chunk_region_start = i - # search min for end - min_num = 1e3 - for i in range(min(peak_node_idx + window_size, len(active_node_num) - 1), peak_node_idx - 1, -1): - if active_node_num[i] < min_num: - min_num = active_node_num[i] - chunk_region_end = i # avoid chunk regions overlap if chunk_regions is not None: @@ -214,7 +190,7 @@ class SearchChunk(object): chunk_infos.append(chunk_info) return chunk_infos - def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List: + def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List: """ Search every possible region within the max chunk region. @@ -235,8 +211,8 @@ class SearchChunk(object): cur_trace[arg] = self.trace_indice._find_trace_from_node(arg) input_trace.append(cur_trace) - for start_idx in range(max_chunk_region[0], peak_node + 1): - for end_idx in range(peak_node, max_chunk_region[1] + 1): + for start_idx in range(max_chunk_region[0], peak_region[0] + 1): + for end_idx in range(peak_region[1], max_chunk_region[1] + 1): # skip non compute nodes if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node( self.node_mgr.get_node_by_idx(end_idx)): @@ -270,13 +246,12 @@ class SearchChunk(object): Returns: best_chunk_region (Dict) """ - peak_node = self._find_peak_node(mem_peak) - max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos) + peak_region = self._find_peak_region(mem_peak) + max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos) if max_chunk_region == None: return None - possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node) - best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node, - max_chunk_region, mem_peak) + possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region) + best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak) best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region) return best_chunk_region diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py index 1bb7d318c..94a29bfd5 100644 --- a/colossalai/autochunk/select_chunk.py +++ b/colossalai/autochunk/select_chunk.py @@ -24,29 +24,16 @@ class SelectChunk(object): else: self.stratge = "min_memory" - def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak): + def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak): if self.stratge == "min_memory": - best_region = self._select_min_memory_chunk_region( - possible_chunk_regions, - chunk_infos, - peak_node, - max_chunk_region, - mem_peak, - ) + best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) elif self.stratge == "fit_memory": - best_region = self._select_fit_memory_chunk_region( - possible_chunk_regions, - chunk_infos, - peak_node, - max_chunk_region, - mem_peak, - ) + best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak) else: raise RuntimeError() return best_region - def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, - mem_peak): + def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak): # stop chunk if max memory satisfy memory limit if max(mem_peak) < self.max_memory: return None @@ -63,17 +50,14 @@ class SelectChunk(object): if len(possible_chunk_regions) == 0: return None - max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]), - max([i["region"][1] for i in possible_chunk_regions])) - # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) cur_chunk_infos = chunk_infos + [cur_region] - cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] - cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1] + cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] + cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1] cur_chunk_region_max_peak = max(cur_chunk_region_peak) if cur_chunk_region_max_peak < self.max_memory: regions_dict.append({ @@ -141,8 +125,7 @@ class SelectChunk(object): count += 1 return count - def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, - mem_peak): + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): # remove illegal regions illegal_regions = [] for i in possible_chunk_regions: diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py index 1e41073d7..92199b79a 100644 --- a/colossalai/autochunk/trace_indice.py +++ b/colossalai/autochunk/trace_indice.py @@ -33,7 +33,6 @@ class TraceIndice(object): self.indice_trace_list = self._init_indice_trace_list() self.indice_view_list = {} self.indice_count = -1 - self.trace_range = [] self.active_node_list = [] def _init_indice_trace_list(self) -> List: @@ -50,8 +49,7 @@ class TraceIndice(object): indice_trace_list.append(cur_trace) return indice_trace_list - def set_trace_range(self, trace_range: List, active_node_list: List) -> None: - self.trace_range = trace_range + def set_active_nodes(self, active_node_list: List) -> None: self.active_node_list = active_node_list def _add_indice(self) -> int: @@ -731,23 +729,35 @@ class TraceIndice(object): dim_from.reverse() # search view list - for view_node, view_dict in self.indice_view_list.items(): - if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from - and view_dict["dim_from"] == dim_to): - # inheirt indice from current node - if len_diff == 1: - if origin_shape[dim_from[0]] == 1: - self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) - elif origin_shape[dim_from[1]] == 1: - self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) - elif len_diff == -1: - if target_shape[dim_to[0]] == 1: - self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) - elif target_shape[dim_to[1]] == 1: - self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) - # inherid indice from input node of last view - for dim_to_i in dim_to: - self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False) + # for view_node, view_dict in self.indice_view_list.items(): + # if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from + # and view_dict["dim_from"] == dim_to): + # # inheirt indice from current node + # if len_diff == 1: + # if origin_shape[dim_from[0]] == 1: + # self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) + # elif origin_shape[dim_from[1]] == 1: + # self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + # elif len_diff == -1: + # if target_shape[dim_to[0]] == 1: + # self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) + # elif target_shape[dim_to[1]] == 1: + # self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + # # inherid indice from input node of last view + # for dim_to_i in dim_to: + # self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False) + + # inheirt indice from current node + if len_diff == 1: + if origin_shape[dim_from[0]] == 1: + self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False) + elif origin_shape[dim_from[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) + elif len_diff == -1: + if target_shape[dim_to[0]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False) + elif target_shape[dim_to[1]] == 1: + self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False) # log view, not used now view_dict = { @@ -762,32 +772,22 @@ class TraceIndice(object): """ clear too far trace to speed up computation """ - trace_range = None - for i in range(len(self.trace_range)): - if self.trace_range[i][1] == node_idx: - trace_range = (self.trace_range[i][0], self.trace_range[i][1]) - break - if self.trace_range[i][1] > node_idx: - break - if trace_range is None: - return + trace_barrier = max(node_idx - 100, 0) + active_nodes = self.active_node_list[trace_barrier] + active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()] - active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1] - active_nodes = set(flat_list(active_nodes)) - active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes] - for i in range(trace_range[0], trace_range[1] + 1): - trace = self.indice_trace_list[i] - # clear compute - for dim_compute in trace["compute"]: - for i in range(len(dim_compute) - 1, -1, -1): - if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes): - dim_compute.pop(i) - continue - # clear source - for dim_source in trace["source"]: - for k in list(dim_source.keys()): - if k < trace_range[0] and k not in active_nodes: - dim_source.pop(k) + trace = self.indice_trace_list[node_idx] + # clear compute + for dim_compute in trace["compute"]: + for i in range(len(dim_compute) - 1, -1, -1): + if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes): + dim_compute.pop(i) + continue + # clear source + for dim_source in trace["source"]: + for k in list(dim_source.keys()): + if k < trace_barrier and k not in active_nodes: + dim_source.pop(k) def trace_indice(self) -> None: for idx, node in enumerate(self.node_mgr.get_node_list()): diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py index c6bbc219e..7c0bc29b5 100644 --- a/colossalai/autochunk/utils.py +++ b/colossalai/autochunk/utils.py @@ -11,8 +11,8 @@ logger = get_dist_logger() class NodeMgr(object): - def __init__(self, gm) -> None: - self._node_list = list(gm.graph.nodes) + def __init__(self, nodes_list: List[Node]) -> None: + self._node_list = nodes_list self._node_dict = {} self._set_node_dict() @@ -76,6 +76,8 @@ def flat_list(inputs: Any) -> List: for i in inputs: if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): res.extend(flat_list(i)) + elif isinstance(i, dict): + res.extend(flat_list(list(i.keys()))) else: res.append(i) return res @@ -135,13 +137,6 @@ def is_non_compute_node_except_placeholder_output(node: Node) -> bool: return is_non_compute_node_except_placeholder(node) -def find_node_idx(name: str, nodes_list: List) -> int: - for idx, node in enumerate(nodes_list): - if node.name == name: - return idx - raise RuntimeError("name %s not found in node list" % name) - - def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None: for key, value in user_to_last_uses.items(): for n in value: diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index 2f56f139a..896751e40 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -61,7 +61,7 @@ def _benchmark_evoformer_stack_gm( # bench mem = _benchmark_memory(gm, inputs) speed = _benchmark_speed(gm, inputs) - print("evoformer stack gm, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args))) + print("evoformer stack gm, mem: %.2fMB, time: %.4fs" % (mem, speed)) def _benchmark_evoformer_stack_origin( @@ -83,14 +83,15 @@ def _benchmark_evoformer_stack_origin( # bench mem = _benchmark_memory(model, inputs) speed = _benchmark_speed(model, inputs) - print("evoformer stack origin, mem: %.2fMB, time: %.4fs, data_args: %s" % (mem, speed, str(data_args))) + print("evoformer stack origin, mem: %.2fMB, time: %.4fs" % (mem, speed)) + return mem def _benchmark_memory(model, inputs): with torch.no_grad(): torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 - model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + model(*inputs) new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 return new_max_mem - now_mem @@ -108,13 +109,18 @@ def _benchmark_speed(model, inputs, loop=5): return (time2 - time1) / loop -def benchmark_evoformer_stack(): +def benchmark_evoformer_stack(data_args): from test_autochunk_evoformer_stack import get_data, get_model - data_args = [128, 256] - print("") - _benchmark_evoformer_stack_origin(data_args, get_model, get_data) - _benchmark_evoformer_stack_gm(data_args, 600, get_model, get_data) - _benchmark_evoformer_stack_gm(data_args, 400, get_model, get_data) + print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1])) + max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data) + for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]: + try: + _benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data) + except RuntimeError as e: + if e.args[0] == 'Search failed. Try a larger memory threshold.': + break + except Exception as e: + raise e _benchmark_evoformer_stack_gm(data_args, None, get_model, get_data) @@ -128,4 +134,7 @@ if __name__ == "__main__": port=free_port(), backend="nccl", ) - benchmark_evoformer_stack() + benchmark_evoformer_stack((256, 256)) + benchmark_evoformer_stack((256, 512)) + benchmark_evoformer_stack((256, 1024)) + benchmark_evoformer_stack((256, 1280)) diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index be727701c..17a5abf4c 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -55,10 +55,10 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: def get_chunk_target() -> Dict: return { - None: [(120, 123), (222, 237), (269, 289), (305, 311), (100, 105), (146, 152), (187, 193), (241, 242), - (25, 50)], - 20: [(120, 123), (232, 237), (277, 282), (305, 306), (100, 101), (34, 39)], - 24: [(120, 123)], + None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184), + (140, 145), (162, 163), (203, 204)], + 20: [(120, 123), (232, 237), (277, 282), (305, 306)], + 24: [(122, 123)], } diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index f8102f351..ad955479e 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -53,15 +53,6 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: return meta_args, concrete_args -def get_chunk_target() -> Dict: - return { - None: [(128, 131), (230, 245), (277, 297), (313, 319), (108, 113), (154, 160), (195, 201), (249, 250), - (36, 46)], - 20: [(128, 131), (240, 245), (285, 290), (313, 314), (108, 109), (41, 46)], - 24: [(128, 131)], - } - - @pytest.mark.skipif( not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", @@ -75,7 +66,6 @@ def test_extramsa_block(data_args, max_memory): max_memory=max_memory, get_model=get_model, get_data=get_data, - get_chunk_target=get_chunk_target, ) mp.spawn(run_func, nprocs=1) @@ -87,7 +77,6 @@ if __name__ == "__main__": max_memory=None, get_model=get_model, get_data=get_data, - get_chunk_target=get_chunk_target, print_code=False, print_mem=False, print_progress=False, diff --git a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py index 43cefcb74..5791af351 100644 --- a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py +++ b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py @@ -95,7 +95,7 @@ def _benchmark_memory(model, inputs): with torch.no_grad(): torch.cuda.reset_peak_memory_stats() now_mem = float(torch.cuda.memory_allocated()) / 1024**2 - model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs]) + model(*inputs) new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2 return new_max_mem - now_mem @@ -116,8 +116,7 @@ def _benchmark_speed(model, inputs, loop=5): def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12): from test_autochunk_gpt import GPT2Config, GPT2Model, get_data model = GPT2Model - config = GPT2Config(n_embd=n_embd, n_position=seq, n_layer=2, n_head=n_head) - config.max_position_embeddings = seq + config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head) model = model(config=config) shape = [batch, seq] print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head)) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py index 6e1076ec7..018a2557a 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -44,20 +44,19 @@ def test_autochunk_gpt(model, shape, max_memory): data=get_data(shape), max_memory=max_memory, model=model, - config=GPT2Config(n_embd=96, n_position=shape[1], n_layer=2, n_head=4), + config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4), ) mp.spawn(run_func, nprocs=1) if __name__ == "__main__": - run_test( - rank=0, - data=get_data((BATCH_SIZE, SEQ_LENGTH)), - max_memory=None, - model=GPT2Model, - config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), - print_code=False, - print_est_mem=False, - print_mem=False, - print_progress=False, - ) + run_test(rank=0, + data=get_data((BATCH_SIZE, SEQ_LENGTH)), + max_memory=None, + model=GPT2Model, + config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4), + print_code=False, + print_est_mem=False, + print_mem=False, + print_progress=False, + eval_mem=False) diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py index cc26168c7..bc5eda7ed 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -24,6 +24,7 @@ def assert_codegen_run( print_mem: bool = False, print_progress: bool = False, print_code: bool = False, + eval_mem: bool = False, ) -> List[Dict]: meta_args, concrete_args, sequence = data if concrete_args is None: @@ -39,12 +40,11 @@ def assert_codegen_run( meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence] meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors] interp.propagate(*meta_tensors) - codegen = AutoChunkCodeGen( - meta_graph, - max_memory=max_memory, - print_mem=print_est_mem, - print_progress=print_progress, - ) + codegen = AutoChunkCodeGen(meta_graph, + max_memory=max_memory, + print_mem=print_est_mem, + print_progress=print_progress, + eval_mem=eval_mem) chunks = codegen.chunk_infos # trace and recompile @@ -108,6 +108,7 @@ def run_test( print_est_mem: bool = False, print_mem: bool = False, print_progress: bool = False, + eval_mem: bool = False, get_chunk_target: Any = None, ) -> None: model = model(config=config) @@ -122,15 +123,14 @@ def run_test( ) # build model and input - chunks = assert_codegen_run( - model, - data=data, - max_memory=max_memory, - print_code=print_code, - print_est_mem=print_est_mem, - print_mem=print_mem, - print_progress=print_progress, - ) + chunks = assert_codegen_run(model, + data=data, + max_memory=max_memory, + print_code=print_code, + print_est_mem=print_est_mem, + print_mem=print_mem, + print_progress=print_progress, + eval_mem=eval_mem) if get_chunk_target is not None: chunk_found = [i["region"] for i in chunks]