From 54a34a7e46d2f9e0234eb9295f3507e720ba21b2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 15 Nov 2022 11:30:43 +0800 Subject: [PATCH] update active log --- chunk_codegen.py | 56 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index c1d9e26e7..ade986d1e 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -407,18 +407,41 @@ class MemoryEstimator(object): x = x.numel * torch.tensor([], dtype=x.dtype).element_size() return x - def _get_output_node_size(self, n): + def _get_output_node(self, n): fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')} - return activation_size(fwd_out) + out_size = activation_size(fwd_out) + out_node = [n.name] if out_size > 0 else [] + return out_size, out_node + + def _get_output_node_size(self, n): + return self._get_output_node(n)[0] + + def _add_active_node(self, n, active_list): + new_active = self._get_output_node(n)[1] + for i in new_active: + if i not in active_list: + active_list.append(i) + def _get_delete_node(self, user, user_to_last_uses): + delete_size = 0 + delete_node = [] + if user.op not in ('placeholder', 'output'): + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + out_node = [self._get_output_node(i) for i in nodes_to_delete] + delete_size = sum([i[0] for i in out_node]) + for i in range(len(out_node)): + if out_node[i][0] > 0: + delete_node.append(out_node[i][1][0]) + return delete_size, delete_node + def _get_delete_node_size(self, user, user_to_last_uses): - if user.op in ('placeholder', 'output'): - return 0 - nodes_to_delete = user_to_last_uses.get(user, []) - if len(nodes_to_delete): - delete_size = sum([self._get_output_node_size(i) for i in nodes_to_delete]) - return delete_size - return 0 + return self._get_delete_node(user, user_to_last_uses)[0] + + def _remove_active_node(self, user, user_to_last_uses, active_list): + delete_node = self._get_delete_node(user, user_to_last_uses)[1] + for i in delete_node: + active_list.remove(i) def _get_last_usr(self, nodes): node_to_last_use: Dict[Node, Node] = {} @@ -438,7 +461,7 @@ class MemoryEstimator(object): mem = 0 not_contiguous_ops = ['transpose', 'permute'] - if node.op == 'call_function' and 'matmul' in node.name: + if node.op == 'call_function' and any(n in node.name for n in ['matmul', 'reshape']): for n in node.args: if n in not_contiguous_list: # matmul won't change origin tensor, but create a tmp copy @@ -463,6 +486,8 @@ class MemoryEstimator(object): act_memory_peak_log = [] act_memory_after_node_log = [] not_contiguous_list = [] + active_node_list = [] + active_node_list_log = [] user_to_last_uses = self._get_last_usr(list(gm.graph.nodes)) _delete_free_var_from_last_use(user_to_last_uses) for node in gm.graph.nodes: @@ -470,7 +495,7 @@ class MemoryEstimator(object): if node.op == 'placeholder': act_memory += self._get_meta_node_size(node) / (1024 ** 2) act_memory_peak_log.append(act_memory) - act_memory_after_node_log.append(act_memory) + active_node_list.append(node.name) # skip output elif node.op == 'output': continue @@ -484,8 +509,12 @@ class MemoryEstimator(object): # delete useless memory act_memory -= self._get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) act_memory -= self._get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) - act_memory_after_node_log.append(act_memory) + # log active node + self._add_active_node(node, active_node_list) + self._remove_active_node(node, user_to_last_uses, active_node_list) + act_memory_after_node_log.append(act_memory) + active_node_list_log.append(copy.deepcopy(active_node_list)) print("no chunk") self._print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") self._print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") @@ -551,7 +580,6 @@ class MemoryEstimator(object): # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - # TODO: permute will create a tmp copy if not contiguous act_memory += self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) act_memory += self._get_output_node_size(node) * chunk_ratio / (1024 ** 2) # record max act memory @@ -694,9 +722,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v within_chunk_region = False node_list = list(nodes) + memory_estimator = MemoryEstimator() memory_estimator.estimate_chunk_inference_mem(meta_graph, chunk_starts, chunk_ends, [1], [2]) memory_estimator.estimate_inference_mem(meta_graph) + node_index_tracer = NodeIndexTracer(meta_graph) node_index_tracer.trace_node_idx()