[autoparallel] add bias addtion function class (#2098)

* [autoparallel] add bias addtion function class

* polish code

* polish
This commit is contained in:
YuliangLiu0306
2022-12-08 11:31:51 +08:00
committed by GitHub
parent 3af7e65dea
commit b175e6d58e
5 changed files with 216 additions and 33 deletions

View File

@@ -20,7 +20,7 @@ from torch.fx.proxy import ParameterProxy, Proxy
from ..proxy import ColoProxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .bias_addition_patch import module_to_func_dict
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
__all__ = ['ColoTracer']
@@ -96,7 +96,8 @@ class ColoTracer(Tracer):
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
handle = bias_addition_function.get(target)(self, target, args, kwargs)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)