refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

@@ -1,40 +0,0 @@
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.util.tracer import trace
class InnerChatDBSummary(BaseChat):
chat_scene: str = ChatScene.InnerChatDBSummary.value()
"""Number of results to return from the query"""
def __init__(
self,
chat_session_id,
user_input,
db_select,
db_summary,
):
""" """
super().__init__(
chat_mode=ChatScene.InnerChatDBSummary,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=db_select,
)
self.db_input = db_select
self.db_summary = db_summary
@trace()
async def generate_input_values(self) -> Dict:
input_values = {
"db_input": self.db_input,
"db_profile_summary": self.db_summary,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.InnerChatDBSummary.value

View File

@@ -1,17 +0,0 @@
import logging
from dbgpt.core.interface.output_parser import BaseOutputParser
logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
return clean_str
def parse_view_response(self, ai_text, data) -> str:
return ai_text
def get_format_instructions(self) -> str:
pass

View File

@@ -1,45 +0,0 @@
import json
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_knowledge.inner_db_summary.out_parser import (
NormalChatOutputParser,
)
CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE = """
Based on the following known database information?, answer which tables are involved in the user input.
Known database information:{db_profile_summary}
Input:{db_input}
You should only respond in JSON format as described below and ensure the response can be parsed by Python json.loads
"""
PROMPT_RESPONSE = """You must respond in JSON format as following format:
{response}
The response format must be JSON, and the key of JSON must be "table".
"""
RESPONSE_FORMAT = {"table": ["orders", "products"]}
PROMPT_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.InnerChatDBSummary.value(),
input_variables=["db_profile_summary", "db_input", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -19,7 +19,6 @@ _DEFAULT_TEMPLATE_ZH = (
_DEFAULT_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1.2.and 3.
"""
_DEFAULT_TEMPLATE = (

View File

@@ -1,32 +0,0 @@
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.app.scene.chat_knowledge.rewrite.prompt import prompt
class QueryRewrite(BaseChat):
chat_scene: str = ChatScene.QueryRewrite.value()
"""query rewrite by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.QueryRewrite
super().__init__(
chat_param=chat_param,
)
self.nums = chat_param["select_param"]
self.current_user_input = chat_param["current_user_input"]
async def generate_input_values(self):
input_values = {
"nums": self.nums,
"original_query": self.current_user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.QueryRewrite.value

View File

@@ -1,42 +0,0 @@
import logging
from dbgpt.core.interface.output_parser import BaseOutputParser
logger = logging.getLogger(__name__)
class QueryRewriteParser(BaseOutputParser):
def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(self, response, max_length: int = 128):
lowercase = True
try:
results = []
response = response.strip()
if response.startswith("queries:"):
response = response[len("queries:") :]
queries = response.split(",")
if len(queries) == 1:
queries = response.split("")
if len(queries) == 1:
queries = response.split("?")
if len(queries) == 1:
queries = response.split("")
for k in queries:
rk = k
if lowercase:
rk = rk.lower()
s = rk.strip()
if s == "":
continue
results.append(s)
except Exception as e:
logger.error(f"parse query rewrite prompt_response error: {e}")
return []
return results
def parse_view_response(self, speak, data) -> str:
return data

View File

@@ -1,41 +0,0 @@
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from .out_parser import QueryRewriteParser
CFG = Config()
PROMPT_SCENE_DEFINE = """You are a helpful assistant that generates multiple search queries based on a single input query."""
_DEFAULT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries<queries>'
"original_query{original_query}\n"
"queries\n"
"""
_DEFAULT_TEMPLATE_EN = """
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
"original query:: {original_query}\n"
"queries:\n"
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_RESPONSE = """"""
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.QueryRewrite.value(),
input_variables=["nums", "original_query"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=QueryRewriteParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -1,28 +0,0 @@
from typing import Dict
from dbgpt.app.scene import BaseChat, ChatScene
class ExtractSummary(BaseChat):
chat_scene: str = ChatScene.ExtractSummary.value()
"""get summary by llm"""
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ExtractSummary
super().__init__(
chat_param=chat_param,
)
self.user_input = chat_param["select_param"]
async def generate_input_values(self):
input_values = {
"context": self.user_input,
}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ExtractSummary.value

View File

@@ -1,28 +0,0 @@
import logging
from typing import List, Tuple
from dbgpt.core.interface.output_parser import BaseOutputParser, ResponseTye
logger = logging.getLogger(__name__)
class ExtractSummaryParser(BaseOutputParser):
def __init__(self, is_stream_out: bool, **kwargs):
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_prompt_response(
self, response, max_length: int = 128
) -> List[Tuple[str, str, str]]:
# clean_str = super().parse_prompt_response(response)
print("clean prompt response:", response)
return response
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
return data
def parse_model_nostream_resp(self, response: ResponseTye, sep: str) -> str:
try:
return super().parse_model_nostream_resp(response, sep)
except Exception as e:
return str(e)

View File

@@ -1,46 +0,0 @@
from dbgpt.core.interface.prompt import PromptTemplate
from dbgpt._private.config import Config
from dbgpt.app.scene import ChatScene
from dbgpt.app.scene.chat_knowledge.summary.out_parser import ExtractSummaryParser
CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an expert Q&A system that is trusted around the world.\nAlways answer the query using the provided context information, and not prior knowledge.\nSome rules to follow:\n1. Never directly reference the given context in your answer.\n2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines."""
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
The assistant gives helpful, detailed, professional and polite answers to the user's questions."""
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
{context}
答案尽量精确和简单,不要过长长度控制在100字左右
"""
_DEFAULT_TEMPLATE_EN = """
Write a quick summary of the following context:
{context}
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_RESPONSE = """"""
RESPONSE_FORMAT = """"""
PROMPT_NEED_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ExtractSummary.value(),
input_variables=["context"],
response_format=None,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ExtractSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT),
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -5,6 +5,7 @@ from typing import Dict, List
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt._private.config import Config
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
@@ -16,7 +17,9 @@ from dbgpt.app.knowledge.document_db import (
KnowledgeDocumentEntity,
)
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import trace
CFG = Config()
@@ -35,8 +38,7 @@ class ChatKnowledge(BaseChat):
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding_engine.embedding_engine import EmbeddingEngine
from dbgpt.rag.embedding_engine.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
@@ -59,17 +61,37 @@ class ChatKnowledge(BaseChat):
if self.space_context is None or self.space_context.get("prompt") is None
else int(self.space_context["prompt"]["max_token"])
)
vector_store_config = {
"vector_store_name": self.knowledge_space,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
}
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
self.knowledge_embedding_client = EmbeddingEngine(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.storage.vector_store.connector import VectorStoreConnector
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(name=self.knowledge_space, embedding_fn=embedding_fn)
vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config,
)
query_rewrite = None
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.llm_client = DefaultLLMClient(worker_manager=self.worker_manager)
if CFG.KNOWLEDGE_SEARCH_REWRITE:
query_rewrite = QueryRewrite(
llm_client=self.llm_client,
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
self.embedding_retriever = EmbeddingRetriever(
top_k=self.top_k,
vector_store_connector=vector_store_connector,
query_rewrite=query_rewrite,
)
self.prompt_template.template_is_strict = False
self.relations = None
@@ -110,50 +132,31 @@ class ChatKnowledge(BaseChat):
if self.space_context and self.space_context.get("prompt"):
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
from dbgpt.rag.retriever.reinforce import QueryReinforce
from dbgpt.util.chat_util import run_async_tasks
# query reinforce, get similar queries
query_reinforce = QueryReinforce(
query=self.current_user_input, model_name=self.llm_model
)
queries = []
if CFG.KNOWLEDGE_SEARCH_REWRITE:
queries = await query_reinforce.rewrite()
print("rewrite queries:", queries)
queries.append(self.current_user_input)
from dbgpt._private.chat_util import run_async_tasks
# similarity search from vector db
tasks = [self.execute_similar_search(query) for query in queries]
docs_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates_with_scores = reduce(lambda x, y: x + y, docs_with_scores)
# candidates document rerank
from dbgpt.rag.retriever.rerank import DefaultRanker
ranker = DefaultRanker(self.top_k)
candidates_with_scores = ranker.rank(candidates_with_scores)
tasks = [self.execute_similar_search(self.current_user_input)]
candidates_with_scores = await run_async_tasks(tasks=tasks, concurrency_limit=1)
candidates_with_scores = reduce(lambda x, y: x + y, candidates_with_scores)
self.chunks_with_score = []
if not candidates_with_scores or len(candidates_with_scores) == 0:
print("no relevant docs to retrieve")
context = "no relevant docs to retrieve"
else:
self.chunks_with_score = []
for d, score in candidates_with_scores:
for chunk in candidates_with_scores:
chucks = self.chunk_dao.get_document_chunks(
query=DocumentChunkEntity(content=d.page_content),
query=DocumentChunkEntity(content=chunk.content),
document_ids=self.document_ids,
)
if len(chucks) > 0:
self.chunks_with_score.append((chucks[0], score))
self.chunks_with_score.append((chucks[0], chunk.score))
context = [doc.page_content for doc, _ in candidates_with_scores]
context = context[: self.max_token]
context = "\n".join([doc.content for doc in candidates_with_scores])
self.relations = list(
set(
[
os.path.basename(str(d.metadata.get("source", "")))
for d, _ in candidates_with_scores
for d in candidates_with_scores
]
)
)
@@ -201,7 +204,8 @@ class ChatKnowledge(BaseChat):
references_list = list(references_dict.values())
references_ele.set("references", json.dumps(references_list))
html = ET.tostring(references_ele, encoding="utf-8")
return html.decode("utf-8")
reference = html.decode("utf-8")
return reference.replace("\\n", "")
@property
def chat_type(self) -> str:
@@ -213,10 +217,6 @@ class ChatKnowledge(BaseChat):
async def execute_similar_search(self, query):
"""execute similarity search"""
return await blocking_func_to_async(
self._executor,
self.knowledge_embedding_client.similar_search_with_scores,
query,
self.top_k,
self.recall_score,
return await self.embedding_retriever.aretrieve_with_scores(
query, self.recall_score
)