mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[example] make palm + GeminiDPP work (#2227)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum, nn, matmul
|
||||
from torch import einsum, matmul, nn
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
@@ -86,8 +86,6 @@ def FeedForward(dim, mult=4):
|
||||
|
||||
|
||||
# attention
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_head=64, heads=8):
|
||||
@@ -142,8 +140,6 @@ class Attention(nn.Module):
|
||||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
||||
|
||||
|
||||
|
||||
# split heads
|
||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||
@@ -165,7 +161,7 @@ class Attention(nn.Module):
|
||||
# similarity
|
||||
|
||||
#sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2))
|
||||
sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
|
||||
sim = sim.reshape(b, h, i, j)
|
||||
|
||||
# causal mask
|
||||
@@ -183,7 +179,7 @@ class Attention(nn.Module):
|
||||
# aggregate values
|
||||
|
||||
#out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
out = matmul(attn.reshape(b_, h_*i_, j_), v)
|
||||
out = matmul(attn.reshape(b_, h_ * i_, j_), v)
|
||||
out = out.reshape(b_, h_, i_, d_)
|
||||
|
||||
# merge heads
|
||||
|
Reference in New Issue
Block a user