[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase']
__all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
if n_src.op == "output":
n_dst.meta = n_src.meta
@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
"""Solve the checkpointing problem and return the solution."""
def get_node_list(self):
"""Get the node list.
"""
"""Get the node list."""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
"""Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
return (
not sum([v for _, v in deps.items()])
and not any(map(_is_inplace, n.users))
and not any(map(_is_shape_consistency, n.users))
)
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
assert (
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])