mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-29 22:41:15 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
6070287b34
commit
eadcad8749
@ -1,9 +1,10 @@
|
|||||||
import numpy as np
|
|
||||||
import json
|
import json
|
||||||
from FlagEmbedding import FlagAutoModel
|
|
||||||
import time
|
|
||||||
from rank_bm25 import BM25Okapi
|
|
||||||
import hnswlib
|
import hnswlib
|
||||||
|
import numpy as np
|
||||||
|
from FlagEmbedding import FlagAutoModel
|
||||||
|
from rank_bm25 import BM25Okapi
|
||||||
|
|
||||||
|
|
||||||
def get_list_shape(lst):
|
def get_list_shape(lst):
|
||||||
shape = []
|
shape = []
|
||||||
@ -13,49 +14,55 @@ def get_list_shape(lst):
|
|||||||
current = current[0]
|
current = current[0]
|
||||||
return tuple(shape)
|
return tuple(shape)
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
return FlagAutoModel.from_finetuned(
|
return FlagAutoModel.from_finetuned(
|
||||||
'BAAI/bge-base-en-v1.5',
|
"BAAI/bge-base-en-v1.5",
|
||||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
|
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
|
||||||
# devices='cpu', # Uncomment this line if you want to use GPU.
|
# devices='cpu', # Uncomment this line if you want to use GPU.
|
||||||
use_fp16=True
|
use_fp16=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode_query(model, query):
|
def encode_query(model, query):
|
||||||
query_vectors = [np.array(model.encode(query)).tolist()]
|
query_vectors = [np.array(model.encode(query)).tolist()]
|
||||||
print('query_vectors_shape', get_list_shape(query_vectors))
|
print("query_vectors_shape", get_list_shape(query_vectors))
|
||||||
return query_vectors
|
return query_vectors
|
||||||
|
|
||||||
|
|
||||||
def load_data(vectors_path, docs_path):
|
def load_data(vectors_path, docs_path):
|
||||||
vectors = np.load(vectors_path).tolist()
|
vectors = np.load(vectors_path).tolist()
|
||||||
with open(docs_path, 'r', encoding='utf-8') as file:
|
with open(docs_path, "r", encoding="utf-8") as file:
|
||||||
docs = json.load(file)
|
docs = json.load(file)
|
||||||
return vectors, docs
|
return vectors, docs
|
||||||
|
|
||||||
|
|
||||||
def build_hnsw_index(vectors):
|
def build_hnsw_index(vectors):
|
||||||
# start_time = time.time()
|
# start_time = time.time()
|
||||||
num_elements = len(vectors)
|
num_elements = len(vectors)
|
||||||
p = hnswlib.Index(space='cosine', dim=768)
|
p = hnswlib.Index(space="cosine", dim=768)
|
||||||
p.init_index(max_elements=num_elements, ef_construction=200, M=16)
|
p.init_index(max_elements=num_elements, ef_construction=200, M=16)
|
||||||
# M defines the maximum number of outgoing connections in the graph. Higher M leads to higher accuracy/run_time at fixed ef/efConstruction.
|
# M defines the maximum number of outgoing connections in the graph. Higher M leads to higher accuracy/run_time at fixed ef/efConstruction.
|
||||||
# ef_construction controls index search speed/build speed tradeoff. Increasing the efConstruction parameter may enhance index quality, but it also tends to lengthen the indexing time.
|
# ef_construction controls index search speed/build speed tradeoff. Increasing the efConstruction parameter may enhance index quality, but it also tends to lengthen the indexing time.
|
||||||
p.add_items(np.array(vectors), np.arange(num_elements))
|
p.add_items(np.array(vectors), np.arange(num_elements))
|
||||||
# HNSW_time = time.time()
|
# HNSW_time = time.time()
|
||||||
#print('HNSW build time:', HNSW_time - start_time)
|
# print('HNSW build time:', HNSW_time - start_time)
|
||||||
p.set_ef(32)
|
p.set_ef(32)
|
||||||
# ef controlling query time/accuracy trade-off. Higher ef leads to more accurate but slower search.
|
# ef controlling query time/accuracy trade-off. Higher ef leads to more accurate but slower search.
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
def search_hnsw(index, query_vectors, docs):
|
def search_hnsw(index, query_vectors, docs):
|
||||||
# HNSW_time = time.time()
|
# HNSW_time = time.time()
|
||||||
labels, distances = index.knn_query(np.array(query_vectors), k=10)
|
labels, distances = index.knn_query(np.array(query_vectors), k=10)
|
||||||
results = [docs[i]['content'] for i in labels[0]]
|
results = [docs[i]["content"] for i in labels[0]]
|
||||||
# end_HNSW_time = time.time()
|
# end_HNSW_time = time.time()
|
||||||
# print('HNSW search time:', end_HNSW_time - HNSW_time)
|
# print('HNSW search time:', end_HNSW_time - HNSW_time)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def build_bm25(docs):
|
def build_bm25(docs):
|
||||||
corpus = [doc['content'] for doc in docs]
|
corpus = [doc["content"] for doc in docs]
|
||||||
tokenized_corpus = [list(text.split()) for text in corpus]
|
tokenized_corpus = [list(text.split()) for text in corpus]
|
||||||
# bm25_build_start = time.time()
|
# bm25_build_start = time.time()
|
||||||
bm25 = BM25Okapi(tokenized_corpus)
|
bm25 = BM25Okapi(tokenized_corpus)
|
||||||
@ -63,6 +70,7 @@ def build_bm25(docs):
|
|||||||
# print('BM25 build time:', bm25_build_end - bm25_build_start)
|
# print('BM25 build time:', bm25_build_end - bm25_build_start)
|
||||||
return bm25, corpus
|
return bm25, corpus
|
||||||
|
|
||||||
|
|
||||||
def search_bm25(bm25, corpus, query):
|
def search_bm25(bm25, corpus, query):
|
||||||
# bm25_search_start = time.time()
|
# bm25_search_start = time.time()
|
||||||
tokenized_query = list(query.split())
|
tokenized_query = list(query.split())
|
||||||
@ -73,6 +81,7 @@ def search_bm25(bm25, corpus, query):
|
|||||||
# print('BM25 search time:', bm25_search_end - bm25_search_start)
|
# print('BM25 search time:', bm25_search_end - bm25_search_start)
|
||||||
return bm25_results
|
return bm25_results
|
||||||
|
|
||||||
|
|
||||||
def merge_results(results, bm25_results):
|
def merge_results(results, bm25_results):
|
||||||
merged_results = []
|
merged_results = []
|
||||||
for i in range(len(results)):
|
for i in range(len(results)):
|
||||||
@ -82,20 +91,23 @@ def merge_results(results, bm25_results):
|
|||||||
merged_results = list(set(merged_results))
|
merged_results = list(set(merged_results))
|
||||||
return merged_results
|
return merged_results
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model = load_model()
|
model = load_model()
|
||||||
query = "This is a test query to find relevant documents."
|
query = "This is a test query to find relevant documents."
|
||||||
query_vectors = encode_query(model, query)
|
query_vectors = encode_query(model, query)
|
||||||
vectors, docs = load_data('PATH_TO_YOUR_EMBEDDING.npy', 'PATH_TO_YOUR_JSON.json')
|
vectors, docs = load_data("#PATH_TO_YOUR_EMBEDDING.npy#", "#PATH_TO_YOUR_JSON.json#")
|
||||||
|
|
||||||
hnsw_index = build_hnsw_index(vectors)
|
hnsw_index = build_hnsw_index(vectors)
|
||||||
hnsw_results = search_hnsw(hnsw_index, query_vectors, docs)
|
hnsw_results = search_hnsw(hnsw_index, query_vectors, docs)
|
||||||
|
|
||||||
bm25, corpus = build_bm25(docs)
|
bm25, corpus = build_bm25(docs)
|
||||||
bm25_results = search_bm25(bm25, corpus, query)
|
bm25_results = search_bm25(bm25, corpus, query)
|
||||||
|
|
||||||
merged_results = merge_results(hnsw_results, bm25_results)
|
merged_results = merge_results(hnsw_results, bm25_results)
|
||||||
|
|
||||||
return merged_results
|
return merged_results
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
retrieved_data=main()
|
retrieved_data = main()
|
||||||
|
@ -53,4 +53,3 @@ pip install -r requirements.txt
|
|||||||
- FlagEmbedding
|
- FlagEmbedding
|
||||||
- haystack
|
- haystack
|
||||||
- haystack-integrations
|
- haystack-integrations
|
||||||
|
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from FlagEmbedding import FlagAutoModel
|
from FlagEmbedding import FlagAutoModel
|
||||||
import time
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
import os
|
|
||||||
|
|
||||||
def load_model(model_name="BAAI/bge-base-en-v1.5", use_fp16=True):
|
def load_model(model_name="BAAI/bge-base-en-v1.5", use_fp16=True):
|
||||||
return FlagAutoModel.from_finetuned(
|
return FlagAutoModel.from_finetuned(
|
||||||
@ -42,17 +44,17 @@ def load_embeddings(file_path):
|
|||||||
def main():
|
def main():
|
||||||
config = {
|
config = {
|
||||||
"model_name": "BAAI/bge-base-en-v1.5",
|
"model_name": "BAAI/bge-base-en-v1.5",
|
||||||
"json_path": #PATH_TO_YOUR_JSON.json#,
|
"json_path": "#PATH_TO_YOUR_JSON.json#",
|
||||||
"embedding_path": #PATH_TO_YOUR_EMBEDDING.npy#,
|
"embedding_path": "#PATH_TO_YOUR_EMBEDDING.npy#",
|
||||||
"use_fp16": True,
|
"use_fp16": True,
|
||||||
"use_precomputed_embeddings": False
|
"use_precomputed_embeddings": False
|
||||||
}
|
}
|
||||||
|
|
||||||
model = load_model(
|
model = load_model(
|
||||||
model_name=config["model_name"],
|
model_name=config["model_name"],
|
||||||
use_fp16=config["use_fp16"]
|
use_fp16=config["use_fp16"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if config["use_precomputed_embeddings"]:
|
if config["use_precomputed_embeddings"]:
|
||||||
embeddings = load_embeddings(config["embedding_path"])
|
embeddings = load_embeddings(config["embedding_path"])
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
@ -61,17 +63,17 @@ def main():
|
|||||||
data = load_data(config["json_path"])
|
data = load_data(config["json_path"])
|
||||||
if not data:
|
if not data:
|
||||||
return
|
return
|
||||||
|
|
||||||
texts = extract_texts(data)
|
texts = extract_texts(data)
|
||||||
embeddings = generate_embeddings(model, texts)
|
embeddings = generate_embeddings(model, texts)
|
||||||
save_embeddings(embeddings, config["embedding_path"])
|
save_embeddings(embeddings, config["embedding_path"])
|
||||||
|
|
||||||
##### Test demo with simple KNN cosine_similarity
|
##### Test demo with simple KNN cosine_similarity
|
||||||
# query='This is a test query to find relevant documents.'
|
# query='This is a test query to find relevant documents.'
|
||||||
# query_embedding=np.array(model.encode(query))
|
# query_embedding=np.array(model.encode(query))
|
||||||
# similarity_scores = cosine_similarity([query_embedding], embeddings)
|
# similarity_scores = cosine_similarity([query_embedding], embeddings)
|
||||||
# indices = np.argsort(-similarity_scores)
|
# indices = np.argsort(-similarity_scores)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from pathlib import Path
|
|
||||||
import time
|
|
||||||
import json
|
import json
|
||||||
from haystack import Pipeline
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from haystack import Document, Pipeline
|
||||||
from haystack.components.converters import PyPDFToDocument
|
from haystack.components.converters import PyPDFToDocument
|
||||||
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
|
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
|
||||||
from haystack.components.writers import DocumentWriter
|
from haystack.components.writers import DocumentWriter
|
||||||
from haystack.document_stores.types import DuplicatePolicy
|
|
||||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
from haystack import Document
|
from haystack.document_stores.types import DuplicatePolicy
|
||||||
|
|
||||||
|
|
||||||
def create_indexing_pipeline():
|
def create_indexing_pipeline():
|
||||||
document_store = InMemoryDocumentStore()
|
document_store = InMemoryDocumentStore()
|
||||||
@ -15,23 +16,23 @@ def create_indexing_pipeline():
|
|||||||
cleaner = DocumentCleaner()
|
cleaner = DocumentCleaner()
|
||||||
splitter = DocumentSplitter(split_by="sentence", split_length=1)
|
splitter = DocumentSplitter(split_by="sentence", split_length=1)
|
||||||
writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
|
writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
|
||||||
|
|
||||||
indexing_pipeline = Pipeline()
|
indexing_pipeline = Pipeline()
|
||||||
indexing_pipeline.add_component("converter", converter)
|
indexing_pipeline.add_component("converter", converter)
|
||||||
indexing_pipeline.add_component("cleaner", cleaner)
|
indexing_pipeline.add_component("cleaner", cleaner)
|
||||||
indexing_pipeline.add_component("splitter", splitter)
|
indexing_pipeline.add_component("splitter", splitter)
|
||||||
indexing_pipeline.add_component("writer", writer)
|
indexing_pipeline.add_component("writer", writer)
|
||||||
|
|
||||||
indexing_pipeline.connect("converter", "cleaner")
|
indexing_pipeline.connect("converter", "cleaner")
|
||||||
indexing_pipeline.connect("cleaner", "splitter")
|
indexing_pipeline.connect("cleaner", "splitter")
|
||||||
indexing_pipeline.connect("splitter", "writer")
|
indexing_pipeline.connect("splitter", "writer")
|
||||||
|
|
||||||
return indexing_pipeline, document_store
|
return indexing_pipeline, document_store
|
||||||
|
|
||||||
def process_pdfs(pdf_directory, indexing_pipeline):
|
def process_pdfs(pdf_directory, indexing_pipeline):
|
||||||
papers_dir = Path(pdf_directory)
|
papers_dir = Path(pdf_directory)
|
||||||
pdf_files = list(papers_dir.glob("*.pdf"))
|
pdf_files = list(papers_dir.glob("*.pdf"))
|
||||||
for pdf_file in pdf_files:
|
for pdf_file in pdf_files:
|
||||||
try:
|
try:
|
||||||
indexing_pipeline.run({"converter": {"sources": [pdf_file]}})
|
indexing_pipeline.run({"converter": {"sources": [pdf_file]}})
|
||||||
except:
|
except:
|
||||||
@ -44,8 +45,8 @@ def save_to_json(document_store, output_path):
|
|||||||
json.dump(docs_list, f, ensure_ascii=False, indent=2)
|
json.dump(docs_list, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
PDF_DIRECTORY = #PATH_TO_YOUR_PDF_DIRECTORY#
|
PDF_DIRECTORY = "#PATH_TO_YOUR_PDF_DIRECTORY#"
|
||||||
OUTPUT_JSON = #PATH_TO_YOUR_JSON#
|
OUTPUT_JSON = "#PATH_TO_YOUR_JSON#"
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
indexing_pipeline, document_store = create_indexing_pipeline()
|
indexing_pipeline, document_store = create_indexing_pipeline()
|
||||||
|
@ -4,4 +4,4 @@ hnswlib
|
|||||||
rank_bm25
|
rank_bm25
|
||||||
FlagEmbedding
|
FlagEmbedding
|
||||||
haystack
|
haystack
|
||||||
haystack-integrations
|
haystack-integrations
|
||||||
|
Loading…
Reference in New Issue
Block a user