[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -1,14 +1,16 @@
import copy
import torch
import torch.nn.functional as F
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.utils import free_port
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
@@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
node.meta['activation_offload'] = [0, True, False]
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
node.meta['activation_offload'] = [1, True, True]
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
node.meta['activation_offload'] = [2, False, True]
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
node.meta['activation_checkpoint'] = [0]
node.meta['activation_offload'] = True
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()