mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-18 17:32:00 +00:00
refactor: clean up prep index
This commit is contained in:
parent
f2161f7e59
commit
c20379f7e9
@ -9,67 +9,86 @@ from pyarrow import compute as pc
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from gpt4all.utils.read import read_config
|
from gpt4all.utils.read import read_config
|
||||||
from gpt4all.index.embed_texts import embed_texts
|
from gpt4all.index.embed_texts import embed_texts
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
SPLIT = 'train'
|
|
||||||
CHUNK_SIZE = 1024
|
CHUNK_SIZE = 1024
|
||||||
K = 5
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
dist.init_process_group("nccl")
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default="config.yaml")
|
parser.add_argument("--config", type=str, default="config.yaml")
|
||||||
|
parser.add_argument("--split", type=str, default="train")
|
||||||
|
parser.add_argument("--k", type=int, default=5)
|
||||||
|
|
||||||
args = parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def prep_index():
|
||||||
|
args = parse_args()
|
||||||
config = read_config(args.config)
|
config = read_config(args.config)
|
||||||
|
|
||||||
#load index
|
|
||||||
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
|
index = hnswlib.Index(space=config['index_space'], dim=config['index_dim'])
|
||||||
|
print("loading index")
|
||||||
index.load_index(config['index_path'])
|
index.load_index(config['index_path'])
|
||||||
|
|
||||||
# load query dataset
|
# load query dataset
|
||||||
ds_path = config['dataset_path']
|
ds_path = config['dataset_path']
|
||||||
|
|
||||||
# load retrieval dataset
|
# load retrieval dataset
|
||||||
|
print("loading retrieval dataset")
|
||||||
|
print(config["index_database"])
|
||||||
|
if os.path.exists(config['index_database']):
|
||||||
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
retrieval_dataset = Dataset.load_from_disk(config['index_database'])
|
||||||
|
else:
|
||||||
|
retrieval_dataset = load_dataset(config['index_database'], split=args.split)
|
||||||
|
|
||||||
# vectorize queries
|
# vectorize queries
|
||||||
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{SPLIT}"
|
query_vector_path = f"{ds_path}_queries_embedded/{ds_path}_embedded_{args.split}"
|
||||||
if not os.path.exists(query_vector_path):
|
if not os.path.exists(query_vector_path):
|
||||||
print('Embedding dataset...')
|
print('Embedding dataset...')
|
||||||
ds = embed_texts(ds_path,
|
ds = embed_texts(ds_path,
|
||||||
config['batch_size'],
|
config['batch_size'],
|
||||||
embed_on=config['query_embedding_field'],
|
embed_on=config['query_embedding_field'],
|
||||||
save_to_disk=False,
|
save_to_disk=False,
|
||||||
split=SPLIT)
|
split=args.split)
|
||||||
ds.save_to_disk(query_vector_path)
|
ds.save_to_disk(query_vector_path)
|
||||||
else:
|
else:
|
||||||
print('Found cached embedding dataset!')
|
print('Found cached embedding dataset!')
|
||||||
ds = Dataset.load_from_disk(query_vector_path)
|
ds = Dataset.load_from_disk(query_vector_path)
|
||||||
|
|
||||||
#build training dataset
|
#build training dataset
|
||||||
train_dataset = load_dataset(ds_path, split=SPLIT)
|
train_dataset = load_dataset(ds_path, split=args.split)
|
||||||
|
|
||||||
#search the index for each query
|
#search the index for each query
|
||||||
neighbor_embs_column = []
|
neighbor_embs_column = []
|
||||||
neighbor_ids_column = []
|
neighbor_ids_column = []
|
||||||
for chunk_start in range(0, len(ds), CHUNK_SIZE):
|
for chunk_start in tqdm(range(0, len(ds), CHUNK_SIZE)):
|
||||||
chunk_end = chunk_start + CHUNK_SIZE
|
chunk_end = chunk_start + CHUNK_SIZE
|
||||||
chunk = ds[chunk_start:chunk_end]
|
chunk = ds[chunk_start:chunk_end]
|
||||||
query_vectors = np.array(chunk['embedding'])
|
query_vectors = np.array(chunk['embedding'])
|
||||||
neighbor_ids, _ = index.knn_query(query_vectors, k=K, num_threads=-1)
|
neighbor_ids, _ = index.knn_query(query_vectors, k=args.k, num_threads=-1) # neighbor ids is of shape [n_queries, n_neighbors]
|
||||||
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
value_set = pa.array([str(e) for e in neighbor_ids.flatten()])
|
||||||
#TODO @nussy should be id
|
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['id'], value_set))
|
||||||
neighbor_objs = retrieval_dataset._data.filter(pc.is_in(retrieval_dataset._data['index'], value_set))
|
|
||||||
neighbor_ids_column.extend(neighbor_objs['index']) #TODO @nussy should be id
|
|
||||||
neighbor_embs_column.extend(neighbor_objs['embedding'])
|
|
||||||
|
|
||||||
#import pdb;pdb.set_trace()
|
# build mapping between indices and embeddings
|
||||||
|
neighbor_id_list = neighbor_objs['id']
|
||||||
|
emb_list = neighbor_objs['embedding']
|
||||||
|
idx_to_embedding = {idx.as_py(): emb_list[i] for i, idx in enumerate(neighbor_id_list)}
|
||||||
|
|
||||||
|
neighbor_embs = []
|
||||||
|
for cur_neighbor_ids in neighbor_ids:
|
||||||
|
cur_embs = [idx_to_embedding[id].as_py() for id in cur_neighbor_ids]
|
||||||
|
neighbor_embs.append(cur_embs)
|
||||||
|
|
||||||
|
neighbor_embs_column.extend(neighbor_embs)
|
||||||
|
neighbor_ids_column.extend(neighbor_ids)
|
||||||
|
|
||||||
|
print("adding neighbor ids")
|
||||||
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
|
train_dataset = train_dataset.add_column('neighbor_ids', neighbor_ids_column)
|
||||||
|
print("adding neighbor embeddings")
|
||||||
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column)
|
train_dataset = train_dataset.add_column('neighbor_embeddings', neighbor_embs_column)
|
||||||
|
|
||||||
supplemented_dataset_path = f"{ds_path}_supplemented_{SPLIT}/"
|
supplemented_dataset_path = f"{ds_path}_supplemented_{args.split}/"
|
||||||
train_dataset.save_to_disk(supplemented_dataset_path)
|
train_dataset.save_to_disk(supplemented_dataset_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
prep_index()
|
Loading…
Reference in New Issue
Block a user