fix: prep index for mem attn

This commit is contained in:
Zach Nussbaum 2023-05-18 19:43:38 +00:00
parent 60f2f99cd7
commit 18b04347f5
2 changed files with 56 additions and 11 deletions

View File

@ -0,0 +1,45 @@
# model/tokenizer
model_name: "EleutherAI/pythia-1b"
tokenizer_name: "EleutherAI/pythia-1b"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/lethe"
push_to_hub: false
memory_attn_layer: 12
# dataset
streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
max_length: 1024
batch_size: 64
pct_test: 0.05
q_column: "question"
a_column: "answers"
context_column: "neighbor_text"
num_memories_per_index: 500000
num_neighbors_to_retrieve: 32
mem_chunk_size: 64
# train dynamics
lr: 1.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 100
save_every: -1
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/mem_attn"
checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 5
debug: false
scheduler: false
# logging
wandb: false
wandb_entity: gpt4all
wandb_project_name: mem_attn
seed: 42

View File

@ -11,7 +11,7 @@ from gpt4all.utils.read import read_config
from gpt4all.index.embed_texts import embed_texts from gpt4all.index.embed_texts import embed_texts
from tqdm import tqdm from tqdm import tqdm
CHUNK_SIZE = 1024 CHUNK_SIZE = 2048
def parse_args(): def parse_args():
parser = ArgumentParser() parser = ArgumentParser()
@ -57,10 +57,10 @@ def prep_index():
#build training dataset #build training dataset
train_dataset = load_dataset(ds_path, split=args.split) train_dataset = load_dataset(ds_path, split=args.split)
#search the index for each query # search the index for each query
neighbor_embs_column = [] neighbor_text_column = []
neighbor_ids_column = [] neighbor_ids_column = []
for chunk_start in tqdm(range(0, len(ds), CHUNK_SIZE)): for _, chunk_start in enumerate(tqdm(range(0, len(ds), CHUNK_SIZE))):
chunk_end = chunk_start + CHUNK_SIZE chunk_end = chunk_start + CHUNK_SIZE
chunk = ds[chunk_start:chunk_end] chunk = ds[chunk_start:chunk_end]
query_vectors = np.array(chunk['embedding']) query_vectors = np.array(chunk['embedding'])
@ -70,21 +70,21 @@ def prep_index():
# build mapping between indices and embeddings # build mapping between indices and embeddings
neighbor_id_list = neighbor_objs['id'] neighbor_id_list = neighbor_objs['id']
emb_list = neighbor_objs['embedding'] documents = neighbor_objs['text']
idx_to_embedding = {idx.as_py(): emb_list[i] for i, idx in enumerate(neighbor_id_list)} idx_to_embedding = {idx.as_py(): documents[i] for i, idx in enumerate(neighbor_id_list)}
neighbor_embs = [] neighbor_text = []
for cur_neighbor_ids in neighbor_ids: for cur_neighbor_ids in neighbor_ids:
cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids] cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids]
neighbor_embs.append(cur_embs) neighbor_text.append(cur_embs)
neighbor_embs_column.extend(neighbor_embs) neighbor_text_column.extend(neighbor_text)
neighbor_ids_column.extend(neighbor_ids) neighbor_ids_column.extend(neighbor_ids)
print("adding neighbor ids") print("adding neighbor ids")
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column) train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
print("adding neighbor embeddings") print("adding neighbors")
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column) train_dataset = train_dataset.add_column('neighbor_text', neighbor_text_column)
supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/" supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/"
train_dataset.save_to_disk(supplemented_dataset_path) train_dataset.save_to_disk(supplemented_dataset_path)