mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 14:34:28 +00:00
feat: new knowledge chat scene
1.add new knowledge chat scene 2.format file format
This commit is contained in:
parent
5599bb63ea
commit
fa73a4fae6
@ -4,28 +4,11 @@ from chromadb.errors import NotEnoughElementsException
|
|||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.embedding_engine.csv_embedding import CSVEmbedding
|
|
||||||
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
from pilot.embedding_engine.knowledge_type import get_knowledge_embedding
|
||||||
from pilot.embedding_engine.markdown_embedding import MarkdownEmbedding
|
|
||||||
from pilot.embedding_engine.pdf_embedding import PDFEmbedding
|
|
||||||
from pilot.embedding_engine.ppt_embedding import PPTEmbedding
|
|
||||||
from pilot.embedding_engine.url_embedding import URLEmbedding
|
|
||||||
from pilot.embedding_engine.word_embedding import WordEmbedding
|
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
# KnowledgeEmbeddingType = {
|
|
||||||
# ".txt": (MarkdownEmbedding, {}),
|
|
||||||
# ".md": (MarkdownEmbedding, {}),
|
|
||||||
# ".pdf": (PDFEmbedding, {}),
|
|
||||||
# ".doc": (WordEmbedding, {}),
|
|
||||||
# ".docx": (WordEmbedding, {}),
|
|
||||||
# ".csv": (CSVEmbedding, {}),
|
|
||||||
# ".ppt": (PPTEmbedding, {}),
|
|
||||||
# ".pptx": (PPTEmbedding, {}),
|
|
||||||
# }
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeEmbedding:
|
class KnowledgeEmbedding:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -57,23 +40,6 @@ class KnowledgeEmbedding:
|
|||||||
|
|
||||||
def init_knowledge_embedding(self):
|
def init_knowledge_embedding(self):
|
||||||
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
return get_knowledge_embedding(self.knowledge_type, self.knowledge_source, self.vector_store_config)
|
||||||
# if self.file_type == "url":
|
|
||||||
# embedding = URLEmbedding(
|
|
||||||
# file_path=self.file_path,
|
|
||||||
# vector_store_config=self.vector_store_config,
|
|
||||||
# )
|
|
||||||
# return embedding
|
|
||||||
# extension = "." + self.file_path.rsplit(".", 1)[-1]
|
|
||||||
# if extension in KnowledgeEmbeddingType:
|
|
||||||
# knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
|
||||||
# embedding = knowledge_class(
|
|
||||||
# self.file_path,
|
|
||||||
# vector_store_config=self.vector_store_config,
|
|
||||||
# **knowledge_args
|
|
||||||
# )
|
|
||||||
# return embedding
|
|
||||||
# raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
|
||||||
# return embedding
|
|
||||||
|
|
||||||
def similar_search(self, text, topk):
|
def similar_search(self, text, topk):
|
||||||
vector_client = VectorStoreConnector(
|
vector_client = VectorStoreConnector(
|
||||||
|
@ -14,6 +14,8 @@ from typing import List
|
|||||||
|
|
||||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from pilot.openapi.knowledge.knowledge_service import KnowledgeService
|
||||||
|
from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
@ -27,6 +29,7 @@ router = APIRouter()
|
|||||||
CFG = Config()
|
CFG = Config()
|
||||||
CHAT_FACTORY = ChatFactory()
|
CHAT_FACTORY = ChatFactory()
|
||||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||||
|
knowledge_service = KnowledgeService()
|
||||||
|
|
||||||
|
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
@ -101,9 +104,8 @@ def plugins_select_info():
|
|||||||
|
|
||||||
|
|
||||||
def knowledge_list():
|
def knowledge_list():
|
||||||
knowledge: dict = {}
|
request = KnowledgeSpaceRequest()
|
||||||
### TODO
|
return knowledge_service.get_knowledge_space(request)
|
||||||
return knowledge
|
|
||||||
|
|
||||||
|
|
||||||
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
@router.post('/v1/chat/mode/params/list', response_model=Result[dict])
|
||||||
@ -164,7 +166,7 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
|||||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||||
chat_param.update("plugin_selector", dialogue.select_param)
|
chat_param.update("plugin_selector", dialogue.select_param)
|
||||||
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
elif ChatScene.ChatKnowledge == dialogue.chat_mode:
|
||||||
chat_param.update("knowledge_name", dialogue.select_param)
|
chat_param.update("knowledge_space", dialogue.select_param)
|
||||||
|
|
||||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
if not chat.prompt_template.stream_out:
|
if not chat.prompt_template.stream_out:
|
||||||
|
0
pilot/scene/chat_knowledge/v1/__init__.py
Normal file
0
pilot/scene/chat_knowledge/v1/__init__.py
Normal file
66
pilot/scene/chat_knowledge/v1/chat.py
Normal file
66
pilot/scene/chat_knowledge/v1/chat.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from chromadb.errors import NoIndexException
|
||||||
|
|
||||||
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.sql_database import Database
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
from pilot.common.markdown_text import (
|
||||||
|
generate_markdown_table,
|
||||||
|
generate_htm_table,
|
||||||
|
datas_to_table_html,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pilot.configs.model_config import (
|
||||||
|
DATASETS_DIR,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pilot.scene.chat_knowledge.default.prompt import prompt
|
||||||
|
from pilot.embedding_engine.knowledge_embedding import KnowledgeEmbedding
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatKnowledge(BaseChat):
|
||||||
|
chat_scene: str = ChatScene.ChatKnowledge.value
|
||||||
|
|
||||||
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
|
def __init__(self, chat_session_id, user_input, knowledge_space):
|
||||||
|
""" """
|
||||||
|
super().__init__(
|
||||||
|
chat_mode=ChatScene.ChatKnowledge,
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
current_user_input=user_input,
|
||||||
|
)
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": knowledge_space,
|
||||||
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
}
|
||||||
|
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
|
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
try:
|
||||||
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
|
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||||
|
)
|
||||||
|
context = [d.page_content for d in docs]
|
||||||
|
context = context[:2000]
|
||||||
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
|
except NoIndexException:
|
||||||
|
raise ValueError(
|
||||||
|
"you have no knowledge space, please add your knowledge space"
|
||||||
|
)
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_type(self) -> str:
|
||||||
|
return ChatScene.ChatKnowledge.value
|
19
pilot/scene/chat_knowledge/v1/out_parser.py
Normal file
19
pilot/scene/chat_knowledge/v1/out_parser.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, NamedTuple
|
||||||
|
import pandas as pd
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
|
||||||
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
|
return model_out_text
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
pass
|
54
pilot/scene/chat_knowledge/v1/prompt.py
Normal file
54
pilot/scene/chat_knowledge/v1/prompt.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import builtins
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
|
||||||
|
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||||
|
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||||
|
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||||
|
已知内容:
|
||||||
|
{context}
|
||||||
|
问题:
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||||
|
known information:
|
||||||
|
{context}
|
||||||
|
question:
|
||||||
|
{question}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = (
|
||||||
|
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatKnowledge.value,
|
||||||
|
input_variables=["context", "question"],
|
||||||
|
response_format=None,
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=NormalChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
@ -11,6 +11,8 @@ import uuid
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
@ -664,7 +666,8 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)
|
||||||
)
|
)
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
knowledge_source=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||||
|
knowledge_type=KnowledgeType.DOCUMENT.value,
|
||||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
|
Loading…
Reference in New Issue
Block a user