mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 11:39:11 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
0
dbgpt/app/scene/chat_knowledge/v1/__init__.py
Normal file
0
dbgpt/app/scene/chat_knowledge/v1/__init__.py
Normal file
222
dbgpt/app/scene/chat_knowledge/v1/chat.py
Normal file
222
dbgpt/app/scene/chat_knowledge/v1/chat.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import json
|
||||
import os
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
)
|
||||
|
||||
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
||||
from dbgpt.app.knowledge.document_db import (
|
||||
KnowledgeDocumentDao,
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatKnowledge(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatKnowledge.value()
|
||||
"""KBQA Chat Module"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
"""Chat Knowledge Module Initialization
|
||||
Args:
|
||||
- chat_param: Dict
|
||||
- chat_session_id: (str) chat session_id
|
||||
- current_user_input: (str) current user input
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) space name
|
||||
"""
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
self.space_context = self.get_space_context(self.knowledge_space)
|
||||
self.top_k = (
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
if self.space_context is None
|
||||
else int(self.space_context["embedding"]["topk"])
|
||||
)
|
||||
self.recall_score = (
|
||||
CFG.KNOWLEDGE_SEARCH_RECALL_SCORE
|
||||
if self.space_context is None
|
||||
else float(self.space_context["embedding"]["recall_score"])
|
||||
)
|
||||
self.max_token = (
|
||||
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
|
||||
if self.space_context is None or self.space_context.get("prompt") is None
|
||||
else int(self.space_context["prompt"]["max_token"])
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": self.knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
self.knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
self.relations = None
|
||||
self.chunk_dao = DocumentChunkDao()
|
||||
document_dao = KnowledgeDocumentDao()
|
||||
documents = document_dao.get_documents(
|
||||
query=KnowledgeDocumentEntity(space=self.knowledge_space)
|
||||
)
|
||||
if len(documents) > 0:
|
||||
self.document_ids = [document.id for document in documents]
|
||||
|
||||
async def stream_call(self):
|
||||
last_output = None
|
||||
async for output in super().stream_call():
|
||||
last_output = output
|
||||
yield output
|
||||
|
||||
if (
|
||||
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||
and last_output
|
||||
and type(self.relations) == list
|
||||
and len(self.relations) > 0
|
||||
and hasattr(last_output, "text")
|
||||
):
|
||||
last_output.text = (
|
||||
last_output.text + "\n\nrelations:\n\n" + ",".join(self.relations)
|
||||
)
|
||||
reference = f"\n\n{self.parse_source_view(self.chunks_with_score)}"
|
||||
last_output = last_output + reference
|
||||
yield last_output
|
||||
|
||||
def stream_call_reinforce_fn(self, text):
|
||||
"""return reference"""
|
||||
return text + f"\n\n{self.parse_source_view(self.chunks_with_score)}"
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
if self.space_context and self.space_context.get("prompt"):
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
from dbgpt.rag.retriever.reinforce import QueryReinforce
|
||||
|
||||
# query reinforce, get similar queries
|
||||
query_reinforce = QueryReinforce(
|
||||
query=self.current_user_input, model_name=self.llm_model
|
||||
)
|
||||
queries = []
|
||||
if CFG.KNOWLEDGE_SEARCH_REWRITE:
|
||||
queries = await query_reinforce.rewrite()
|
||||
print("rewrite queries:", queries)
|
||||
queries.append(self.current_user_input)
|
||||
from dbgpt._private.chat_util import run_async_tasks
|
||||
|
||||
# similarity search from vector db
|
||||
tasks = [self.execute_similar_search(query) for query in queries]
|
||||
docs_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
candidates_with_scores = reduce(lambda x, y: x + y, docs_with_scores)
|
||||
# candidates document rerank
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker
|
||||
|
||||
ranker = DefaultRanker(self.top_k)
|
||||
candidates_with_scores = ranker.rank(candidates_with_scores)
|
||||
self.chunks_with_score = []
|
||||
if not candidates_with_scores or len(candidates_with_scores) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
else:
|
||||
self.chunks_with_score = []
|
||||
for d, score in candidates_with_scores:
|
||||
chucks = self.chunk_dao.get_document_chunks(
|
||||
query=DocumentChunkEntity(content=d.page_content),
|
||||
document_ids=self.document_ids,
|
||||
)
|
||||
if len(chucks) > 0:
|
||||
self.chunks_with_score.append((chucks[0], score))
|
||||
|
||||
context = [doc.page_content for doc, _ in candidates_with_scores]
|
||||
|
||||
context = context[: self.max_token]
|
||||
self.relations = list(
|
||||
set(
|
||||
[
|
||||
os.path.basename(str(d.metadata.get("source", "")))
|
||||
for d, _ in candidates_with_scores
|
||||
]
|
||||
)
|
||||
)
|
||||
input_values = {
|
||||
"context": context,
|
||||
"question": self.current_user_input,
|
||||
"relations": self.relations,
|
||||
}
|
||||
return input_values
|
||||
|
||||
def parse_source_view(self, chunks_with_score: List):
|
||||
"""
|
||||
format knowledge reference view message to web
|
||||
<references title="'References'" references="'[{name:aa.pdf,chunks:[{10:text},{11:text}]},{name:bb.pdf,chunks:[{12,text}]}]'"> </references>
|
||||
"""
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
references_ele = ET.Element("references")
|
||||
title = "References"
|
||||
references_ele.set("title", title)
|
||||
references_dict = {}
|
||||
for chunk, score in chunks_with_score:
|
||||
doc_name = chunk.doc_name
|
||||
if doc_name not in references_dict:
|
||||
references_dict[doc_name] = {
|
||||
"name": doc_name,
|
||||
"chunks": [
|
||||
{
|
||||
"id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"meta_info": chunk.meta_info,
|
||||
"recall_score": score,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
references_dict[doc_name]["chunks"].append(
|
||||
{
|
||||
"id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"meta_info": chunk.meta_info,
|
||||
"recall_score": score,
|
||||
}
|
||||
)
|
||||
references_list = list(references_dict.values())
|
||||
references_ele.set("references", json.dumps(references_list))
|
||||
html = ET.tostring(references_ele, encoding="utf-8")
|
||||
return html.decode("utf-8")
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatKnowledge.value()
|
||||
|
||||
def get_space_context(self, space_name):
|
||||
service = KnowledgeService()
|
||||
return service.get_space_context(space_name)
|
||||
|
||||
async def execute_similar_search(self, query):
|
||||
"""execute similarity search"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor,
|
||||
self.knowledge_embedding_client.similar_search_with_scores,
|
||||
query,
|
||||
self.top_k,
|
||||
self.recall_score,
|
||||
)
|
13
dbgpt/app/scene/chat_knowledge/v1/out_parser.py
Normal file
13
dbgpt/app/scene/chat_knowledge/v1/out_parser.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, T
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
44
dbgpt/app/scene/chat_knowledge/v1/prompt.py
Normal file
44
dbgpt/app/scene/chat_knowledge/v1/prompt.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造, 回答的时候最好按照1.2.3.点进行总结。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question},请使用和用户相同的语言进行回答.
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. When answering, it is best to summarize according to points 1.2.3.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question},when answering, use the same language as the "user".
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatKnowledge.value(),
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, language=CFG.LANGUAGE, is_default=True)
|
50
dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py
Normal file
50
dbgpt/app/scene/chat_knowledge/v1/prompt_chatglm.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatKnowledge.value(),
|
||||
input_variables=["context", "question"],
|
||||
response_format=None,
|
||||
template_define=None,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(
|
||||
prompt,
|
||||
language=CFG.LANGUAGE,
|
||||
is_default=False,
|
||||
model_names=["chatglm-6b-int4", "chatglm-6b", "chatglm2-6b", "chatglm2-6b-int4"],
|
||||
)
|
Reference in New Issue
Block a user