we now get the neighbors embeddings from the disk index

This commit is contained in:
Zach Nussbaum 2023-04-23 21:04:43 +00:00
parent 240feae277
commit cc3cc3f7e9
2 changed files with 5 additions and 6 deletions

View File

@ -14,7 +14,7 @@ batch_size: 32
#index #index
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin" index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text" index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text"
index_space: "cosine" index_space: "cosine"
index_dim: 384 index_dim: 384
query_embedding_field: 'question' query_embedding_field: 'question'

View File

@ -1,9 +1,11 @@
import os import os
import hnswlib import hnswlib
import numpy as np import numpy as np
import pyarrow as pa
from datasets import Dataset from datasets import Dataset
import torch.distributed as dist import torch.distributed as dist
from datasets import load_dataset from datasets import load_dataset
from pyarrow import compute as pc
from argparse import ArgumentParser from argparse import ArgumentParser
from gpt4all.utils.read import read_config from gpt4all.utils.read import read_config
from gpt4all.index.embed_texts import embed_texts from gpt4all.index.embed_texts import embed_texts
@ -32,8 +34,6 @@ if __name__ == "__main__":
#load retrieval dataset #load retrieval dataset
retrieval_dataset = Dataset.load_from_disk(config['index_database']) retrieval_dataset = Dataset.load_from_disk(config['index_database'])
print(type(retrieval_dataset._data))
raise
#vectorize queries #vectorize queries
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
@ -50,9 +50,8 @@ if __name__ == "__main__":
chunk_end = chunk_start + CHUNK_SIZE chunk_end = chunk_start + CHUNK_SIZE
chunk = ds[chunk_start:chunk_end] chunk = ds[chunk_start:chunk_end]
query_vectors = np.array(chunk['embedding']) query_vectors = np.array(chunk['embedding'])
print(query_vectors.shape)
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
raise value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
#get the embeddings for each of the neighbor ids