mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[autoparallel] add bias addtion function class (#2098)
* [autoparallel] add bias addtion function class * polish code * polish
This commit is contained in:
@@ -20,7 +20,7 @@ from torch.fx.proxy import ParameterProxy, Proxy
|
||||
|
||||
from ..proxy import ColoProxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .bias_addition_patch import module_to_func_dict
|
||||
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
|
||||
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
@@ -96,7 +96,8 @@ class ColoTracer(Tracer):
|
||||
handle = None
|
||||
if kind == "call_function":
|
||||
if bias_addition_function.has(target):
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
|
Reference in New Issue
Block a user