mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 00:07:29 +00:00
[example] add palm pytorch version (#2172)
This commit is contained in:
parent
12e7bcd720
commit
27327a4c90
64
examples/language/palm/README.md
Normal file
64
examples/language/palm/README.md
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
<img src="./palm.gif" width="450px"></img>
|
||||||
|
|
||||||
|
## PaLM - Pytorch
|
||||||
|
|
||||||
|
Implementation of the specific Transformer architecture from <a href="https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html">PaLM - Scaling Language Modeling with Pathways</a>, in less than 200 lines of code.
|
||||||
|
|
||||||
|
This model is pretty much SOTA on everything language.
|
||||||
|
|
||||||
|
It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
```bash
|
||||||
|
$ pip install PaLM-pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from palm_pytorch import PaLM
|
||||||
|
|
||||||
|
palm = PaLM(
|
||||||
|
num_tokens = 20000,
|
||||||
|
dim = 512,
|
||||||
|
depth = 12,
|
||||||
|
heads = 8,
|
||||||
|
dim_head = 64,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = torch.randint(0, 20000, (1, 2048))
|
||||||
|
logits = palm(tokens) # (1, 2048, 20000)
|
||||||
|
```
|
||||||
|
|
||||||
|
The PaLM 540B in the paper would be
|
||||||
|
|
||||||
|
```python
|
||||||
|
palm = PaLM(
|
||||||
|
num_tokens = 256000,
|
||||||
|
dim = 18432,
|
||||||
|
depth = 118,
|
||||||
|
heads = 48,
|
||||||
|
dim_head = 256
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test on Enwik8
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ python train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Todo
|
||||||
|
|
||||||
|
- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer
|
||||||
|
|
||||||
|
## Citations
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{chowdhery2022PaLM,
|
||||||
|
title = {PaLM: Scaling Language Modeling with Pathways},
|
||||||
|
author = {Chowdhery, Aakanksha et al},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
3
examples/language/palm/data/README.md
Normal file
3
examples/language/palm/data/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Data source
|
||||||
|
|
||||||
|
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
|
1
examples/language/palm/palm_pytorch/__init__.py
Normal file
1
examples/language/palm/palm_pytorch/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from palm_pytorch.palm_pytorch import PaLM
|
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
# helper function
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def eval_decorator(fn):
|
||||||
|
|
||||||
|
def inner(model, *args, **kwargs):
|
||||||
|
was_training = model.training
|
||||||
|
model.eval()
|
||||||
|
out = fn(model, *args, **kwargs)
|
||||||
|
model.train(was_training)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
# top k filtering
|
||||||
|
|
||||||
|
|
||||||
|
def top_k(logits, thres=0.9):
|
||||||
|
k = int((1 - thres) * logits.shape[-1])
|
||||||
|
val, ind = torch.topk(logits, k)
|
||||||
|
probs = torch.full_like(logits, float("-inf"))
|
||||||
|
probs.scatter_(1, ind, val)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class AutoregressiveWrapper(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, net, max_seq_len=2048, pad_value=0):
|
||||||
|
super().__init__()
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.pad_value = pad_value
|
||||||
|
self.net = net
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@eval_decorator
|
||||||
|
def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs):
|
||||||
|
b, t, device = *start_tokens.shape, start_tokens.device
|
||||||
|
|
||||||
|
out = start_tokens
|
||||||
|
|
||||||
|
for _ in range(seq_len):
|
||||||
|
logits = self.net(out, **kwargs)[:, -1, :]
|
||||||
|
|
||||||
|
filtered_logits = top_k(logits, thres=filter_thres)
|
||||||
|
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
||||||
|
|
||||||
|
sample = torch.multinomial(probs, 1)
|
||||||
|
|
||||||
|
out = torch.cat((out, sample), dim=-1)
|
||||||
|
|
||||||
|
if exists(eos_token):
|
||||||
|
is_eos_token = out == eos_token
|
||||||
|
|
||||||
|
if is_eos_token.any(dim=-1).all():
|
||||||
|
# mask out everything after the eos tokens
|
||||||
|
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
||||||
|
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
|
||||||
|
out = out.masked_fill(mask, self.pad_value)
|
||||||
|
break
|
||||||
|
|
||||||
|
out = out[:, t:]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
x_inp, x_labels = x[:, :-1], x[:, 1:]
|
||||||
|
logits = self.net(x_inp, **kwargs)
|
||||||
|
return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)
|
198
examples/language/palm/palm_pytorch/palm_pytorch.py
Normal file
198
examples/language/palm/palm_pytorch/palm_pytorch.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
# they use layernorm without bias, something that pytorch does not offer
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.gamma = nn.Parameter(torch.ones(dim))
|
||||||
|
self.register_buffer("beta", torch.zeros(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||||
|
|
||||||
|
|
||||||
|
# parallel with residual
|
||||||
|
# discovered by Wang et al + EleutherAI from GPT-J fame
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelResidual(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *fns):
|
||||||
|
super().__init__()
|
||||||
|
self.fns = nn.ModuleList(fns)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + sum([fn(x) for fn in self.fns])
|
||||||
|
|
||||||
|
|
||||||
|
# rotary positional embedding
|
||||||
|
# https://arxiv.org/abs/2104.09864
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq)
|
||||||
|
|
||||||
|
def forward(self, max_seq_len, *, device):
|
||||||
|
seq = torch.arange(max_seq_len, device=device)
|
||||||
|
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||||
|
return torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
|
x1, x2 = x.unbind(dim=-2)
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(pos, t):
|
||||||
|
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
||||||
|
|
||||||
|
|
||||||
|
# feedforward
|
||||||
|
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU
|
||||||
|
# https://arxiv.org/abs/2002.05202
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return F.silu(gate) * x
|
||||||
|
|
||||||
|
|
||||||
|
def FeedForward(dim, mult=4):
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
return nn.Sequential(
|
||||||
|
LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim * 2, bias=False),
|
||||||
|
SwiGLU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, dim_head=64, heads=8):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.norm = LayerNorm(dim)
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
|
# for caching causal mask and rotary embeddings
|
||||||
|
|
||||||
|
self.register_buffer("mask", None, persistent=False)
|
||||||
|
self.register_buffer("pos_emb", None, persistent=False)
|
||||||
|
|
||||||
|
def get_mask(self, n, device):
|
||||||
|
if self.mask is not None and self.mask.shape[-1] >= n:
|
||||||
|
return self.mask[:n, :n]
|
||||||
|
|
||||||
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
||||||
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def get_rotary_embedding(self, n, device):
|
||||||
|
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
||||||
|
return self.pos_emb[:n]
|
||||||
|
|
||||||
|
pos_emb = self.rotary_emb(n, device=device)
|
||||||
|
self.register_buffer("position", pos_emb, persistent=False)
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
n, device, h = x.shape[1], x.device, self.heads
|
||||||
|
|
||||||
|
# pre layernorm
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# queries, keys, values
|
||||||
|
|
||||||
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
||||||
|
|
||||||
|
# split heads
|
||||||
|
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||||
|
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||||
|
# https://arxiv.org/abs/1911.02150
|
||||||
|
|
||||||
|
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
||||||
|
|
||||||
|
# rotary embeddings
|
||||||
|
|
||||||
|
positions = self.get_rotary_embedding(n, device)
|
||||||
|
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
||||||
|
|
||||||
|
# scale
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
# similarity
|
||||||
|
|
||||||
|
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||||
|
|
||||||
|
# causal mask
|
||||||
|
|
||||||
|
causal_mask = self.get_mask(n, device)
|
||||||
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
|
||||||
|
|
||||||
|
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
|
||||||
|
net = nn.Sequential(
|
||||||
|
nn.Embedding(num_tokens, dim), *[
|
||||||
|
ParallelResidual(
|
||||||
|
Attention(dim=dim, dim_head=dim_head, heads=heads),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
) for _ in range(depth)
|
||||||
|
], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
|
||||||
|
|
||||||
|
# they used embedding weight tied projection out to logits, not common, but works
|
||||||
|
net[-1].weight = net[0].weight
|
||||||
|
|
||||||
|
nn.init.normal_(net[0].weight, std=0.02)
|
||||||
|
return net
|
109
examples/language/palm/train.py
Normal file
109
examples/language/palm/train.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import gzip
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import tqdm
|
||||||
|
from palm_pytorch import PaLM
|
||||||
|
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
# constants
|
||||||
|
|
||||||
|
NUM_BATCHES = int(1e5)
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
GRADIENT_ACCUMULATE_EVERY = 4
|
||||||
|
LEARNING_RATE = 2e-4
|
||||||
|
VALIDATE_EVERY = 100
|
||||||
|
GENERATE_EVERY = 500
|
||||||
|
GENERATE_LENGTH = 512
|
||||||
|
SEQ_LEN = 1024
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def cycle(loader):
|
||||||
|
while True:
|
||||||
|
for data in loader:
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(token):
|
||||||
|
return str(chr(max(32, token)))
|
||||||
|
|
||||||
|
|
||||||
|
def decode_tokens(tokens):
|
||||||
|
return "".join(list(map(decode_token, tokens)))
|
||||||
|
|
||||||
|
|
||||||
|
# instantiate GPT-like decoder model
|
||||||
|
|
||||||
|
model = PaLM(num_tokens=256, dim=512, depth=8)
|
||||||
|
|
||||||
|
model = AutoregressiveWrapper(model, max_seq_len=2048)
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
# prepare enwik8 data
|
||||||
|
|
||||||
|
with gzip.open("./data/enwik8.gz") as file:
|
||||||
|
X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
|
||||||
|
trX, vaX = np.split(X, [int(90e6)])
|
||||||
|
data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
|
||||||
|
|
||||||
|
|
||||||
|
class TextSamplerDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, data, seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.data = data
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
||||||
|
full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
|
||||||
|
return full_seq.cuda()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.data.size(0) // self.seq_len
|
||||||
|
|
||||||
|
|
||||||
|
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
||||||
|
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
||||||
|
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
|
||||||
|
val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
|
||||||
|
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
# training
|
||||||
|
|
||||||
|
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
||||||
|
loss = model(next(train_loader))
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
print(f"training loss: {loss.item()}")
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
||||||
|
optim.step()
|
||||||
|
optim.zero_grad()
|
||||||
|
|
||||||
|
if i % VALIDATE_EVERY == 0:
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
loss = model(next(val_loader))
|
||||||
|
print(f"validation loss: {loss.item()}")
|
||||||
|
|
||||||
|
if i % GENERATE_EVERY == 0:
|
||||||
|
model.eval()
|
||||||
|
inp = random.choice(val_dataset)[:-1]
|
||||||
|
prime = decode_tokens(inp)
|
||||||
|
print(f"%s \n\n %s", (prime, "*" * 100))
|
||||||
|
|
||||||
|
sample = model.generate(inp[None, ...], GENERATE_LENGTH)
|
||||||
|
output_str = decode_tokens(sample[0])
|
||||||
|
print(output_str)
|
@ -1,6 +1,8 @@
|
|||||||
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
Loading…
Reference in New Issue
Block a user