mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[autoparallel] support addbmm computation (#2102)
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
from .addbmm import Addbmm
|
||||
from .addmm import Addmm
|
||||
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict
|
||||
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict
|
||||
|
@@ -0,0 +1,75 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function, bias_addition_method
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
|
||||
|
||||
@bias_addition_method.register(torch.Tensor.addbmm)
|
||||
@bias_addition_function.register(torch.addbmm)
|
||||
class Addbmm(LinearBasedBiasFunc):
|
||||
|
||||
def extract_kwargs_from_origin_func(self):
|
||||
kwargs = {}
|
||||
if 'beta' in self.kwargs:
|
||||
kwargs['beta'] = self.kwargs['beta']
|
||||
if 'alpha' in self.kwargs:
|
||||
kwargs['alpha'] = self.kwargs['alpha']
|
||||
return kwargs
|
||||
|
||||
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
|
||||
"""
|
||||
This method is used to create the non_bias_func proxy, the node created by this proxy will
|
||||
compute the main computation, such as convolution, with bias option banned.
|
||||
"""
|
||||
assert self.substitute_func == torch.bmm
|
||||
node_kind = 'call_function'
|
||||
node_target = self.substitute_func
|
||||
|
||||
node_args = (input_proxy, other_proxy)
|
||||
# torch.bmm does not have any kwargs
|
||||
node_kwargs = {}
|
||||
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return non_bias_func_proxy
|
||||
|
||||
def insert_sum_node(self, input_proxy, sum_dims=0):
|
||||
'''
|
||||
This method is used to sum the input_proxy through the sum_dims.
|
||||
'''
|
||||
node_kind = 'call_function'
|
||||
node_target = torch.sum
|
||||
node_args = (input_proxy, sum_dims)
|
||||
node_kwargs = {}
|
||||
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return sum_proxy
|
||||
|
||||
def generate(self):
|
||||
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
|
||||
|
||||
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
|
||||
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
|
||||
|
||||
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
|
||||
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
|
||||
kwargs = self.extract_kwargs_from_origin_func()
|
||||
|
||||
if 'beta' in kwargs:
|
||||
beta = kwargs['beta']
|
||||
# doing the multiplication with beta if it exists(temp_2 = beta * input)
|
||||
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||
else:
|
||||
beta_proxy = self.args[0]
|
||||
|
||||
if 'alpha' in kwargs:
|
||||
alpha = kwargs['alpha']
|
||||
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
|
||||
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
|
||||
else:
|
||||
alpha_proxy = sum_proxy
|
||||
|
||||
# doing the addition(temp_4 = temp_2 + temp_3)
|
||||
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
|
||||
|
||||
return bias_addition_proxy
|
@@ -3,10 +3,11 @@ import operator
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function
|
||||
from ...registry import bias_addition_function, bias_addition_method
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
|
||||
|
||||
@bias_addition_method.register(torch.Tensor.addmm)
|
||||
@bias_addition_function.register(torch.addmm)
|
||||
class Addmm(LinearBasedBiasFunc):
|
||||
|
||||
@@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc):
|
||||
kwargs['alpha'] = self.kwargs['alpha']
|
||||
return kwargs
|
||||
|
||||
def coefficent_for_addmm(self, input_proxy, coefficent):
|
||||
"""
|
||||
This method is used to create a coefficent node for the numerical correctness.
|
||||
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
|
||||
Therefore, we need to use this method insert two more operator.mul nodes for
|
||||
the computation graph to compute the final result.
|
||||
"""
|
||||
node_kind = 'call_function'
|
||||
node_target = operator.mul
|
||||
node_args = (
|
||||
input_proxy,
|
||||
coefficent,
|
||||
)
|
||||
node_kwargs = {}
|
||||
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return mul_proxy
|
||||
|
||||
def transpose_other_operand_for_linear(self, other_proxy):
|
||||
'''
|
||||
This method is used to transpose the other operand for linear function.
|
||||
@@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc):
|
||||
|
||||
if 'beta' in kwargs:
|
||||
beta = kwargs['beta']
|
||||
beta_proxy = self.coefficent_for_addmm(self.args[0], beta)
|
||||
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||
else:
|
||||
beta_proxy = self.args[0]
|
||||
|
||||
if 'alpha' in kwargs:
|
||||
alpha = kwargs['alpha']
|
||||
alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy)
|
||||
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
|
||||
else:
|
||||
alpha_proxy = non_bias_linear_func_proxy
|
||||
|
||||
|
@@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_mul_node(self, input_proxy, coefficent):
|
||||
"""
|
||||
This method is used to create a coefficent node for the numerical correctness.
|
||||
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
|
||||
Therefore, we need to use this method insert two more operator.mul nodes for
|
||||
the computation graph to compute the final result.
|
||||
"""
|
||||
node_kind = 'call_function'
|
||||
node_target = operator.mul
|
||||
node_args = (
|
||||
input_proxy,
|
||||
coefficent,
|
||||
)
|
||||
node_kwargs = {}
|
||||
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
|
||||
return mul_proxy
|
||||
|
||||
|
||||
class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||
"""
|
||||
@@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||
|
||||
func_to_func_dict = {
|
||||
torch.addmm: F.linear,
|
||||
torch.addbmm: torch.bmm,
|
||||
}
|
||||
|
||||
method_to_func_dict = {
|
||||
torch.Tensor.addmm: F.linear,
|
||||
torch.Tensor.addbmm: torch.bmm,
|
||||
}
|
||||
|
@@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution
|
||||
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
|
||||
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
|
||||
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
|
||||
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')
|
||||
|
@@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
|
||||
|
||||
from ..proxy import ColoProxy
|
||||
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
|
||||
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
|
||||
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
|
||||
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
|
||||
from .registry import (
|
||||
bias_addition_function,
|
||||
bias_addition_method,
|
||||
bias_addition_module,
|
||||
meta_patched_function,
|
||||
meta_patched_module,
|
||||
)
|
||||
|
||||
__all__ = ['ColoTracer']
|
||||
|
||||
@@ -100,12 +106,14 @@ class ColoTracer(Tracer):
|
||||
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
|
||||
elif bias_addition_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
|
||||
function_to_substitute = func_to_func_dict[target]
|
||||
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
|
||||
|
||||
elif kind == "call_method":
|
||||
method = getattr(args_metas[0].__class__, target)
|
||||
if bias_addition_function.has(method):
|
||||
handle = bias_addition_function.get(method)(self, target, args, kwargs)
|
||||
if bias_addition_method.has(method):
|
||||
function_to_substitute = method_to_func_dict[method]
|
||||
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
|
||||
|
||||
elif kind == "call_module":
|
||||
if not hasattr(self, "orig_forward"):
|
||||
|
Reference in New Issue
Block a user