mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||
)
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
__all___ = ["CheckpointSolverBase"]
|
||||
|
||||
|
||||
def _copy_output(src: Graph, dst: Graph):
|
||||
"""Copy the output node from src to dst"""
|
||||
for n_src, n_dst in zip(src.nodes, dst.nodes):
|
||||
if n_src.op == 'output':
|
||||
if n_src.op == "output":
|
||||
n_dst.meta = n_src.meta
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
|
||||
|
||||
|
||||
class CheckpointSolverBase(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def solve(self):
|
||||
"""Solve the checkpointing problem and return the solution.
|
||||
"""
|
||||
pass
|
||||
"""Solve the checkpointing problem and return the solution."""
|
||||
|
||||
def get_node_list(self):
|
||||
"""Get the node list.
|
||||
"""
|
||||
"""Get the node list."""
|
||||
return [[node] for node in self.graph.nodes]
|
||||
|
||||
def _linearize_graph(self) -> List[List[Node]]:
|
||||
@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
|
||||
"""
|
||||
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
"""Get the inplace argument from ``torch.fx.Node``"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
|
||||
return inplace
|
||||
|
||||
def _is_shape_consistency(n: Node):
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
||||
"""
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
|
||||
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
||||
map(_is_shape_consistency, n.users))
|
||||
return (
|
||||
not sum([v for _, v in deps.items()])
|
||||
and not any(map(_is_inplace, n.users))
|
||||
and not any(map(_is_shape_consistency, n.users))
|
||||
)
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
for name in self.cnode:
|
||||
try:
|
||||
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
|
||||
f"Common node {name} is not an input of the model."
|
||||
assert (
|
||||
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
|
||||
), f"Common node {name} is not an input of the model."
|
||||
except StopIteration:
|
||||
raise ValueError(f"Common node name {name} not in graph.")
|
||||
|
||||
@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
|
||||
region = []
|
||||
|
||||
# propagate common node attr if possible
|
||||
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
|
||||
]) or _is_cop(n.target):
|
||||
if len(n.all_input_nodes) == len(
|
||||
[node for node in n.all_input_nodes if node.name in self.cnode]
|
||||
) or _is_cop(n.target):
|
||||
self.cnode.append(n.name)
|
||||
else:
|
||||
deps[n] = len([user for user in n.users if user.op != "output"])
|
||||
|
Reference in New Issue
Block a user