mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
nearly have the neighbor caching working, but combinging info into final dataset is challenging
This commit is contained in:
parent
cc3cc3f7e9
commit
869829a065
@ -30,9 +30,9 @@ class PadCollateInputs:
|
|||||||
return padded_inputs
|
return padded_inputs
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False):
|
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
|
||||||
rank0_print(f"World size: {dist.get_world_size()}")
|
rank0_print(f"World size: {dist.get_world_size()}")
|
||||||
dataset = load_dataset(f"{ds_path}", split="train")
|
dataset = load_dataset(f"{ds_path}", split=split)
|
||||||
rank0_print(f"Dataset size: {len(dataset)}")
|
rank0_print(f"Dataset size: {len(dataset)}")
|
||||||
|
|
||||||
model = Embedder()
|
model = Embedder()
|
||||||
|
@ -10,6 +10,7 @@ from argparse import ArgumentParser
|
|||||||
from gpt4all.utils.read import read_config
|
from gpt4all.utils.read import read_config
|
||||||
from gpt4all.index.embed_texts import embed_texts
|
from gpt4all.index.embed_texts import embed_texts
|
||||||
|
|
||||||
|
SPLIT = 'train'
|
||||||
CHUNK_SIZE = 1024
|
CHUNK_SIZE = 1024
|
||||||
K = 5
|
K = 5
|
||||||
|
|
||||||
@ -36,22 +37,39 @@ if __name__ == "__main__":
|
|||||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||||
|
|
||||||
#vectorize queries
|
#vectorize queries
|
||||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
|
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}"
|
||||||
if not os.path.exists(query_vector_path):
|
if not os.path.exists(query_vector_path):
|
||||||
print('Embedding dataset...')
|
print('Embedding dataset...')
|
||||||
ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False)
|
ds = embed_texts(ds_path,
|
||||||
|
config['batch_size'],
|
||||||
|
embed_on=config['query_embedding_field'],
|
||||||
|
save_to_disk=False,
|
||||||
|
split=SPLIT)
|
||||||
ds.save_to_disk(query_vector_path)
|
ds.save_to_disk(query_vector_path)
|
||||||
else:
|
else:
|
||||||
print('Found cached embedding dataset!')
|
print('Found cached embedding dataset!')
|
||||||
ds = Dataset.load_from_disk(query_vector_path)
|
ds = Dataset.load_from_disk(query_vector_path)
|
||||||
|
|
||||||
|
#build training dataset
|
||||||
|
train_dataset = load_dataset(ds_path, split=SPLIT)
|
||||||
|
|
||||||
#search the index for each query
|
#search the index for each query
|
||||||
|
neighbor_embs_column = []
|
||||||
|
neighbor_ids_column = []
|
||||||
for chunk_start in range(0, len(ds), CHUNK_SIZE):
|
for chunk_start in range(0, len(ds), CHUNK_SIZE):
|
||||||
chunk_end = chunk_start + CHUNK_SIZE
|
chunk_end = chunk_start + CHUNK_SIZE
|
||||||
chunk = ds[chunk_start:chunk_end]
|
chunk = ds[chunk_start:chunk_end]
|
||||||
query_vectors = np.array(chunk['embedding'])
|
query_vectors = np.array(chunk['embedding'])
|
||||||
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
||||||
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
||||||
out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
#TODO @nussy should be id
|
||||||
|
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
||||||
|
neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id
|
||||||
|
neighbor_embs_column.extend(neighbor_objs['embedding'])
|
||||||
|
|
||||||
|
#import pdb;pdb.set_trace()
|
||||||
|
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
|
||||||
|
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column)
|
||||||
|
|
||||||
|
supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/"
|
||||||
|
train_dataset.save_to_disk(supplemented_dataset_path)
|
||||||
|
Loading…
Reference in New Issue
Block a user