[autoparallel] support addbmm computation (#2102)

This commit is contained in:
YuliangLiu0306
2022-12-08 21:15:11 +08:00
committed by GitHub
parent d3d4630495
commit 0fecbb9e20
7 changed files with 179 additions and 65 deletions

View File

@@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
from ..proxy import ColoProxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from .registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
__all__ = ['ColoTracer']
@@ -100,12 +106,14 @@ class ColoTracer(Tracer):
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_function.has(method):
handle = bias_addition_function.get(method)(self, target, args, kwargs)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):