[autoparallel] support linear function bias addition (#2104)

This commit is contained in:
YuliangLiu0306
2022-12-09 10:31:36 +08:00
committed by GitHub
parent 6a71d3a0d9
commit d87baa85d9
5 changed files with 211 additions and 2 deletions

View File

@@ -102,8 +102,13 @@ class ColoTracer(Tracer):
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]