mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-29 14:30:40 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
eadcad8749
commit
3fdd4e7733
@ -1,10 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from FlagEmbedding import FlagAutoModel
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
|
||||
def load_model(model_name="BAAI/bge-base-en-v1.5", use_fp16=True):
|
||||
@ -12,27 +10,32 @@ def load_model(model_name="BAAI/bge-base-en-v1.5", use_fp16=True):
|
||||
model_name,
|
||||
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
|
||||
# device='cpu', # Uncomment this line if you want to use GPU.
|
||||
use_fp16=use_fp16
|
||||
use_fp16=use_fp16,
|
||||
)
|
||||
|
||||
|
||||
def load_data(file_path):
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
print("Error loading data from", file_path)
|
||||
return []
|
||||
|
||||
|
||||
def extract_texts(data):
|
||||
return [doc.get("content", '').strip() for doc in data]
|
||||
return [doc.get("content", "").strip() for doc in data]
|
||||
|
||||
|
||||
def generate_embeddings(model, texts):
|
||||
return np.array(model.encode(texts))
|
||||
|
||||
|
||||
def save_embeddings(embeddings, output_path):
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
np.save(output_path, embeddings)
|
||||
|
||||
|
||||
def load_embeddings(file_path):
|
||||
try:
|
||||
return np.load(file_path)
|
||||
@ -47,13 +50,10 @@ def main():
|
||||
"json_path": "#PATH_TO_YOUR_JSON.json#",
|
||||
"embedding_path": "#PATH_TO_YOUR_EMBEDDING.npy#",
|
||||
"use_fp16": True,
|
||||
"use_precomputed_embeddings": False
|
||||
"use_precomputed_embeddings": False,
|
||||
}
|
||||
|
||||
model = load_model(
|
||||
model_name=config["model_name"],
|
||||
use_fp16=config["use_fp16"]
|
||||
)
|
||||
model = load_model(model_name=config["model_name"], use_fp16=config["use_fp16"])
|
||||
|
||||
if config["use_precomputed_embeddings"]:
|
||||
embeddings = load_embeddings(config["embedding_path"])
|
||||
@ -68,7 +68,7 @@ def main():
|
||||
embeddings = generate_embeddings(model, texts)
|
||||
save_embeddings(embeddings, config["embedding_path"])
|
||||
|
||||
##### Test demo with simple KNN cosine_similarity
|
||||
##### Test demo with simple KNN cosine_similarity
|
||||
# query='This is a test query to find relevant documents.'
|
||||
# query_embedding=np.array(model.encode(query))
|
||||
# similarity_scores = cosine_similarity([query_embedding], embeddings)
|
||||
@ -76,5 +76,6 @@ def main():
|
||||
|
||||
return embeddings
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from haystack import Document, Pipeline
|
||||
from haystack import Pipeline
|
||||
from haystack.components.converters import PyPDFToDocument
|
||||
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
|
||||
from haystack.components.writers import DocumentWriter
|
||||
@ -29,6 +29,7 @@ def create_indexing_pipeline():
|
||||
|
||||
return indexing_pipeline, document_store
|
||||
|
||||
|
||||
def process_pdfs(pdf_directory, indexing_pipeline):
|
||||
papers_dir = Path(pdf_directory)
|
||||
pdf_files = list(papers_dir.glob("*.pdf"))
|
||||
@ -38,20 +39,23 @@ def process_pdfs(pdf_directory, indexing_pipeline):
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def save_to_json(document_store, output_path):
|
||||
all_documents = document_store.filter_documents()
|
||||
docs_list = [doc.to_dict() for doc in all_documents]
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(docs_list, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
PDF_DIRECTORY = "#PATH_TO_YOUR_PDF_DIRECTORY#"
|
||||
OUTPUT_JSON = "#PATH_TO_YOUR_JSON#"
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
time.time()
|
||||
indexing_pipeline, document_store = create_indexing_pipeline()
|
||||
process_pdfs(PDF_DIRECTORY, indexing_pipeline)
|
||||
save_to_json(document_store, OUTPUT_JSON)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user