mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
@@ -10,7 +10,6 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
class NodeMgr(object):
|
||||
|
||||
def __init__(self, nodes_list: List[Node]) -> None:
|
||||
self._node_list = nodes_list
|
||||
self._node_dict = {}
|
||||
@@ -174,16 +173,22 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List,
|
||||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (input_node not in nodes and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)):
|
||||
if (
|
||||
input_node not in nodes
|
||||
and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)
|
||||
):
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if (output_node not in nodes and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)):
|
||||
if (
|
||||
output_node not in nodes
|
||||
and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)
|
||||
):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
@@ -238,7 +243,10 @@ def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
||||
for node in node_list:
|
||||
if get_node_shape(node) is not None:
|
||||
out.append(node)
|
||||
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
||||
node.meta['fwd_out'][0], int):
|
||||
elif (
|
||||
len(node.meta["fwd_out"]) > 0
|
||||
and isinstance(node.meta["fwd_out"], list)
|
||||
and isinstance(node.meta["fwd_out"][0], int)
|
||||
):
|
||||
out.append(node)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user