[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,3 +1,3 @@
from .fx import get_topology as get_fx_topology
__all__ = ['get_fx_topology']
__all__ = ["get_fx_topology"]

View File

@@ -10,7 +10,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
elif is_output:
partition_id = 1
else:
prefix = 'submod_'
prefix = "submod_"
partition_id = int(partition_name.split(prefix)[-1]) + 2
return partition_id
@@ -27,10 +27,10 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False):
def find_input_in_partition(node, partitions, input_partitions=None):
p_input_val = None
direct_def = not node.name.startswith('getitem')
direct_def = not node.name.startswith("getitem")
# search in input
if direct_def and input_partitions is not None:
partition_id = partition_name_to_id('', is_input=True)
partition_id = partition_name_to_id("", is_input=True)
for i, input_node in enumerate(input_partitions):
if input_node == node:
p_input_val = PartitionInputVal(partition_id=partition_id, offset=i)
@@ -57,7 +57,7 @@ def find_input_in_partition(node, partitions, input_partitions=None):
def find_output_in_partition(node, partitions, output_partitions=None):
p_output_val = PartitionOutputVal()
for user in node.users:
direct_use = not user.name.startswith('getitem')
direct_use = not user.name.startswith("getitem")
# user is mid partition
for partition in partitions:
# direct call
@@ -82,7 +82,7 @@ def find_output_in_partition(node, partitions, output_partitions=None):
output_node = output_partitions[0]
if user.op == output_node.op:
output_keys = {}
partition_id = partition_name_to_id('', is_output=True)
partition_id = partition_name_to_id("", is_output=True)
torch.fx.graph.map_arg(output_node.args[0], lambda n: output_keys.setdefault(n))
for i, arg in enumerate(output_keys):
if arg == node:
@@ -99,11 +99,11 @@ def get_topology(gm: GraphModule):
partitions = []
output_partitions = []
for node in gm.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
input_partitions.append(node)
elif node.name.startswith('submod_'):
elif node.name.startswith("submod_"):
partitions.append(node)
elif node.op == 'output':
elif node.op == "output":
output_partitions.append(node)
else:
continue
@@ -127,7 +127,7 @@ def get_topology(gm: GraphModule):
# set output for submodule
direct_use = True
for user in partition.users:
if user.name.startswith('getitem'):
if user.name.startswith("getitem"):
direct_use = False
break
if direct_use:
@@ -146,7 +146,8 @@ def get_topology(gm: GraphModule):
topo_output_partition = Partition()
torch.fx.graph.map_arg(
partition.args[0],
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)))
lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions)),
)
topo.set_partitions(partition_id=1, partition=topo_output_partition)
topo.set_output_partition_id(partition_id=1)