mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[fx] Add activation checkpoint solver rotor (#1496)
* [fx] fix defining ckpt functions inside forward * [fx] Modify activation checkpoint codegen and add ColoGraphModule * [fx] some modification * some modifications * some modifications * some modifications * some modifications * some code modifications * [automatic_parallel] ckpt solver rotor * [fx] add ckpt_solver_rotor * [fx] modification * code refactor * code refactor
This commit is contained in:
@@ -9,7 +9,7 @@ import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
@@ -22,7 +22,7 @@ except:
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
SOLVERS = [chen_greedy]
|
||||
SOLVERS = [chen_greedy, solver_rotor]
|
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||
@@ -77,7 +77,10 @@ def _run_ckpt_solver(rank):
|
||||
MetaInfoProp(gm).run(data)
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
gm = solver(gm)
|
||||
if solver == solver_rotor:
|
||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
|
||||
else:
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
@@ -106,7 +109,10 @@ def _run_ckpt_solver_torch11(rank):
|
||||
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
gm = solver(gm)
|
||||
if solver == solver_rotor:
|
||||
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
|
||||
else:
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
|
Reference in New Issue
Block a user