diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index eb9089bd3..14e4277bd 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -14,7 +14,7 @@ from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, MODEL_NAME_PATH, VS_ROOT_PATH from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator -from pilot.source_embedding.pdf_embedding import PDFEmbedding +from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR @@ -539,12 +539,11 @@ def knowledge_embedding_store(vs_id, files): filename = os.path.split(file.name)[-1] shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)) - knowledge_embedding = PDFEmbedding(file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), model_name=MODEL_NAME_PATH, - vector_store_config={"vector_store_name": vs_id, + knowledge_embedding = KnowledgeEmbedding.knowledge_embedding(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), MODEL_NAME_PATH, {"vector_store_name": vs_id, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) knowledge_embedding.source_embedding() logger.info("knowledge embedding success") - return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename + ".vectordb") + return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") if __name__ == "__main__": diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py new file mode 100644 index 000000000..adc8c430f --- /dev/null +++ b/pilot/source_embedding/knowledge_embedding.py @@ -0,0 +1,20 @@ +from pilot.source_embedding.csv_embedding import CSVEmbedding +from pilot.source_embedding.markdown_embedding import MarkdownEmbedding +from pilot.source_embedding.pdf_embedding import PDFEmbedding + + +class KnowledgeEmbedding: + @staticmethod + def knowledge_embedding(file_path:str, model_name, vector_store_config): + if file_path.endswith(".pdf"): + embedding = PDFEmbedding(file_path=file_path, model_name=model_name, + vector_store_config=vector_store_config) + elif file_path.endswith(".md"): + embedding = MarkdownEmbedding(file_path=file_path, model_name=model_name, + vector_store_config=vector_store_config) + + elif file_path.endswith(".csv"): + embedding = CSVEmbedding(file_path=file_path, model_name=model_name, + vector_store_config=vector_store_config) + + return embedding \ No newline at end of file diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py new file mode 100644 index 000000000..66f8c5aa5 --- /dev/null +++ b/pilot/source_embedding/markdown_embedding.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from bs4 import BeautifulSoup +from langchain.document_loaders import TextLoader +from langchain.schema import Document +import markdown + +from pilot.source_embedding import SourceEmbedding, register + + +class MarkdownEmbedding(SourceEmbedding): + """markdown embedding for read markdown document.""" + + def __init__(self, file_path, model_name, vector_store_config): + """Initialize with markdown path.""" + self.file_path = file_path + self.model_name = model_name + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from markdown path.""" + loader = TextLoader(self.file_path) + return loader.load() + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + content = markdown.markdown(d.page_content) + soup = BeautifulSoup(content, 'html.parser') + for tag in soup(['!doctype', 'meta', 'i.fa']): + tag.extract() + documents[i].page_content = soup.get_text() + documents[i].page_content = documents[i].page_content.replace(" ", "").replace("\n", " ") + i += 1 + return documents + + + diff --git a/pilot/source_embedding/search_milvus.py b/pilot/source_embedding/search_milvus.py index 25acff097..09d4a4cb0 100644 --- a/pilot/source_embedding/search_milvus.py +++ b/pilot/source_embedding/search_milvus.py @@ -49,7 +49,10 @@ model_name = "/Users/chenketing/Desktop/project/all-MiniLM-L6-v2" embeddings = HuggingFaceEmbeddings(model_name=model_name) # text_embeddings = Text2Vectors() -mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_c"}) +mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"}) + +mivuls.insert(["textc","tezt2"]) +print("success") # mivuls.from_texts(texts=data, embedding=embeddings) # docs,