diff --git a/gpt4all/index/embed_texts.py b/gpt4all/index/embed_texts.py index a139f5df..c2007fde 100644 --- a/gpt4all/index/embed_texts.py +++ b/gpt4all/index/embed_texts.py @@ -30,9 +30,9 @@ class PadCollateInputs: return padded_inputs -def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False): +def embed_texts(ds_path, batch_size, embed_on='text', save_to_disk=False, split='train'): rank0_print(f"World size: {dist.get_world_size()}") - dataset = load_dataset(f"{ds_path}", split="train") + dataset = load_dataset(f"{ds_path}", split=split) rank0_print(f"Dataset size: {len(dataset)}") model = Embedder() diff --git a/gpt4all/index/prep_index_for_train.py b/gpt4all/index/prep_index_for_train.py index cf8152fd..f442e995 100644 --- a/gpt4all/index/prep_index_for_train.py +++ b/gpt4all/index/prep_index_for_train.py @@ -10,6 +10,7 @@ from argparse import ArgumentParser from gpt4all.utils.read import read_config from gpt4all.index.embed_texts import embed_texts +SPLIT = 'train' CHUNK_SIZE = 1024 K = 5 @@ -36,22 +37,39 @@ if __name__ == "__main__": retrieval_dataset = Dataset.load_from_disk(config['index_database']) #vectorize queries - query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded" + query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}" if not os.path.exists(query_vector_path): print('Embedding dataset...') - ds = embed_texts(ds_path, config['batch_size'], embed_on=config['query_embedding_field'], save_to_disk=False) + ds = embed_texts(ds_path, + config['batch_size'], + embed_on=config['query_embedding_field'], + save_to_disk=False, + split=SPLIT) ds.save_to_disk(query_vector_path) else: print('Found cached embedding dataset!') ds = Dataset.load_from_disk(query_vector_path) + #build training dataset + train_dataset = load_dataset(ds_path, split=SPLIT) + #search the index for each query + neighbor_embs_column = [] + neighbor_ids_column = [] for chunk_start in range(0, len(ds), CHUNK_SIZE): chunk_end = chunk_start + CHUNK_SIZE chunk = ds[chunk_start:chunk_end] query_vectors = np.array(chunk['embedding']) neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1) value_set = pa.array([str(e) for e in neighbor_ids.flatten()]) - out = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) + #TODO @nussy should be id + neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set)) + neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id + neighbor_embs_column.extend(neighbor_objs['embedding']) + #import pdb;pdb.set_trace() + train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column) + train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column) + supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/" + train_dataset.save_to_disk(supplemented_dataset_path)