From a06342425b2a67d99088c2e6893be82d7bc69370 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 27 Jun 2023 15:29:13 +0800 Subject: [PATCH] feat: knowledge backend management 1.knowledge_type.py 2.knowledge backend api --- pilot/embedding_engine/knowledge_embedding.py | 66 ++++++++++--------- pilot/embedding_engine/knowledge_type.py | 62 +++++++++++++++++ pilot/{server => openapi}/api_v1/api_v1.py | 2 +- .../api_v1/api_view_model.py | 0 .../knowledge/document_chunk_dao.py | 0 .../knowledge/knowledge_controller.py | 35 ++++++---- .../knowledge/knowledge_document_dao.py | 2 + .../knowledge/knowledge_service.py | 21 +++--- .../knowledge/knowledge_space_dao.py | 2 +- .../knowledge/request/knowledge_request.py | 2 + pilot/scene/chat_knowledge/url/chat.py | 5 +- pilot/server/__init__.py | 2 + pilot/server/llmserver.py | 4 +- pilot/server/webserver.py | 9 ++- pilot/source_embedding/knowledge_embedding.py | 36 +--------- tools/knowlege_init.py | 5 +- 16 files changed, 153 insertions(+), 100 deletions(-) create mode 100644 pilot/embedding_engine/knowledge_type.py rename pilot/{server => openapi}/api_v1/api_v1.py (98%) rename pilot/{server => openapi}/api_v1/api_view_model.py (100%) rename pilot/{server => openapi}/knowledge/document_chunk_dao.py (100%) rename pilot/{server => openapi}/knowledge/knowledge_controller.py (77%) rename pilot/{server => openapi}/knowledge/knowledge_document_dao.py (99%) rename pilot/{server => openapi}/knowledge/knowledge_service.py (89%) rename pilot/{server => openapi}/knowledge/knowledge_space_dao.py (97%) rename pilot/{server => openapi}/knowledge/request/knowledge_request.py (97%) diff --git a/pilot/embedding_engine/knowledge_embedding.py b/pilot/embedding_engine/knowledge_embedding.py index 132b360fc..3171cee89 100644 --- a/pilot/embedding_engine/knowledge_embedding.py +++ b/pilot/embedding_engine/knowledge_embedding.py @@ -5,6 +5,7 @@ from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config from pilot.embedding_engine.csv_embedding import CSVEmbedding +from pilot.embedding_engine.knowledge_type import get_knowledge_embedding from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding from pilot.embedding_engine.pdf_embedding import PDFEmbedding from pilot.embedding_engine.ppt_embedding import PPTEmbedding @@ -14,16 +15,16 @@ from pilot.vector_store.connector import VectorStoreConnector CFG = Config() -KnowledgeEmbeddingType = { - ".txt": (MarkdownEmbedding, {}), - ".md": (MarkdownEmbedding, {}), - ".pdf": (PDFEmbedding, {}), - ".doc": (WordEmbedding, {}), - ".docx": (WordEmbedding, {}), - ".csv": (CSVEmbedding, {}), - ".ppt": (PPTEmbedding, {}), - ".pptx": (PPTEmbedding, {}), -} +# KnowledgeEmbeddingType = { +# ".txt": (MarkdownEmbedding, {}), +# ".md": (MarkdownEmbedding, {}), +# ".pdf": (PDFEmbedding, {}), +# ".doc": (WordEmbedding, {}), +# ".docx": (WordEmbedding, {}), +# ".csv": (CSVEmbedding, {}), +# ".ppt": (PPTEmbedding, {}), +# ".pptx": (PPTEmbedding, {}), +# } class KnowledgeEmbedding: @@ -31,14 +32,14 @@ class KnowledgeEmbedding: self, model_name, vector_store_config, - file_type: Optional[str] = "default", - file_path: Optional[str] = None, + knowledge_type: Optional[str], + knowledge_source: Optional[str] = None, ): - """Initialize with Loader url, model_name, vector_store_config""" - self.file_path = file_path + """Initialize with knowledge embedding client, model_name, vector_store_config, knowledge_type, knowledge_source""" + self.knowledge_source = knowledge_source self.model_name = model_name self.vector_store_config = vector_store_config - self.file_type = file_type + self.knowledge_type = knowledge_type self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.vector_store_config["embeddings"] = self.embeddings @@ -55,23 +56,24 @@ class KnowledgeEmbedding: return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): - if self.file_type == "url": - embedding = URLEmbedding( - file_path=self.file_path, - vector_store_config=self.vector_store_config, - ) - return embedding - extension = "." + self.file_path.rsplit(".", 1)[-1] - if extension in KnowledgeEmbeddingType: - knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] - embedding = knowledge_class( - self.file_path, - vector_store_config=self.vector_store_config, - **knowledge_args - ) - return embedding - raise ValueError(f"Unsupported knowledge file type '{extension}'") - return embedding + return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config) + # if self.file_type == "url": + # embedding = URLEmbedding( + # file_path=self.file_path, + # vector_store_config=self.vector_store_config, + # ) + # return embedding + # extension = "." + self.file_path.rsplit(".", 1)[-1] + # if extension in KnowledgeEmbeddingType: + # knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] + # embedding = knowledge_class( + # self.file_path, + # vector_store_config=self.vector_store_config, + # **knowledge_args + # ) + # return embedding + # raise ValueError(f"Unsupported knowledge file type '{extension}'") + # return embedding def similar_search(self, text, topk): vector_client = VectorStoreConnector( diff --git a/pilot/embedding_engine/knowledge_type.py b/pilot/embedding_engine/knowledge_type.py new file mode 100644 index 000000000..a2b3d563c --- /dev/null +++ b/pilot/embedding_engine/knowledge_type.py @@ -0,0 +1,62 @@ +from enum import Enum + +from pilot.embedding_engine.csv_embedding import CSVEmbedding +from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding +from pilot.embedding_engine.pdf_embedding import PDFEmbedding +from pilot.embedding_engine.ppt_embedding import PPTEmbedding +from pilot.embedding_engine.string_embedding import StringEmbedding +from pilot.embedding_engine.url_embedding import URLEmbedding +from pilot.embedding_engine.word_embedding import WordEmbedding + +DocumentEmbeddingType = { + ".txt": (MarkdownEmbedding, {}), + ".md": (MarkdownEmbedding, {}), + ".pdf": (PDFEmbedding, {}), + ".doc": (WordEmbedding, {}), + ".docx": (WordEmbedding, {}), + ".csv": (CSVEmbedding, {}), + ".ppt": (PPTEmbedding, {}), + ".pptx": (PPTEmbedding, {}), +} + + +class KnowledgeType(Enum): + DOCUMENT = "DOCUMENT" + URL = "URL" + TEXT = "TEXT" + OSS = "OSS" + NOTION = "NOTION" + + +def get_knowledge_embedding(knowledge_type, knowledge_source, vector_store_config): + match knowledge_type: + case KnowledgeType.DOCUMENT.value: + extension = "." + knowledge_source.rsplit(".", 1)[-1] + if extension in DocumentEmbeddingType: + knowledge_class, knowledge_args = DocumentEmbeddingType[extension] + embedding = knowledge_class( + knowledge_source, + vector_store_config=vector_store_config, + **knowledge_args, + ) + return embedding + raise ValueError(f"Unsupported knowledge document type '{extension}'") + case KnowledgeType.URL.value: + embedding = URLEmbedding( + file_path=knowledge_source, + vector_store_config=vector_store_config, + ) + return embedding + case KnowledgeType.TEXT.value: + embedding = StringEmbedding( + file_path=knowledge_source, + vector_store_config=vector_store_config, + ) + return embedding + case KnowledgeType.OSS.value: + raise Exception("OSS have not integrate") + case KnowledgeType.NOTION.value: + raise Exception("NOTION have not integrate") + + case _: + raise Exception("unknown knowledge type") diff --git a/pilot/server/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py similarity index 98% rename from pilot/server/api_v1/api_v1.py rename to pilot/openapi/api_v1/api_v1.py index 19f4e765c..dfe7e1da6 100644 --- a/pilot/server/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -9,7 +9,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from typing import List -from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo +from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo from pilot.configs.config import Config from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene diff --git a/pilot/server/api_v1/api_view_model.py b/pilot/openapi/api_v1/api_view_model.py similarity index 100% rename from pilot/server/api_v1/api_view_model.py rename to pilot/openapi/api_v1/api_view_model.py diff --git a/pilot/server/knowledge/document_chunk_dao.py b/pilot/openapi/knowledge/document_chunk_dao.py similarity index 100% rename from pilot/server/knowledge/document_chunk_dao.py rename to pilot/openapi/knowledge/document_chunk_dao.py diff --git a/pilot/server/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py similarity index 77% rename from pilot/server/knowledge/knowledge_controller.py rename to pilot/openapi/knowledge/knowledge_controller.py index 123e9238b..cf368136f 100644 --- a/pilot/server/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -1,26 +1,22 @@ -import json -import os -import sys -from typing import List +from tempfile import NamedTemporaryFile + +from fastapi import APIRouter, File, UploadFile -from fastapi import APIRouter from langchain.embeddings import HuggingFaceEmbeddings -ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(ROOT_PATH) from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG -from pilot.server.api_v1.api_view_model import Result +from pilot.openapi.api_v1.api_view_model import Result from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding -from pilot.server.knowledge.knowledge_service import KnowledgeService -from pilot.server.knowledge.request.knowledge_request import ( +from pilot.openapi.knowledge.knowledge_service import KnowledgeService +from pilot.openapi.knowledge.request.knowledge_request import ( KnowledgeQueryRequest, KnowledgeQueryResponse, KnowledgeDocumentRequest, DocumentSyncRequest, ChunkQueryRequest, DocumentQueryRequest, ) -from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest +from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest CFG = Config() router = APIRouter() @@ -74,6 +70,21 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): return Result.faild(code="E000X", msg=f"document list error {e}") +@router.post("/knowledge/{space_name}/document/upload") +def document_sync(space_name: str, file: UploadFile = File(...)): + print(f"/document/upload params: {space_name}") + try: + with NamedTemporaryFile(delete=False) as tmp: + tmp.write(file.read()) + tmp_path = tmp.name + tmp_content = tmp.read() + + return {"file_path": tmp_path, "file_content": tmp_content} + Result.succ([]) + except Exception as e: + return Result.faild(code="E000X", msg=f"document sync error {e}") + + @router.post("/knowledge/{space_name}/document/sync") def document_sync(space_name: str, request: DocumentSyncRequest): print(f"Received params: {space_name}, {request}") @@ -90,7 +101,7 @@ def document_sync(space_name: str, request: DocumentSyncRequest): def document_list(space_name: str, query_request: ChunkQueryRequest): print(f"/document/list params: {space_name}, {query_request}") try: - Result.succ(knowledge_space_service.get_document_chunks( + return Result.succ(knowledge_space_service.get_document_chunks( query_request )) except Exception as e: diff --git a/pilot/server/knowledge/knowledge_document_dao.py b/pilot/openapi/knowledge/knowledge_document_dao.py similarity index 99% rename from pilot/server/knowledge/knowledge_document_dao.py rename to pilot/openapi/knowledge/knowledge_document_dao.py index 497ad6b21..276907b69 100644 --- a/pilot/server/knowledge/knowledge_document_dao.py +++ b/pilot/openapi/knowledge/knowledge_document_dao.py @@ -9,6 +9,8 @@ from pilot.configs.config import Config CFG = Config() Base = declarative_base() + + class KnowledgeDocumentEntity(Base): __tablename__ = 'knowledge_document' id = Column(Integer, primary_key=True) diff --git a/pilot/server/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py similarity index 89% rename from pilot/server/knowledge/knowledge_service.py rename to pilot/openapi/knowledge/knowledge_service.py index 7372276bd..b6713dff4 100644 --- a/pilot/server/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -4,14 +4,15 @@ from datetime import datetime from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding +from pilot.embedding_engine.knowledge_type import KnowledgeType from pilot.logs import logger -from pilot.server.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao -from pilot.server.knowledge.knowledge_document_dao import ( +from pilot.openapi.knowledge.document_chunk_dao import DocumentChunkEntity, DocumentChunkDao +from pilot.openapi.knowledge.knowledge_document_dao import ( KnowledgeDocumentDao, KnowledgeDocumentEntity, ) -from pilot.server.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity -from pilot.server.knowledge.request.knowledge_request import ( +from pilot.openapi.knowledge.knowledge_space_dao import KnowledgeSpaceDao, KnowledgeSpaceEntity +from pilot.openapi.knowledge.request.knowledge_request import ( KnowledgeSpaceRequest, KnowledgeDocumentRequest, DocumentQueryRequest, ChunkQueryRequest, ) @@ -24,6 +25,7 @@ document_chunk_dao = DocumentChunkDao() CFG=Config() + class SyncStatus(Enum): TODO = "TODO" FAILED = "FAILED" @@ -65,7 +67,7 @@ class KnowledgeService: chunk_size=0, status=SyncStatus.TODO.name, last_sync=datetime.now(), - content="", + content=request.content, ) knowledge_document_dao.create_knowledge_document(document) return True @@ -99,8 +101,8 @@ class KnowledgeService: space=space_name, ) doc = knowledge_document_dao.get_knowledge_documents(query)[0] - client = KnowledgeEmbedding(file_path=doc.doc_name, - file_type="url", + client = KnowledgeEmbedding(knowledge_source=doc.content, + knowledge_type=doc.doc_type.upper(), model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config={ "vector_store_name": space_name, @@ -127,11 +129,6 @@ class KnowledgeService: ) for chunk_doc in chunk_docs] document_chunk_dao.create_documents_chunks(chunk_entities) - #update document status - # doc.status = SyncStatus.RUNNING.name - # doc.chunk_size = len(chunk_docs) - # doc.gmt_modified = datetime.now() - # knowledge_document_dao.update_knowledge_document(doc) return True diff --git a/pilot/server/knowledge/knowledge_space_dao.py b/pilot/openapi/knowledge/knowledge_space_dao.py similarity index 97% rename from pilot/server/knowledge/knowledge_space_dao.py rename to pilot/openapi/knowledge/knowledge_space_dao.py index 20bb578a8..31894d6ac 100644 --- a/pilot/server/knowledge/knowledge_space_dao.py +++ b/pilot/openapi/knowledge/knowledge_space_dao.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.declarative import declarative_base from pilot.configs.config import Config -from pilot.server.knowledge.request.knowledge_request import KnowledgeSpaceRequest +from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from sqlalchemy.orm import sessionmaker CFG = Config() diff --git a/pilot/server/knowledge/request/knowledge_request.py b/pilot/openapi/knowledge/request/knowledge_request.py similarity index 97% rename from pilot/server/knowledge/request/knowledge_request.py rename to pilot/openapi/knowledge/request/knowledge_request.py index 039d84727..bf68b06f1 100644 --- a/pilot/server/knowledge/request/knowledge_request.py +++ b/pilot/openapi/knowledge/request/knowledge_request.py @@ -29,6 +29,8 @@ class KnowledgeDocumentRequest(BaseModel): doc_name: str """doc_type: doc type""" doc_type: str + """content: content""" + content: str """text_chunk_size: text_chunk_size""" # text_chunk_size: int diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index acf7bbaeb..b65bfe55d 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -1,3 +1,4 @@ +from pilot.embedding_engine.knowledge_type import KnowledgeType from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -44,8 +45,8 @@ class ChatUrlKnowledge(BaseChat): self.knowledge_embedding_client = KnowledgeEmbedding( model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config=vector_store_config, - file_type="url", - file_path=url, + knowledge_type=KnowledgeType.URL.value, + knowledge_source=url, ) # url soruce in vector diff --git a/pilot/server/__init__.py b/pilot/server/__init__.py index ac72fc637..0435c3679 100644 --- a/pilot/server/__init__.py +++ b/pilot/server/__init__.py @@ -13,3 +13,5 @@ if "pytest" in sys.argv or "pytest" in sys.modules or os.getenv("CI"): load_dotenv(verbose=True, override=True) del load_dotenv + + diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 1e6278e84..3030d1fdc 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -23,7 +23,6 @@ from pilot.configs.model_config import * from pilot.model.llm_out.vicuna_base_llm import get_embeddings from pilot.model.loader import ModelLoader from pilot.server.chat_adapter import get_llm_chat_adapter -from knowledge.knowledge_controller import router CFG = Config() @@ -106,6 +105,7 @@ worker = ModelWorker( ) app = FastAPI() +from pilot.openapi.knowledge.knowledge_controller import router app.include_router(router) origins = [ @@ -119,7 +119,7 @@ app.add_middleware( allow_origins=origins, allow_credentials=True, allow_methods=["*"], - allow_headers=["*"], + allow_headers=["*"] ) class PromptRequest(BaseModel): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index be6224e3b..a5650dbf0 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -11,6 +11,8 @@ import uuid import gradio as gr +from pilot.embedding_engine.knowledge_type import KnowledgeType + ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -57,7 +59,7 @@ from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.staticfiles import StaticFiles -from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler +from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler # 加载插件 CFG = Config() @@ -652,8 +654,9 @@ def knowledge_embedding_store(vs_id, files): file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) ) knowledge_embedding_client = KnowledgeEmbedding( - file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), - model_name=LLM_MODEL_CONFIG["text2vec"], + knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), + knowledge_type=KnowledgeType.DOCUMENT.value, + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 97b515897..6caccc474 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -4,27 +4,11 @@ from chromadb.errors import NotEnoughElementsException from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.source_embedding.csv_embedding import CSVEmbedding -from pilot.source_embedding.markdown_embedding import MarkdownEmbedding -from pilot.source_embedding.pdf_embedding import PDFEmbedding -from pilot.source_embedding.ppt_embedding import PPTEmbedding -from pilot.source_embedding.url_embedding import URLEmbedding -from pilot.source_embedding.word_embedding import WordEmbedding +from pilot.embedding_engine.knowledge_type import get_knowledge_embedding from pilot.vector_store.connector import VectorStoreConnector CFG = Config() -KnowledgeEmbeddingType = { - ".txt": (MarkdownEmbedding, {}), - ".md": (MarkdownEmbedding, {}), - ".pdf": (PDFEmbedding, {}), - ".doc": (WordEmbedding, {}), - ".docx": (WordEmbedding, {}), - ".csv": (CSVEmbedding, {}), - ".ppt": (PPTEmbedding, {}), - ".pptx": (PPTEmbedding, {}), -} - class KnowledgeEmbedding: def __init__( @@ -54,23 +38,7 @@ class KnowledgeEmbedding: return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): - if self.file_type == "url": - embedding = URLEmbedding( - file_path=self.file_path, - vector_store_config=self.vector_store_config, - ) - return embedding - extension = "." + self.file_path.rsplit(".", 1)[-1] - if extension in KnowledgeEmbeddingType: - knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] - embedding = knowledge_class( - self.file_path, - vector_store_config=self.vector_store_config, - **knowledge_args, - ) - return embedding - raise ValueError(f"Unsupported knowledge file type '{extension}'") - return embedding + return get_knowledge_embedding(self.file_type.upper(), self.file_path, self.vector_store_config) def similar_search(self, text, topk): vector_client = VectorStoreConnector( diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index e72c13aeb..56889bd7f 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -4,6 +4,8 @@ import argparse import os import sys +from pilot.embedding_engine.knowledge_type import KnowledgeType + sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) from pilot.configs.config import Config @@ -30,7 +32,8 @@ class LocalKnowledgeInit: filename = os.path.join(root, file) # docs = self._load_file(filename) ke = KnowledgeEmbedding( - file_path=filename, + knowledge_source=filename, + knowledge_type=KnowledgeType.DOCUMENT.value, model_name=self.model_name, vector_store_config=self.vector_store_config, )