mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[fx] add rules to linearize computation graphs for searching. (#1461)
* [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Callable
|
||||
import copy
|
||||
import re
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
@@ -7,7 +8,7 @@ from torch.fx import GraphModule
|
||||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||
from colossalai.fx.passes.algorithms import chen_greedy
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
@@ -20,7 +21,7 @@ except:
|
||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
SOLVERS = [chen_greedy, chen_sqrtn]
|
||||
SOLVERS = [chen_greedy]
|
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||
@@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
||||
return True
|
||||
|
||||
|
||||
def _is_graph_linearized(gm: GraphModule):
|
||||
code = gm.code
|
||||
# find patterns like r' return output_1, output_2', which is not expected on a linearized graph
|
||||
pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+')
|
||||
if pattern.findall(code):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||
model_cls: Callable[[], torch.nn.Module]):
|
||||
criterion = torch.nn.MSELoss()
|
||||
@@ -66,12 +77,13 @@ def _run_ckpt_solver(rank):
|
||||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
check_backward_consistency(m, gm, solver, model_cls)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
@@ -94,12 +106,13 @@ def _run_ckpt_solver_torch11(rank):
|
||||
MetaInfoProp(gm).run(data)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
check_backward_consistency(m, gm, solver, model_cls)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
def test_ckpt_solver_torch11():
|
||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||
|
Reference in New Issue
Block a user