mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-05 19:46:30 +00:00
nearly have the neighbor caching working, but combinging info into final dataset is challenging
This commit is contained in:
parent
cc3cc3f7e9
commit
869829a065
@ -30,9 +30,9 @@ class PadCollateInputs:
|
||||
return padded_inputs
|
||||
|
||||
|
||||
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False):
|
||||
def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'):
|
||||
rank0_print(f"World size: {dist.get_world_size()}")
|
||||
dataset = load_dataset(f"{ds_path}", split="train")
|
||||
dataset = load_dataset(f"{ds_path}", split=split)
|
||||
rank0_print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
model = Embedder()
|
||||
|
@ -10,6 +10,7 @@ from argparse import ArgumentParser
|
||||
from gpt4all.utils.read import read_config
|
||||
from gpt4all.index.embed_texts import embed_texts
|
||||
|
||||
SPLIT = 'train'
|
||||
CHUNK_SIZE = 1024
|
||||
K = 5
|
||||
|
||||
@ -36,22 +37,39 @@ if __name__ == "__main__":
|
||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||
|
||||
#vectorize queries
|
||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded"
|
||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}"
|
||||
if not os.path.exists(query_vector_path):
|
||||
print('Embedding dataset...')
|
||||
ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False)
|
||||
ds = embed_texts(ds_path,
|
||||
config['batch_size'],
|
||||
embed_on=config['query_embedding_field'],
|
||||
save_to_disk=False,
|
||||
split=SPLIT)
|
||||
ds.save_to_disk(query_vector_path)
|
||||
else:
|
||||
print('Found cached embedding dataset!')
|
||||
ds = Dataset.load_from_disk(query_vector_path)
|
||||
|
||||
#build training dataset
|
||||
train_dataset = load_dataset(ds_path, split=SPLIT)
|
||||
|
||||
#search the index for each query
|
||||
neighbor_embs_column = []
|
||||
neighbor_ids_column = []
|
||||
for chunk_start in range(0, len(ds), CHUNK_SIZE):
|
||||
chunk_end = chunk_start + CHUNK_SIZE
|
||||
chunk = ds[chunk_start:chunk_end]
|
||||
query_vectors = np.array(chunk['embedding'])
|
||||
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
||||
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
||||
out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
||||
#TODO @nussy should be id
|
||||
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
||||
neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id
|
||||
neighbor_embs_column.extend(neighbor_objs['embedding'])
|
||||
|
||||
#import pdb;pdb.set_trace()
|
||||
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
|
||||
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column)
|
||||
|
||||
supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/"
|
||||
train_dataset.save_to_disk(supplemented_dataset_path)
|
||||
|
Loading…
Reference in New Issue
Block a user