mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-16 16:31:30 +00:00
refactor: clean up prep index
This commit is contained in:
parent
f2161f7e59
commit
c20379f7e9
@ -9,67 +9,86 @@ from pyarrow import compute as pc
|
||||
from argparse import ArgumentParser
|
||||
from gpt4all.utils.read import read_config
|
||||
from gpt4all.index.embed_texts import embed_texts
|
||||
from tqdm import tqdm
|
||||
|
||||
SPLIT = 'train'
|
||||
CHUNK_SIZE = 1024
|
||||
K = 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="config.yaml")
|
||||
parser.add_argument("--split", type=str, default="train")
|
||||
parser.add_argument("--k", type=int, default=5)
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prep_index():
|
||||
args = parse_args()
|
||||
config = read_config(args.config)
|
||||
|
||||
#load index
|
||||
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
|
||||
print("loading index")
|
||||
index.load_index(config['index_path'])
|
||||
|
||||
#load query dataset
|
||||
# load query dataset
|
||||
ds_path = config['dataset_path']
|
||||
|
||||
#load retrieval dataset
|
||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||
# load retrieval dataset
|
||||
print("loading retrieval dataset")
|
||||
print(config["index_database"])
|
||||
if os.path.exists(config['index_database']):
|
||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||
else:
|
||||
retrieval_dataset = load_dataset(config['index_database'], split=args.split)
|
||||
|
||||
#vectorize queries
|
||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}"
|
||||
# vectorize queries
|
||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{args.split}"
|
||||
if not os.path.exists(query_vector_path):
|
||||
print('Embedding dataset...')
|
||||
ds = embed_texts(ds_path,
|
||||
config['batch_size'],
|
||||
embed_on=config['query_embedding_field'],
|
||||
save_to_disk=False,
|
||||
split=SPLIT)
|
||||
split=args.split)
|
||||
ds.save_to_disk(query_vector_path)
|
||||
else:
|
||||
print('Found cached embedding dataset!')
|
||||
ds = Dataset.load_from_disk(query_vector_path)
|
||||
|
||||
#build training dataset
|
||||
train_dataset = load_dataset(ds_path, split=SPLIT)
|
||||
train_dataset = load_dataset(ds_path, split=args.split)
|
||||
|
||||
#search the index for each query
|
||||
neighbor_embs_column = []
|
||||
neighbor_ids_column = []
|
||||
for chunk_start in range(0, len(ds), CHUNK_SIZE):
|
||||
for chunk_start in tqdm(range(0, len(ds), CHUNK_SIZE)):
|
||||
chunk_end = chunk_start + CHUNK_SIZE
|
||||
chunk = ds[chunk_start:chunk_end]
|
||||
query_vectors = np.array(chunk['embedding'])
|
||||
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
||||
neighbor_ids, _ = index.knn_query(query_vectors, k=args.k, num_threads=-1) # neighbor ids is of shape [n_queries, n_neighbors]
|
||||
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
||||
#TODO @nussy should be id
|
||||
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
||||
neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id
|
||||
neighbor_embs_column.extend(neighbor_objs['embedding'])
|
||||
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['id'], value_set))
|
||||
|
||||
#import pdb;pdb.set_trace()
|
||||
# build mapping between indices and embeddings
|
||||
neighbor_id_list = neighbor_objs['id']
|
||||
emb_list = neighbor_objs['embedding']
|
||||
idx_to_embedding = {idx.as_py(): emb_list[i] for i, idx in enumerate(neighbor_id_list)}
|
||||
|
||||
neighbor_embs = []
|
||||
for cur_neighbor_ids in neighbor_ids:
|
||||
cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids]
|
||||
neighbor_embs.append(cur_embs)
|
||||
|
||||
neighbor_embs_column.extend(neighbor_embs)
|
||||
neighbor_ids_column.extend(neighbor_ids)
|
||||
|
||||
print("adding neighbor ids")
|
||||
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
|
||||
print("adding neighbor embeddings")
|
||||
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column)
|
||||
|
||||
supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/"
|
||||
supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/"
|
||||
train_dataset.save_to_disk(supplemented_dataset_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
prep_index()
|
Loading…
Reference in New Issue
Block a user