diff --git a/configs/train/finetune_memory.yaml b/configs/train/finetune_memory.yaml new file mode 100644 index 00000000..9cff8058 --- /dev/null +++ b/configs/train/finetune_memory.yaml @@ -0,0 +1,45 @@ +# model/tokenizer +model_name: "EleutherAI/pythia-1b" +tokenizer_name: "EleutherAI/pythia-1b" +version: null +gradient_checkpointing: true +save_name: "nomic-ai/lethe" +push_to_hub: false +memory_attn_layer: 12 + +# dataset +streaming: false +num_proc: 64 +dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train" +max_length: 1024 +batch_size: 64 +pct_test: 0.05 +q_column: "question" +a_column: "answers" +context_column: "neighbor_text" +num_memories_per_index: 500000 +num_neighbors_to_retrieve: 32 +mem_chunk_size: 64 + +# train dynamics +lr: 1.0e-4 +min_lr: 0 +weight_decay: 0.0 +eval_every: 100 +save_every: -1 +log_grads_every: 100 +log_lr_every: 10 +output_dir: "ckpts/mem_attn" +checkpoint: null +lora: false +warmup_steps: 500 +num_epochs: 5 +debug: false +scheduler: false + +# logging +wandb: false +wandb_entity: gpt4all +wandb_project_name: mem_attn +seed: 42 + diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index b0d32711..a674b3b5 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -11,7 +11,7 @@ from gpt4all.utils.read import read_config from gpt4all.index.embed_texts import embed_texts from tqdm import tqdm -CHUNK_SIZE = 1024 +CHUNK_SIZE = 2048 def parse_args(): parser = ArgumentParser() @@ -57,10 +57,10 @@ def prep_index(): #build training dataset train_dataset = load_dataset(ds_path, split=args.split) - #search the index for each query - neighbor_embs_column = [] + # search the index for each query + neighbor_text_column = [] neighbor_ids_column = [] - for chunk_start in tqdm(range(0, len(ds), CHUNK_SIZE)): + for _, chunk_start in enumerate(tqdm(range(0, len(ds), CHUNK_SIZE))): chunk_end = chunk_start + CHUNK_SIZE chunk = ds[chunk_start:chunk_end] query_vectors = np.array(chunk['embedding']) @@ -70,21 +70,21 @@ def prep_index(): # build mapping between indices and embeddings neighbor_id_list = neighbor_objs['id'] - emb_list = neighbor_objs['embedding'] - idx_to_embedding = {idx.as_py(): emb_list[i] for i, idx in enumerate(neighbor_id_list)} + documents = neighbor_objs['text'] + idx_to_embedding = {idx.as_py(): documents[i] for i, idx in enumerate(neighbor_id_list)} - neighbor_embs = [] + neighbor_text = [] for cur_neighbor_ids in neighbor_ids: cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids] - neighbor_embs.append(cur_embs) + neighbor_text.append(cur_embs) - neighbor_embs_column.extend(neighbor_embs) + neighbor_text_column.extend(neighbor_text) neighbor_ids_column.extend(neighbor_ids) print("adding neighbor ids") train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column) - print("adding neighbor embeddings") - train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column) + print("adding neighbors") + train_dataset = train_dataset.add_column('neighbor_text', neighbor_text_column) supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/" train_dataset.save_to_disk(supplemented_dataset_path)