[example] add palm pytorch version (#2172)

This commit is contained in:
Jiarui Fang
2022-12-22 10:15:34 +08:00
committed by GitHub
parent 12e7bcd720
commit 27327a4c90
7 changed files with 454 additions and 0 deletions

View File

@@ -0,0 +1 @@
from palm_pytorch.palm_pytorch import PaLM

View File

@@ -0,0 +1,77 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
# helper function
def exists(val):
return val is not None
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# top k filtering
def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(1, ind, val)
return probs
class AutoregressiveWrapper(nn.Module):
def __init__(self, net, max_seq_len=2048, pad_value=0):
super().__init__()
self.max_seq_len = max_seq_len
self.pad_value = pad_value
self.net = net
@torch.no_grad()
@eval_decorator
def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs):
b, t, device = *start_tokens.shape, start_tokens.device
out = start_tokens
for _ in range(seq_len):
logits = self.net(out, **kwargs)[:, -1, :]
filtered_logits = top_k(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
if exists(eos_token):
is_eos_token = out == eos_token
if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break
out = out[:, t:]
return out
def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
logits = self.net(x_inp, **kwargs)
return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)

View File

@@ -0,0 +1,198 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn
# normalization
# they use layernorm without bias, something that pytorch does not offer
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
# parallel with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
class ParallelResidual(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
return x + sum([fn(x) for fn in self.fns])
# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device)
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# feedforward
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU
# https://arxiv.org/abs/2002.05202
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias=False),
SwiGLU(),
nn.Linear(inner_dim, dim, bias=False),
)
# attention
class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8):
super().__init__()
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.heads = heads
self.scale = dim_head**-0.5
self.rotary_emb = RotaryEmbedding(dim_head)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# for caching causal mask and rotary embeddings
self.register_buffer("mask", None, persistent=False)
self.register_buffer("pos_emb", None, persistent=False)
def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("position", pos_emb, persistent=False)
return pos_emb
def forward(self, x):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device, h = x.shape[1], x.device, self.heads
# pre layernorm
x = self.norm(x)
# queries, keys, values
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150
q = rearrange(q, "b n (h d) -> b h n d", h=h)
# rotary embeddings
positions = self.get_rotary_embedding(n, device)
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
# scale
q = q * self.scale
# similarity
sim = einsum("b h i d, b j d -> b h i j", q, k)
# causal mask
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
# aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v)
# merge heads
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# transformer
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
net = nn.Sequential(
nn.Embedding(num_tokens, dim), *[
ParallelResidual(
Attention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
) for _ in range(depth)
], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
# they used embedding weight tied projection out to logits, not common, but works
net[-1].weight = net[0].weight
nn.init.normal_(net[0].weight, std=0.02)
return net