diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md new file mode 100644 index 000000000..486bf240f --- /dev/null +++ b/examples/language/palm/README.md @@ -0,0 +1,64 @@ + + +## PaLM - Pytorch + +Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code. + +This model is pretty much SOTA on everything language. + +It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is. + +## Install +```bash +$ pip install PaLM-pytorch +``` + +## Usage + +```python +import torch +from palm_pytorch import PaLM + +palm = PaLM( + num_tokens = 20000, + dim = 512, + depth = 12, + heads = 8, + dim_head = 64, +) + +tokens = torch.randint(0, 20000, (1, 2048)) +logits = palm(tokens) # (1, 2048, 20000) +``` + +The PaLM 540B in the paper would be + +```python +palm = PaLM( + num_tokens = 256000, + dim = 18432, + depth = 118, + heads = 48, + dim_head = 256 +) +``` + +## Test on Enwik8 + +```bash +$ python train.py +``` + +## Todo + +- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer + +## Citations + +```bibtex +@article{chowdhery2022PaLM, + title = {PaLM: Scaling Language Modeling with Pathways}, + author = {Chowdhery, Aakanksha et al}, + year = {2022} +} +``` diff --git a/examples/language/palm/data/README.md b/examples/language/palm/data/README.md new file mode 100644 index 000000000..56433b4dc --- /dev/null +++ b/examples/language/palm/data/README.md @@ -0,0 +1,3 @@ +# Data source + +The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ diff --git a/examples/language/palm/palm_pytorch/__init__.py b/examples/language/palm/palm_pytorch/__init__.py new file mode 100644 index 000000000..dab49645a --- /dev/null +++ b/examples/language/palm/palm_pytorch/__init__.py @@ -0,0 +1 @@ +from palm_pytorch.palm_pytorch import PaLM diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py new file mode 100644 index 000000000..dc4f3d856 --- /dev/null +++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py @@ -0,0 +1,77 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +# helper function + + +def exists(val): + return val is not None + + +def eval_decorator(fn): + + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +# top k filtering + + +def top_k(logits, thres=0.9): + k = int((1 - thres) * logits.shape[-1]) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float("-inf")) + probs.scatter_(1, ind, val) + return probs + + +class AutoregressiveWrapper(nn.Module): + + def __init__(self, net, max_seq_len=2048, pad_value=0): + super().__init__() + self.max_seq_len = max_seq_len + self.pad_value = pad_value + self.net = net + + @torch.no_grad() + @eval_decorator + def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs): + b, t, device = *start_tokens.shape, start_tokens.device + + out = start_tokens + + for _ in range(seq_len): + logits = self.net(out, **kwargs)[:, -1, :] + + filtered_logits = top_k(logits, thres=filter_thres) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + sample = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + if exists(eos_token): + is_eos_token = out == eos_token + + if is_eos_token.any(dim=-1).all(): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 + out = out.masked_fill(mask, self.pad_value) + break + + out = out[:, t:] + return out + + def forward(self, x, **kwargs): + x_inp, x_labels = x[:, :-1], x[:, 1:] + logits = self.net(x_inp, **kwargs) + return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels) diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py new file mode 100644 index 000000000..1509dd84e --- /dev/null +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -0,0 +1,198 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import einsum, nn + +# normalization +# they use layernorm without bias, something that pytorch does not offer + + +class LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer("beta", torch.zeros(dim)) + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) + + +# parallel with residual +# discovered by Wang et al + EleutherAI from GPT-J fame + + +class ParallelResidual(nn.Module): + + def __init__(self, *fns): + super().__init__() + self.fns = nn.ModuleList(fns) + + def forward(self, x): + return x + sum([fn(x) for fn in self.fns]) + + +# rotary positional embedding +# https://arxiv.org/abs/2104.09864 + + +class RotaryEmbedding(nn.Module): + + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device=device) + freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + return torch.cat((freqs, freqs), dim=-1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(pos, t): + return (t * pos.cos()) + (rotate_half(t) * pos.sin()) + + +# feedforward +# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU +# https://arxiv.org/abs/2002.05202 + + +class SwiGLU(nn.Module): + + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + LayerNorm(dim), + nn.Linear(dim, inner_dim * 2, bias=False), + SwiGLU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +# attention + + +class Attention(nn.Module): + + def __init__(self, dim, dim_head=64, heads=8): + super().__init__() + inner_dim = dim_head * heads + self.norm = LayerNorm(dim) + self.heads = heads + self.scale = dim_head**-0.5 + self.rotary_emb = RotaryEmbedding(dim_head) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, dim_head * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + # for caching causal mask and rotary embeddings + + self.register_buffer("mask", None, persistent=False) + self.register_buffer("pos_emb", None, persistent=False) + + def get_mask(self, n, device): + if self.mask is not None and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def get_rotary_embedding(self, n, device): + if self.pos_emb is not None and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n] + + pos_emb = self.rotary_emb(n, device=device) + self.register_buffer("position", pos_emb, persistent=False) + return pos_emb + + def forward(self, x): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device, h = x.shape[1], x.device, self.heads + + # pre layernorm + + x = self.norm(x) + + # queries, keys, values + + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + + # split heads + # they use multi-query single-key-value attention, yet another Noam Shazeer paper + # they found no performance loss past a certain scale, and more efficient decoding obviously + # https://arxiv.org/abs/1911.02150 + + q = rearrange(q, "b n (h d) -> b h n d", h=h) + + # rotary embeddings + + positions = self.get_rotary_embedding(n, device) + q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k)) + + # scale + + q = q * self.scale + + # similarity + + sim = einsum("b h i d, b j d -> b h i j", q, k) + + # causal mask + + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum("b h i j, b j d -> b h i d", attn, v) + + # merge heads + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +# transformer + + +def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4): + net = nn.Sequential( + nn.Embedding(num_tokens, dim), *[ + ParallelResidual( + Attention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ) for _ in range(depth) + ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False)) + + # they used embedding weight tied projection out to logits, not common, but works + net[-1].weight = net[0].weight + + nn.init.normal_(net[0].weight, std=0.02) + return net diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py new file mode 100644 index 000000000..ba243e507 --- /dev/null +++ b/examples/language/palm/train.py @@ -0,0 +1,109 @@ +import gzip +import random + +import numpy as np +import torch +import torch.optim as optim +import tqdm +from palm_pytorch import PaLM +from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset + +# constants + +NUM_BATCHES = int(1e5) +BATCH_SIZE = 4 +GRADIENT_ACCUMULATE_EVERY = 4 +LEARNING_RATE = 2e-4 +VALIDATE_EVERY = 100 +GENERATE_EVERY = 500 +GENERATE_LENGTH = 512 +SEQ_LEN = 1024 + +# helpers + + +def cycle(loader): + while True: + for data in loader: + yield data + + +def decode_token(token): + return str(chr(max(32, token))) + + +def decode_tokens(tokens): + return "".join(list(map(decode_token, tokens))) + + +# instantiate GPT-like decoder model + +model = PaLM(num_tokens=256, dim=512, depth=8) + +model = AutoregressiveWrapper(model, max_seq_len=2048) +model.cuda() + +# prepare enwik8 data + +with gzip.open("./data/enwik8.gz") as file: + X = np.fromstring(file.read(int(95e6)), dtype=np.uint8) + trX, vaX = np.split(X, [int(90e6)]) + data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX) + + +class TextSamplerDataset(Dataset): + + def __init__(self, data, seq_len): + super().__init__() + self.data = data + self.seq_len = seq_len + + def __getitem__(self, index): + rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) + full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long() + return full_seq.cuda() + + def __len__(self): + return self.data.size(0) // self.seq_len + + +train_dataset = TextSamplerDataset(data_train, SEQ_LEN) +val_dataset = TextSamplerDataset(data_val, SEQ_LEN) +train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) +val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) + +# optimizer + +optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) + +# training + +for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): + model.train() + + for __ in range(GRADIENT_ACCUMULATE_EVERY): + loss = model(next(train_loader)) + loss.backward() + + print(f"training loss: {loss.item()}") + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optim.step() + optim.zero_grad() + + if i % VALIDATE_EVERY == 0: + model.eval() + with torch.no_grad(): + loss = model(next(val_loader)) + print(f"validation loss: {loss.item()}") + + if i % GENERATE_EVERY == 0: + model.eval() + inp = random.choice(val_dataset)[:-1] + prime = decode_tokens(inp) + print(f"%s \n\n %s", (prime, "*" * 100)) + + sample = model.generate(inp[None, ...], GENERATE_LENGTH) + output_str = decode_tokens(sample[0]) + print(output_str) diff --git a/tests/test_gemini/update/test_convert_torch_module.py b/tests/test_gemini/update/test_convert_torch_module.py index c0fd94b40..160099167 100644 --- a/tests/test_gemini/update/test_convert_torch_module.py +++ b/tests/test_gemini/update/test_convert_torch_module.py @@ -1,6 +1,8 @@ +import os from functools import partial import pytest +import torch import torch.multiprocessing as mp import colossalai