[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,7 +1,9 @@
import torch
from typing import Dict
from torch.fx.node import Node, map_arg
import torch
from torch.fx.graph import Graph
from torch.fx.node import Node, map_arg
def get_comm_size(prev_partition, next_partition):
"""
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
for n in input_nodes:
if n.name in parent_node_names and n not in visited_nodes:
comm_size += n.meta['tensor_meta'].numel
comm_size += n.meta["tensor_meta"].numel
visited_nodes.add(n)
return comm_size
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
"""
input_nodes: Dict[Node, None] = {}
for node in graph.nodes:
if node.op == 'output':
if node.op == "output":
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
placeholder_nodes = []
for node in input_nodes.keys():
if node.op == 'placeholder':
if node.op == "placeholder":
placeholder_nodes.append(node)
for node in placeholder_nodes:
input_nodes.pop(node)
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
"""
top_node_list = set()
for node in graph.nodes:
if node.op == 'output':
if node.op == "output":
continue
is_top = False
def _get_top(node):
nonlocal is_top
if node.op == 'placeholder':
if node.op == "placeholder":
is_top = True
map_arg(node.args, lambda n: _get_top(n))
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
def get_all_consumers(graph: Graph, node: Node):
"""
Given a graph and a node of this graph, return all consumers of the node.
Returns:
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
"""
@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
for node in gm.graph.nodes:
if hasattr(node, 'bfs_level'):
print(node.name, node.bfs_level)
Output:
graph():
%x : [#users=2] = placeholder[target=x]
@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
while nodes_to_process:
new_process_list = []
for node in nodes_to_process:
if node.op == 'output':
if node.op == "output":
continue
node.bfs_level = current_level
new_process_list.extend(get_all_consumers(graph, node))
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
torch.nn.Module: the module associated with the given node
"""
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
assert (
node.graph.owning_module is not None
), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
module = node.graph.owning_module.get_submodule(node.target)
return module