Merge branch 'junior' of https://github.com/nomic-ai/gpt4all into junior

This commit is contained in:
Zach Nussbaum 2023-05-02 14:11:06 +00:00
commit 201505dd62
38 changed files with 2301 additions and 109 deletions

5
.gitignore vendored
View File

@ -1,3 +1,8 @@
*.arrow
*.swp
# ignore knn index
gpt4all/index/**/
.DS_Store
*.pkl
ckpts*
.deepspeed_env

View File

@ -0,0 +1,18 @@
# model/tokenizer
model_name: "/home/paperspace/gpt4all/gpt4all/train/ckpts/epoch_2"
tokenizer_name: "EleutherAI/gpt-j-6B"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/gpt-jr"
encoder_dim: 384
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_validation"
max_length: 1024
batch_size: 32
pct_test: 0.05
q_column: "question"
a_column: "answers"
encoder_column: "neighbor_embeddings"

View File

@ -0,0 +1,13 @@
# dataset
streaming: false
num_proc: 64
dataset_path: "squad"
max_length: 1024
batch_size: 32
#index
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-cohere-index.bin"
index_database: "nomic-ai/cohere-wiki-sbert"
index_space: "cosine"
index_dim: 384
query_embedding_field: 'question'

View File

@ -0,0 +1,40 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6B"
tokenizer_name: "EleutherAI/gpt-j-6B"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/gpt-jr-decay-alpha"
encoder_dim: 384
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
max_length: 1024
batch_size: 32
pct_test: 0.05
q_column: "question"
a_column: "answers"
encoder_column: "neighbor_embeddings"
# train dynamics
lr: 1.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 50
save_every: 500
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/decay_alpha"
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 5
# logging
wandb: true
wandb_entity: gpt4all
wandb_project_name: retrieval
seed: 42

View File

@ -1,8 +0,0 @@
#!/bin/bash
export WORKER_IP=$1
N_GPUS=8
# create dir if doesn't exist
sudo mkdir -p /job
printf "localhost slots=$N_GPUS\n$WORKER_IP slots=$N_GPUS" | sudo tee /job/hostfile
echo /job/hostfile

0
gpt4all/__init__.py Normal file
View File

2
gpt4all/data/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .instruction_tuning_dataloader import *
from .retrieval_dataloader import *

View File

@ -1,61 +1,11 @@
import glob
import torch
from datasets import load_dataset, concatenate_datasets
from datasets import load_dataset
import os
import hnswlib
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator
def tokenize_inputs(config, tokenizer, examples):
max_length = config["max_length"]
# hacky backward compatible
different_eos = tokenizer.eos_token != "</s>"
out = {"labels": [], "input_ids": []}
for prompt, response in zip(examples["prompt"], examples["response"]):
if different_eos:
if response.count("</s> \n") > 0:
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
# hack if our prompt is super long
# we need to include some labels so we arbitrarily trunacate at max_length // 2
# if the length is too long
if prompt_len >= max_length // 2:
# if prompt is too long, truncate
# but make sure to truncate to at max 1024 tokens
new_len = min(max_length // 2, len(prompt) // 2)
prompt = prompt[:new_len]
# get new prompt length
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
labels = input_tokens.clone()
labels[:prompt_len] = -100
if len(labels) < max_length:
# pad to max_length with -100
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
if (labels == -100).sum() == len(labels) - 1:
print(prompt)
print(response)
raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
out["labels"].append(labels)
out["input_ids"].append(input_tokens)
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
return out
from .preprocess import tokenize_inputs
def load_data(config, tokenizer):
@ -85,13 +35,13 @@ def load_data(config, tokenizer):
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
batched=True,
remove_columns=["source", "prompt"],
**kwargs
@ -116,6 +66,7 @@ def load_data(config, tokenizer):
return train_dataloader, val_dataloader
def load_data_for_inference(config, tokenizer):
dataset_path = config["dataset_path"]
@ -152,12 +103,12 @@ def load_data_for_inference(config, tokenizer):
# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
batched=True,
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
lambda ele: tokenize_inputs(config, tokenizer, ele, "prompt", "response"),
batched=True,
**kwargs
)

View File

@ -0,0 +1,51 @@
import torch
def tokenize_inputs(config, tokenizer, examples, input_col, target_col):
max_length = config["max_length"]
# hacky backward compatible
different_eos = tokenizer.eos_token != "</s>"
out = {"labels": [], "input_ids": []}
for prompt, response in zip(examples[input_col], examples[target_col]):
if different_eos:
if response.count("</s> \n") > 0:
response = response.replace("</s> \n", f"{tokenizer.eos_token} \n")
prompt_len = len(tokenizer(prompt + "\n", return_tensors="pt")["input_ids"][0])
# hack if our prompt is super long
# we need to include some labels so we arbitrarily trunacate at max_length // 2
# if the length is too long
if prompt_len >= max_length // 2:
# if prompt is too long, truncate
# but make sure to truncate to at max 1024 tokens
new_len = min(max_length // 2, len(prompt) // 2)
prompt = prompt[:new_len]
# get new prompt length
prompt_len = tokenizer(prompt + "\n", return_tensors="pt", max_length=max_length // 2, truncation=True).input_ids.ne(tokenizer.pad_token_id).sum().item()
assert prompt_len <= max_length // 2, f"prompt length {prompt_len} exceeds max length {max_length}"
input_tokens = tokenizer(prompt + "\n" + response + tokenizer.eos_token,
truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze()
labels = input_tokens.clone()
labels[:prompt_len] = -100
if len(labels) < max_length:
# pad to max_length with -100
labels = torch.cat([labels, torch.full((max_length - len(labels),), -100)])
assert (labels == -100).sum() < len(labels), f"Labels are all -100, something wrong. prompt length {prompt_len} exceeds max length {max_length}"
if (labels == -100).sum() == len(labels) - 1:
print(prompt)
print(response)
raise
input_tokens = tokenizer.pad({"input_ids": input_tokens}, padding="max_length", max_length=max_length)["input_ids"]
out["labels"].append(labels)
out["input_ids"].append(input_tokens)
out = {k: torch.stack(v) if isinstance(v, list) else v for k, v in out.items()}
return out

View File

@ -0,0 +1,75 @@
from datasets import load_dataset, Dataset
import os
from torch.utils.data import DataLoader
from .preprocess import tokenize_inputs
from transformers import DefaultDataCollator
def load_retrieval_augmented_data(config, tokenizer, split="train", split_dataset=True):
dataset_path = config["dataset_path"]
if os.path.exists(dataset_path):
dataset = Dataset.load_from_disk(dataset_path)
else:
dataset = load_dataset(dataset_path, split=split)
question_col = config["q_column"]
answer_col = config["a_column"]
encoder_column = config["encoder_column"]
if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}
# strip any unneccessary whitespace
# there's one question that's includes a ton of whitespace
dataset = dataset.map(lambda ele: {question_col: [q.strip() for q in ele[question_col]]}, batched=True)
# in squad, the data is formatted where each ele in answers is a dict where the key text holds
# a list of the answer
dataset = dataset.map(lambda ele: {answer_col: [t["text"][0] for t in ele[answer_col]]}, batched=True)
dataset = dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele, question_col, answer_col),
batched=True,
**kwargs
)
# tokenize inputs + labels in teacher-force format
# rename encoder hidden states if not already called that
if encoder_column != "encoder_hidden_states":
dataset = dataset.rename_column(encoder_column, "encoder_hidden_states")
columns_to_keep = ["input_ids", "labels", "encoder_hidden_states"]
col_names_to_rm = [col for col in dataset.column_names if col not in columns_to_keep]
dataset = dataset.remove_columns(col_names_to_rm)
if split_dataset:
dataset = dataset.train_test_split(test_size=config["pct_test"], seed=config["seed"])
train_dataset, val_dataset = dataset["train"], dataset["test"]
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return train_dataloader, val_dataloader
else:
dataloader = DataLoader(
dataset,
batch_size=config["batch_size"],
collate_fn=DefaultDataCollator(),
)
return dataloader

0
gpt4all/eval/__init__.py Normal file
View File

View File

@ -3,7 +3,7 @@ import torch
import pickle
import numpy as np
from tqdm import tqdm
from read import read_config
from gpt4all.utils.read import read_config
from argparse import ArgumentParser
from peft import PeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

View File

@ -0,0 +1,54 @@
import torch
from gpt4all.models import GPTJRForCausalLM
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
from gpt4all.train.metrics import f1_score, exact_match_score
from gpt4all.utils.read import read_config
from transformers import AutoTokenizer
from argparse import ArgumentParser
from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
dataloader = load_retrieval_augmented_data(config, tokenizer, split_dataset=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTJRForCausalLM.from_pretrained(config["model_name"], use_cache=False)
model.to(device)
model.eval()
# Evaluate the model on the SQUAD dataset
f1s = []
exact_matches = []
with torch.no_grad():
for batch in tqdm(dataloader):
outputs = model(input_ids=batch["input_ids"].to(device),
labels=batch["labels"].to(device),
encoder_hidden_states=batch["encoder_hidden_states"].to(device))
predicted_tokens = outputs.logits.argmax(dim=-1)
predicted = tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)
labels = batch["labels"]
mask = labels == -100
labels[mask] = tokenizer.pad_token_id
ground_truth = tokenizer.batch_decode(labels, skip_special_tokens=True)
f1 = f1_score(predicted, ground_truth)
exact_match = exact_match_score(predicted, ground_truth)
f1s.extend(f1)
exact_matches.extend(exact_match)
print(torch.tensor(f1s).mean())
print(torch.tensor(exact_matches).to(torch.float32).mean())

34
gpt4all/index/README.md Normal file
View File

@ -0,0 +1,34 @@
# How to Tokenize and Embed
Split text into chunks
```
python tokenize_texts.py
```
Embbed Texts
```
torchrun --master_port=29085 --nproc-per-node 8 embed_texts.py --ds_path=tokenized --batch_size=2048
```
Combine Embeddings and Build Index
```
python build_index.py --ds_path=wiki_sample_tokenized --embed_folder=wiki_sample_embedded
```
To use the Index
```
import hnswlib
index = hnswlib.Index(space='l2', dim=384)
index.load_index(<path to index>)
```
Prep index for train
```
CUDA_VISIBLE_DEVICES=7 torchrun --master_port=29086 --nproc-per-node 1 prep_index_for_train.py --config=../../configs/train/finetune_gptjr.yaml
```

View File

View File

@ -0,0 +1,102 @@
import os
from datasets import Dataset, concatenate_datasets
import glob
from argparse import ArgumentParser
import hnswlib
import pyarrow as pa
import pyarrow.compute as pc
from tqdm import tqdm
def parse_args():
parser = ArgumentParser()
parser.add_argument("--ds_path", type=str, required=True)
parser.add_argument("--embed_folder", type=str, required=True)
parser.add_argument("--index_path", type=str, default="wiki-index")
return parser.parse_args()
def concat_embedded(folder):
files = glob.glob(f"{folder}/*")
all_embeddings = []
for file in files:
all_embeddings.append(Dataset.load_from_disk(file))
all_embeddings = concatenate_datasets(all_embeddings)
return all_embeddings
def join(original_ds, embedded_ds):
embedded_ds = embedded_ds.add_column("index", range(len(embedded_ds)))
embed_table = embedded_ds.data.table
seen = set()
indices = []
for i, id in enumerate(original_ds["id"]):
if id not in seen:
indices.append(i)
seen.add(id)
mask = pc.is_in(embed_table["index"], value_set=pa.array(indices, pa.int32()))
filtered_table = embed_table.filter(mask)
import pdb; pdb.set_trace()
# sort to make sure we're adding in right order
filtered_table = filtered_table.sort_by("id")
original_table = original_ds.data.table
original_table = original_table.sort_by("id")
original_table = original_table.append_column(
"embedding", filtered_table["embedding"]
)
# there's definitely a better way to do this but
# Dataset(original_table) throws `KeyError: 'embedding'`
joined = Dataset.from_dict(original_table.to_pydict())
return joined
def build_index(orig_path, embed_folder_path, index_path):
if not os.path.exists(orig_path + "_embedded_with_text"):
# TODO: this doesn't work for large datasets!
# just convert to pandas and do this manually
ds = Dataset.load_from_disk(orig_path)
embed_ds = concat_embedded(embed_folder_path)
print("Concatenated embeddings")
print(f"Length: {len(ds)}")
print(f"Length: {len(embed_ds)}")
ds = join(ds, embed_ds)
ds = ds.add_column("index", range(len(ds)))
print("Saving to disk")
ds.save_to_disk(f"{orig_path}_embedded_with_text")
else:
ds = Dataset.load_from_disk(orig_path + "_embedded_with_text")
print(f"Length of ds: {len(ds)}")
print("Building index")
index = hnswlib.Index(space="cosine", dim=384)
# not sure what we should set M and ef_construction to
index.init_index(max_elements=len(ds), M=64, ef_construction=200)
print("Adding items")
chunk_size = 50_000
num_chunks = len(ds) // chunk_size
progbar = tqdm(total=num_chunks)
start = 0
while start < len(ds):
chunk = ds[start:start + chunk_size]
index.add_items(chunk["embedding"], chunk["id"], num_threads=64)
progbar.update(1)
start += chunk_size
print("Saving index")
index.save_index(index_path + ".bin")
if __name__ == "__main__":
args = parse_args()
build_index(args.ds_path, args.embed_folder, args.index_path)

70
gpt4all/index/embed.py Normal file
View File

@ -0,0 +1,70 @@
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
class Embedder:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.embedder = AutoModel.from_pretrained(model_name)
# hack
self.offset = self.tokenizer.model_max_length // 2
def _mean_pool(self, model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
sentence_embeddings = torch.sum(
token_embeddings * input_mask_expanded, 1
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return F.normalize(sentence_embeddings, p=2, dim=1)
def chunk_text(self, text):
tokenized_text = {"input_ids": [], "attention_mask": []}
tokenized = self.tokenizer(text)
tokenized_len = len(tokenized["input_ids"])
max_len = self.tokenizer.model_max_length
if tokenized_len > max_len:
start = 0
while start < tokenized_len:
tokenized_text["input_ids"].append(
tokenized["input_ids"][start : start + max_len]
)
tokenized_text["attention_mask"].append(
tokenized["attention_mask"][start : start + max_len]
)
# this could probably be done better
start += self.offset
else:
tokenized_text["input_ids"].append(tokenized["input_ids"])
tokenized_text["attention_mask"].append(tokenized["attention_mask"])
return tokenized_text
def tokenize(self, text):
return self.tokenizer(text, return_tensors="pt", truncation=True, padding="max_length")
def __call__(self, batch):
if isinstance(batch, str):
tokenized = self.tokenizer(batch, return_tensors="pt", truncation=True)
return self._mean_pool(
self.embedder(
input_ids=tokenized["input_ids"],
attention_mask=tokenized["attention_mask"],
),
tokenized["attention_mask"],
)
else:
outputs = self.embedder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
embedding = self._mean_pool(outputs, batch["attention_mask"])
return {"id": batch["id"], "embedding": embedding}
def to(self, device):
self.embedder.to(device)

View File

@ -0,0 +1,116 @@
import os
import torch.distributed as dist
from argparse import ArgumentParser
from datasets import Dataset
from gpt4all.index.embed import Embedder
from gpt4all.utils.distributed_utils import rank0_print
from torch.utils.data import DataLoader, DistributedSampler
from transformers.trainer_pt_utils import nested_numpify
from transformers import BatchEncoding
from tqdm import tqdm
import numpy as np
import torch
from datasets import load_dataset
# this isn't used but keeping in case we need it in the future
# collate and pad inputs to the right shape
class PadCollateInputs:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, batch):
mapped_inputs = {"input_ids": [], "attention_mask": []}
mapped_inputs["input_ids"] = [b["tokenized_chunk"] for b in batch]
mapped_inputs["attention_mask"] = [b["tokenized_attn_mask"] for b in batch]
encoding = BatchEncoding(mapped_inputs)
padded_inputs = self.tokenizer.pad(
encoding, padding="max_length", return_tensors="pt"
)
padded_inputs["id"] = [b["id"] for b in batch]
return padded_inputs
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank0_print(f"World size: {world_size}")
if os.path.exists(ds_path):
dataset = Dataset.load_from_disk(ds_path)
else:
dataset = load_dataset(ds_path, split=split)
rank0_print(f"Dataset size: {len(dataset)}")
model = Embedder()
if "input_ids" not in dataset.column_names:
dataset = dataset.map(lambda x: model.tokenize(x[embed_on]), batched=True, num_proc=64)
columns_to_keep = ["id", "input_ids", "attention_mask"]
to_remove = [e for e in dataset.column_names if not e in columns_to_keep]
dataset = dataset.remove_columns(to_remove)
dataset = dataset.with_format("torch")
num_processes = dist.get_world_size() if dist.is_initialized() else 1
local_rank = dist.get_rank() if dist.is_initialized() else 0
if num_processes > 1:
sampler = DistributedSampler(
dataset,
shuffle=False,
drop_last=False,
num_replicas=num_processes,
rank=local_rank,
)
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=False,
)
model.to(f"cuda:{local_rank}")
with torch.no_grad():
embedded_outputs = {"id": [], "embedding": []}
for batch in tqdm(dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["attention_mask"] = batch["attention_mask"].to(f"cuda:{local_rank}")
outputs = model(batch)
embedded_outputs["id"].extend(batch["id"])
embedded_outputs["embedding"].extend(outputs["embedding"])
embedded_outputs["embedding"] = nested_numpify(embedded_outputs["embedding"])
embedded_outputs["id"] = np.stack(embedded_outputs["id"])
embedded_outputs["embedding"] = np.stack(embedded_outputs["embedding"])
ds = Dataset.from_dict(embedded_outputs)
# feeling lazy, don't want to wait for all_gather to finish
# will load and concat in a single process in another script
if save_to_disk:
ds.save_to_disk(f"{ds_path}_embedded/{ds_path}_embedded_rank_{local_rank}")
return ds
def main():
dist.init_process_group("nccl")
parser = ArgumentParser()
parser.add_argument("--ds_path", type=str, default="tokenized")
parser.add_argument("--batch_size", type=int, default=1)
args = parser.parse_args()
embed_texts(args.ds_path, args.batch_size, save_to_disk=True)
if __name__ == "__main__":
# parse arguments by reading in a config
main()

View File

@ -0,0 +1,94 @@
import os
import hnswlib
import numpy as np
import pyarrow as pa
from datasets import Dataset
import torch.distributed as dist
from datasets import load_dataset
from pyarrow import compute as pc
from argparse import ArgumentParser
from gpt4all.utils.read import read_config
from gpt4all.index.embed_texts import embed_texts
from tqdm import tqdm
CHUNK_SIZE = 1024
def parse_args():
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--k", type=int, default=5)
return parser.parse_args()
def prep_index():
args = parse_args()
config = read_config(args.config)
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
print("loading index")
index.load_index(config['index_path'])
# load query dataset
ds_path = config['dataset_path']
# load retrieval dataset
print("loading retrieval dataset")
print(config["index_database"])
if os.path.exists(config['index_database']):
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
else:
retrieval_dataset = load_dataset(config['index_database'], split="train")
# vectorize queries
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{args.split}"
if not os.path.exists(query_vector_path):
print('Embedding dataset...')
ds = embed_texts(ds_path,
config['batch_size'],
embed_on=config['query_embedding_field'],
save_to_disk=False,
split=args.split)
ds.save_to_disk(query_vector_path)
else:
print('Found cached embedding dataset!')
ds = Dataset.load_from_disk(query_vector_path)
#build training dataset
train_dataset = load_dataset(ds_path, split=args.split)
#search the index for each query
neighbor_embs_column = []
neighbor_ids_column = []
for chunk_start in tqdm(range(0, len(ds), CHUNK_SIZE)):
chunk_end = chunk_start + CHUNK_SIZE
chunk = ds[chunk_start:chunk_end]
query_vectors = np.array(chunk['embedding'])
neighbor_ids, _ = index.knn_query(query_vectors, k=args.k, num_threads=-1) # neighbor ids is of shape [n_queries, n_neighbors]
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['id'], value_set))
# build mapping between indices and embeddings
neighbor_id_list = neighbor_objs['id']
emb_list = neighbor_objs['embedding']
idx_to_embedding = {idx.as_py(): emb_list[i] for i, idx in enumerate(neighbor_id_list)}
neighbor_embs = []
for cur_neighbor_ids in neighbor_ids:
cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids]
neighbor_embs.append(cur_embs)
neighbor_embs_column.extend(neighbor_embs)
neighbor_ids_column.extend(neighbor_ids)
print("adding neighbor ids")
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
print("adding neighbor embeddings")
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column)
supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/"
train_dataset.save_to_disk(supplemented_dataset_path)
if __name__ == "__main__":
prep_index()

View File

View File

@ -0,0 +1,72 @@
from datasets import load_dataset
from argparse import ArgumentParser
from gpt4all.index.embed import Embedder
def parse_args():
parser = ArgumentParser()
# fmt: off
parser.add_argument("--tokenized_save_path", type=str, default="tokenized")
parser.add_argument("--ds_name", type=str, default="wikipedia")
parser.add_argument("--ds_version", type=str, default="20220301.simple")
parser.add_argument("--sbert_model", type=str, default="sentence-transformers/all-MiniLM-L6-v2")
# fmt: on
return parser.parse_args()
def tokenize_texts(examples, embedder):
split_data = {k: [] for k in examples.keys()}
split_data["tokenized_chunk"] = []
split_data["tokenized_attn_mask"] = []
keys = [k for k in examples.keys() if k != "text"]
for i, text in enumerate(examples["text"]):
tokenized_text = embedder.chunk_text(text)
# do we want to add sep/cls tokens to beginning and end?
decoded_text = embedder.tokenizer.batch_decode(
sequences=tokenized_text["input_ids"]
)
num_splits = len(tokenized_text["input_ids"])
split_data["id"].extend(
[f"{examples['id'][i]}_split_{j}" for j in range(num_splits)]
)
for col in keys:
if col != "id":
split_data[col].extend(
[examples[col][i]] * len(tokenized_text["input_ids"])
)
split_data["text"].extend(decoded_text)
split_data["tokenized_chunk"].extend(tokenized_text["input_ids"])
split_data["tokenized_attn_mask"].extend(tokenized_text["attention_mask"])
return split_data
def chunk_dataset(
ds_name="wikipedia",
version="20220301.simple",
sbert_model="sentence-transformers/all-MiniLM-L6-v2",
save_path="tokenized",
):
dataset = load_dataset(ds_name, version, split="train")
print(len(dataset))
embedder = Embedder(sbert_model)
dataset = dataset.map(
lambda x: tokenize_texts(x, embedder), batched=True, num_proc=64
)
dataset.save_to_disk(save_path)
if __name__ == "__main__":
args = parse_args()
chunked_dataset = chunk_dataset(
ds_name=args.ds_name,
version=args.ds_version,
sbert_model=args.sbert_model,
save_path=args.tokenized_save_path,
)

View File

View File

@ -1,6 +1,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModelForCausalLM
from read import read_config
from gpt4all.utils.read import read_config
from argparse import ArgumentParser
import torch
import time

View File

@ -2,13 +2,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
from argparse import ArgumentParser
from read import read_config
from accelerate.utils import set_seed
from data import load_data_for_inference
from gpt4all.utils.read import read_config
from accelerate.utils import set_seed
from gpt4all.data.instruction_tuning_dataloader import load_data_for_inference
from gpt4all.utils.distributed_utils import rank0_print
from tqdm import tqdm
from datasets import Dataset
from datasets import Dataset
import torch.distributed as dist
from transformers.trainer_pt_utils import nested_numpify
from transformers.trainer_pt_utils import nested_numpify
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
@ -21,56 +22,64 @@ def calc_cross_entropy_no_reduction(lm_logits, labels):
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss_fct = nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels).mean(dim=1)
return loss
def rank0_print(msg):
if dist.get_rank() == 0:
print(msg)
def inference(config):
set_seed(config['seed'])
set_seed(config["seed"])
rank0_print(f"World size: {dist.get_world_size()}")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
tokenizer = AutoTokenizer.from_pretrained(
config["tokenizer_name"], model_max_length=config["max_length"]
)
# llama has no pad token, set it to new token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
train_dataset, val_dataset = load_data_for_inference(config, tokenizer)
num_processes = dist.get_world_size()
local_rank = dist.get_rank()
train_sampler = DistributedSampler(train_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
train_sampler = DistributedSampler(
train_dataset,
shuffle=False,
drop_last=True,
num_replicas=num_processes,
rank=local_rank,
)
train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=train_sampler,
drop_last=True
drop_last=True,
)
val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=True, num_replicas=num_processes, rank=local_rank)
val_sampler = DistributedSampler(
val_dataset,
shuffle=False,
drop_last=True,
num_replicas=num_processes,
rank=local_rank,
)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=val_sampler,
drop_last=True
drop_last=True,
)
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
config["model_name"],
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
model.to(f"cuda:{local_rank}")
with torch.no_grad():
@ -78,14 +87,18 @@ def inference(config):
for batch in tqdm(train_dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
outputs = model(
input_ids=batch["input_ids"],
labels=batch["labels"],
output_hidden_states=True,
)
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
train_outputs["loss"].extend(loss)
embeddings = outputs.hidden_states[-1]
batch_size = batch["input_ids"].shape[0]
sequence_lengths = []
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# <|endoftext|> is repeated
for item in batch["input_ids"]:
indices = torch.where(item == tokenizer.pad_token_id)[0]
@ -101,7 +114,9 @@ def inference(config):
sequence_lengths.append(len(item) - 1)
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
pooled_logits = embeddings[
torch.arange(batch_size, device=embeddings.device), sequence_lengths
]
train_outputs["embeddings"].append(pooled_logits)
train_outputs["index"].extend(batch["index"].to(model.device))
@ -120,29 +135,40 @@ def inference(config):
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())
filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"])
filtered_train = filtered_train.add_column("is_train", [True] * len(filtered_train))
filtered_train = filtered_train.add_column(
"is_train", [True] * len(filtered_train)
)
filtered_train.to_json(f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
filtered_train.to_json(
f"inference/epoch_2_embeddings_train_shard_{local_rank}.jsonl",
lines=True,
orient="records",
num_proc=64,
)
val_outputs = {"loss": [], "embeddings": [], "index": []}
for batch in tqdm(val_dataloader, disable=local_rank != 0):
batch["input_ids"] = batch["input_ids"].to(f"cuda:{local_rank}")
batch["labels"] = batch["labels"].to(f"cuda:{local_rank}")
outputs = model(input_ids=batch["input_ids"], labels=batch["labels"], output_hidden_states=True)
outputs = model(
input_ids=batch["input_ids"],
labels=batch["labels"],
output_hidden_states=True,
)
loss = calc_cross_entropy_no_reduction(outputs.logits, batch["labels"])
val_outputs["loss"].extend(loss)
embeddings = outputs.hidden_states[-1]
batch_size = batch["input_ids"].shape[0]
sequence_lengths = []
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# since we use mutiturn with multiple <|endoftext|>, we need to find the place where
# <|endoftext|> is repeated
for item in batch["input_ids"]:
indices = torch.where(item == tokenizer.pad_token_id)[0]
@ -158,7 +184,9 @@ def inference(config):
sequence_lengths.append(len(item) - 1)
sequence_lengths = torch.tensor(sequence_lengths)
pooled_logits = embeddings[torch.arange(batch_size, device=embeddings.device), sequence_lengths]
pooled_logits = embeddings[
torch.arange(batch_size, device=embeddings.device), sequence_lengths
]
val_outputs["embeddings"].append(pooled_logits)
val_outputs["index"].extend(batch["index"].to(model.device))
@ -176,7 +204,7 @@ def inference(config):
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
mask = pc.is_in(table["index"], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())
@ -184,8 +212,13 @@ def inference(config):
filtered_val = filtered_val.add_column("loss", df_val["loss"])
filtered_val = filtered_val.add_column("is_train", [False] * len(filtered_val))
filtered_val.to_json(f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl", lines=True, orient="records", num_proc=64)
filtered_val.to_json(
f"inference/epoch_2_embeddings_val_shard_{local_rank}.jsonl",
lines=True,
orient="records",
num_proc=64,
)
def main():
dist.init_process_group("nccl")
@ -201,4 +234,3 @@ def main():
if __name__ == "__main__":
# parse arguments by reading in a config
main()

View File

@ -0,0 +1,8 @@
from .configuration_gpt_jr import GPTJRConfig
from .modeling_gpt_jr import GPTJRForCausalLM
__all__ = [
"GPTJRConfig",
"GPTJRForCausalLM"
]

View File

@ -0,0 +1,145 @@
# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" GPT-J model configuration"""
from collections import OrderedDict
from typing import Any, List, Mapping, Optional
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
from transformers.utils import logging
logger = logging.get_logger(__name__)
GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json",
# See all GPT-J models at https://huggingface.co/models?filter=gpt_j
}
class GPTJRConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the GPT-J
[EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
[`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50400):
Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GPTJModel`].
n_positions (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
n_embd (`int`, *optional*, defaults to 4096):
Dimensionality of the embeddings and hidden states.
n_layer (`int`, *optional*, defaults to 28):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
rotary_dim (`int`, *optional*, defaults to 64):
Number of dimensions in the embedding that Rotary Position Embedding is applied to.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
resid_pdrop (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (`int`, *optional*, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (`bool`, *optional*, defaults to `True`):
Scale attention weights by dividing by sqrt(hidden_size).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Example:
```python
>>> from transformers import GPTJModel, GPTJConfig
>>> # Initializing a GPT-J 6B configuration
>>> configuration = GPTJConfig()
>>> # Initializing a model from the configuration
>>> model = GPTJModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gptj"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size=50400,
n_positions=2048,
n_embd=4096,
n_layer=28,
n_head=16,
rotary_dim=64,
n_inner=None,
activation_function="gelu_new",
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
scale_attn_weights=True,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
tie_word_embeddings=False,
encoder_dim=4096,
encoder_path=None,
**kwargs
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner
self.rotary_dim = rotary_dim
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.encoder_dim = encoder_dim
super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
)

View File

@ -0,0 +1,916 @@
# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPT-J model."""
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from gpt4all.models.configuration_gpt_jr import GPTJRConfig
logger = logging.get_logger(__name__)
GPTJR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"EleutherAI/gpt-j-6B",
# See all GPT-J models at https://huggingface.co/models?filter=gptj
]
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
dim = x.shape[-1]
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)
class GPTJRAttention(nn.Module):
def __init__(self, config):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.rotary_dim = None
if config.rotary_dim is not None:
self.rotary_dim = config.rotary_dim
def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
attn_weights = attn_weights / self.scale_attn
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
# if we are doing cross attention
if encoder_hidden_states is not None:
key = self.k_proj(encoder_hidden_states)
value = self.v_proj(encoder_hidden_states)
else:
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
seq_len = key.shape[1]
offset = 0
if layer_past is not None:
offset = layer_past[0].shape[-2]
seq_len += offset
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class GPTJRCrossAttention(GPTJRAttention):
def __init__(self, config):
super().__init__(config)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 1, 3, 2)
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# query -> (bs, num_attention_heads, seq_len, head_dim)
# key -> (bs, num_attention_heads, head_dim, neighbors)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
attn_weights = torch.matmul(query, key)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# value -> (bs, num_attention_heads, seq_len, head_dim)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
# if we are doing cross attention
key = self.k_proj(encoder_hidden_states)
value = self.v_proj(encoder_hidden_states)
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, False)
# (bs, dim) -> (bs, num_attention_heads, head_dim)
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
value = value.permute(0, 3, 1, 2)
key = key.permute(0, 3, 2, 1)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class GPTJRMLP(nn.Module):
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
super().__init__()
embed_dim = config.n_embd
self.fc_in = nn.Linear(embed_dim, intermediate_size)
self.fc_out = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
hidden_states = self.fc_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc_out(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class GPTJRBlock(nn.Module):
def __init__(self, config):
super().__init__()
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = GPTJRAttention(config)
self.mlp = GPTJRMLP(inner_dim, config)
self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon)
self.cross_attn = GPTJRCrossAttention(config)
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
step: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
# shape (bs, seq_len, hidden_dim)
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
feed_forward_hidden_states = self.mlp(hidden_states)
self_attention_residual = attn_output + feed_forward_hidden_states + residual
# encoder_hidden_states -> (bs, knn, encoder_dim)
if encoder_hidden_states.dtype != hidden_states.dtype:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
encoder_normed = self.ln_2(encoder_hidden_states)
# cross_attn_outputs -> (bs, seq_len, dim)
cross_attn_output = self.cross_attn(
hidden_states,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output[0]
)
if step is not None:
alpha = self._update_alpha(step)
alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype)
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions)
def _update_alpha(self, iteration):
return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
class GPTJRPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GPTJRConfig
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPTJRBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear,)):
# Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTJRModel):
module.gradient_checkpointing = value
class GPTJRModel(GPTJRPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.n_embd
self.vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPTJRBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.h))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
self.wte = self.wte.to(self.first_device)
# Load onto devices
for k, v in self.device_map.items():
for block in v:
cuda_device = "cuda:" + str(k)
self.h[block] = self.h[block].to(cuda_device)
# ln_f to last
self.ln_f = self.ln_f.to(self.last_device)
def deparallelize(self):
self.model_parallel = False
self.device_map = None
self.first_device = "cpu"
self.last_device = "cpu"
self.wte = self.wte.to("cpu")
for index in range(len(self.h)):
self.h[index] = self.h[index].to("cpu")
self.ln_f = self.ln_f.to("cpu")
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x num_attention_heads x N x N
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions, step)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
None,
attention_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
encoder_hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
step=step
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class GPTJRForCausalLM(GPTJRPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"]
def __init__(self, config):
super().__init__(config)
self.transformer = GPTJRModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
# Model parallel
self.model_parallel = False
self.device_map = None
# Initialize weights and apply final processing
self.post_init()
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.transformer.h))
self.transformer.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.transformer.first_device)
self.model_parallel = True
def deparallelize(self):
self.transformer.deparallelize()
self.transformer = self.transformer.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
torch.cuda.empty_cache()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
step=step,
)
hidden_states = transformer_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.transformer.first_device)
hidden_states = hidden_states.to(self.lm_head.weight.device)
# make sure sampling in fp16 works correctly and
# compute loss in fp32 to match with mesh-tf version
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
# TODO: do we need to do conversion to fp32 if training in bf16?
lm_logits = self.lm_head(hidden_states).to(torch.float32)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss.to(hidden_states.dtype)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
[`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

View File

@ -0,0 +1,50 @@
import torch
from gpt4all.models import GPTJRForCausalLM, GPTJRConfig
from transformers import AutoTokenizer, AutoModel
# seed torch
torch.manual_seed(0)
config = GPTJRConfig(encoder_dim=384, n_layer=4)
print("loaded config")
print("loading model")
model = GPTJRForCausalLM(config)
print("loaded model")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
tokenizer.pad_token = tokenizer.eos_token
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
text = "The quick brown fox jumps over the lazy dog."
print("Encoded knn")
tokenized = encoder_tokenizer(text, return_tensors="pt")
# bs, seq_len, dim
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
# make 2 neighbors
# (bs, knn, encoding_dim)
encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
inputs = "What did the fox do?"
print("Encoded inputs")
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
print("Running model")
outputs = model(**tokenized_input, encoder_outputs=encoder_outputs)
print(outputs)
print(outputs[0].shape)

View File

50
gpt4all/train/metrics.py Normal file
View File

@ -0,0 +1,50 @@
from collections import Counter
import string
import re
# adapted from huggingface
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(predictions, ground_truths):
total_f1 = []
for prediction, ground_truth in zip(predictions, ground_truths):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
total_f1.append(0)
continue
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
total_f1.append(f1)
return total_f1
def exact_match_score(predictions, ground_truths):
exact_scores = []
for prediction, ground_truth in zip(predictions, ground_truths):
exact_scores.append(normalize_answer(prediction) == normalize_answer(ground_truth))
return exact_scores

View File

@ -1,13 +1,13 @@
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
from read import read_config
from gpt4all.utils.read import read_config
from accelerate import Accelerator
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
from peft import get_peft_model, LoraConfig, TaskType
from data import load_data
from gpt4all.data.instruction_tuning_dataloader import load_data
from torchmetrics import MeanMetric
from tqdm import tqdm
import wandb
@ -104,7 +104,7 @@ def train(accelerator, config):
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=config["warmup_steps"], warmup_num_steps=config["warmup_steps"]
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(

View File

@ -0,0 +1,256 @@
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
import torch
from torch.optim import AdamW
from argparse import ArgumentParser
from gpt4all.utils.read import read_config
from accelerate import Accelerator
from accelerate.utils import DummyScheduler, DummyOptim, set_seed
from peft import get_peft_model, LoraConfig, TaskType
from gpt4all.data.retrieval_dataloader import load_retrieval_augmented_data
from torchmetrics import MeanMetric
from tqdm import tqdm
from gpt4all.models import GPTJRForCausalLM
from gpt4all.train.metrics import f1_score, exact_match_score
import wandb
import torch.distributed as dist
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
return log
def evaluate(model, val_dataloader, step, main_process=False):
model.eval()
val_loss = MeanMetric(nan_strategy="error").to(model.device)
with torch.no_grad():
for batch in tqdm(val_dataloader, disable=not main_process):
outputs = model(input_ids=batch["input_ids"],
labels=batch["labels"],
encoder_hidden_states=batch["encoder_hidden_states"],
step=step)
loss_values = accelerator.gather_for_metrics({"loss": outputs["loss"].detach()})
val_loss.update(loss_values["loss"])
return val_loss
def train(accelerator, config):
set_seed(config['seed'])
accelerator.print(config)
accelerator.print(f"Using {accelerator.num_processes} GPUs")
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
# if no pad token, set it to eos
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
with accelerator.main_process_first():
train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer)
checkpoint = config["gradient_checkpointing"]
#ensures back compat with non retrieval models
if 'encoder_dim' in config:
with accelerator.main_process_first():
model = GPTJRForCausalLM.from_pretrained(config["model_name"],
revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True,
encoder_dim=config["encoder_dim"],
)
else:
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True,
trust_remote_code=True)
if checkpoint:
model.gradient_checkpointing_enable()
if config["lora"]:
peft_config = LoraConfig(
# should R be configurable?
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
optimizer_cls = (
AdamW
if accelerator.state.deepspeed_plugin is None
or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
else DummyOptim
)
# karpathy doesn't decay embeddding, maybe we should exclude
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
if accelerator.state.deepspeed_plugin is not None:
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
"gradient_accumulation_steps"
]
# decay to min_lr instead of 0
lr_ratio = config["min_lr"] / config["lr"]
accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}")
total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"]
# instead of decaying to zero, decay to ratio of min_lr / lr
total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"]
accelerator.print(f"Total training steps: {total_num_steps}")
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
scheduler = get_scheduler(
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader, scheduler
)
# setup for saving training states in case preemption
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
accelerator.load_state(config["checkpoint"])
accelerator.print(f"Resumed from checkpoint: {config['checkpoint']}")
path = os.path.basename(config["train_args"]["resume_from_checkpoint"])
training_difference = os.path.splitext(path)[0]
resume_step = int(training_difference.replace("step_", ""))
accelerator.skip_first_batches(train_dataloader, resume_step)
accelerator.print(f"Resuming from step {resume_step}")
# log gradients
if accelerator.is_main_process and config["wandb"]:
wandb.watch(model, log_freq=config["log_grads_every"], log="all")
main_process = accelerator.is_main_process
for epoch in range(config["num_epochs"]):
train_loss = MeanMetric(nan_strategy="error").to(model.device)
for step, batch in enumerate(tqdm(train_dataloader, disable=not main_process)):
model.train()
outputs = model(input_ids=batch["input_ids"],
labels=batch["labels"],
encoder_hidden_states=batch["encoder_hidden_states"],
step=step)
loss = outputs.loss
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
train_loss.update(loss_values["loss"])
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
# get gradient norm of all params
# log LR in case something weird happens
if config["wandb"]:
if step > 0 and step % (config["log_lr_every"] ) == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
curr_step = step + epoch * len(train_dataloader)
val_loss = evaluate(model, val_dataloader, step=curr_step, main_process=main_process)
log_train = {
"train_loss": train_loss.compute()
}
log_val = {
"val_loss": val_loss.compute(),
}
if config["wandb"]:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
f"{config['output_dir']}/final",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.end_training()
if __name__ == "__main__":
# parse arguments by reading in a config
parser = ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
args = parser.parse_args()
config = read_config(args.config)
if config["wandb"]:
accelerator = Accelerator(log_with="wandb")
accelerator.init_trackers(
project_name=config["wandb_project_name"],
config=config,
init_kwargs={"wandb": {"entity": config["wandb_entity"]}},
)
else:
accelerator = Accelerator()
train(accelerator, config=config)

View File

View File

@ -0,0 +1,9 @@
import torch.distributed as dist
def rank0_print(msg):
if dist.is_initialized():
if dist.get_rank() == 0:
print(msg)
else:
print(msg)

View File

@ -12,4 +12,7 @@ sentencepiece
jsonlines
nomic
scikit-learn
matplotlib
matplotlib
apache_beam
mwparserfromhell
hnswlib

34
setup.py Normal file
View File

@ -0,0 +1,34 @@
from setuptools import setup, find_packages
with open('README.md', 'r', encoding='utf-8') as f:
long_description = f.read()
with open('requirements.txt', 'r', encoding='utf-8') as f:
requirements = [line.strip() for line in f if line.strip()]
setup(
name='gpt4all',
version='0.0.1',
author='nomic-ai',
author_email='zach@nomic-ai',
description='an ecosystem of open-source chatbots trained on a massive collections of clean assistant data including code, stories and dialogue',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/nomic-ai/gpt4all',
packages=find_packages(),
install_requires=requirements,
classifiers=[
'Development Status :: 3 - Alpha',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Topic :: Text Processing :: Linguistic',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Intended Audience :: Science/Research',
'Operating System :: OS Independent',
],
python_requires='>=3.6',
)