diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md
new file mode 100644
index 000000000..486bf240f
--- /dev/null
+++ b/examples/language/palm/README.md
@@ -0,0 +1,64 @@
+
+
+## PaLM - Pytorch
+
+Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways, in less than 200 lines of code.
+
+This model is pretty much SOTA on everything language.
+
+It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.
+
+## Install
+```bash
+$ pip install PaLM-pytorch
+```
+
+## Usage
+
+```python
+import torch
+from palm_pytorch import PaLM
+
+palm = PaLM(
+ num_tokens = 20000,
+ dim = 512,
+ depth = 12,
+ heads = 8,
+ dim_head = 64,
+)
+
+tokens = torch.randint(0, 20000, (1, 2048))
+logits = palm(tokens) # (1, 2048, 20000)
+```
+
+The PaLM 540B in the paper would be
+
+```python
+palm = PaLM(
+ num_tokens = 256000,
+ dim = 18432,
+ depth = 118,
+ heads = 48,
+ dim_head = 256
+)
+```
+
+## Test on Enwik8
+
+```bash
+$ python train.py
+```
+
+## Todo
+
+- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer
+
+## Citations
+
+```bibtex
+@article{chowdhery2022PaLM,
+ title = {PaLM: Scaling Language Modeling with Pathways},
+ author = {Chowdhery, Aakanksha et al},
+ year = {2022}
+}
+```
diff --git a/examples/language/palm/data/README.md b/examples/language/palm/data/README.md
new file mode 100644
index 000000000..56433b4dc
--- /dev/null
+++ b/examples/language/palm/data/README.md
@@ -0,0 +1,3 @@
+# Data source
+
+The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
diff --git a/examples/language/palm/palm_pytorch/__init__.py b/examples/language/palm/palm_pytorch/__init__.py
new file mode 100644
index 000000000..dab49645a
--- /dev/null
+++ b/examples/language/palm/palm_pytorch/__init__.py
@@ -0,0 +1 @@
+from palm_pytorch.palm_pytorch import PaLM
diff --git a/examples/language/palm/palm_pytorch/autoregressive_wrapper.py b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py
new file mode 100644
index 000000000..dc4f3d856
--- /dev/null
+++ b/examples/language/palm/palm_pytorch/autoregressive_wrapper.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+
+# helper function
+
+
+def exists(val):
+ return val is not None
+
+
+def eval_decorator(fn):
+
+ def inner(model, *args, **kwargs):
+ was_training = model.training
+ model.eval()
+ out = fn(model, *args, **kwargs)
+ model.train(was_training)
+ return out
+
+ return inner
+
+
+# top k filtering
+
+
+def top_k(logits, thres=0.9):
+ k = int((1 - thres) * logits.shape[-1])
+ val, ind = torch.topk(logits, k)
+ probs = torch.full_like(logits, float("-inf"))
+ probs.scatter_(1, ind, val)
+ return probs
+
+
+class AutoregressiveWrapper(nn.Module):
+
+ def __init__(self, net, max_seq_len=2048, pad_value=0):
+ super().__init__()
+ self.max_seq_len = max_seq_len
+ self.pad_value = pad_value
+ self.net = net
+
+ @torch.no_grad()
+ @eval_decorator
+ def generate(self, start_tokens, seq_len, eos_token=None, temperature=1.0, filter_thres=0.9, **kwargs):
+ b, t, device = *start_tokens.shape, start_tokens.device
+
+ out = start_tokens
+
+ for _ in range(seq_len):
+ logits = self.net(out, **kwargs)[:, -1, :]
+
+ filtered_logits = top_k(logits, thres=filter_thres)
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
+
+ sample = torch.multinomial(probs, 1)
+
+ out = torch.cat((out, sample), dim=-1)
+
+ if exists(eos_token):
+ is_eos_token = out == eos_token
+
+ if is_eos_token.any(dim=-1).all():
+ # mask out everything after the eos tokens
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
+ mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
+ out = out.masked_fill(mask, self.pad_value)
+ break
+
+ out = out[:, t:]
+ return out
+
+ def forward(self, x, **kwargs):
+ x_inp, x_labels = x[:, :-1], x[:, 1:]
+ logits = self.net(x_inp, **kwargs)
+ return F.cross_entropy(rearrange(logits, "b c n -> b n c"), x_labels)
diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py
new file mode 100644
index 000000000..1509dd84e
--- /dev/null
+++ b/examples/language/palm/palm_pytorch/palm_pytorch.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import einsum, nn
+
+# normalization
+# they use layernorm without bias, something that pytorch does not offer
+
+
+class LayerNorm(nn.Module):
+
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.eps = eps
+ self.gamma = nn.Parameter(torch.ones(dim))
+ self.register_buffer("beta", torch.zeros(dim))
+
+ def forward(self, x):
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
+
+
+# parallel with residual
+# discovered by Wang et al + EleutherAI from GPT-J fame
+
+
+class ParallelResidual(nn.Module):
+
+ def __init__(self, *fns):
+ super().__init__()
+ self.fns = nn.ModuleList(fns)
+
+ def forward(self, x):
+ return x + sum([fn(x) for fn in self.fns])
+
+
+# rotary positional embedding
+# https://arxiv.org/abs/2104.09864
+
+
+class RotaryEmbedding(nn.Module):
+
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, max_seq_len, *, device):
+ seq = torch.arange(max_seq_len, device=device)
+ freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
+ return torch.cat((freqs, freqs), dim=-1)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
+ x1, x2 = x.unbind(dim=-2)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(pos, t):
+ return (t * pos.cos()) + (rotate_half(t) * pos.sin())
+
+
+# feedforward
+# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU
+# https://arxiv.org/abs/2002.05202
+
+
+class SwiGLU(nn.Module):
+
+ def forward(self, x):
+ x, gate = x.chunk(2, dim=-1)
+ return F.silu(gate) * x
+
+
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ LayerNorm(dim),
+ nn.Linear(dim, inner_dim * 2, bias=False),
+ SwiGLU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+# attention
+
+
+class Attention(nn.Module):
+
+ def __init__(self, dim, dim_head=64, heads=8):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.norm = LayerNorm(dim)
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ self.rotary_emb = RotaryEmbedding(dim_head)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ # for caching causal mask and rotary embeddings
+
+ self.register_buffer("mask", None, persistent=False)
+ self.register_buffer("pos_emb", None, persistent=False)
+
+ def get_mask(self, n, device):
+ if self.mask is not None and self.mask.shape[-1] >= n:
+ return self.mask[:n, :n]
+
+ mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
+ self.register_buffer("mask", mask, persistent=False)
+ return mask
+
+ def get_rotary_embedding(self, n, device):
+ if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
+ return self.pos_emb[:n]
+
+ pos_emb = self.rotary_emb(n, device=device)
+ self.register_buffer("position", pos_emb, persistent=False)
+ return pos_emb
+
+ def forward(self, x):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+
+ n, device, h = x.shape[1], x.device, self.heads
+
+ # pre layernorm
+
+ x = self.norm(x)
+
+ # queries, keys, values
+
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
+
+ # split heads
+ # they use multi-query single-key-value attention, yet another Noam Shazeer paper
+ # they found no performance loss past a certain scale, and more efficient decoding obviously
+ # https://arxiv.org/abs/1911.02150
+
+ q = rearrange(q, "b n (h d) -> b h n d", h=h)
+
+ # rotary embeddings
+
+ positions = self.get_rotary_embedding(n, device)
+ q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
+
+ # scale
+
+ q = q * self.scale
+
+ # similarity
+
+ sim = einsum("b h i d, b j d -> b h i j", q, k)
+
+ # causal mask
+
+ causal_mask = self.get_mask(n, device)
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
+
+ # attention
+
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+ attn = sim.softmax(dim=-1)
+
+ # aggregate values
+
+ out = einsum("b h i j, b j d -> b h i d", attn, v)
+
+ # merge heads
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+# transformer
+
+
+def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
+ net = nn.Sequential(
+ nn.Embedding(num_tokens, dim), *[
+ ParallelResidual(
+ Attention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ) for _ in range(depth)
+ ], LayerNorm(dim), nn.Linear(dim, num_tokens, bias=False))
+
+ # they used embedding weight tied projection out to logits, not common, but works
+ net[-1].weight = net[0].weight
+
+ nn.init.normal_(net[0].weight, std=0.02)
+ return net
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
new file mode 100644
index 000000000..ba243e507
--- /dev/null
+++ b/examples/language/palm/train.py
@@ -0,0 +1,109 @@
+import gzip
+import random
+
+import numpy as np
+import torch
+import torch.optim as optim
+import tqdm
+from palm_pytorch import PaLM
+from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
+from torch.nn import functional as F
+from torch.utils.data import DataLoader, Dataset
+
+# constants
+
+NUM_BATCHES = int(1e5)
+BATCH_SIZE = 4
+GRADIENT_ACCUMULATE_EVERY = 4
+LEARNING_RATE = 2e-4
+VALIDATE_EVERY = 100
+GENERATE_EVERY = 500
+GENERATE_LENGTH = 512
+SEQ_LEN = 1024
+
+# helpers
+
+
+def cycle(loader):
+ while True:
+ for data in loader:
+ yield data
+
+
+def decode_token(token):
+ return str(chr(max(32, token)))
+
+
+def decode_tokens(tokens):
+ return "".join(list(map(decode_token, tokens)))
+
+
+# instantiate GPT-like decoder model
+
+model = PaLM(num_tokens=256, dim=512, depth=8)
+
+model = AutoregressiveWrapper(model, max_seq_len=2048)
+model.cuda()
+
+# prepare enwik8 data
+
+with gzip.open("./data/enwik8.gz") as file:
+ X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
+ trX, vaX = np.split(X, [int(90e6)])
+ data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
+
+
+class TextSamplerDataset(Dataset):
+
+ def __init__(self, data, seq_len):
+ super().__init__()
+ self.data = data
+ self.seq_len = seq_len
+
+ def __getitem__(self, index):
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
+ full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
+ return full_seq.cuda()
+
+ def __len__(self):
+ return self.data.size(0) // self.seq_len
+
+
+train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
+val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
+train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
+val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
+
+# optimizer
+
+optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
+
+# training
+
+for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
+ model.train()
+
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
+ loss = model(next(train_loader))
+ loss.backward()
+
+ print(f"training loss: {loss.item()}")
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
+ optim.step()
+ optim.zero_grad()
+
+ if i % VALIDATE_EVERY == 0:
+ model.eval()
+ with torch.no_grad():
+ loss = model(next(val_loader))
+ print(f"validation loss: {loss.item()}")
+
+ if i % GENERATE_EVERY == 0:
+ model.eval()
+ inp = random.choice(val_dataset)[:-1]
+ prime = decode_tokens(inp)
+ print(f"%s \n\n %s", (prime, "*" * 100))
+
+ sample = model.generate(inp[None, ...], GENERATE_LENGTH)
+ output_str = decode_tokens(sample[0])
+ print(output_str)
diff --git a/tests/test_gemini/update/test_convert_torch_module.py b/tests/test_gemini/update/test_convert_torch_module.py
index c0fd94b40..160099167 100644
--- a/tests/test_gemini/update/test_convert_torch_module.py
+++ b/tests/test_gemini/update/test_convert_torch_module.py
@@ -1,6 +1,8 @@
+import os
from functools import partial
import pytest
+import torch
import torch.multiprocessing as mp
import colossalai