mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 06:53:05 +00:00
we now get the neighbors embeddings from the disk index
This commit is contained in:
parent
240feae277
commit
cc3cc3f7e9
@ -14,7 +14,7 @@ batch_size: 32
|
|||||||
|
|
||||||
#index
|
#index
|
||||||
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
|
index_path: "/home/paperspace/gpt4all/gpt4all/index/wiki-sample-index.bin"
|
||||||
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki-full-tokenized_embedded_with_text"
|
index_database: "/home/paperspace/gpt4all/gpt4all/index/wiki_sample_tokenized_embedded_with_text"
|
||||||
index_space: "cosine"
|
index_space: "cosine"
|
||||||
index_dim: 384
|
index_dim: 384
|
||||||
query_embedding_field: 'question'
|
query_embedding_field: 'question'
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import hnswlib
|
import hnswlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pyarrow as pa
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from pyarrow import compute as pc
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from gpt4all.utils.read import read_config
|
from gpt4all.utils.read import read_config
|
||||||
from gpt4all.index.embed_texts import embed_texts
|
from gpt4all.index.embed_texts import embed_texts
|
||||||
@ -32,8 +34,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
#load retrieval dataset
|
#load retrieval dataset
|
||||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||||
print(type(retrieval_dataset._data))
|
|
||||||
raise
|
|
||||||
|
|
||||||
#vectorize queries
|
#vectorize queries
|
||||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
|
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
|
||||||
@ -50,9 +50,8 @@ if __name__ == "__main__":
|
|||||||
chunk_end = chunk_start + CHUNK_SIZE
|
chunk_end = chunk_start + CHUNK_SIZE
|
||||||
chunk = ds[chunk_start:chunk_end]
|
chunk = ds[chunk_start:chunk_end]
|
||||||
query_vectors = np.array(chunk['embedding'])
|
query_vectors = np.array(chunk['embedding'])
|
||||||
print(query_vectors.shape)
|
|
||||||
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
||||||
raise
|
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
||||||
|
out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
||||||
|
|
||||||
#get the embeddings for each of the neighbor ids
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user