mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[autoparallel] support addbmm computation (#2102)
This commit is contained in:
@@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
|
||||
|
||||
from ..proxy import ColoProxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
|
||||
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
|
||||
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
|
||||
from .registry import (
|
||||
bias_addition_function,
|
||||
bias_addition_method,
|
||||
bias_addition_module,
|
||||
meta_patched_function,
|
||||
meta_patched_module,
|
||||
)
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
@@ -100,12 +106,14 @@ class ColoTracer(Tracer):
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_function.has(method):
|
||||
handle = bias_addition_function.get(method)(self, target, args, kwargs)
|
||||
if bias_addition_method.has(method):
|
||||
function_to_substitute = method_to_func_dict[method]
|
||||
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
|
||||
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
|
Reference in New Issue
Block a user