diff --git a/pilot/embedding_engine/knowledge_embedding.py b/pilot/embedding_engine/knowledge_embedding.py index 3171cee89..a1d148130 100644 --- a/pilot/embedding_engine/knowledge_embedding.py +++ b/pilot/embedding_engine/knowledge_embedding.py @@ -4,28 +4,11 @@ from chromadb.errors import NotEnoughElementsException from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.embedding_engine.csv_embedding import CSVEmbedding from pilot.embedding_engine.knowledge_type import get_knowledge_embedding -from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding -from pilot.embedding_engine.pdf_embedding import PDFEmbedding -from pilot.embedding_engine.ppt_embedding import PPTEmbedding -from pilot.embedding_engine.url_embedding import URLEmbedding -from pilot.embedding_engine.word_embedding import WordEmbedding from pilot.vector_store.connector import VectorStoreConnector CFG = Config() -# KnowledgeEmbeddingType = { -# ".txt": (MarkdownEmbedding, {}), -# ".md": (MarkdownEmbedding, {}), -# ".pdf": (PDFEmbedding, {}), -# ".doc": (WordEmbedding, {}), -# ".docx": (WordEmbedding, {}), -# ".csv": (CSVEmbedding, {}), -# ".ppt": (PPTEmbedding, {}), -# ".pptx": (PPTEmbedding, {}), -# } - class KnowledgeEmbedding: def __init__( @@ -57,23 +40,6 @@ class KnowledgeEmbedding: def init_knowledge_embedding(self): return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config) - # if self.file_type == "url": - # embedding = URLEmbedding( - # file_path=self.file_path, - # vector_store_config=self.vector_store_config, - # ) - # return embedding - # extension = "." + self.file_path.rsplit(".", 1)[-1] - # if extension in KnowledgeEmbeddingType: - # knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] - # embedding = knowledge_class( - # self.file_path, - # vector_store_config=self.vector_store_config, - # **knowledge_args - # ) - # return embedding - # raise ValueError(f"Unsupported knowledge file type '{extension}'") - # return embedding def similar_search(self, text, topk): vector_client = VectorStoreConnector( diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index b75179eaf..376e17ae4 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -14,6 +14,8 @@ from typing import List from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.configs.config import Config +from pilot.openapi.knowledge.knowledge_service import KnowledgeService +from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory @@ -27,6 +29,7 @@ router = APIRouter() CFG = Config() CHAT_FACTORY = ChatFactory() logger = build_logger("api_v1", LOGDIR + "api_v1.log") +knowledge_service = KnowledgeService() async def validation_exception_handler(request: Request, exc: RequestValidationError): @@ -101,9 +104,8 @@ def plugins_select_info(): def knowledge_list(): - knowledge: dict = {} - ### TODO - return knowledge + request = KnowledgeSpaceRequest() + return knowledge_service.get_knowledge_space(request) @router.post('/v1/chat/mode/params/list', response_model=Result[dict]) @@ -164,7 +166,7 @@ async def chat_completions(dialogue: ConversationVo = Body()): elif ChatScene.ChatExecution == dialogue.chat_mode: chat_param.update("plugin_selector", dialogue.select_param) elif ChatScene.ChatKnowledge == dialogue.chat_mode: - chat_param.update("knowledge_name", dialogue.select_param) + chat_param.update("knowledge_space", dialogue.select_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) if not chat.prompt_template.stream_out: diff --git a/pilot/scene/chat_knowledge/v1/__init__.py b/pilot/scene/chat_knowledge/v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py new file mode 100644 index 000000000..321c4d8eb --- /dev/null +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -0,0 +1,66 @@ +from chromadb.errors import NoIndexException + +from pilot.scene.base_chat import BaseChat, logger, headers +from pilot.scene.base import ChatScene +from pilot.common.sql_database import Database +from pilot.configs.config import Config + +from pilot.common.markdown_text import ( + generate_markdown_table, + generate_htm_table, + datas_to_table_html, +) + +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, +) + +from pilot.scene.chat_knowledge.default.prompt import prompt +from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding + +CFG = Config() + + +class ChatKnowledge(BaseChat): + chat_scene: str = ChatScene.ChatKnowledge.value + + """Number of results to return from the query""" + + def __init__(self, chat_session_id, user_input, knowledge_space): + """ """ + super().__init__( + chat_mode=ChatScene.ChatKnowledge, + chat_session_id=chat_session_id, + current_user_input=user_input, + ) + vector_store_config = { + "vector_store_name": knowledge_space, + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + self.knowledge_embedding_client = KnowledgeEmbedding( + model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + vector_store_config=vector_store_config, + ) + + def generate_input_values(self): + try: + docs = self.knowledge_embedding_client.similar_search( + self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE + ) + context = [d.page_content for d in docs] + context = context[:2000] + input_values = {"context": context, "question": self.current_user_input} + except NoIndexException: + raise ValueError( + "you have no knowledge space, please add your knowledge space" + ) + return input_values + + + + @property + def chat_type(self) -> str: + return ChatScene.ChatKnowledge.value diff --git a/pilot/scene/chat_knowledge/v1/out_parser.py b/pilot/scene/chat_knowledge/v1/out_parser.py new file mode 100644 index 000000000..e5edc9b20 --- /dev/null +++ b/pilot/scene/chat_knowledge/v1/out_parser.py @@ -0,0 +1,19 @@ +import json +import re +from abc import ABC, abstractmethod +from typing import Dict, NamedTuple +import pandas as pd +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + + +class NormalChatOutputParser(BaseOutputParser): + def parse_prompt_response(self, model_out_text) -> T: + return model_out_text + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_knowledge/v1/prompt.py b/pilot/scene/chat_knowledge/v1/prompt.py new file mode 100644 index 000000000..0fd9f9ff3 --- /dev/null +++ b/pilot/scene/chat_knowledge/v1/prompt.py @@ -0,0 +1,54 @@ +import builtins +import importlib + +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle + +from pilot.scene.chat_normal.out_parser import NormalChatOutputParser + + +CFG = Config() + +PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. + The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ + + +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 + 已知内容: + {context} + 问题: + {question} +""" +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatKnowledge.value, + input_variables=["context", "question"], + response_format=None, + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=NormalChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), +) + + +CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c2cd9d434..9513c0c5b 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -11,6 +11,8 @@ import uuid import gradio as gr +from pilot.embedding_engine.knowledge_type import KnowledgeType + ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -664,7 +666,8 @@ def knowledge_embedding_store(vs_id, files): file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename) ) knowledge_embedding_client = KnowledgeEmbedding( - file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), + knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), + knowledge_type=KnowledgeType.DOCUMENT.value, model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],