mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function, bias_addition_method
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
@@ -10,13 +7,12 @@ from .bias_addition_function import LinearBasedBiasFunc
|
||||
@bias_addition_method.register(torch.Tensor.addbmm)
|
||||
@bias_addition_function.register(torch.addbmm)
|
||||
class Addbmm(LinearBasedBiasFunc):
|
||||
|
||||
def extract_kwargs_from_origin_func(self):
|
||||
kwargs = {}
|
||||
if 'beta' in self.kwargs:
|
||||
kwargs['beta'] = self.kwargs['beta']
|
||||
if 'alpha' in self.kwargs:
|
||||
kwargs['alpha'] = self.kwargs['alpha']
|
||||
if "beta" in self.kwargs:
|
||||
kwargs["beta"] = self.kwargs["beta"]
|
||||
if "alpha" in self.kwargs:
|
||||
kwargs["alpha"] = self.kwargs["alpha"]
|
||||
return kwargs
|
||||
|
||||
def create_non_bias_func_proxy(self, input_proxy, other_proxy):
|
||||
@@ -25,7 +21,7 @@ class Addbmm(LinearBasedBiasFunc):
|
||||
compute the main computation, such as convolution, with bias option banned.
|
||||
"""
|
||||
assert self.substitute_func == torch.bmm
|
||||
node_kind = 'call_function'
|
||||
node_kind = "call_function"
|
||||
node_target = self.substitute_func
|
||||
|
||||
node_args = (input_proxy, other_proxy)
|
||||
@@ -35,10 +31,10 @@ class Addbmm(LinearBasedBiasFunc):
|
||||
return non_bias_func_proxy
|
||||
|
||||
def insert_sum_node(self, input_proxy, sum_dims=0):
|
||||
'''
|
||||
"""
|
||||
This method is used to sum the input_proxy through the sum_dims.
|
||||
'''
|
||||
node_kind = 'call_function'
|
||||
"""
|
||||
node_kind = "call_function"
|
||||
node_target = torch.sum
|
||||
node_args = (input_proxy, sum_dims)
|
||||
node_kwargs = {}
|
||||
@@ -55,15 +51,15 @@ class Addbmm(LinearBasedBiasFunc):
|
||||
sum_proxy = self.insert_sum_node(non_bias_linear_func_proxy)
|
||||
kwargs = self.extract_kwargs_from_origin_func()
|
||||
|
||||
if 'beta' in kwargs:
|
||||
beta = kwargs['beta']
|
||||
if "beta" in kwargs:
|
||||
beta = kwargs["beta"]
|
||||
# doing the multiplication with beta if it exists(temp_2 = beta * input)
|
||||
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||
else:
|
||||
beta_proxy = self.args[0]
|
||||
|
||||
if 'alpha' in kwargs:
|
||||
alpha = kwargs['alpha']
|
||||
if "alpha" in kwargs:
|
||||
alpha = kwargs["alpha"]
|
||||
# doing the multiplication with alpha if it exists(temp_3 = alpha * temp_1)
|
||||
alpha_proxy = self.create_mul_node(alpha, sum_proxy)
|
||||
else:
|
||||
|
@@ -1,7 +1,4 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function, bias_addition_method
|
||||
from .bias_addition_function import LinearBasedBiasFunc
|
||||
@@ -10,17 +7,16 @@ from .bias_addition_function import LinearBasedBiasFunc
|
||||
@bias_addition_method.register(torch.Tensor.addmm)
|
||||
@bias_addition_function.register(torch.addmm)
|
||||
class Addmm(LinearBasedBiasFunc):
|
||||
|
||||
def extract_kwargs_from_origin_func(self):
|
||||
kwargs = {}
|
||||
if 'beta' in self.kwargs:
|
||||
kwargs['beta'] = self.kwargs['beta']
|
||||
if 'alpha' in self.kwargs:
|
||||
kwargs['alpha'] = self.kwargs['alpha']
|
||||
if "beta" in self.kwargs:
|
||||
kwargs["beta"] = self.kwargs["beta"]
|
||||
if "alpha" in self.kwargs:
|
||||
kwargs["alpha"] = self.kwargs["alpha"]
|
||||
return kwargs
|
||||
|
||||
def transpose_other_operand_for_linear(self, other_proxy):
|
||||
'''
|
||||
"""
|
||||
This method is used to transpose the other operand for linear function.
|
||||
For example:
|
||||
input = torch.rand(3, 4)
|
||||
@@ -30,8 +26,8 @@ class Addmm(LinearBasedBiasFunc):
|
||||
# To keep the computation graph consistent with the origin computation graph, we need to transpose the m2
|
||||
# before we call the linear function.
|
||||
new_output = torch.linear(m1, m2.transpose(0, 1)) + input
|
||||
'''
|
||||
node_kind = 'call_function'
|
||||
"""
|
||||
node_kind = "call_function"
|
||||
node_target = torch.transpose
|
||||
node_args = (other_proxy, 0, 1)
|
||||
node_kwargs = {}
|
||||
@@ -43,14 +39,14 @@ class Addmm(LinearBasedBiasFunc):
|
||||
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[1], transpose_proxy)
|
||||
kwargs = self.extract_kwargs_from_origin_func()
|
||||
|
||||
if 'beta' in kwargs:
|
||||
beta = kwargs['beta']
|
||||
if "beta" in kwargs:
|
||||
beta = kwargs["beta"]
|
||||
beta_proxy = self.create_mul_node(self.args[0], beta)
|
||||
else:
|
||||
beta_proxy = self.args[0]
|
||||
|
||||
if 'alpha' in kwargs:
|
||||
alpha = kwargs['alpha']
|
||||
if "alpha" in kwargs:
|
||||
alpha = kwargs["alpha"]
|
||||
alpha_proxy = self.create_mul_node(alpha, non_bias_linear_func_proxy)
|
||||
else:
|
||||
alpha_proxy = non_bias_linear_func_proxy
|
||||
|
@@ -29,7 +29,6 @@ class BiasAdditionFunc(ABC):
|
||||
to insert two more operator.mul nodes for the computation graph to compute the
|
||||
final result.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
@@ -50,7 +49,6 @@ class BiasAdditionFunc(ABC):
|
||||
%mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
|
||||
%add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_mul_node(self, input_proxy, coefficent):
|
||||
"""
|
||||
@@ -59,7 +57,7 @@ class BiasAdditionFunc(ABC):
|
||||
Therefore, we need to use this method insert two more operator.mul nodes for
|
||||
the computation graph to compute the final result.
|
||||
"""
|
||||
node_kind = 'call_function'
|
||||
node_kind = "call_function"
|
||||
node_target = operator.mul
|
||||
node_args = (
|
||||
input_proxy,
|
||||
@@ -82,7 +80,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||
compute the main computation, such as convolution, with bias option banned.
|
||||
"""
|
||||
assert self.substitute_func == torch.nn.functional.linear
|
||||
node_kind = 'call_function'
|
||||
node_kind = "call_function"
|
||||
node_target = self.substitute_func
|
||||
|
||||
node_args = (input_proxy, other_proxy)
|
||||
@@ -96,7 +94,7 @@ class LinearBasedBiasFunc(BiasAdditionFunc):
|
||||
This method is used to create the bias_addition_proxy, the node created by this proxy will
|
||||
compute the sum of non_bias_func result and bias with some reshape operation if needed.
|
||||
"""
|
||||
bias_add_node_kind = 'call_function'
|
||||
bias_add_node_kind = "call_function"
|
||||
bias_add_node_target = operator.add
|
||||
bias_add_args = (non_bias_func_proxy, bias_proxy)
|
||||
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
|
||||
|
@@ -1,6 +1,3 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_function
|
||||
@@ -9,17 +6,16 @@ from .bias_addition_function import LinearBasedBiasFunc
|
||||
|
||||
@bias_addition_function.register(F.linear)
|
||||
class Linear(LinearBasedBiasFunc):
|
||||
|
||||
def extract_kwargs_from_origin_func(self):
|
||||
assert 'bias' in self.kwargs
|
||||
assert "bias" in self.kwargs
|
||||
kwargs = {}
|
||||
if 'bias' in self.kwargs:
|
||||
kwargs['bias'] = self.kwargs['bias']
|
||||
if "bias" in self.kwargs:
|
||||
kwargs["bias"] = self.kwargs["bias"]
|
||||
return kwargs
|
||||
|
||||
def generate(self):
|
||||
non_bias_linear_func_proxy = self.create_non_bias_func_proxy(self.args[0], self.args[1])
|
||||
kwargs = self.extract_kwargs_from_origin_func()
|
||||
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs['bias'])
|
||||
bias_addition_proxy = self.create_bias_addition_proxy(non_bias_linear_func_proxy, kwargs["bias"])
|
||||
|
||||
return bias_addition_proxy
|
||||
|
@@ -27,8 +27,8 @@ class BiasAdditionModule(ABC):
|
||||
Note: this function will be invoked during module initializing,
|
||||
you should never call this function.
|
||||
"""
|
||||
weight_node_kind = 'get_attr'
|
||||
weight_node_target = self.target + '.weight'
|
||||
weight_node_kind = "get_attr"
|
||||
weight_node_target = self.target + ".weight"
|
||||
weight_proxy = self.tracer.create_proxy(weight_node_kind, weight_node_target, (), {})
|
||||
return weight_proxy
|
||||
|
||||
@@ -39,8 +39,8 @@ class BiasAdditionModule(ABC):
|
||||
Note: this function will be invoked during module initializing,
|
||||
you should never call this function.
|
||||
"""
|
||||
bias_node_kind = 'get_attr'
|
||||
bias_node_target = self.target + '.bias'
|
||||
bias_node_kind = "get_attr"
|
||||
bias_node_target = self.target + ".bias"
|
||||
bias_proxy = self.tracer.create_proxy(bias_node_kind, bias_node_target, (), {})
|
||||
return bias_proxy
|
||||
|
||||
@@ -54,14 +54,13 @@ class BiasAdditionModule(ABC):
|
||||
considered during module initializing. However, we need to consider those attributes as kwargs
|
||||
in F.conv2d.
|
||||
"""
|
||||
pass
|
||||
|
||||
def create_non_bias_func_proxy(self, input_proxy=None):
|
||||
"""
|
||||
This method is used to create the non_bias_func proxy, the node created by this proxy will
|
||||
compute the main computation, such as convolution, with bias option banned.
|
||||
"""
|
||||
node_kind = 'call_function'
|
||||
node_kind = "call_function"
|
||||
node_target = self.substitute_func
|
||||
if input_proxy is None:
|
||||
input_proxy = self.args[0]
|
||||
@@ -75,7 +74,7 @@ class BiasAdditionModule(ABC):
|
||||
This method is used to create the bias_addition_proxy, the node created by this proxy will
|
||||
compute the sum of non_bias_func result and bias with some reshape operation if needed.
|
||||
"""
|
||||
bias_add_node_kind = 'call_function'
|
||||
bias_add_node_kind = "call_function"
|
||||
bias_add_node_target = operator.add
|
||||
bias_add_args = (non_bias_func_proxy, bias_proxy)
|
||||
bias_add_proxy = self.tracer.create_proxy(bias_add_node_kind, bias_add_node_target, tuple(bias_add_args), {})
|
||||
@@ -100,7 +99,6 @@ class BiasAdditionModule(ABC):
|
||||
%view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
|
||||
%add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
module_to_func_dict = {
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple
|
||||
from torch.nn.modules.utils import _pair, _single, _triple
|
||||
|
||||
from ...registry import bias_addition_module
|
||||
from .bias_addition_module import BiasAdditionModule
|
||||
@@ -10,17 +9,16 @@ from .bias_addition_module import BiasAdditionModule
|
||||
@bias_addition_module.register(torch.nn.Conv2d)
|
||||
@bias_addition_module.register(torch.nn.Conv3d)
|
||||
class BiasAdditionConv(BiasAdditionModule):
|
||||
|
||||
def extract_kwargs_from_mod(self):
|
||||
root = self.tracer.root
|
||||
conv_module = root.get_submodule(self.target)
|
||||
kwarg_attributes = ['groups', 'dilation', 'stride']
|
||||
kwarg_attributes = ["groups", "dilation", "stride"]
|
||||
non_bias_kwargs = {}
|
||||
for attr_name in kwarg_attributes:
|
||||
if hasattr(conv_module, attr_name):
|
||||
non_bias_kwargs[attr_name] = getattr(conv_module, attr_name)
|
||||
if conv_module.padding_mode != "zeros":
|
||||
#TODO: non zeros mode requires some extra processing for input
|
||||
# TODO: non zeros mode requires some extra processing for input
|
||||
conv_type = type(conv_module)
|
||||
if conv_type == "torch.nn.Conv1d":
|
||||
padding_element = _single(0)
|
||||
@@ -28,9 +26,9 @@ class BiasAdditionConv(BiasAdditionModule):
|
||||
padding_element = _pair(0)
|
||||
elif conv_type == "torch.nn.Conv3d":
|
||||
padding_element = _triple(0)
|
||||
non_bias_kwargs['padding'] = padding_element
|
||||
non_bias_kwargs["padding"] = padding_element
|
||||
else:
|
||||
non_bias_kwargs['padding'] = getattr(conv_module, 'padding')
|
||||
non_bias_kwargs["padding"] = getattr(conv_module, "padding")
|
||||
|
||||
return non_bias_kwargs
|
||||
|
||||
@@ -41,11 +39,12 @@ class BiasAdditionConv(BiasAdditionModule):
|
||||
"""
|
||||
bias_shape = [1] * (dimensions - 1)
|
||||
bias_shape[0] = -1
|
||||
bias_reshape_node_kind = 'call_method'
|
||||
bias_reshape_node_target = 'view'
|
||||
bias_reshape_node_kind = "call_method"
|
||||
bias_reshape_node_target = "view"
|
||||
bias_reshape_node_args = (self.bias_proxy, torch.Size(bias_shape))
|
||||
bias_reshape_proxy = self.tracer.create_proxy(bias_reshape_node_kind, bias_reshape_node_target,
|
||||
bias_reshape_node_args, {})
|
||||
bias_reshape_proxy = self.tracer.create_proxy(
|
||||
bias_reshape_node_kind, bias_reshape_node_target, bias_reshape_node_args, {}
|
||||
)
|
||||
return bias_reshape_proxy
|
||||
|
||||
def generate(self):
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...registry import bias_addition_module
|
||||
from .bias_addition_module import BiasAdditionModule
|
||||
@@ -7,7 +6,6 @@ from .bias_addition_module import BiasAdditionModule
|
||||
|
||||
@bias_addition_module.register(torch.nn.Linear)
|
||||
class BiasAdditionLinear(BiasAdditionModule):
|
||||
|
||||
def extract_kwargs_from_mod(self):
|
||||
return {}
|
||||
|
||||
|
Reference in New Issue
Block a user