[example] make palm + GeminiDPP work (#2227)

This commit is contained in:
Jiarui Fang
2022-12-29 14:28:31 +08:00
committed by GitHub
parent 63cc77173b
commit 2cdecc9f38
3 changed files with 39 additions and 56 deletions

View File

@@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn, matmul
from torch import einsum, matmul, nn
# normalization
# they use layernorm without bias, something that pytorch does not offer
@@ -86,8 +86,6 @@ def FeedForward(dim, mult=4):
# attention
class Attention(nn.Module):
def __init__(self, dim, dim_head=64, heads=8):
@@ -142,8 +140,6 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
# split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
@@ -165,7 +161,7 @@ class Attention(nn.Module):
# similarity
#sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2))
sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
sim = sim.reshape(b, h, i, j)
# causal mask
@@ -183,7 +179,7 @@ class Attention(nn.Module):
# aggregate values
#out = einsum("b h i j, b j d -> b h i d", attn, v)
out = matmul(attn.reshape(b_, h_*i_, j_), v)
out = matmul(attn.reshape(b_, h_ * i_, j_), v)
out = out.reshape(b_, h_, i_, d_)
# merge heads