mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
from .ckpt_solver_base import CheckpointSolverBase
|
||||
|
||||
__all__ = ['CheckpointSolverChen']
|
||||
__all__ = ["CheckpointSolverChen"]
|
||||
|
||||
|
||||
class CheckpointSolverChen(CheckpointSolverBase):
|
||||
|
||||
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
||||
Returns:
|
||||
graph (Graph): The optimized graph, should be a copy of the original graph.
|
||||
"""
|
||||
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
|
||||
ckpt = self.grid_search()
|
||||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
nodes = self.node_list[idx]
|
||||
for n in nodes:
|
||||
if n.op in checkpointable_op:
|
||||
n.meta['activation_checkpoint'] = i
|
||||
n.meta["activation_checkpoint"] = i
|
||||
return deepcopy(self.graph)
|
||||
|
||||
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
|
||||
|
Reference in New Issue
Block a user