From 78cfe4362b4550635f609a8b52a8489c7f9aa564 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 2 Nov 2022 13:59:48 +0800 Subject: [PATCH] basic chunk --- chunk_codegen.py | 66 ++++++++++++++++++++++---------------------- chunk_codegen_run.py | 15 +++++----- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 09fda2b98..c605e35f4 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -46,6 +46,19 @@ def pack_hook_no_input(self, x): return pack_hook, unpack_hook +def _gen_loop_5(to_keep): + context = "chunk_result = []\nfor gen_loop_idx in range(4):\n" + context += " chunk_tensor = " + to_keep + "[gen_loop_idx, :]\n" + return context + + +def _gen_loop_5_final(final_name, to_keep): + context = " chunk_result.append(" + final_name + ")\n" + context += "chunk_result = torch.cat(chunk_result, dim=0); " + to_keep[0] + " = None\n" + context += final_name + " = chunk_result; chunk_result = None\n" + return context + + def _gen_save_tensors_hooks_context(offload_input=True) -> str: """Generate customized saved_tensors_hooks @@ -410,57 +423,40 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func - is_hook_inserted = False node_idx = 0 - while 1: + to_keep = [] + while node_idx < len(node_list): # break if we finish the processing all the nodes if node_idx >= len(node_list): break - # process ckpt_regions - if node_idx in start_idx: - ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] - emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) - node_idx += len(ckpt_node_list) - # process node in forward function else: node = node_list[node_idx] if node_idx in chunk_starts: - chunk_label = chunk_labels[chunk_starts.index(node_idx)] - _, chunk_input, chunk_bar = chunk_label + # save chunk input var, dont delete it + to_keep.extend(node.args[0].name) within_chunk_region = True - - # insert hook functions if needed - if not is_hook_inserted: - pack_hook, unpack_hook = _gen_saved_tensors_hooks() - ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") - is_hook_inserted = True - - if chunk_input and chunk_bar: - body.append(_gen_save_on_cpu_context()) - - elif chunk_input: - for par in chunk_inputs[chunk_label[0]]: - body.append(f"setattr({par}, 'offload', True)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=True)) - - else: - for par in chunk_inputs[chunk_label[0]]: - body.append(f"setattr({par}, 'offload', False)\n") - body.append(_gen_save_tensors_hooks_context(offload_input=False)) + # add for loop + body.append(_gen_loop_5(to_keep[0])) + # change first node's input to new chunked var + node_args = list(node.args) + node_args[0] = 'chunk_tensor' if within_chunk_region: emit_node_func(node, body) body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body) + delete_unused_value_func(node, body, to_keep) else: emit_node_func(node, body) - delete_unused_value_func(node, body) + if node_idx not in chunk_inputs: + delete_unused_value_func(node, body, to_keep) if node_idx in chunk_ends: + body.append(_gen_loop_5_final(node.name, to_keep)) + to_keep = [] within_chunk_region = False node_idx += 1 @@ -572,7 +568,7 @@ if CODEGEN_AVAILABLE: map_arg(node.kwargs, lambda n: register_last_uses(n, node)) # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body): + def delete_unused_values(user: Node, body, to_keep=[]): """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage @@ -584,6 +580,9 @@ if CODEGEN_AVAILABLE: body.append('\n') return nodes_to_delete = user_to_last_uses.get(user, []) + for n in nodes_to_delete: + if n.name in to_keep: + nodes_to_delete.remove(n) if len(nodes_to_delete): to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) body.append(f'; {to_delete_str}\n') @@ -693,5 +692,6 @@ if CODEGEN_AVAILABLE: {wrap_stmts} {prologue} -{code}""" +{code}""" + print(fn_code) return PythonCode(fn_code, globals_) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py index 85164bdad..69b327d4b 100644 --- a/chunk_codegen_run.py +++ b/chunk_codegen_run.py @@ -54,6 +54,7 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T # test forward non_fx_out = model(data) fx_out = gm(data) + print(non_fx_out.shape, fx_out.shape) assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" # test barckward @@ -86,13 +87,13 @@ def _run_offload_codegen(rank): setattr(node, "activation_offload", [0, True, False]) if node.name == "linear1": setattr(node, "activation_offload", [0, True, False]) - if node.name == "linear2": - setattr(node, "activation_offload", [1, True, True]) - if node.name == "linear4": - setattr(node, "activation_offload", [2, False, True]) - if node.name == "linear5": - setattr(node, "activation_checkpoint", [0]) - setattr(node, "activation_offload", True) + # if node.name == "linear2": + # setattr(node, "activation_offload", [1, True, True]) + # if node.name == "linear4": + # setattr(node, "activation_offload", [2, False, True]) + # if node.name == "linear5": + # setattr(node, "activation_checkpoint", [0]) + # setattr(node, "activation_offload", True) gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile()