[fx] allow native ckpt trace and codegen. (#2438)

This commit is contained in:
Super Daniel 2023-01-11 13:49:59 +08:00 committed by GitHub
parent 41429b9b28
commit c41e59e5ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 23 deletions

View File

@ -1,17 +1,21 @@
import os import os
import warnings import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.module import _addindent from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path
try: try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True COLOGM = True
except: except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False COLOGM = False
if COLOGM: if COLOGM:
@ -19,6 +23,7 @@ if COLOGM:
class ColoGraphModule(GraphModule): class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name) super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals): def bind(self, ckpt_def, globals):

View File

@ -13,6 +13,7 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None, meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule: ) -> ColoGraphModule:
""" """
Symbolic tracing API Symbolic tracing API
@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
""" """
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name) return ColoGraphModule(root, graph, name)

View File

@ -1,7 +1,7 @@
import enum import enum
import functools import functools
import operator
import inspect import inspect
import operator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@ -286,7 +286,6 @@ class ColoTracer(Tracer):
self.graph.lint() self.graph.lint()
return self.graph return self.graph
@contextmanager @contextmanager
def trace_activation_checkpoint(self, enabled: bool): def trace_activation_checkpoint(self, enabled: bool):
if enabled: if enabled:
@ -316,7 +315,6 @@ class ColoTracer(Tracer):
# recover the checkpoint function upon exit # recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def _post_check(self, non_concrete_arg_names: Set[str]): def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888. # https://github.com/pytorch/pytorch/pull/55888.
@ -385,18 +383,23 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None, meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule: ) -> ColoGraphModule:
if is_compatible_with_meta(): if is_compatible_with_meta():
if meta_args is not None: if meta_args is not None:
root.to(default_device()) root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)) graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
root.cpu() root.cpu()
else: else:
graph = Tracer().trace(root, concrete_args=concrete_args) graph = Tracer().trace(root, concrete_args=concrete_args)
else: else:
from .tracer import ColoTracer as OrigColoTracer from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name) return ColoGraphModule(root, graph, name)
@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs) node.kwargs)
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder': if kind == 'placeholder':
meta_out = meta_args[target] if target in meta_args else concrete_args.get( meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
_truncate_suffix(target), None)
elif kind == 'get_attr': elif kind == 'get_attr':
attr_itr = root attr_itr = root
atoms = target.split(".") atoms = target.split(".")
@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
else: else:
if target not in _TensorPropertyMethod: if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs)) **tree_map(unwrap_fn, kwargs))
elif kind == 'call_module': elif kind == 'call_module':
mod = root.get_submodule(target) mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
meta_out = None meta_out = None
return meta_out return meta_out
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta: if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target] meta_out = meta_args[target]
@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
return meta_out return meta_out
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None): def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
result_graph = Graph() result_graph = Graph()
value_remap = {} value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
if target == torch.nn.functional.linear: if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None: if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
else: else:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif bias_addition_function.has(target.__name__): elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul) # use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif kind == "call_method": elif kind == "call_method":
method = getattr(args_metas[0].__class__, target) method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method): if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method] function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif kind == "call_module": elif kind == "call_module":
# if not hasattr(self, "orig_forward"): # if not hasattr(self, "orig_forward"):
@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
mod_type = type(mod) mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None: if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type] function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
if handle is not None: if handle is not None:
handle.generate() handle.generate()
for node_inserted in tracer.graph.nodes: for node_inserted in tracer.graph.nodes:
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n]) value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
last_node = value_remap[node_inserted] last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node value_remap[orig_node] = last_node
else: else:
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n]) value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
del tracer del tracer
gm.graph = result_graph gm.graph = result_graph
gm.recompile() gm.recompile()
meta_prop_pass(gm, root_model, meta_args) meta_prop_pass(gm, root_model, meta_args)