we now get the neighbors embeddings from the disk index

This commit is contained in:
Zach Nussbaum 2023-04-23 21:04:43 +00:00
parent 240feae277
commit cc3cc3f7e9
2 changed files with 5 additions and 6 deletions

View File

@ -14,7 +14,7 @@ batch_size: 32
#index
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text"
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text"
index_space: "cosine"
index_dim: 384
query_embedding_field: 'question'

View File

@ -1,9 +1,11 @@
import os
import hnswlib
import numpy as np
import pyarrow as pa
from datasets import Dataset
import torch.distributed as dist
from datasets import load_dataset
from pyarrow import compute as pc
from argparse import ArgumentParser
from gpt4all.utils.read import read_config
from gpt4all.index.embed_texts import embed_texts
@ -32,8 +34,6 @@ if __name__ == "__main__":
#load retrieval dataset
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
print(type(retrieval_dataset._data))
raise
#vectorize queries
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
@ -50,9 +50,8 @@ if __name__ == "__main__":
chunk_end = chunk_start + CHUNK_SIZE
chunk = ds[chunk_start:chunk_end]
query_vectors = np.array(chunk['embedding'])
print(query_vectors.shape)
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
raise
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
#get the embeddings for each of the neighbor ids