From cc3cc3f7e9fba7ef8432b490316876826621fb0d Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Sun, 23 Apr 2023 21:04:43 +0000 Subject: [PATCH] we now get the neighbors embeddings from the disk index --- configs/train/finetune_gptjr.yaml | 2 +- gpt4all/index/prep_index_for_train.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/configs/train/finetune_gptjr.yaml b/configs/train/finetune_gptjr.yaml index 291cf841..a03f8ff0 100644 --- a/configs/train/finetune_gptjr.yaml +++ b/configs/train/finetune_gptjr.yaml @@ -14,7 +14,7 @@ batch_size: 32 #index index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" -index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text" +index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text" index_space: "cosine" index_dim: 384 query_embedding_field: 'question' diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index dcb90da3..cf8152fd 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -1,9 +1,11 @@ import os import hnswlib import numpy as np +import pyarrow as pa from datasets import Dataset import torch.distributed as dist from datasets import load_dataset +from pyarrow import compute as pc from argparse import ArgumentParser from gpt4all.utils.read import read_config from gpt4all.index.embed_texts import embed_texts @@ -32,8 +34,6 @@ if __name__ == "__main__": #load retrieval dataset retrieval_dataset = Dataset.load_from_disk(config['index_database']) - print(type(retrieval_dataset._data)) - raise #vectorize queries query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" @@ -50,9 +50,8 @@ if __name__ == "__main__": chunk_end = chunk_start + CHUNK_SIZE chunk = ds[chunk_start:chunk_end] query_vectors = np.array(chunk['embedding']) - print(query_vectors.shape) neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) - raise + value_set = pa.array([str(e) for e in neighbor_ids.flatten()]) + out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) - #get the embeddings for each of the neighbor ids