From c41e59e5adc27d08b17234eada91ebcb3d876b23 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Wed, 11 Jan 2023 13:49:59 +0800 Subject: [PATCH] [fx] allow native ckpt trace and codegen. (#2438) --- colossalai/fx/graph_module.py | 15 ++++++--- colossalai/fx/tracer/_symbolic_trace.py | 3 +- colossalai/fx/tracer/experimental.py | 42 +++++++++++++++---------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index fbafd326c..2d6a71f19 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -1,17 +1,21 @@ import os import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Type, Union + import torch import torch.nn as nn from torch.nn.modules.module import _addindent -from typing import Type, Dict, List, Any, Union, Optional, Set -from pathlib import Path + try: - from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src - from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode + from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen + from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall + + from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen COLOGM = True except: - from torch.fx.graph_module import GraphModule from torch.fx.graph import Graph + from torch.fx.graph_module import GraphModule COLOGM = False if COLOGM: @@ -19,6 +23,7 @@ if COLOGM: class ColoGraphModule(GraphModule): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): + graph.set_codegen(ActivationCheckpointCodeGen()) super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py index bff2f6a10..5c04eeace 100644 --- a/colossalai/fx/tracer/_symbolic_trace.py +++ b/colossalai/fx/tracer/_symbolic_trace.py @@ -13,6 +13,7 @@ def symbolic_trace( root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, meta_args: Optional[Dict[str, Any]] = None, + trace_act_ckpt=False, ) -> ColoGraphModule: """ Symbolic tracing API @@ -49,6 +50,6 @@ def symbolic_trace( This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. """ - graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return ColoGraphModule(root, graph, name) diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py index 6fee5f5d0..88b65b618 100644 --- a/colossalai/fx/tracer/experimental.py +++ b/colossalai/fx/tracer/experimental.py @@ -1,7 +1,7 @@ import enum import functools -import operator import inspect +import operator from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -286,7 +286,6 @@ class ColoTracer(Tracer): self.graph.lint() return self.graph - @contextmanager def trace_activation_checkpoint(self, enabled: bool): if enabled: @@ -316,7 +315,6 @@ class ColoTracer(Tracer): # recover the checkpoint function upon exit torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func - def _post_check(self, non_concrete_arg_names: Set[str]): # This is necessary because concrete args are added as input to the traced module since # https://github.com/pytorch/pytorch/pull/55888. @@ -385,18 +383,23 @@ def symbolic_trace( root: Union[torch.nn.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None, meta_args: Optional[Dict[str, Any]] = None, + trace_act_ckpt=False, ) -> ColoGraphModule: if is_compatible_with_meta(): if meta_args is not None: root.to(default_device()) wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x - graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)) + graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, + concrete_args=concrete_args, + meta_args=tree_map(wrap_fn, meta_args)) root.cpu() else: graph = Tracer().trace(root, concrete_args=concrete_args) else: from .tracer import ColoTracer as OrigColoTracer - graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, + concrete_args=concrete_args, + meta_args=meta_args) name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ return ColoGraphModule(root, graph, name) @@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule, node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs) + def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n if kind == 'placeholder': - meta_out = meta_args[target] if target in meta_args else concrete_args.get( - _truncate_suffix(target), None) + meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None) elif kind == 'get_attr': attr_itr = root atoms = target.split(".") @@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa else: if target not in _TensorPropertyMethod: meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), - **tree_map(unwrap_fn, kwargs)) + **tree_map(unwrap_fn, kwargs)) elif kind == 'call_module': mod = root.get_submodule(target) meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) @@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa meta_out = None return meta_out + def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): if kind == "placeholder" and target in meta_args and meta_args[target].is_meta: meta_out = meta_args[target] @@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): return meta_out -def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None): +def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None): result_graph = Graph() value_remap = {} unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n @@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar if target == torch.nn.functional.linear: if 'bias' in kwargs and kwargs['bias'] is not None: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) else: function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) elif bias_addition_function.has(target.__name__): # use name for some builtin op like @ (matmul) function_to_substitute = func_to_func_dict[target] - handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) elif kind == "call_method": method = getattr(args_metas[0].__class__, target) if bias_addition_method.has(method): function_to_substitute = method_to_func_dict[method] - handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) elif kind == "call_module": # if not hasattr(self, "orig_forward"): @@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar mod_type = type(mod) if bias_addition_module.has(mod_type) and mod.bias is not None: function_to_substitute = module_to_func_dict[mod_type] - handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) + handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, + function_to_substitute) if handle is not None: handle.generate() for node_inserted in tracer.graph.nodes: - value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n]) + value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n]) last_node = value_remap[node_inserted] value_remap[orig_node] = last_node else: - value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n]) + value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n]) del tracer gm.graph = result_graph gm.recompile() meta_prop_pass(gm, root_model, meta_args) -