[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,14 +1,13 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, matmul, nn
from torch import matmul, nn
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
@@ -24,7 +23,6 @@ class LayerNorm(nn.Module):
class ParallelResidual(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
@@ -38,16 +36,15 @@ class ParallelResidual(nn.Module):
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device)
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
#freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
# freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
# freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq)
freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j))
return torch.cat((freqs, freqs), dim=-1)
@@ -69,7 +66,6 @@ def apply_rotary_pos_emb(pos, t):
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@@ -87,7 +83,6 @@ def FeedForward(dim, mult=4):
# attention
class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8):
super().__init__()
inner_dim = dim_head * heads
@@ -160,7 +155,7 @@ class Attention(nn.Module):
# similarity
#sim = einsum("b h i d, b j d -> b h i j", q, k)
# sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
sim = sim.reshape(b, h, i, j)
@@ -178,7 +173,7 @@ class Attention(nn.Module):
# aggregate values
#out = einsum("b h i j, b j d -> b h i d", attn, v)
# out = einsum("b h i j, b j d -> b h i d", attn, v)
out = matmul(attn.reshape(b_, h_ * i_, j_), v)
out = out.reshape(b_, h_, i_, d_)
@@ -193,12 +188,17 @@ class Attention(nn.Module):
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
net = nn.Sequential(
nn.Embedding(num_tokens, dim), *[
nn.Embedding(num_tokens, dim),
*[
ParallelResidual(
Attention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
) for _ in range(depth)
], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
)
for _ in range(depth)
],
LayerNorm(dim),
nn.Linear(dim, num_tokens, bias=False),
)
# they used embedding weight tied projection out to logits, not common, but works
net[-1].weight = net[0].weight