mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 06:30:02 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from dbgpt.component import ComponentType
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.util import get_or_create_event_loop
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
@@ -58,6 +59,9 @@ class BaseChat(ABC):
|
||||
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
|
||||
)
|
||||
self.llm_echo = False
|
||||
self.worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self.model_cache_enable = chat_param.get("model_cache_enable", False)
|
||||
|
||||
### load prompt template
|
||||
@@ -162,6 +166,10 @@ class BaseChat(ABC):
|
||||
"BaseChat.__call_base.prompt_template.format", metadata=metadata
|
||||
):
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
### prompt context token adapt according to llm max context length
|
||||
current_prompt = await self.prompt_context_token_adapt(
|
||||
prompt=current_prompt
|
||||
)
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
llm_messages = self.generate_llm_messages()
|
||||
@@ -169,6 +177,7 @@ class BaseChat(ABC):
|
||||
# Not new server mode, we convert the message format(List[ModelMessage]) to list of dict
|
||||
# fix the error of "Object of type ModelMessage is not JSON serializable" when passing the payload to request.post
|
||||
llm_messages = list(map(lambda m: m.dict(), llm_messages))
|
||||
|
||||
payload = {
|
||||
"model": self.llm_model,
|
||||
"prompt": self.generate_llm_text(),
|
||||
@@ -431,6 +440,39 @@ class BaseChat(ABC):
|
||||
return message.content
|
||||
return None
|
||||
|
||||
async def prompt_context_token_adapt(self, prompt) -> str:
|
||||
"""prompt token adapt according to llm max context length"""
|
||||
model_metadata = await self.worker_manager.get_model_metadata(
|
||||
{"model": self.llm_model}
|
||||
)
|
||||
current_token_count = await self.worker_manager.count_token(
|
||||
{"model": self.llm_model, "prompt": prompt}
|
||||
)
|
||||
if current_token_count == -1:
|
||||
logger.warning(
|
||||
"tiktoken not installed, please `pip install tiktoken` first"
|
||||
)
|
||||
template_define_token_count = 0
|
||||
if len(self.prompt_template.template_define) > 0:
|
||||
template_define_token_count = await self.worker_manager.count_token(
|
||||
{
|
||||
"model": self.llm_model,
|
||||
"prompt": self.prompt_template.template_define,
|
||||
}
|
||||
)
|
||||
current_token_count += template_define_token_count
|
||||
if (
|
||||
current_token_count + self.prompt_template.max_new_tokens
|
||||
) > model_metadata.context_length:
|
||||
prompt = prompt[
|
||||
: (
|
||||
model_metadata.context_length
|
||||
- self.prompt_template.max_new_tokens
|
||||
- template_define_token_count
|
||||
)
|
||||
]
|
||||
return prompt
|
||||
|
||||
def generate(self, p) -> str:
|
||||
"""
|
||||
generate context for LLM input
|
||||
|
@@ -63,14 +63,11 @@ class ChatDashboard(BaseChat):
|
||||
try:
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_similar_tables,
|
||||
client.get_db_summary,
|
||||
self.db_name,
|
||||
self.current_user_input,
|
||||
self.top_k,
|
||||
)
|
||||
# table_infos = client.get_similar_tables(
|
||||
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
# )
|
||||
print("dashboard vector find tables:{}", table_infos)
|
||||
except Exception as e:
|
||||
print("db summary find error!" + str(e))
|
||||
|
@@ -19,22 +19,14 @@ class ChatFactory(metaclass=Singleton):
|
||||
from dbgpt.app.scene.chat_dashboard.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.v1.chat import ChatKnowledge
|
||||
from dbgpt.app.scene.chat_knowledge.v1.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.inner_db_summary.chat import (
|
||||
InnerChatDBSummary,
|
||||
)
|
||||
from dbgpt.app.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet
|
||||
from dbgpt.app.scene.chat_knowledge.extract_triplet.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.extract_entity.chat import ExtractEntity
|
||||
from dbgpt.app.scene.chat_knowledge.extract_entity.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.summary.chat import ExtractSummary
|
||||
from dbgpt.app.scene.chat_knowledge.summary.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.refine_summary.chat import (
|
||||
ExtractRefineSummary,
|
||||
)
|
||||
from dbgpt.app.scene.chat_knowledge.refine_summary.prompt import prompt
|
||||
from dbgpt.app.scene.chat_knowledge.rewrite.chat import QueryRewrite
|
||||
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
|
||||
from dbgpt.app.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||
|
@@ -1,40 +0,0 @@
|
||||
from typing import Dict
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
|
||||
class InnerChatDBSummary(BaseChat):
|
||||
chat_scene: str = ChatScene.InnerChatDBSummary.value()
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
db_select,
|
||||
db_summary,
|
||||
):
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.InnerChatDBSummary,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=db_select,
|
||||
)
|
||||
|
||||
self.db_input = db_select
|
||||
self.db_summary = db_summary
|
||||
|
||||
@trace()
|
||||
async def generate_input_values(self) -> Dict:
|
||||
input_values = {
|
||||
"db_input": self.db_input,
|
||||
"db_profile_summary": self.db_summary,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.InnerChatDBSummary.value
|
@@ -1,17 +0,0 @@
|
||||
import logging
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
return clean_str
|
||||
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
return ai_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_knowledge.inner_db_summary.out_parser import (
|
||||
NormalChatOutputParser,
|
||||
)
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Based on the following known database information?, answer which tables are involved in the user input.
|
||||
Known database information:{db_profile_summary}
|
||||
Input:{db_input}
|
||||
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
|
||||
|
||||
|
||||
"""
|
||||
PROMPT_RESPONSE = """You must respond in JSON format as following format:
|
||||
{response}
|
||||
The response format must be JSON, and the key of JSON must be "table".
|
||||
"""
|
||||
|
||||
|
||||
RESPONSE_FORMAT = {"table": ["orders", "products"]}
|
||||
|
||||
|
||||
PROMPT_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.InnerChatDBSummary.value(),
|
||||
input_variables=["db_profile_summary", "db_input", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@@ -19,7 +19,6 @@ _DEFAULT_TEMPLATE_ZH = (
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
|
||||
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
|
@@ -1,32 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
|
||||
|
||||
|
||||
class QueryRewrite(BaseChat):
|
||||
chat_scene: str = ChatScene.QueryRewrite.value()
|
||||
|
||||
"""query rewrite by llm"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.QueryRewrite
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.nums = chat_param["select_param"]
|
||||
self.current_user_input = chat_param["current_user_input"]
|
||||
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"nums": self.nums,
|
||||
"original_query": self.current_user_input,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.QueryRewrite.value
|
@@ -1,42 +0,0 @@
|
||||
import logging
|
||||
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryRewriteParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_prompt_response(self, response, max_length: int = 128):
|
||||
lowercase = True
|
||||
try:
|
||||
results = []
|
||||
response = response.strip()
|
||||
|
||||
if response.startswith("queries:"):
|
||||
response = response[len("queries:") :]
|
||||
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
for k in queries:
|
||||
rk = k
|
||||
if lowercase:
|
||||
rk = rk.lower()
|
||||
s = rk.strip()
|
||||
if s == "":
|
||||
continue
|
||||
results.append(s)
|
||||
except Exception as e:
|
||||
logger.error(f"parse query rewrite prompt_response error: {e}")
|
||||
return []
|
||||
return results
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
return data
|
@@ -1,41 +0,0 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from .out_parser import QueryRewriteParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are a helpful assistant that generates multiple search queries based on a single input query."""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>':
|
||||
"original_query:{original_query}\n"
|
||||
"queries:\n"
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
|
||||
"original query:: {original_query}\n"
|
||||
"queries:\n"
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_RESPONSE = """"""
|
||||
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.QueryRewrite.value(),
|
||||
input_variables=["nums", "original_query"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=QueryRewriteParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@@ -1,28 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
|
||||
|
||||
class ExtractSummary(BaseChat):
|
||||
chat_scene: str = ChatScene.ExtractSummary.value()
|
||||
|
||||
"""get summary by llm"""
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.ExtractSummary
|
||||
super().__init__(
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.user_input = chat_param["select_param"]
|
||||
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"context": self.user_input,
|
||||
}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ExtractSummary.value
|
@@ -1,28 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, ResponseTye
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractSummaryParser(BaseOutputParser):
|
||||
def __init__(self, is_stream_out: bool, **kwargs):
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_prompt_response(
|
||||
self, response, max_length: int = 128
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# clean_str = super().parse_prompt_response(response)
|
||||
print("clean prompt response:", response)
|
||||
return response
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
return data
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str) -> str:
|
||||
try:
|
||||
return super().parse_model_nostream_resp(response, sep)
|
||||
except Exception as e:
|
||||
return str(e)
|
@@ -1,46 +0,0 @@
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
|
||||
from dbgpt.app.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser
|
||||
|
||||
CFG = Config()
|
||||
|
||||
# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
|
||||
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions."""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
|
||||
{context}
|
||||
答案尽量精确和简单,不要过长,长度控制在100字左右
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Write a quick summary of the following context:
|
||||
{context}
|
||||
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_RESPONSE = """"""
|
||||
|
||||
|
||||
RESPONSE_FORMAT = """"""
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExtractSummary.value(),
|
||||
input_variables=["context"],
|
||||
response_format=None,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=ExtractSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@@ -5,6 +5,7 @@ from typing import Dict, List
|
||||
|
||||
from dbgpt.app.scene import BaseChat, ChatScene
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import ComponentType
|
||||
|
||||
from dbgpt.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
@@ -16,7 +17,9 @@ from dbgpt.app.knowledge.document_db import (
|
||||
KnowledgeDocumentEntity,
|
||||
)
|
||||
from dbgpt.app.knowledge.service import KnowledgeService
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.model import DefaultLLMClient
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.tracer import trace
|
||||
|
||||
CFG = Config()
|
||||
@@ -35,8 +38,7 @@ class ChatKnowledge(BaseChat):
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) space name
|
||||
"""
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
@@ -59,17 +61,37 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context is None or self.space_context.get("prompt") is None
|
||||
else int(self.space_context["prompt"]["max_token"])
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": self.knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
self.knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
embedding_fn = embedding_factory.create(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
config = VectorStoreConfig(name=self.knowledge_space, embedding_fn=embedding_fn)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
query_rewrite = None
|
||||
self.worker_manager = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
|
||||
if CFG.KNOWLEDGE_SEARCH_REWRITE:
|
||||
query_rewrite = QueryRewrite(
|
||||
llm_client=self.llm_client,
|
||||
model_name=self.llm_model,
|
||||
language=CFG.LANGUAGE,
|
||||
)
|
||||
self.embedding_retriever = EmbeddingRetriever(
|
||||
top_k=self.top_k,
|
||||
vector_store_connector=vector_store_connector,
|
||||
query_rewrite=query_rewrite,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
self.relations = None
|
||||
@@ -110,50 +132,31 @@ class ChatKnowledge(BaseChat):
|
||||
if self.space_context and self.space_context.get("prompt"):
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
from dbgpt.rag.retriever.reinforce import QueryReinforce
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
# query reinforce, get similar queries
|
||||
query_reinforce = QueryReinforce(
|
||||
query=self.current_user_input, model_name=self.llm_model
|
||||
)
|
||||
queries = []
|
||||
if CFG.KNOWLEDGE_SEARCH_REWRITE:
|
||||
queries = await query_reinforce.rewrite()
|
||||
print("rewrite queries:", queries)
|
||||
queries.append(self.current_user_input)
|
||||
from dbgpt._private.chat_util import run_async_tasks
|
||||
|
||||
# similarity search from vector db
|
||||
tasks = [self.execute_similar_search(query) for query in queries]
|
||||
docs_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
candidates_with_scores = reduce(lambda x, y: x + y, docs_with_scores)
|
||||
# candidates document rerank
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker
|
||||
|
||||
ranker = DefaultRanker(self.top_k)
|
||||
candidates_with_scores = ranker.rank(candidates_with_scores)
|
||||
tasks = [self.execute_similar_search(self.current_user_input)]
|
||||
candidates_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
|
||||
candidates_with_scores = reduce(lambda x, y: x + y, candidates_with_scores)
|
||||
self.chunks_with_score = []
|
||||
if not candidates_with_scores or len(candidates_with_scores) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
else:
|
||||
self.chunks_with_score = []
|
||||
for d, score in candidates_with_scores:
|
||||
for chunk in candidates_with_scores:
|
||||
chucks = self.chunk_dao.get_document_chunks(
|
||||
query=DocumentChunkEntity(content=d.page_content),
|
||||
query=DocumentChunkEntity(content=chunk.content),
|
||||
document_ids=self.document_ids,
|
||||
)
|
||||
if len(chucks) > 0:
|
||||
self.chunks_with_score.append((chucks[0], score))
|
||||
self.chunks_with_score.append((chucks[0], chunk.score))
|
||||
|
||||
context = [doc.page_content for doc, _ in candidates_with_scores]
|
||||
|
||||
context = context[: self.max_token]
|
||||
context = "\n".join([doc.content for doc in candidates_with_scores])
|
||||
self.relations = list(
|
||||
set(
|
||||
[
|
||||
os.path.basename(str(d.metadata.get("source", "")))
|
||||
for d, _ in candidates_with_scores
|
||||
for d in candidates_with_scores
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -201,7 +204,8 @@ class ChatKnowledge(BaseChat):
|
||||
references_list = list(references_dict.values())
|
||||
references_ele.set("references", json.dumps(references_list))
|
||||
html = ET.tostring(references_ele, encoding="utf-8")
|
||||
return html.decode("utf-8")
|
||||
reference = html.decode("utf-8")
|
||||
return reference.replace("\\n", "")
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
@@ -213,10 +217,6 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
async def execute_similar_search(self, query):
|
||||
"""execute similarity search"""
|
||||
return await blocking_func_to_async(
|
||||
self._executor,
|
||||
self.knowledge_embedding_client.similar_search_with_scores,
|
||||
query,
|
||||
self.top_k,
|
||||
self.recall_score,
|
||||
return await self.embedding_retriever.aretrieve_with_scores(
|
||||
query, self.recall_score
|
||||
)
|
||||
|
@@ -2,16 +2,20 @@ from typing import Dict, Optional, List
|
||||
from dataclasses import dataclass
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.interface.prompt import PromptTemplate
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.scene import ChatScene
|
||||
from dbgpt.core.interface.message import OnceConversation
|
||||
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
from dbgpt.storage.chat_history.base import BaseChatHistoryMemory
|
||||
from dbgpt.storage.chat_history.chat_hisotry_factory import ChatHistory
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
# TODO move global config
|
||||
CFG = Config()
|
||||
@@ -184,23 +188,14 @@ class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
|
||||
|
||||
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
|
||||
# TODO, decompose the current operator into some atomic operators
|
||||
knowledge_space = input_value.select_param
|
||||
vector_store_config = {
|
||||
"vector_store_name": knowledge_space,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
}
|
||||
embedding_factory = self.system_app.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
knowledge_embedding_client = EmbeddingEngine(
|
||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
|
||||
space_context = await self._get_space_context(knowledge_space)
|
||||
top_k = (
|
||||
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
@@ -219,16 +214,28 @@ class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
|
||||
]
|
||||
input_value.prompt_template.template = space_context["prompt"]["template"]
|
||||
|
||||
config = VectorStoreConfig(
|
||||
name=knowledge_space,
|
||||
embedding_fn=embedding_factory.create(
|
||||
EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
),
|
||||
)
|
||||
vector_store_connector = VectorStoreConnector(
|
||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||
vector_store_config=config,
|
||||
)
|
||||
embedding_retriever = EmbeddingRetriever(
|
||||
top_k=top_k, vector_store_connector=vector_store_connector
|
||||
)
|
||||
docs = await self.blocking_func_to_async(
|
||||
knowledge_embedding_client.similar_search,
|
||||
embedding_retriever.retrieve,
|
||||
input_value.current_user_input,
|
||||
top_k,
|
||||
)
|
||||
if not docs or len(docs) == 0:
|
||||
print("no relevant docs to retrieve")
|
||||
context = "no relevant docs to retrieve"
|
||||
else:
|
||||
context = [d.page_content for d in docs]
|
||||
context = [d.content for d in docs]
|
||||
context = context[:max_token]
|
||||
relations = list(
|
||||
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
|
||||
|
Reference in New Issue
Block a user