[fx] allow native ckpt trace and codegen. (#2438)

This commit is contained in:
Super Daniel
2023-01-11 13:49:59 +08:00
committed by GitHub
parent 41429b9b28
commit c41e59e5ad
3 changed files with 37 additions and 23 deletions

View File

@@ -1,17 +1,21 @@
import os
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path
try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True
except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False
if COLOGM:
@@ -19,6 +23,7 @@ if COLOGM:
class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):