mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -2,11 +2,13 @@ import copy
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
@@ -14,7 +16,6 @@ from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from torch.fx import GraphModule
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -94,6 +95,7 @@ def _run_ckpt_solver(rank):
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
|
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank):
|
||||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
@@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank):
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
@@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
@@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
@@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation():
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 0
|
||||
assert node.meta.get('activation_checkpoint', -1) == 0
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 1
|
||||
assert node.meta.get('activation_checkpoint', -1) == 1
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
graph = tracer.trace(module)
|
||||
|
Reference in New Issue
Block a user