[fx] Add activation checkpoint solver rotor (#1496)

* [fx] fix defining ckpt functions inside forward

* [fx] Modify activation checkpoint codegen and add ColoGraphModule

* [fx] some modification

* some modifications

* some modifications

* some modifications

* some modifications

* some code modifications

* [automatic_parallel] ckpt solver rotor

* [fx] add ckpt_solver_rotor

* [fx] modification

* code refactor

* code refactor
This commit is contained in:
Boyuan Yao
2022-08-26 10:34:21 +08:00
committed by GitHub
parent 09c023bee2
commit de1e716dc4
6 changed files with 529 additions and 4 deletions

View File

@@ -9,7 +9,7 @@ import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
@@ -22,7 +22,7 @@ except:
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
SOLVERS = [chen_greedy]
SOLVERS = [chen_greedy, solver_rotor]
def _is_activation_checkpoint_available(gm: GraphModule):
@@ -77,7 +77,10 @@ def _run_ckpt_solver(rank):
MetaInfoProp(gm).run(data)
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
gm = solver(gm)
if solver == solver_rotor:
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
else:
gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
@@ -106,7 +109,10 @@ def _run_ckpt_solver_torch11(rank):
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm)
if solver == solver_rotor:
gm = solver(gm, data, mem_limit=500 * 1024 * 1024, mem_slots=500)
else:
gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"