diff --git a/chunk_codegen.py b/chunk_codegen.py index 47cda0f8e..6740cd44a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -827,7 +827,7 @@ def _find_input_and_output_nodes(nodes: List[Node]): for node in nodes: for input_node in node._input_nodes.keys(): node_repr = repr(input_node) - if input_node not in nodes and node_repr not in input_nodes: + if input_node not in nodes and input_node not in input_nodes: input_nodes.append(input_node) # if a node has a user node which is not in the node list @@ -835,7 +835,7 @@ def _find_input_and_output_nodes(nodes: List[Node]): for node in nodes: for output_node in node.users.keys(): node_repr = repr(node) - if output_node not in nodes and node_repr not in output_nodes: + if output_node not in nodes and output_node not in output_nodes: output_nodes.append(output_node) return input_nodes, output_nodes @@ -848,6 +848,16 @@ def _find_idx_by_name(name, nodes_list): raise RuntimeError("name %s not found in node list" % name) +def _replace_name(context, name_from, name_to): + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")] + for p in patterns: + source = p[0] + name_from + p[1] + target = p[0] + name_to + p[1] + if source in context: + context = context.replace(source, target) + return context + + def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph): """Emit code with nested activation checkpoint When we detect some of the node.activation_checkpoint is a List, we will use @@ -905,8 +915,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - if node_idx in chunk_starts: - body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor') + body[-1] = _replace_name(body[-1], chunk_inputs[region_idx][0].name, 'chunk_tensor') body[-1] = ' ' + body[-1] delete_unused_value_func(node, body, chunk_inputs_names)