mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
|
||||
@@ -44,11 +45,11 @@ def test_activation_checkpoint_annotation():
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_1_linear1', 'mlp_1_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 0
|
||||
assert node.meta.get('activation_checkpoint', -1) == 0
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.name in ['mlp_2_linear1', 'mlp_2_linear2']:
|
||||
assert getattr(node, 'activation_checkpoint', -1) == 1
|
||||
assert node.meta.get('activation_checkpoint', -1) == 1
|
||||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
graph = tracer.trace(module)
|
||||
|
Reference in New Issue
Block a user