From 5cdfcfe1d168e39d39a741112c036fa1455f0d06 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 12 Dec 2022 17:29:07 +0800 Subject: [PATCH] code style --- chunk_codegen.py | 49 ++++-------------------------------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 3bea84fae..96dcbfc0f 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -92,24 +92,10 @@ class FlowTracer(object): self._add_trace(i.name) self._add_node(i.name, i) - def _is_non_compute_node(self, node): - if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - - def _is_non_compute_node_except_placeholder(self, node): - if any(i in node.op for i in ["get_attr", "output"]) or any( - i in node.name for i in ["getitem", "getattr"] - ): - return True - return False - def _find_flow_for_node(self, node): if type(self.node_list[0]) != type(node): return None - if self._is_non_compute_node_except_placeholder(node): + if _is_non_compute_node_except_placeholder(node): return None for name, trace in self.flow_trace.items(): for i in trace: @@ -135,7 +121,7 @@ class FlowTracer(object): raise RuntimeError("invalid node") def _get_flow_mix_node(self, node): - if self._is_non_compute_node(node): + if _is_non_compute_node(node): return None _, node_trace = self.find_node_flow(node) if len(node_trace["outside_depend"]) == 0: @@ -160,10 +146,9 @@ class FlowTracer(object): for node in self.node_list: # skip if non compute node if all( - type(arg) != type(node) - or self._is_non_compute_node_except_placeholder(arg) + type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg) for arg in node.args - ) or self._is_non_compute_node(node): + ) or _is_non_compute_node(node): continue node_input_flows = [self._find_flow_for_node(arg) for arg in node.args] @@ -1411,32 +1396,6 @@ def _gen_loop_end( return context -def _find_input_and_output_nodes(nodes: List[Node]): - """ - Find the input and output node names which are not found in the given list of nodes. - """ - input_nodes = [] - output_nodes = [] - - # if a node has an input node which is not in the node list - # we treat that input node as the input of the checkpoint function - for node in nodes: - for input_node in node._input_nodes.keys(): - node_repr = repr(input_node) - if input_node not in nodes and input_node not in input_nodes: - input_nodes.append(input_node) - - # if a node has a user node which is not in the node list - # we treat that user node as the node receiving the current node output - for node in nodes: - for output_node in node.users.keys(): - node_repr = repr(node) - if output_node not in nodes and output_node not in output_nodes: - output_nodes.append(output_node) - - return input_nodes, output_nodes - - def _find_chunk_all_input_nodes(nodes: List[Node]): """ Find non-compute input and output node names.