mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -22,7 +22,7 @@ from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_a
|
||||
import colossalai
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
|
||||
_register_custom_builtin('colossalai', 'import colossalai', colossalai)
|
||||
_register_custom_builtin("colossalai", "import colossalai", colossalai)
|
||||
|
||||
|
||||
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
|
||||
@@ -43,17 +43,17 @@ def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
|
||||
"""
|
||||
Generate the checkpoint function call code text
|
||||
"""
|
||||
outputs = ', '.join(output_vars)
|
||||
inputs = ', '.join(input_vars)
|
||||
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
|
||||
outputs = ", ".join(output_vars)
|
||||
inputs = ", ".join(input_vars)
|
||||
return f"{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})"
|
||||
|
||||
|
||||
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
||||
"""
|
||||
Check if the node could end the ckpt region at `ckpt_level`
|
||||
"""
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
|
||||
if len(node.meta["info"].activation_checkpoint) > ckpt_level:
|
||||
return node.meta["info"].activation_checkpoint[ckpt_level] is not None
|
||||
return True
|
||||
|
||||
|
||||
@@ -94,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||
if len(node.meta["info"].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta["info"].activation_checkpoint[ckpt_level]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
@@ -131,13 +131,9 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
return ckpt_regions
|
||||
|
||||
|
||||
def emit_ckpt_func(body,
|
||||
ckpt_func,
|
||||
node_list: List[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
ckpt_level=0,
|
||||
in_ckpt=False):
|
||||
def emit_ckpt_func(
|
||||
body, ckpt_func, node_list: List[Node], emit_node_func, delete_unused_value_func, ckpt_level=0, in_ckpt=False
|
||||
):
|
||||
"""Emit ckpt function in nested way
|
||||
|
||||
Args:
|
||||
@@ -156,12 +152,12 @@ def emit_ckpt_func(body,
|
||||
|
||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta["info"].activation_checkpoint[: ckpt_level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
ckpt_func.append(f"{ckpt_fn_def}\n")
|
||||
|
||||
# if there is more level to fetch
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta["info"].activation_checkpoint), node_list)):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
@@ -174,33 +170,40 @@ def emit_ckpt_func(body,
|
||||
break
|
||||
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
|
||||
ckpt_level + 1, True)
|
||||
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(
|
||||
ckpt_func,
|
||||
ckpt_func_buffer,
|
||||
ckpt_node_list,
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
ckpt_level + 1,
|
||||
True,
|
||||
)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
else:
|
||||
node = node_list[node_idx]
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
ckpt_func[-1] = " " + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
node_idx += 1
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
|
||||
ckpt_func += ckpt_func_buffer
|
||||
|
||||
# last level
|
||||
else:
|
||||
for node in node_list:
|
||||
emit_node_func(node, ckpt_func)
|
||||
ckpt_func[-1] = ' ' + ckpt_func[-1]
|
||||
ckpt_func[-1] = " " + ckpt_func[-1]
|
||||
delete_unused_value_func(node, ckpt_func)
|
||||
|
||||
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
|
||||
ckpt_func.append(" " + _gen_ckpt_output(outputs) + "\n\n")
|
||||
|
||||
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
|
||||
usage = _gen_ckpt_usage(label, inputs, outputs, False) + "\n"
|
||||
if in_ckpt:
|
||||
usage = ' ' + usage
|
||||
usage = " " + usage
|
||||
body.append(usage)
|
||||
|
||||
|
||||
@@ -229,7 +232,7 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
|
||||
# process ckpt_regions
|
||||
if node_idx in start_idx:
|
||||
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
|
||||
ckpt_node_list = node_list[node_idx : end_idx[start_idx.index(node_idx)] + 1]
|
||||
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
|
||||
node_idx += len(ckpt_node_list)
|
||||
|
||||
@@ -243,7 +246,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class ActivationCheckpointCodeGen(CodeGen):
|
||||
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
@@ -251,7 +253,7 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
wrapped_fns: Dict[str, None] = {}
|
||||
|
||||
# Wrap string in list to pass by reference
|
||||
maybe_return_annotation: List[str] = ['']
|
||||
maybe_return_annotation: List[str] = [""]
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
@@ -259,7 +261,7 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
Graph, like functions or types.
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
@@ -281,16 +283,16 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
def type_repr(o: Any):
|
||||
if o == ():
|
||||
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
||||
return '()'
|
||||
return "()"
|
||||
|
||||
typename = _type_repr(o)
|
||||
|
||||
if hasattr(o, '__origin__'):
|
||||
if hasattr(o, "__origin__"):
|
||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||
|
||||
if hasattr(o, '__args__'):
|
||||
if hasattr(o, "__args__"):
|
||||
# Assign global names for each of the inner type variables.
|
||||
args = [type_repr(arg) for arg in o.__args__]
|
||||
|
||||
@@ -309,19 +311,18 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
return repr(arg)
|
||||
|
||||
args_s = ', '.join(_get_repr(a) for a in args)
|
||||
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
|
||||
args_s = ", ".join(_get_repr(a) for a in args)
|
||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f'{args_s}, {kwargs_s}'
|
||||
return f"{args_s}, {kwargs_s}"
|
||||
return args_s or kwargs_s
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
@@ -347,82 +348,94 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
of the code is optimal.
|
||||
"""
|
||||
if user.op == 'placeholder':
|
||||
if user.op == "placeholder":
|
||||
return
|
||||
if user.op == 'output':
|
||||
body.append('\n')
|
||||
if user.op == "output":
|
||||
body.append("\n")
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
|
||||
body.append(f'; {to_delete_str}\n')
|
||||
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
|
||||
body.append(f"; {to_delete_str}\n")
|
||||
else:
|
||||
body.append('\n')
|
||||
body.append("\n")
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
|
||||
if node.op == 'placeholder':
|
||||
maybe_type_annotation = "" if node.type is None else f" : {type_repr(node.type)}"
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
|
||||
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
|
||||
raw_name = node.target.replace('*', '')
|
||||
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}"
|
||||
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
body.append(f'{repr(node)} = {raw_name}\n')
|
||||
body.append(f"{repr(node)} = {raw_name}\n")
|
||||
return
|
||||
elif node.op == 'call_method':
|
||||
elif node.op == "call_method":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
|
||||
f'({_format_args(node.args[1:], node.kwargs)})')
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == 'call_function':
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in magic_methods:
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||
)
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
|
||||
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
|
||||
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
|
||||
if node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods:
|
||||
body.append(
|
||||
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||
)
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if global_name == 'getattr' and \
|
||||
isinstance(node.args, tuple) and \
|
||||
isinstance(node.args[1], str) and \
|
||||
node.args[1].isidentifier() and \
|
||||
len(node.args) == 2:
|
||||
if (
|
||||
global_name == "getattr"
|
||||
and isinstance(node.args, tuple)
|
||||
and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier()
|
||||
and len(node.args) == 2
|
||||
):
|
||||
body.append(
|
||||
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||
)
|
||||
return
|
||||
body.append(
|
||||
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
|
||||
if node.meta.get('is_wrapped', False):
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == 'call_module':
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = '
|
||||
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == 'get_attr':
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
|
||||
return
|
||||
elif node.op == 'output':
|
||||
elif node.op == "output":
|
||||
if node.type is not None:
|
||||
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||
body.append(self.generate_output(node.args[0]))
|
||||
return
|
||||
raise NotImplementedError(f'node: {node.op} {node.target}')
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
@@ -432,13 +445,13 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
# have been emitted. To continue to have valid Python code, emit a
|
||||
# single pass statement
|
||||
body.append('pass\n')
|
||||
body.append("pass\n")
|
||||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global('wrap', torch.fx.wrap)
|
||||
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
else:
|
||||
wrap_stmts = ''
|
||||
wrap_stmts = ""
|
||||
|
||||
if self._body_transformer:
|
||||
body = self._body_transformer(body)
|
||||
@@ -447,11 +460,11 @@ class ActivationCheckpointCodeGen(CodeGen):
|
||||
add_global(name, value)
|
||||
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = ''.join(ckpt_func) + prologue
|
||||
prologue = "".join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = ''.join(body)
|
||||
code = '\n'.join(' ' + line for line in code.split('\n'))
|
||||
code = "".join(body)
|
||||
code = "\n".join(" " + line for line in code.split("\n"))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
{prologue}
|
||||
|
@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
|
||||
|
||||
try:
|
||||
from torch.fx.graph import _PyTreeCodeGen
|
||||
|
||||
SUPPORT_PT_CODEGEN = True
|
||||
except ImportError:
|
||||
SUPPORT_PT_CODEGEN = False
|
||||
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
|
||||
# This is a copy of torch.fx.graph_module._WrappedCall.
|
||||
# It should be removed when we stop supporting torch < 1.12.0.
|
||||
class _WrappedCall:
|
||||
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
self.cls_call = cls_call
|
||||
@@ -50,12 +50,14 @@ class _WrappedCall:
|
||||
|
||||
# constituent substrings of the error message
|
||||
tb_repr = traceback.format_exc()
|
||||
custom_msg = ("Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:")
|
||||
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
|
||||
custom_msg = (
|
||||
"Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:"
|
||||
)
|
||||
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
|
||||
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
|
||||
|
||||
# joined message
|
||||
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||||
@@ -65,11 +67,14 @@ class _WrappedCall:
|
||||
if self.cls_call is not None:
|
||||
return self.cls_call(obj, *args, **kwargs)
|
||||
else:
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
assert e.__traceback__
|
||||
topmost_framesummary: traceback.FrameSummary = \
|
||||
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
|
||||
topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
|
||||
traceback.walk_tb(e.__traceback__)
|
||||
)[
|
||||
-1
|
||||
] # type: ignore[arg-type]
|
||||
if "eval_with_key" in topmost_framesummary.filename:
|
||||
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
|
||||
raise e.with_traceback(None)
|
||||
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: torch.fx.Graph,
|
||||
class_name: str = 'GraphModule'):
|
||||
def __init__(
|
||||
self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
|
||||
):
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
|
||||
# To split ckpt functions code and forward code
|
||||
@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||||
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||||
|
||||
if '_wrapped_call' not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
if "_wrapped_call" not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
return self._wrapped_call(self, *args, **kwargs)
|
||||
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
"""
|
||||
folder = Path(folder)
|
||||
Path(folder).mkdir(exist_ok=True)
|
||||
torch.save(self.state_dict(), folder / 'state_dict.pt')
|
||||
torch.save(self.state_dict(), folder / "state_dict.pt")
|
||||
tab = " " * 4
|
||||
|
||||
# we add import colossalai here
|
||||
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
|
||||
for module_name, module in self.named_children():
|
||||
module_str = _gen_model_repr(module_name, module)
|
||||
if module_str is None:
|
||||
module_file = folder / f'{module_name}.pt'
|
||||
module_file = folder / f"{module_name}.pt"
|
||||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
|
||||
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
|
||||
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
module_file = folder / 'module.py'
|
||||
module_file = folder / "module.py"
|
||||
module_file.write_text(model_str)
|
||||
|
||||
init_file = folder / '__init__.py'
|
||||
init_file.write_text('from .module import *')
|
||||
init_file = folder / "__init__.py"
|
||||
init_file.write_text("from .module import *")
|
||||
|
||||
if len(blobified_modules) > 0:
|
||||
warnings.warn("Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}")
|
||||
warnings.warn(
|
||||
"Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}"
|
||||
)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
from torch.fx import Node
|
||||
|
||||
from colossalai._analyzer.envs import MeshConfig
|
||||
|
||||
@@ -85,12 +85,12 @@ class MetaInfo:
|
||||
node: Node
|
||||
|
||||
# directory
|
||||
mod_dir: str = ''
|
||||
mod_dir: str = ""
|
||||
|
||||
# ctx[data_ptr] = Tensor
|
||||
# mark the storage for ctx.save_for_backward
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
|
||||
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
|
||||
|
||||
# should be updated after each graph manipulation
|
||||
# ============================== Update ====================================
|
||||
@@ -100,7 +100,7 @@ class MetaInfo:
|
||||
|
||||
inputs: Tuple[torch.Tensor] = ()
|
||||
outputs: Tuple[torch.Tensor] = ()
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
is_alias: Tuple[bool] = () # whether the output is an alias of input
|
||||
|
||||
# compute cost
|
||||
fwd_flop: Optional[int] = 0
|
||||
@@ -112,29 +112,29 @@ class MetaInfo:
|
||||
|
||||
# should keep the same whenever manipulated
|
||||
# ============================= Invariant ==================================
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
|
||||
to_offload: Optional[bool] = False
|
||||
sharding_spec: str = 'RR'
|
||||
sharding_spec: str = "RR"
|
||||
|
||||
def __new__(cls, node: Node, **kwargs):
|
||||
orig_init = cls.__init__
|
||||
|
||||
# if initialized, return the existing one
|
||||
# should disable the __init__ function
|
||||
if node.meta.get('info', None) is not None:
|
||||
if node.meta.get("info", None) is not None:
|
||||
|
||||
def _dummy(self, *args, **kwargs):
|
||||
if getattr(self, '_is_init', False):
|
||||
if getattr(self, "_is_init", False):
|
||||
self._is_init = True
|
||||
orig_init(self, *args, **kwargs)
|
||||
cls.__init__ = orig_init
|
||||
|
||||
cls.__init__ = _dummy
|
||||
return node.meta['info']
|
||||
return node.meta["info"]
|
||||
return super().__new__(cls)
|
||||
|
||||
def __post_init__(self):
|
||||
self.node.meta['info'] = self
|
||||
self.node.meta["info"] = self
|
||||
|
||||
@property
|
||||
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
|
||||
@@ -188,24 +188,26 @@ class MetaInfo:
|
||||
return compute_size_in_bytes(self.inputs)
|
||||
|
||||
def __repr__(self):
|
||||
s = f'Node {self.node.name}'
|
||||
s = f"Node {self.node.name}"
|
||||
if self.parameters:
|
||||
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
|
||||
s += f"\n\thas parameter of size {_format_memory(self.param_size)}"
|
||||
if self.buffers:
|
||||
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
|
||||
s += f"\n\thas buffer of size {_format_memory(self.buffer_size)}"
|
||||
if self.output_size:
|
||||
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
|
||||
s += f"\n\thas output activation of size {_format_memory(self.output_size)}"
|
||||
# if self.total_size:
|
||||
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
|
||||
if self.temp_size:
|
||||
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
|
||||
s += f"\n\thas temp activation of size {_format_memory(self.temp_size)}"
|
||||
if self.backward_size:
|
||||
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
|
||||
s += f'\n\tfwd_flop = {self.fwd_flop}'\
|
||||
f'\n\tbwd_flop = {self.bwd_flop}'\
|
||||
f'\n\tfwd_comm = {self.fwd_comm}'\
|
||||
f'\n\tbwd_comm = {self.bwd_comm}'\
|
||||
f'\n\tto_recompute = {self.to_recompute}'\
|
||||
f'\n\tto_offload = {self.to_offload}'\
|
||||
f'\n\tsharding_spec = {self.sharding_spec}'
|
||||
s += f"\n\thas backward activation of size {_format_memory(self.backward_size)}"
|
||||
s += (
|
||||
f"\n\tfwd_flop = {self.fwd_flop}"
|
||||
f"\n\tbwd_flop = {self.bwd_flop}"
|
||||
f"\n\tfwd_comm = {self.fwd_comm}"
|
||||
f"\n\tbwd_comm = {self.bwd_comm}"
|
||||
f"\n\tto_recompute = {self.to_recompute}"
|
||||
f"\n\tto_offload = {self.to_offload}"
|
||||
f"\n\tsharding_spec = {self.sharding_spec}"
|
||||
)
|
||||
return s
|
||||
|
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.autograd.profiler_util import _format_memory, _format_time
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
|
||||
@@ -13,14 +13,14 @@ from colossalai._analyzer.fx.node_util import MetaInfo
|
||||
def _format_flops(flops: float) -> str:
|
||||
"""Returns a formatted FLOP size string"""
|
||||
if flops > 1e12:
|
||||
return f'{flops / 1e12:.2f} TFLOPs'
|
||||
return f"{flops / 1e12:.2f} TFLOPs"
|
||||
elif flops > 1e9:
|
||||
return f'{flops / 1e9:.2f} GFLOPs'
|
||||
return f"{flops / 1e9:.2f} GFLOPs"
|
||||
elif flops > 1e6:
|
||||
return f'{flops / 1e6:.2f} MFLOPs'
|
||||
return f"{flops / 1e6:.2f} MFLOPs"
|
||||
elif flops > 1e3:
|
||||
return f'{flops / 1e3:.2f} kFLOPs'
|
||||
return f'{flops} FLOPs'
|
||||
return f"{flops / 1e3:.2f} kFLOPs"
|
||||
return f"{flops} FLOPs"
|
||||
|
||||
|
||||
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -42,10 +42,11 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
Fetch shape argument from ``ShapeProp`` without re-executing
|
||||
the ``GraphModule`` from scratch.
|
||||
"""
|
||||
|
||||
_profileable = [
|
||||
'call_function',
|
||||
'call_module',
|
||||
'call_method',
|
||||
"call_function",
|
||||
"call_module",
|
||||
"call_method",
|
||||
]
|
||||
|
||||
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
|
||||
@@ -77,14 +78,13 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
self.args_iter: Iterator[Any] = iter(args)
|
||||
|
||||
for node in self.module.graph.nodes:
|
||||
|
||||
self.run_node(node) # No need to store.
|
||||
self.run_node(node) # No need to store.
|
||||
|
||||
if self.garbage_collect_values:
|
||||
for to_delete in self.user_to_last_uses.get(node, []):
|
||||
del self.env[to_delete]
|
||||
|
||||
if node.op == 'output':
|
||||
if node.op == "output":
|
||||
output_val = self.env[node]
|
||||
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
|
||||
|
||||
@@ -133,9 +133,11 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError:
|
||||
print("`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library.")
|
||||
print(
|
||||
"`summary` relies on the library `tabulate`, "
|
||||
"which could not be found on this machine. Run `pip "
|
||||
"install tabulate` to install the library."
|
||||
)
|
||||
|
||||
# Build up a list of summary information for each node
|
||||
node_summaries: List[List[Any]] = []
|
||||
@@ -145,36 +147,38 @@ class GraphProfiler(torch.fx.Interpreter):
|
||||
node: Node
|
||||
n_info = MetaInfo(node)
|
||||
last_n_info = last_n_info or n_info
|
||||
node_summaries.append([
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
])
|
||||
node_summaries.append(
|
||||
[
|
||||
node.op,
|
||||
str(node),
|
||||
_format_memory(n_info.accumulate_size),
|
||||
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
|
||||
_format_memory(n_info.output_size),
|
||||
_format_memory(n_info.temp_size),
|
||||
_format_memory(n_info.param_size),
|
||||
_format_memory(n_info.backward_size),
|
||||
_format_flops(n_info.fwd_flop),
|
||||
_format_flops(n_info.bwd_flop),
|
||||
]
|
||||
)
|
||||
last_n_info = n_info
|
||||
|
||||
# Use the ``tabulate`` library to create a well-formatted table
|
||||
# presenting our summary information
|
||||
headers: List[str] = [
|
||||
'Op type',
|
||||
'Op',
|
||||
'Accumulate size',
|
||||
'Incremental size',
|
||||
'Output size',
|
||||
'Temp size',
|
||||
'Param size',
|
||||
'Backward size',
|
||||
'Fwd FLOPs',
|
||||
'Bwd FLOPs',
|
||||
"Op type",
|
||||
"Op",
|
||||
"Accumulate size",
|
||||
"Incremental size",
|
||||
"Output size",
|
||||
"Temp size",
|
||||
"Param size",
|
||||
"Backward size",
|
||||
"Fwd FLOPs",
|
||||
"Bwd FLOPs",
|
||||
]
|
||||
|
||||
return tabulate(node_summaries, headers=headers, stralign='right')
|
||||
return tabulate(node_summaries, headers=headers, stralign="right")
|
||||
|
||||
|
||||
class CommunicationProfiler(GraphProfiler):
|
||||
@@ -222,6 +226,7 @@ class FlopProfiler(GraphProfiler):
|
||||
>>> def my_fn_flop_count_impl(*args, **kwargs):
|
||||
>>> return 0, 0
|
||||
"""
|
||||
|
||||
_custom_flop_count_impl = {}
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
@@ -246,11 +251,13 @@ class FlopProfiler(GraphProfiler):
|
||||
(
|
||||
n_info.fwd_flop,
|
||||
n_info.bwd_flop,
|
||||
) = getattr(self, n.op)(n.target, args, kwargs)
|
||||
) = getattr(
|
||||
self, n.op
|
||||
)(n.target, args, kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
|
||||
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
|
||||
f"Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. "
|
||||
f"Please refer to function's docstring to register the relevant profile_impl for this node!"
|
||||
) from e
|
||||
|
||||
# retain the autograd graph
|
||||
@@ -259,7 +266,7 @@ class FlopProfiler(GraphProfiler):
|
||||
|
||||
return _denormalize_tuple(n_info.outputs)
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the profiling result.
|
||||
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
|
||||
@@ -283,7 +290,7 @@ class FlopProfiler(GraphProfiler):
|
||||
else:
|
||||
return flop_count(target, *args, **kwargs)
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the profiling result.
|
||||
|
||||
@@ -301,7 +308,7 @@ class FlopProfiler(GraphProfiler):
|
||||
assert isinstance(target, str)
|
||||
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
|
||||
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node and return the profiling result.
|
||||
|
||||
@@ -336,9 +343,10 @@ def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule
|
||||
Returns:
|
||||
GraphModule: The same GraphModule with profiling information
|
||||
"""
|
||||
for profiler_cls in (FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
for profiler_cls in (
|
||||
FlopProfiler,
|
||||
# CommunicationProfiler, # TODO: add communication profiling
|
||||
):
|
||||
profiler = profiler_cls(module)
|
||||
profiler.propagate(*args, device=_current_device(module))
|
||||
|
||||
|
@@ -54,7 +54,7 @@ def _current_device(module):
|
||||
try:
|
||||
return next(module.parameters()).device
|
||||
except StopIteration:
|
||||
return torch.device('cpu')
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -90,6 +90,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
>>> # do something here
|
||||
>>> return torch.empty(output_shape, device=output_device)
|
||||
"""
|
||||
|
||||
_custom_dispatch_func = {}
|
||||
_mode = MetaTensorMode()
|
||||
|
||||
@@ -115,15 +116,14 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
r = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
def unwrap_fn(elem):
|
||||
|
||||
def _convert_meta(t: torch.Tensor):
|
||||
if t.device == 'meta':
|
||||
if t.device == "meta":
|
||||
return t
|
||||
else:
|
||||
return t.to('meta')
|
||||
return t.to("meta")
|
||||
|
||||
if isinstance(elem, MetaTensor):
|
||||
if getattr(self, '_is_param', False):
|
||||
if getattr(self, "_is_param", False):
|
||||
return torch.nn.Parameter(_convert_meta(elem._tensor))
|
||||
return _convert_meta(elem._tensor)
|
||||
|
||||
@@ -139,21 +139,24 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info = MetaInfo(n)
|
||||
n_info.outputs = _normalize_tuple(r)
|
||||
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
submod = self.fetch_attr(n.target)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
|
||||
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
|
||||
|
||||
else:
|
||||
n_info.parameters.update({
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
})
|
||||
n_info.parameters.update(
|
||||
{
|
||||
k.name: MetaTensor(v)
|
||||
for k, v in zip(n.args, args)
|
||||
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
|
||||
}
|
||||
)
|
||||
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
|
||||
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
|
||||
tuple(v for v in kwargs.values() if is_pure_tensor(v))
|
||||
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + tuple(
|
||||
v for v in kwargs.values() if is_pure_tensor(v)
|
||||
)
|
||||
|
||||
# align with SPMD
|
||||
if isinstance(r, (tuple, list)):
|
||||
@@ -168,7 +171,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
|
||||
return r
|
||||
|
||||
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_function(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node and return the result.
|
||||
If the target of ``Node`` is registered with ``@register_shape_impl``,
|
||||
@@ -197,7 +200,7 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
else:
|
||||
return res
|
||||
|
||||
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
def call_method(self, target: "Target", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node and return the result.
|
||||
|
||||
@@ -218,7 +221,8 @@ class ShapeProp(torch.fx.Interpreter):
|
||||
|
||||
convert_to_parameter = False
|
||||
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
|
||||
args[0], torch.nn.parameter.Parameter):
|
||||
args[0], torch.nn.parameter.Parameter
|
||||
):
|
||||
convert_to_parameter = True
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
|
@@ -1,5 +1,3 @@
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
|
||||
@@ -7,7 +5,6 @@ from .passes.graph_profile import FlopProfiler
|
||||
|
||||
|
||||
def register_flop_count_impl(func):
|
||||
|
||||
def wrapper(impl):
|
||||
FlopProfiler._custom_flop_count_impl[func] = impl
|
||||
return impl
|
||||
@@ -16,7 +13,6 @@ def register_flop_count_impl(func):
|
||||
|
||||
|
||||
def register_shape_impl(func):
|
||||
|
||||
def wrapper(impl):
|
||||
ShapeProp._custom_dispatch_func[func] = impl
|
||||
return impl
|
||||
|
@@ -12,7 +12,7 @@ from .tracer import register_tracer_impl
|
||||
__all__ = []
|
||||
|
||||
|
||||
@register_tracer_impl(F.linear, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.linear, name="_bias_addition_impl")
|
||||
def linear_impl(input, weight, bias=None):
|
||||
if bias is None:
|
||||
return F.linear(input, weight)
|
||||
@@ -20,116 +20,130 @@ def linear_impl(input, weight, bias=None):
|
||||
return F.linear(input, weight) + bias
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.conv1d, name="_bias_addition_impl")
|
||||
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1))
|
||||
(-1, 1)
|
||||
)
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.conv2d, name="_bias_addition_impl")
|
||||
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1))
|
||||
(-1, 1, 1)
|
||||
)
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
|
||||
@register_tracer_impl(F.conv3d, name="_bias_addition_impl")
|
||||
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
|
||||
if bias is None:
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
else:
|
||||
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
|
||||
(-1, 1, 1, 1))
|
||||
(-1, 1, 1, 1)
|
||||
)
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
|
||||
def conv_transpose1d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1)):
|
||||
@register_tracer_impl(F.conv_transpose1d, name="_bias_addition_impl")
|
||||
def conv_transpose1d_impl(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_single(1),
|
||||
padding=_single(0),
|
||||
output_padding=_single(0),
|
||||
groups=1,
|
||||
dilation=_single(1),
|
||||
):
|
||||
if bias is None:
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
return F.conv_transpose1d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
else:
|
||||
return F.conv_transpose1d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1))
|
||||
return F.conv_transpose1d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
) + bias.reshape((-1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
|
||||
def conv_transpose2d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_pair(1),
|
||||
padding=_pair(0),
|
||||
output_padding=_pair(0),
|
||||
groups=1,
|
||||
dilation=_pair(1)):
|
||||
@register_tracer_impl(F.conv_transpose2d, name="_bias_addition_impl")
|
||||
def conv_transpose2d_impl(
|
||||
input, weight, bias=None, stride=_pair(1), padding=_pair(0), output_padding=_pair(0), groups=1, dilation=_pair(1)
|
||||
):
|
||||
if bias is None:
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
return F.conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
else:
|
||||
return F.conv_transpose2d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1))
|
||||
return F.conv_transpose2d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
) + bias.reshape((-1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
|
||||
def conv_transpose3d_impl(input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1)):
|
||||
@register_tracer_impl(F.conv_transpose3d, name="_bias_addition_impl")
|
||||
def conv_transpose3d_impl(
|
||||
input,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=_triple(1),
|
||||
padding=_triple(0),
|
||||
output_padding=_triple(0),
|
||||
groups=1,
|
||||
dilation=_triple(1),
|
||||
):
|
||||
if bias is None:
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation)
|
||||
return F.conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
)
|
||||
else:
|
||||
return F.conv_transpose3d(input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
|
||||
return F.conv_transpose3d(
|
||||
input,
|
||||
weight,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
dilation=dilation,
|
||||
) + bias.reshape((-1, 1, 1, 1))
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.addmm, name="_bias_addition_impl")
|
||||
@register_tracer_impl(torch.Tensor.addmm, name="_bias_addition_impl")
|
||||
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
|
||||
if alpha != 1 and beta != 1:
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
|
||||
@@ -141,8 +155,8 @@ def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
|
||||
return F.linear(mat1, mat2.transpose(0, 1)) + input
|
||||
|
||||
|
||||
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
|
||||
@register_tracer_impl(torch.addbmm, name="_bias_addition_impl")
|
||||
@register_tracer_impl(torch.Tensor.addbmm, name="_bias_addition_impl")
|
||||
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
|
||||
if alpha != 1 and beta != 1:
|
||||
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
|
||||
|
@@ -4,6 +4,7 @@ from .tracer import register_leaf_module, register_leaf_module_impl
|
||||
|
||||
try:
|
||||
import apex
|
||||
|
||||
register_leaf_module(apex.normalization.FusedLayerNorm)
|
||||
register_leaf_module(apex.normalization.FusedRMSNorm)
|
||||
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
|
||||
|
@@ -1,10 +1,8 @@
|
||||
import operator
|
||||
from typing import Any, Callable, Dict, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import Graph, Node, Proxy, Tracer
|
||||
from torch.fx.graph import _Namespace
|
||||
from torch.fx import Node, Proxy
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
@@ -32,7 +30,7 @@ class ColoProxy(Proxy):
|
||||
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
if orig_method in cls._func_dispatch:
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
|
||||
proxy = impl(*args, **kwargs)
|
||||
cls._func_dispatch[orig_method] = impl
|
||||
return proxy
|
||||
@@ -72,7 +70,7 @@ class ColoProxy(Proxy):
|
||||
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
|
||||
proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
|
||||
proxy.meta_data = self._meta_data
|
||||
return proxy
|
||||
|
||||
@@ -89,7 +87,6 @@ class ColoProxy(Proxy):
|
||||
|
||||
|
||||
class ColoAttribute(ColoProxy):
|
||||
|
||||
def __init__(self, root, attr: str, data=None):
|
||||
self.root = root
|
||||
self.attr = attr
|
||||
@@ -102,11 +99,11 @@ class ColoAttribute(ColoProxy):
|
||||
# the node for attributes is added lazily, since most will just be method calls
|
||||
# which do not rely on the getitem call
|
||||
if self._node is None:
|
||||
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
|
||||
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
|
||||
return self._node
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
|
||||
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ColoAttribute({self.node.name}, attr={self.attr})"
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import Tracer
|
||||
@@ -8,6 +8,7 @@ from colossalai._analyzer._subclasses import MetaTensor
|
||||
|
||||
try:
|
||||
from ..codegen import ActivationCheckpointCodeGen
|
||||
|
||||
SUPPORT_ACTIVATION = True
|
||||
except:
|
||||
SUPPORT_ACTIVATION = False
|
||||
@@ -16,7 +17,7 @@ from .tracer import ColoTracer
|
||||
|
||||
|
||||
def _default_device():
|
||||
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
|
||||
def _current_device(module: torch.nn.Module):
|
||||
@@ -144,10 +145,9 @@ def symbolic_trace(
|
||||
if meta_args:
|
||||
device, orig_device = _default_device(), _current_device(root)
|
||||
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
|
||||
bias_addition_split=bias_addition_split).trace(root.to(device),
|
||||
concrete_args=concrete_args,
|
||||
meta_args=tree_map(wrap_fn, meta_args))
|
||||
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt, bias_addition_split=bias_addition_split).trace(
|
||||
root.to(device), concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
|
||||
)
|
||||
if trace_act_ckpt and SUPPORT_ACTIVATION:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
root.to(orig_device)
|
||||
|
@@ -20,11 +20,10 @@ def _truncate_suffix(s: str):
|
||||
import re
|
||||
|
||||
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
|
||||
return re.sub(r'_\d+$', '', s)
|
||||
return re.sub(r"_\d+$", "", s)
|
||||
|
||||
|
||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
|
||||
|
||||
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = "_custom_impl"):
|
||||
def wrapper(impl):
|
||||
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
|
||||
getattr(ColoTracer, name)[func] = impl
|
||||
@@ -34,7 +33,6 @@ def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custo
|
||||
|
||||
|
||||
def register_leaf_module_impl(module: nn.Module):
|
||||
|
||||
def wrapper(impl):
|
||||
ColoTracer._custom_leaf_module_impl[module] = impl
|
||||
return impl
|
||||
@@ -76,7 +74,7 @@ class ColoTracer(Tracer):
|
||||
self.ckpt_regions = []
|
||||
self.ckpt_idx = 0
|
||||
|
||||
self.mod_dir = ''
|
||||
self.mod_dir = ""
|
||||
|
||||
# whether the tracer should split the bias_add ops into two ops
|
||||
self.bias_addition_split = bias_addition_split
|
||||
@@ -87,35 +85,41 @@ class ColoTracer(Tracer):
|
||||
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
|
||||
return False
|
||||
# user can specify which modules are leaf modules and which are not
|
||||
return (type(m) not in self._custom_non_leaf_module
|
||||
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
|
||||
return type(m) not in self._custom_non_leaf_module and (
|
||||
type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)
|
||||
)
|
||||
|
||||
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any]) -> Any:
|
||||
def call_module(
|
||||
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
curr_dir = self.mod_dir
|
||||
self.mod_dir = 'self.' + self.path_of_module(m)
|
||||
self.mod_dir = "self." + self.path_of_module(m)
|
||||
rst = super().call_module(m, forward, args, kwargs)
|
||||
self.mod_dir = curr_dir
|
||||
return rst
|
||||
|
||||
def proxy(self, node: Node) -> 'ColoProxy':
|
||||
def proxy(self, node: Node) -> "ColoProxy":
|
||||
return ColoProxy(node, self)
|
||||
|
||||
def create_proxy(self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
|
||||
|
||||
def create_proxy(
|
||||
self,
|
||||
kind: str,
|
||||
target: Target,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
name: Optional[str] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
|
||||
):
|
||||
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
|
||||
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
|
||||
if kind == 'placeholder':
|
||||
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
|
||||
_truncate_suffix(target), None)
|
||||
elif kind == 'get_attr':
|
||||
if kind == "placeholder":
|
||||
proxy.meta_data = (
|
||||
self.meta_args[target]
|
||||
if target in self.meta_args
|
||||
else self.concrete_args.get(_truncate_suffix(target), None)
|
||||
)
|
||||
elif kind == "get_attr":
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
attr_itr = self.root
|
||||
@@ -125,20 +129,21 @@ class ColoTracer(Tracer):
|
||||
proxy.meta_data = attr_itr
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
elif kind == 'call_function':
|
||||
elif kind == "call_function":
|
||||
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
|
||||
elif kind == 'call_method':
|
||||
elif kind == "call_method":
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
if target == '__call__':
|
||||
if target == "__call__":
|
||||
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
|
||||
else:
|
||||
if target not in _TensorPropertyMethod:
|
||||
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
|
||||
**tree_map(unwrap_fn, kwargs))
|
||||
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
|
||||
*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
|
||||
)
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
elif kind == 'call_module':
|
||||
elif kind == "call_module":
|
||||
mod = self.root.get_submodule(target)
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
@@ -158,11 +163,12 @@ class ColoTracer(Tracer):
|
||||
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
|
||||
return node
|
||||
|
||||
def trace(self,
|
||||
root: torch.nn.Module,
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
|
||||
|
||||
def trace(
|
||||
self,
|
||||
root: torch.nn.Module,
|
||||
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
meta_args: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Graph:
|
||||
if meta_args is None:
|
||||
meta_args = {}
|
||||
|
||||
@@ -177,9 +183,7 @@ class ColoTracer(Tracer):
|
||||
non_concrete_arg_names = sig_names - concrete_arg_names
|
||||
# update concrete args with default values
|
||||
for k, v in sig.parameters.items():
|
||||
if k in sig_names - meta_arg_names and \
|
||||
k not in concrete_args and \
|
||||
v.default is not inspect.Parameter.empty:
|
||||
if k in sig_names - meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
|
||||
concrete_args[k] = v.default
|
||||
|
||||
def _check_arg_name_valid(names: Iterable[str]):
|
||||
@@ -194,9 +198,9 @@ class ColoTracer(Tracer):
|
||||
self.meta_args = meta_args
|
||||
|
||||
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
|
||||
self.mod_dir = 'self'
|
||||
self.mod_dir = "self"
|
||||
self.graph = super().trace(root, concrete_args=concrete_args)
|
||||
self.mod_dir = ''
|
||||
self.mod_dir = ""
|
||||
self.graph.lint()
|
||||
|
||||
for node in self.graph.nodes:
|
||||
@@ -266,17 +270,17 @@ class ColoTracer(Tracer):
|
||||
# override the torch factory functions to create a proxy when the method
|
||||
# is called during ``symbolic_trace()``.
|
||||
def wrap_factory_method(target):
|
||||
|
||||
@functools.wraps(target)
|
||||
def wrapper(*args, **kwargs):
|
||||
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
|
||||
isinstance(p, ColoProxy) for p in kwargs.values())
|
||||
isinstance(p, ColoProxy) for p in kwargs.values()
|
||||
)
|
||||
if is_proxy:
|
||||
# if the arg is a proxy, then need to record this function called on this proxy
|
||||
# e.g. torch.ones(size) where size is an input proxy
|
||||
self.disable_module_getattr = True
|
||||
try:
|
||||
proxy = self.create_proxy('call_function', target, args, kwargs)
|
||||
proxy = self.create_proxy("call_function", target, args, kwargs)
|
||||
finally:
|
||||
self.disable_module_getattr = False
|
||||
return proxy
|
||||
@@ -341,10 +345,13 @@ class ColoTracer(Tracer):
|
||||
if attr_val is p:
|
||||
if n not in parameter_proxy_cache:
|
||||
kwargs = {}
|
||||
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
|
||||
lambda node: ColoProxy(self, node, n, attr_val))
|
||||
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
|
||||
kwargs["proxy_factory_fn"] = (
|
||||
None
|
||||
if not self.param_shapes_constant
|
||||
else lambda node: ColoProxy(self, node, n, attr_val)
|
||||
)
|
||||
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
|
||||
parameter_proxy_cache[n] = val_proxy
|
||||
return parameter_proxy_cache[n]
|
||||
return None
|
||||
@@ -355,8 +362,9 @@ class ColoTracer(Tracer):
|
||||
return maybe_buffer_proxy
|
||||
|
||||
if isinstance(attr_val, torch.nn.Parameter):
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
|
||||
parameter_proxy_cache)
|
||||
maybe_parameter_proxy = maybe_get_proxy_for_attr(
|
||||
attr_val, self.root.named_parameters(), parameter_proxy_cache
|
||||
)
|
||||
if maybe_parameter_proxy is not None:
|
||||
return maybe_parameter_proxy
|
||||
|
||||
|
Reference in New Issue
Block a user