mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[autoparallel] move ckpt solvers to autoparallel folder / refactor code (#1764)
* [autoparallel] first move. * [autoparallel] add solver rotor. * [autoparallel] add ckpt solvers. * [autoparallel] modify codegen. * [fx] fix annotation in test. * [fx] remove check. * [autoparallel] polish docstring. * [fx] refactor MetaTensor.
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -92,11 +93,11 @@ def _run_act_ckpt_codegen(rank):
|
||||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
@@ -148,11 +149,11 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
offload_starts = ['mlp1_linear1']
|
||||
for node in graph.nodes:
|
||||
if node.name in ckpt_nodes:
|
||||
assert hasattr(node, 'activation_checkpoint')
|
||||
assert 'activation_checkpoint' in node.meta
|
||||
|
||||
# annotate the selected node for offload
|
||||
if node.name in offload_starts:
|
||||
setattr(node, 'activation_offload', True)
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -57,16 +58,16 @@ def _run_act_ckpt_codegen(rank):
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
@@ -114,16 +115,16 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||
# annotate nested checkpoint
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 0])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 0]
|
||||
continue
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_checkpoint", [0, 0, None])
|
||||
node.meta['activation_checkpoint'] = [0, 0, None]
|
||||
if node.name == "linear3":
|
||||
setattr(node, "activation_checkpoint", [0, 0, 1])
|
||||
node.meta['activation_checkpoint'] = [0, 0, 1]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_checkpoint", [0, 1, None])
|
||||
node.meta['activation_checkpoint'] = [0, 1, None]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", 1)
|
||||
node.meta['activation_checkpoint'] = 1
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
|
@@ -1,14 +1,16 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx import ColoTracer
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.utils import free_port
|
||||
|
||||
try:
|
||||
from colossalai.fx.codegen import ActivationCheckpointCodeGen
|
||||
@@ -83,16 +85,16 @@ def _run_offload_codegen(rank):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
@@ -138,16 +140,16 @@ def _run_offload_codegen_torch11(rank):
|
||||
# of input offload
|
||||
for node in graph.nodes:
|
||||
if node.name == "linear0":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear1":
|
||||
setattr(node, "activation_offload", [0, True, False])
|
||||
node.meta['activation_offload'] = [0, True, False]
|
||||
if node.name == "linear2":
|
||||
setattr(node, "activation_offload", [1, True, True])
|
||||
node.meta['activation_offload'] = [1, True, True]
|
||||
if node.name == "linear4":
|
||||
setattr(node, "activation_offload", [2, False, True])
|
||||
node.meta['activation_offload'] = [2, False, True]
|
||||
if node.name == "linear5":
|
||||
setattr(node, "activation_checkpoint", [0])
|
||||
setattr(node, "activation_offload", True)
|
||||
node.meta['activation_checkpoint'] = [0]
|
||||
node.meta['activation_offload'] = True
|
||||
|
||||
gm = ColoGraphModule(copy.deepcopy(model), graph)
|
||||
gm.recompile()
|
||||
|
Reference in New Issue
Block a user