mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[autoparallel] support linear function bias addition (#2104)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from .addbmm import Addbmm
|
||||
from .addmm import Addmm
|
||||
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
|
||||
from .linear import Linear
|
||||
|
@@ -106,6 +106,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||
func_to_func_dict = {
|
||||
torch.addmm: F.linear,
|
||||
torch.addbmm: torch.bmm,
|
||||
F.linear: F.linear,
|
||||
}
|
||||
|
||||
method_to_func_dict = {
|
||||
|
@@ -0,0 +1,25 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
|
||||
|
||||
@bias_addition_function.register(F.linear)
|
||||
class Linear(LinearBasedBiasFunc):
|
||||
|
||||
def extract_kwargs_from_origin_func(self):
|
||||
assert 'bias' in self.kwargs
|
||||
kwargs = {}
|
||||
if 'bias' in self.kwargs:
|
||||
kwargs['bias'] = self.kwargs['bias']
|
||||
return kwargs
|
||||
|
||||
def generate(self):
|
||||
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
|
||||
kwargs = self.extract_kwargs_from_origin_func()
|
||||
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
|
||||
|
||||
return bias_addition_proxy
|
@@ -102,8 +102,13 @@ class ColoTracer(Tracer):
|
||||
handle = None
|
||||
if kind == "call_function":
|
||||
if bias_addition_function.has(target):
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
if target == torch.nn.functional.linear:
|
||||
if 'bias' in kwargs and kwargs['bias'] is not None:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
else:
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
|
Reference in New Issue
Block a user