[autoparallel] support addbmm computation (#2102)

This commit is contained in:
YuliangLiu0306
2022-12-08 21:15:11 +08:00
committed by GitHub
parent d3d4630495
commit 0fecbb9e20
7 changed files with 179 additions and 65 deletions

View File

@@ -1,2 +1,3 @@
from .addbmm import Addbmm
from .addmm import Addmm
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict
from .bias_addition_function import BiasAdditionFunc, LinearBasedBiasFunc, func_to_func_dict, method_to_func_dict

View File

@@ -0,0 +1,75 @@
import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addbmm)
@bias_addition_function.register(torch.addbmm)
class Addbmm(LinearBasedBiasFunc):
def extract_kwargs_from_origin_func(self):
kwargs = {}
if 'beta' in self.kwargs:
kwargs['beta'] = self.kwargs['beta']
if 'alpha' in self.kwargs:
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
"""
This method is used to create the non_bias_func proxy, the node created by this proxy will
compute the main computation, such as convolution, with bias option banned.
"""
assert self.substitute_func == torch.bmm
node_kind = 'call_function'
node_target = self.substitute_func
node_args = (input_proxy, other_proxy)
# torch.bmm does not have any kwargs
node_kwargs = {}
non_bias_func_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return non_bias_func_proxy
def insert_sum_node(self, input_proxy, sum_dims=0):
'''
This method is used to sum the input_proxy through the sum_dims.
'''
node_kind = 'call_function'
node_target = torch.sum
node_args = (input_proxy, sum_dims)
node_kwargs = {}
sum_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return sum_proxy
def generate(self):
# The formula for addbmm is output = beta * input + alpha * (torch.bmm(b1, b2))
# doing the non-bias computation(temp_0 = torch.bmm(b1, b2))
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], self.args[2])
# doing sum on the batch dimension(temp_1 = torch.sum(temp_0, 0))
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
kwargs = self.extract_kwargs_from_origin_func()
if 'beta' in kwargs:
beta = kwargs['beta']
# doing the multiplication with beta if it exists(temp_2 = beta * input)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
else:
alpha_proxy = sum_proxy
# doing the addition(temp_4 = temp_2 + temp_3)
bias_addition_proxy = self.create_bias_addition_proxy(alpha_proxy, beta_proxy)
return bias_addition_proxy

View File

@@ -3,10 +3,11 @@ import operator
import torch
import torch.nn.functional as F
from ...registry import bias_addition_function
from ...registry import bias_addition_function, bias_addition_method
from .bias_addition_function import LinearBasedBiasFunc
@bias_addition_method.register(torch.Tensor.addmm)
@bias_addition_function.register(torch.addmm)
class Addmm(LinearBasedBiasFunc):
@@ -18,23 +19,6 @@ class Addmm(LinearBasedBiasFunc):
kwargs['alpha'] = self.kwargs['alpha']
return kwargs
def coefficent_for_addmm(self, input_proxy, coefficent):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_target = operator.mul
node_args = (
input_proxy,
coefficent,
)
node_kwargs = {}
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return mul_proxy
def transpose_other_operand_for_linear(self, other_proxy):
'''
This method is used to transpose the other operand for linear function.
@@ -61,13 +45,13 @@ class Addmm(LinearBasedBiasFunc):
if 'beta' in kwargs:
beta = kwargs['beta']
beta_proxy = self.coefficent_for_addmm(self.args[0], beta)
beta_proxy = self.create_mul_node(self.args[0], beta)
else:
beta_proxy = self.args[0]
if 'alpha' in kwargs:
alpha = kwargs['alpha']
alpha_proxy = self.coefficent_for_addmm(alpha, non_bias_linear_func_proxy)
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
else:
alpha_proxy = non_bias_linear_func_proxy

View File

@@ -52,6 +52,23 @@ class BiasAdditionFunc(ABC):
"""
pass
def create_mul_node(self, input_proxy, coefficent):
"""
This method is used to create a coefficent node for the numerical correctness.
The formula for torch.addmm is out = beta * input + alpha * (m1 @ m2)
Therefore, we need to use this method insert two more operator.mul nodes for
the computation graph to compute the final result.
"""
node_kind = 'call_function'
node_target = operator.mul
node_args = (
input_proxy,
coefficent,
)
node_kwargs = {}
mul_proxy = self.tracer.create_proxy(node_kind, node_target, node_args, node_kwargs)
return mul_proxy
class LinearBasedBiasFunc(BiasAdditionFunc):
"""
@@ -88,4 +105,10 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
func_to_func_dict = {
torch.addmm: F.linear,
torch.addbmm: torch.bmm,
}
method_to_func_dict = {
torch.Tensor.addmm: F.linear,
torch.Tensor.addbmm: torch.bmm,
}

View File

@@ -25,3 +25,4 @@ meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition')
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition')
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition')

View File

@@ -20,8 +20,14 @@ from torch.fx.proxy import ParameterProxy, Proxy
from ..proxy import ColoProxy
from ._tracer_utils import compute_meta_data_for_functions_proxy, extract_meta, is_element_in_list
from .bias_addition_patch import func_to_func_dict, module_to_func_dict
from .registry import bias_addition_function, bias_addition_module, meta_patched_function, meta_patched_module
from .bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from .registry import (
bias_addition_function,
bias_addition_method,
bias_addition_module,
meta_patched_function,
meta_patched_module,
)
__all__ = ['ColoTracer']
@@ -100,12 +106,14 @@ class ColoTracer(Tracer):
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs)
function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_function.has(method):
handle = bias_addition_function.get(method)(self, target, args, kwargs)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(self, target, args, kwargs, function_to_substitute)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):