diff --git a/chunk_codegen.py b/chunk_codegen.py index baf207795..c8bb433ef 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -55,15 +55,49 @@ def _get_last_usr(nodes): return user_to_last_uses +def _delete_free_var_from_last_use(user_to_last_uses): + for key, value in user_to_last_uses.items(): + for n in value: + if n.op == 'placeholder': + user_to_last_uses[key].remove(n) + + +def _get_contiguous_memory(node, not_contiguous_list, delete=False): + mem = 0 + not_contiguous_ops = ['transpose', 'permute'] + + if node.op == 'call_function' and 'matmul' in node.name: + for n in node.args: + if n in not_contiguous_list: + # matmul won't change origin tensor, but create a tmp copy + mem += _get_output_node_size(n) + elif node.op == 'call_module': + for n in node.args: + if n in not_contiguous_list: + # module will just make origin tensor to contiguous + if delete: + not_contiguous_list.remove(n) + elif node.op == 'call_method' and any(i in node.name for i in not_contiguous_ops): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + elif any(i in node.args for i in not_contiguous_list): + if node not in not_contiguous_list: + not_contiguous_list.append(node) + + return mem + + def _estimate_inference_mem(gm: torch.fx.GraphModule): - act_memory = 0 + act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] + not_contiguous_list = [] user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) for node in gm.graph.nodes: # if node is placeholder, just add the size of the node if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) + act_memory += _get_meta_node_size(node) / (1024 ** 2) act_memory_peak_log.append(act_memory) act_memory_after_node_log.append(act_memory) # skip output @@ -72,25 +106,21 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule): # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - act_memory += calculate_fwd_tmp(node) - # act_memory += calculate_fwd_out(node) - act_memory += _get_output_node_size(node) + act_memory += _get_contiguous_memory(node, not_contiguous_list) / (1024 ** 2) + act_memory += _get_output_node_size(node) / (1024 ** 2) # record max act memory act_memory_peak_log.append(act_memory) # delete useless memory - act_memory -= calculate_fwd_tmp(node) - act_memory -= _get_delete_node_size(node, user_to_last_uses) + act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) + act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) / (1024 ** 2) act_memory_after_node_log.append(act_memory) - act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] - act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] - print("no chunk") - _print_mem_log(act_memory_peak_log, "peak") - _print_mem_log(act_memory_after_node_log, "after") + _print_mem_log(act_memory_peak_log, list(gm.graph.nodes), "peak") + _print_mem_log(act_memory_after_node_log, list(gm.graph.nodes), "after") param_memory = parameter_size(gm) - return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) + return act_memory + param_memory, param_memory def _get_chunk_ratio(node, chunk_dim, chunk_size): @@ -111,19 +141,23 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list, return delete_size -def _print_mem_log(log, title=None): +def _print_mem_log(log, nodes, title=None): if title: - print("%-8s" % title, end=' ') - for i in log: - print("%.2f " % i, end='') - print("") + print(title) + for idx, (l, n) in enumerate(zip(log, nodes)): + print("%s:%.2f \t" % (n.name, l), end='') + if (idx + 1) % 3 == 0: + print("") + print("\n") def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nodes, chunk_dims, chunk_sizes): - act_memory = 0 + act_memory = 0.0 act_memory_peak_log = [] act_memory_after_node_log = [] + not_contiguous_list = [] user_to_last_uses = _get_last_usr(list(gm.graph.nodes)) + _delete_free_var_from_last_use(user_to_last_uses) within_chunk = False region_idx = 0 chunk_ratio = 1 # use it to estimate chunk mem @@ -134,11 +168,11 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod if idx in start_nodes: within_chunk = True chunk_ratio = _get_chunk_ratio(node, chunk_dims[region_idx], chunk_sizes[region_idx]) - act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) + act_memory += _get_output_node_size(node_list[end_nodes[region_idx]]) / (1024 ** 2) # if node is placeholder, just add the size of the node if node.op == 'placeholder': - act_memory += _get_meta_node_size(node) * chunk_ratio + act_memory += _get_meta_node_size(node) * chunk_ratio / (1024 ** 2) act_memory_peak_log.append(act_memory) # skip output elif node.op == 'output': @@ -146,36 +180,33 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod # node is an operation, calculate tmp, output node and delete node memory else: # forward memory - act_memory += calculate_fwd_tmp(node) * chunk_ratio - # act_memory += calculate_fwd_out(node) - act_memory += _get_output_node_size(node) * chunk_ratio + act_memory += _get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024 ** 2) + act_memory += _get_output_node_size(node) * chunk_ratio / (1024 ** 2) # record max act memory act_memory_peak_log.append(act_memory) # delete useless memory - act_memory -= calculate_fwd_tmp(node) * chunk_ratio + act_memory -= _get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio / (1024 ** 2) if within_chunk: act_memory -= _get_chunk_delete_node_size( - node, user_to_last_uses, chunk_ratio, node_list, start_nodes[region_idx], end_nodes[region_idx]) + node, user_to_last_uses, chunk_ratio, node_list, + start_nodes[region_idx], end_nodes[region_idx]) / (1024 ** 2) else: - act_memory -= _get_delete_node_size(node, user_to_last_uses) + act_memory -= _get_delete_node_size(node, user_to_last_uses) / (1024 ** 2) if idx in end_nodes: - act_memory -= _get_output_node_size(node) * chunk_ratio + act_memory -= _get_output_node_size(node) * chunk_ratio / (1024 ** 2) within_chunk = False chunk_ratio = 1 region_idx += 1 act_memory_after_node_log.append(act_memory) - act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log] - act_memory_after_node_log = [float(i) / (1024 ** 2) for i in act_memory_after_node_log] - print("chunk") - _print_mem_log(act_memory_peak_log, "peak") - _print_mem_log(act_memory_after_node_log, "after") - + _print_mem_log(act_memory_peak_log, node_list, "peak") + _print_mem_log(act_memory_after_node_log, node_list, "after") + param_memory = parameter_size(gm) - return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2) + return act_memory + param_memory, param_memory def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): @@ -516,7 +547,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v """ # find the offload regions - chunk_regions = [(2, 6)] + chunk_regions = [(58, 62)] chunk_starts = [item[0] for item in chunk_regions] chunk_ends = [item[1] for item in chunk_regions] chunk_inputs = [] @@ -683,7 +714,9 @@ if CODEGEN_AVAILABLE: for node in reversed(nodes): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - + + _delete_free_var_from_last_use(user_to_last_uses) + # NOTE: we add a variable to distinguish body and ckpt_func def delete_unused_values(user: Node, body, to_keep=[]): """ diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index cc975f2ea..39363a80a 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool: def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): - # now_mem = torch.cuda.memory_allocated() / 1024**2 - # max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2)) - # with torch.no_grad(): - # fx_out = gm(node, pair) - # new_now_mem = torch.cuda.memory_allocated() / 1024**2 - # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 - # print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - max_mem)) + now_mem = torch.cuda.memory_allocated() / 1024**2 + with torch.no_grad(): + node0 = node.clone() + pair0 = pair.clone() + node1, pair1 = gm(node0, pair0) + new_now_mem = torch.cuda.memory_allocated() / 1024**2 + new_max_mem = torch.cuda.max_memory_allocated() / 1024**2 + print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem - now_mem)) # test forward with torch.no_grad(): @@ -63,8 +63,8 @@ def _run_offload_codegen(rank): # build model and input model = evoformer_base().cuda() - node = torch.randn(1, 16, 32, 256).cuda() - pair = torch.randn(1, 32, 32, 128).cuda() + node = torch.randn(1, 100, 300, 256).cuda() + pair = torch.randn(1, 300, 300, 128).cuda() # trace the module and replace codegen graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})