[autoparallel] support linear function bias addition (#2104)

This commit is contained in:
YuliangLiu0306
2022-12-09 10:31:36 +08:00
committed by GitHub
parent 6a71d3a0d9
commit d87baa85d9
5 changed files with 211 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
from .addbmm import Addbmm
from .addmm import Addmm
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
from .linear import Linear

View File

@@ -106,6 +106,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
func_to_func_dict = {
torch.addmm: F.linear,
torch.addbmm: torch.bmm,
F.linear: F.linear,
}
method_to_func_dict = {

View File

@@ -0,0 +1,25 @@
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_function.register(F.linear)
class Linear(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
assert 'bias' in self.kwargs
kwargs = {}
if 'bias' in self.kwargs:
kwargs['bias'] = self.kwargs['bias']
return kwargs
def generate(self):
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
kwargs = self.extract_kwargs_from_origin_func()
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
return bias_addition_proxy