DB-GPT/dbgpt/app/scene/chat_knowledge/v1/chat.py
2024-08-28 21:05:27 +08:00

263 lines
9.9 KiB
Python

import json
import os
from functools import reduce
from typing import Dict, List
from dbgpt._private.config import Config
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentDao,
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import (
ChatPromptTemplate,
HumanPromptTemplate,
MessagesPlaceholder,
SystemPromptTemplate,
)
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util.tracer import root_tracer, trace
CFG = Config()
class ChatKnowledge(BaseChat):
chat_scene: str = ChatScene.ChatKnowledge.value()
"""KBQA Chat Module"""
def __init__(self, chat_param: Dict):
"""Chat Knowledge Module Initialization
Args:
- chat_param: Dict
- chat_session_id: (str) chat session_id
- current_user_input: (str) current user input
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
super().__init__(
chat_param=chat_param,
)
self.space_context = self.get_space_context(self.knowledge_space)
self.top_k = (
self.get_knowledge_search_top_size(self.knowledge_space)
if self.space_context is None
else int(self.space_context["embedding"]["topk"])
)
self.recall_score = (
CFG.KNOWLEDGE_SEARCH_RECALL_SCORE
if self.space_context is None
else float(self.space_context["embedding"]["recall_score"])
)
self.max_token = (
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
if self.space_context is None or self.space_context.get("prompt") is None
else int(self.space_context["prompt"]["max_token"])
)
from dbgpt.serve.rag.models.models import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
spaces = KnowledgeSpaceDao().get_knowledge_space(
KnowledgeSpaceEntity(name=self.knowledge_space)
)
if len(spaces) != 1:
raise Exception(f"invalid space name:{self.knowledge_space}")
space = spaces[0]
query_rewrite = None
if CFG.KNOWLEDGE_SEARCH_REWRITE:
query_rewrite = QueryRewrite(
llm_client=self.llm_client,
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
reranker = None
retriever_top_k = self.top_k
if CFG.RERANK_MODEL:
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
self._space_retriever = KnowledgeSpaceRetriever(
space_id=space.id,
top_k=retriever_top_k,
query_rewrite=query_rewrite,
rerank=reranker,
llm_model=self.llm_model,
)
self.prompt_template.template_is_strict = False
self.relations = None
self.chunk_dao = DocumentChunkDao()
document_dao = KnowledgeDocumentDao()
documents = document_dao.get_documents(
query=KnowledgeDocumentEntity(space=space.name)
)
if len(documents) > 0:
self.document_ids = [document.id for document in documents]
async def stream_call(self):
last_output = None
async for output in super().stream_call():
last_output = output
yield output
if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and last_output
and type(self.relations) == list
and len(self.relations) > 0
and hasattr(last_output, "text")
):
last_output.text = (
last_output.text + "\n\nrelations:\n\n" + ",".join(self.relations)
)
reference = f"\n\n{self.parse_source_view(self.chunks_with_score)}"
last_output = last_output + reference
yield last_output
def stream_call_reinforce_fn(self, text):
"""return reference"""
return text + f"\n\n{self.parse_source_view(self.chunks_with_score)}"
@trace()
async def generate_input_values(self) -> Dict:
if self.space_context and self.space_context.get("prompt"):
# Not use template_define
# self.prompt_template.template_define = self.space_context["prompt"]["scene"]
# self.prompt_template.template = self.space_context["prompt"]["template"]
# Replace the template with the prompt template
self.prompt_template.prompt = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(
self.space_context["prompt"]["template"]
),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{question}"),
]
)
from dbgpt.util.chat_util import run_async_tasks
tasks = [self.execute_similar_search(self.current_user_input)]
candidates_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates_with_scores = reduce(lambda x, y: x + y, candidates_with_scores)
self.chunks_with_score = []
if not candidates_with_scores or len(candidates_with_scores) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
self.chunks_with_score = []
for chunk in candidates_with_scores:
chucks = self.chunk_dao.get_document_chunks(
query=DocumentChunkEntity(content=chunk.content),
document_ids=self.document_ids,
)
if len(chucks) > 0:
self.chunks_with_score.append((chucks[0], chunk.score))
context = "\n".join([doc.content for doc in candidates_with_scores])
self.relations = list(
set(
[
os.path.basename(str(d.metadata.get("source", "")))
for d in candidates_with_scores
]
)
)
input_values = {
"context": context,
"question": self.current_user_input,
"relations": self.relations,
}
return input_values
def parse_source_view(self, chunks_with_score: List):
"""
format knowledge reference view message to web
<references title="'References'" references="'[{name:aa.pdf,chunks:[{10:text},{11:text}]},{name:bb.pdf,chunks:[{12,text}]}]'"> </references>
"""
import xml.etree.ElementTree as ET
references_ele = ET.Element("references")
title = "References"
references_ele.set("title", title)
references_dict = {}
for chunk, score in chunks_with_score:
doc_name = chunk.doc_name
if doc_name not in references_dict:
references_dict[doc_name] = {
"name": doc_name,
"chunks": [
{
"id": chunk.id,
"content": chunk.content,
"meta_info": chunk.meta_info,
"recall_score": score,
}
],
}
else:
references_dict[doc_name]["chunks"].append(
{
"id": chunk.id,
"content": chunk.content,
"meta_info": chunk.meta_info,
"recall_score": score,
}
)
references_list = list(references_dict.values())
references_ele.set(
"references", json.dumps(references_list, ensure_ascii=False)
)
html = ET.tostring(references_ele, encoding="utf-8")
reference = html.decode("utf-8")
return reference.replace("\\n", "")
@property
def chat_type(self) -> str:
return ChatScene.ChatKnowledge.value()
def get_space_context_by_id(self, space_id):
service = KnowledgeService()
return service.get_space_context_by_space_id(space_id)
def get_space_context(self, space_name):
service = KnowledgeService()
return service.get_space_context(space_name)
def get_knowledge_search_top_size(self, space_name) -> int:
service = KnowledgeService()
request = KnowledgeSpaceRequest(name=space_name)
spaces = service.get_knowledge_space(request)
if len(spaces) == 1:
from dbgpt.storage import vector_store
if spaces[0].vector_type in vector_store.__knowledge_graph__:
return CFG.KNOWLEDGE_GRAPH_SEARCH_TOP_SIZE
return CFG.KNOWLEDGE_SEARCH_TOP_SIZE
async def execute_similar_search(self, query):
"""execute similarity search"""
with root_tracer.start_span(
"execute_similar_search", metadata={"query": query}
):
return await self._space_retriever.aretrieve_with_scores(
query, self.recall_score
)