[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -1,14 +1,15 @@
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from torch.utils.checkpoint import checkpoint
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank):
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
assert 'activation_checkpoint' in node.meta
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
node.meta['activation_offload'] = True
gm = ColoGraphModule(model, graph)
gm.recompile()
@@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank):
offload_starts = ['mlp1_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
assert 'activation_checkpoint' in node.meta
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
node.meta['activation_offload'] = True
gm = ColoGraphModule(model, graph)
gm.recompile()

View File

@@ -1,14 +1,15 @@
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
from torch.utils.checkpoint import checkpoint
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank):
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
node.meta['activation_checkpoint'] = [0, 0, 0]
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
node.meta['activation_checkpoint'] = [0, 0, None]
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
node.meta['activation_checkpoint'] = [0, 0, 1]
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
node.meta['activation_checkpoint'] = [0, 1, None]
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
node.meta['activation_checkpoint'] = 1
gm = ColoGraphModule(model, graph)
gm.recompile()
@@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank):
# annotate nested checkpoint
for node in graph.nodes:
if node.name == "linear1":
setattr(node, "activation_checkpoint", [0, 0, 0])
node.meta['activation_checkpoint'] = [0, 0, 0]
continue
if node.name == "linear2":
setattr(node, "activation_checkpoint", [0, 0, None])
node.meta['activation_checkpoint'] = [0, 0, None]
if node.name == "linear3":
setattr(node, "activation_checkpoint", [0, 0, 1])
node.meta['activation_checkpoint'] = [0, 0, 1]
if node.name == "linear4":
setattr(node, "activation_checkpoint", [0, 1, None])
node.meta['activation_checkpoint'] = [0, 1, None]
if node.name == "linear5":
setattr(node, "activation_checkpoint", 1)
node.meta['activation_checkpoint'] = 1
gm = ColoGraphModule(model, graph)
gm.recompile()

View File

@@ -1,14 +1,16 @@
import copy
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
@@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()