[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -2,11 +2,13 @@ import copy
import re
from typing import Callable
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
@@ -14,7 +16,6 @@ from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from torch.fx import GraphModule
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -94,6 +95,7 @@ def _run_ckpt_solver(rank):
gpc.destroy()
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)

View File

@@ -1,14 +1,15 @@
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from torch.utils.checkpoint import checkpoint
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank):
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
assert 'activation_checkpoint' in node.meta
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
node.meta['activation_offload'] = True
gm = ColoGraphModule(model, graph)
gm.recompile()
@@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank):
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
assert 'activation_checkpoint' in node.meta
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
node.meta['activation_offload'] = True
gm = ColoGraphModule(model, graph)
gm.recompile()

View File

@@ -1,14 +1,15 @@
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from torch.utils.checkpoint import checkpoint
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank):
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
node.meta['activation_checkpoint'] = [0, 0, 0]
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
node.meta['activation_checkpoint'] = [0, 0, None]
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
node.meta['activation_checkpoint'] = [0, 0, 1]
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
node.meta['activation_checkpoint'] = [0, 1, None]
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
node.meta['activation_checkpoint'] = 1
gm = ColoGraphModule(model, graph)
gm.recompile()
@@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank):
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
node.meta['activation_checkpoint'] = [0, 0, 0]
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
node.meta['activation_checkpoint'] = [0, 0, None]
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
node.meta['activation_checkpoint'] = [0, 0, 1]
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
node.meta['activation_checkpoint'] = [0, 1, None]
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
node.meta['activation_checkpoint'] = 1
gm = ColoGraphModule(model, graph)
gm.recompile()

View File

@@ -1,14 +1,16 @@
import copy
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
@@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()

View File

@@ -1,9 +1,10 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from torch.utils.checkpoint import checkpoint
from colossalai.fx import ColoTracer
class MLP(torch.nn.Module):
@@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation():
for node in gm.graph.nodes:
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 0
assert node.meta.get('activation_checkpoint', -1) == 0
for node in gm.graph.nodes:
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 1
assert node.meta.get('activation_checkpoint', -1) == 1
tracer = ColoTracer(trace_act_ckpt=False)
graph = tracer.trace(module)