mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[fx] allow native ckpt trace and codegen. (#2438)
This commit is contained in:
parent
41429b9b28
commit
c41e59e5ad
@ -1,17 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.modules.module import _addindent
|
from torch.nn.modules.module import _addindent
|
||||||
from typing import Type, Dict, List, Any, Union, Optional, Set
|
|
||||||
from pathlib import Path
|
|
||||||
try:
|
try:
|
||||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
|
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
|
||||||
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
|
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
|
||||||
|
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||||
COLOGM = True
|
COLOGM = True
|
||||||
except:
|
except:
|
||||||
from torch.fx.graph_module import GraphModule
|
|
||||||
from torch.fx.graph import Graph
|
from torch.fx.graph import Graph
|
||||||
|
from torch.fx.graph_module import GraphModule
|
||||||
COLOGM = False
|
COLOGM = False
|
||||||
|
|
||||||
if COLOGM:
|
if COLOGM:
|
||||||
@ -19,6 +23,7 @@ if COLOGM:
|
|||||||
class ColoGraphModule(GraphModule):
|
class ColoGraphModule(GraphModule):
|
||||||
|
|
||||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||||
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||||
super().__init__(root, graph, class_name)
|
super().__init__(root, graph, class_name)
|
||||||
|
|
||||||
def bind(self, ckpt_def, globals):
|
def bind(self, ckpt_def, globals):
|
||||||
|
@ -13,6 +13,7 @@ def symbolic_trace(
|
|||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||||||
meta_args: Optional[Dict[str, Any]] = None,
|
meta_args: Optional[Dict[str, Any]] = None,
|
||||||
|
trace_act_ckpt=False,
|
||||||
) -> ColoGraphModule:
|
) -> ColoGraphModule:
|
||||||
"""
|
"""
|
||||||
Symbolic tracing API
|
Symbolic tracing API
|
||||||
@ -49,6 +50,6 @@ def symbolic_trace(
|
|||||||
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
||||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
return ColoGraphModule(root, graph, name)
|
return ColoGraphModule(root, graph, name)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import operator
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import operator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
@ -286,7 +286,6 @@ class ColoTracer(Tracer):
|
|||||||
self.graph.lint()
|
self.graph.lint()
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def trace_activation_checkpoint(self, enabled: bool):
|
def trace_activation_checkpoint(self, enabled: bool):
|
||||||
if enabled:
|
if enabled:
|
||||||
@ -316,7 +315,6 @@ class ColoTracer(Tracer):
|
|||||||
# recover the checkpoint function upon exit
|
# recover the checkpoint function upon exit
|
||||||
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
|
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
|
||||||
|
|
||||||
|
|
||||||
def _post_check(self, non_concrete_arg_names: Set[str]):
|
def _post_check(self, non_concrete_arg_names: Set[str]):
|
||||||
# This is necessary because concrete args are added as input to the traced module since
|
# This is necessary because concrete args are added as input to the traced module since
|
||||||
# https://github.com/pytorch/pytorch/pull/55888.
|
# https://github.com/pytorch/pytorch/pull/55888.
|
||||||
@ -385,18 +383,23 @@ def symbolic_trace(
|
|||||||
root: Union[torch.nn.Module, Callable[..., Any]],
|
root: Union[torch.nn.Module, Callable[..., Any]],
|
||||||
concrete_args: Optional[Dict[str, Any]] = None,
|
concrete_args: Optional[Dict[str, Any]] = None,
|
||||||
meta_args: Optional[Dict[str, Any]] = None,
|
meta_args: Optional[Dict[str, Any]] = None,
|
||||||
|
trace_act_ckpt=False,
|
||||||
) -> ColoGraphModule:
|
) -> ColoGraphModule:
|
||||||
if is_compatible_with_meta():
|
if is_compatible_with_meta():
|
||||||
if meta_args is not None:
|
if meta_args is not None:
|
||||||
root.to(default_device())
|
root.to(default_device())
|
||||||
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
|
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
|
||||||
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
|
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
|
||||||
|
concrete_args=concrete_args,
|
||||||
|
meta_args=tree_map(wrap_fn, meta_args))
|
||||||
root.cpu()
|
root.cpu()
|
||||||
else:
|
else:
|
||||||
graph = Tracer().trace(root, concrete_args=concrete_args)
|
graph = Tracer().trace(root, concrete_args=concrete_args)
|
||||||
else:
|
else:
|
||||||
from .tracer import ColoTracer as OrigColoTracer
|
from .tracer import ColoTracer as OrigColoTracer
|
||||||
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
|
graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
|
||||||
|
concrete_args=concrete_args,
|
||||||
|
meta_args=meta_args)
|
||||||
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
|
||||||
return ColoGraphModule(root, graph, name)
|
return ColoGraphModule(root, graph, name)
|
||||||
|
|
||||||
@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
|
|||||||
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
|
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
|
||||||
node.kwargs)
|
node.kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
|
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
|
||||||
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
||||||
if kind == 'placeholder':
|
if kind == 'placeholder':
|
||||||
meta_out = meta_args[target] if target in meta_args else concrete_args.get(
|
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
|
||||||
_truncate_suffix(target), None)
|
|
||||||
elif kind == 'get_attr':
|
elif kind == 'get_attr':
|
||||||
attr_itr = root
|
attr_itr = root
|
||||||
atoms = target.split(".")
|
atoms = target.split(".")
|
||||||
@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
|
|||||||
else:
|
else:
|
||||||
if target not in _TensorPropertyMethod:
|
if target not in _TensorPropertyMethod:
|
||||||
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
||||||
**tree_map(unwrap_fn, kwargs))
|
**tree_map(unwrap_fn, kwargs))
|
||||||
elif kind == 'call_module':
|
elif kind == 'call_module':
|
||||||
mod = root.get_submodule(target)
|
mod = root.get_submodule(target)
|
||||||
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||||
@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
|
|||||||
meta_out = None
|
meta_out = None
|
||||||
return meta_out
|
return meta_out
|
||||||
|
|
||||||
|
|
||||||
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
||||||
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
|
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
|
||||||
meta_out = meta_args[target]
|
meta_out = meta_args[target]
|
||||||
@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
|
|||||||
return meta_out
|
return meta_out
|
||||||
|
|
||||||
|
|
||||||
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
|
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
|
||||||
result_graph = Graph()
|
result_graph = Graph()
|
||||||
value_remap = {}
|
value_remap = {}
|
||||||
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
|
||||||
@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
|
|||||||
if target == torch.nn.functional.linear:
|
if target == torch.nn.functional.linear:
|
||||||
if 'bias' in kwargs and kwargs['bias'] is not None:
|
if 'bias' in kwargs and kwargs['bias'] is not None:
|
||||||
function_to_substitute = func_to_func_dict[target]
|
function_to_substitute = func_to_func_dict[target]
|
||||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
|
||||||
|
function_to_substitute)
|
||||||
else:
|
else:
|
||||||
function_to_substitute = func_to_func_dict[target]
|
function_to_substitute = func_to_func_dict[target]
|
||||||
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
|
||||||
|
function_to_substitute)
|
||||||
elif bias_addition_function.has(target.__name__):
|
elif bias_addition_function.has(target.__name__):
|
||||||
# use name for some builtin op like @ (matmul)
|
# use name for some builtin op like @ (matmul)
|
||||||
function_to_substitute = func_to_func_dict[target]
|
function_to_substitute = func_to_func_dict[target]
|
||||||
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
|
||||||
|
function_to_substitute)
|
||||||
|
|
||||||
elif kind == "call_method":
|
elif kind == "call_method":
|
||||||
method = getattr(args_metas[0].__class__, target)
|
method = getattr(args_metas[0].__class__, target)
|
||||||
if bias_addition_method.has(method):
|
if bias_addition_method.has(method):
|
||||||
function_to_substitute = method_to_func_dict[method]
|
function_to_substitute = method_to_func_dict[method]
|
||||||
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
|
||||||
|
function_to_substitute)
|
||||||
|
|
||||||
elif kind == "call_module":
|
elif kind == "call_module":
|
||||||
# if not hasattr(self, "orig_forward"):
|
# if not hasattr(self, "orig_forward"):
|
||||||
@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
|
|||||||
mod_type = type(mod)
|
mod_type = type(mod)
|
||||||
if bias_addition_module.has(mod_type) and mod.bias is not None:
|
if bias_addition_module.has(mod_type) and mod.bias is not None:
|
||||||
function_to_substitute = module_to_func_dict[mod_type]
|
function_to_substitute = module_to_func_dict[mod_type]
|
||||||
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
|
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
|
||||||
|
function_to_substitute)
|
||||||
|
|
||||||
if handle is not None:
|
if handle is not None:
|
||||||
handle.generate()
|
handle.generate()
|
||||||
for node_inserted in tracer.graph.nodes:
|
for node_inserted in tracer.graph.nodes:
|
||||||
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
|
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
|
||||||
last_node = value_remap[node_inserted]
|
last_node = value_remap[node_inserted]
|
||||||
value_remap[orig_node] = last_node
|
value_remap[orig_node] = last_node
|
||||||
else:
|
else:
|
||||||
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
|
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
|
||||||
|
|
||||||
del tracer
|
del tracer
|
||||||
|
|
||||||
gm.graph = result_graph
|
gm.graph = result_graph
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
meta_prop_pass(gm, root_model, meta_args)
|
meta_prop_pass(gm, root_model, meta_args)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user