mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
@@ -24,11 +24,7 @@ class Linear(nn.Module):
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
skip_bias_add=False):
|
||||
def __init__(self, input_size, output_size, bias=True, skip_bias_add=False):
|
||||
super(Linear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@@ -36,9 +32,12 @@ class Linear(nn.Module):
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size,
|
||||
))
|
||||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
)
|
||||
)
|
||||
init.normal_(self.weight)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size))
|
||||
@@ -46,7 +45,7 @@ class Linear(nn.Module):
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
@@ -59,5 +58,7 @@ class Linear(nn.Module):
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
||||
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|
||||
return (
|
||||
f"Linear(in_features={self.input_size}, out_features={self.output_size}, "
|
||||
+ f"bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})"
|
||||
)
|
||||
|
Reference in New Issue
Block a user