[autoparallel] support linear function bias addition (#2104)

This commit is contained in:
YuliangLiu0306
2022-12-09 10:31:36 +08:00
committed by GitHub
parent 6a71d3a0d9
commit d87baa85d9
5 changed files with 211 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
from .addbmm import Addbmm
from .addmm import Addmm
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
from .linear import Linear

View File

@@ -106,6 +106,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
func_to_func_dict = {
torch.addmm: F.linear,
torch.addbmm: torch.bmm,
F.linear: F.linear,
}
method_to_func_dict = {

View File

@@ -0,0 +1,25 @@
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
assert 'bias' in self.kwargs
kwargs = {}
if 'bias' in self.kwargs:
kwargs['bias'] = self.kwargs['bias']
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
return bias_addition_proxy

View File

@@ -102,8 +102,13 @@ class ColoTracer(Tracer):
handle = None
if kind == "call_function":
if bias_addition_function.has(target):
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]