mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-24 06:27:22 +00:00
we now get the neighbors embeddings from the disk index
This commit is contained in:
parent
240feae277
commit
cc3cc3f7e9
@ -14,7 +14,7 @@ batch_size: 32
|
||||
|
||||
#index
|
||||
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
|
||||
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text"
|
||||
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text"
|
||||
index_space: "cosine"
|
||||
index_dim: 384
|
||||
query_embedding_field: 'question'
|
||||
|
@ -1,9 +1,11 @@
|
||||
import os
|
||||
import hnswlib
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from datasets import Dataset
|
||||
import torch.distributed as dist
|
||||
from datasets import load_dataset
|
||||
from pyarrow import compute as pc
|
||||
from argparse import ArgumentParser
|
||||
from gpt4all.utils.read import read_config
|
||||
from gpt4all.index.embed_texts import embed_texts
|
||||
@ -32,8 +34,6 @@ if __name__ == "__main__":
|
||||
|
||||
#load retrieval dataset
|
||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||
print(type(retrieval_dataset._data))
|
||||
raise
|
||||
|
||||
#vectorize queries
|
||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
|
||||
@ -50,9 +50,8 @@ if __name__ == "__main__":
|
||||
chunk_end = chunk_start + CHUNK_SIZE
|
||||
chunk = ds[chunk_start:chunk_end]
|
||||
query_vectors = np.array(chunk['embedding'])
|
||||
print(query_vectors.shape)
|
||||
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
||||
raise
|
||||
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
||||
out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
||||
|
||||
#get the embeddings for each of the neighbor ids
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user