mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[fx] allow native ckpt trace and codegen. (#2438)
This commit is contained in:
@@ -1,17 +1,21 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.module import _addindent
|
||||
from typing import Type, Dict, List, Any, Union, Optional, Set
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
|
||||
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
|
||||
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
COLOGM = True
|
||||
except:
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
COLOGM = False
|
||||
|
||||
if COLOGM:
|
||||
@@ -19,6 +23,7 @@ if COLOGM:
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
|
Reference in New Issue
Block a user