mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from typing import Dict
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
import torch
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
|
||||
def get_comm_size(prev_partition, next_partition):
|
||||
"""
|
||||
@@ -23,7 +25,7 @@ def get_comm_size(prev_partition, next_partition):
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
for n in input_nodes:
|
||||
if n.name in parent_node_names and n not in visited_nodes:
|
||||
comm_size += n.meta['tensor_meta'].numel
|
||||
comm_size += n.meta["tensor_meta"].numel
|
||||
visited_nodes.add(n)
|
||||
return comm_size
|
||||
|
||||
@@ -36,12 +38,12 @@ def get_leaf(graph: Graph):
|
||||
"""
|
||||
input_nodes: Dict[Node, None] = {}
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
map_arg(node.args, lambda n: input_nodes.setdefault(n))
|
||||
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
|
||||
placeholder_nodes = []
|
||||
for node in input_nodes.keys():
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
placeholder_nodes.append(node)
|
||||
for node in placeholder_nodes:
|
||||
input_nodes.pop(node)
|
||||
@@ -60,13 +62,13 @@ def get_top(graph: Graph):
|
||||
"""
|
||||
top_node_list = set()
|
||||
for node in graph.nodes:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
continue
|
||||
is_top = False
|
||||
|
||||
def _get_top(node):
|
||||
nonlocal is_top
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
is_top = True
|
||||
|
||||
map_arg(node.args, lambda n: _get_top(n))
|
||||
@@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node):
|
||||
def get_all_consumers(graph: Graph, node: Node):
|
||||
"""
|
||||
Given a graph and a node of this graph, return all consumers of the node.
|
||||
|
||||
|
||||
Returns:
|
||||
List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``.
|
||||
"""
|
||||
@@ -120,7 +122,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
||||
for node in gm.graph.nodes:
|
||||
if hasattr(node, 'bfs_level'):
|
||||
print(node.name, node.bfs_level)
|
||||
|
||||
|
||||
Output:
|
||||
graph():
|
||||
%x : [#users=2] = placeholder[target=x]
|
||||
@@ -148,7 +150,7 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
||||
while nodes_to_process:
|
||||
new_process_list = []
|
||||
for node in nodes_to_process:
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
continue
|
||||
node.bfs_level = current_level
|
||||
new_process_list.extend(get_all_consumers(graph, node))
|
||||
@@ -165,8 +167,9 @@ def get_node_module(node) -> torch.nn.Module:
|
||||
torch.nn.Module: the module associated with the given node
|
||||
"""
|
||||
|
||||
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
|
||||
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
|
||||
assert (
|
||||
node.graph.owning_module is not None
|
||||
), "Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object"
|
||||
assert node.op == "call_module", f"Expected node.op to be call_module, but found {node.op}"
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
return module
|
||||
|
||||
|
Reference in New Issue
Block a user