[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)

* [autoparallel] first move.

* [autoparallel] add solver rotor.

* [autoparallel] add ckpt solvers.

* [autoparallel] modify codegen.

* [fx] fix annotation in test.

* [fx] remove check.

* [autoparallel] polish docstring.

* [fx] refactor MetaTensor.
This commit is contained in:
Super Daniel
2022-11-01 10:43:15 +08:00
committed by GitHub
parent 2b859502d5
commit 1e88811c7a
16 changed files with 1025 additions and 119 deletions

View File

@@ -1,9 +1,10 @@
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from torch.utils.checkpoint import checkpoint
from colossalai.fx import ColoTracer
class MLP(torch.nn.Module):
@@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation():
for node in gm.graph.nodes:
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 0
assert node.meta.get('activation_checkpoint', -1) == 0
for node in gm.graph.nodes:
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
assert getattr(node, 'activation_checkpoint', -1) == 1
assert node.meta.get('activation_checkpoint', -1) == 1
tracer = ColoTracer(trace_act_ckpt=False)
graph = tracer.trace(module)