mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
feat(RAG):add rag operators and rag awel examples (#1061)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
@@ -2,14 +2,14 @@ from typing import List, Optional
|
||||
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
|
||||
|
||||
REWRITE_PROMPT_TEMPLATE_EN = """
|
||||
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
|
||||
"original query:: {original_query}\n"
|
||||
"queries:\n"
|
||||
Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":
|
||||
"original query:{original_query}\n"
|
||||
"queries:"
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>':
|
||||
"original_query:{original_query}\n"
|
||||
"queries:\n"
|
||||
REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>'
|
||||
"original_query:{original_query}\n"
|
||||
"queries:"
|
||||
"""
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ class QueryRewrite:
|
||||
- query: (str), user query
|
||||
- model_name: (str), llm model name
|
||||
- llm_client: (Optional[LLMClient])
|
||||
- language: (Optional[str]), language
|
||||
"""
|
||||
self._model_name = model_name
|
||||
self._llm_client = llm_client
|
||||
@@ -39,17 +40,22 @@ class QueryRewrite:
|
||||
else REWRITE_PROMPT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]:
|
||||
async def rewrite(
|
||||
self, origin_query: str, context: Optional[str], nums: Optional[int] = 1
|
||||
) -> List[str]:
|
||||
"""query rewrite
|
||||
Args:
|
||||
origin_query: str original query
|
||||
context: Optional[str] context
|
||||
nums: Optional[int] rewrite nums
|
||||
Returns:
|
||||
queries: List[str]
|
||||
"""
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
|
||||
prompt = self._prompt_template.format(original_query=origin_query, nums=nums)
|
||||
prompt = self._prompt_template.format(
|
||||
context=context, original_query=origin_query, nums=nums
|
||||
)
|
||||
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
|
||||
request = ModelRequest(model=self._model_name, messages=messages)
|
||||
tasks = [self._llm_client.generate(request)]
|
||||
@@ -61,8 +67,12 @@ class QueryRewrite:
|
||||
queries,
|
||||
)
|
||||
)
|
||||
print("rewrite queries:", queries)
|
||||
return self._parse_llm_output(output=queries[0])
|
||||
if len(queries) == 0:
|
||||
print("llm generate no rewrite queries.")
|
||||
return queries
|
||||
new_queries = self._parse_llm_output(output=queries[0])[0:nums]
|
||||
print(f"rewrite queries: {new_queries}")
|
||||
return new_queries
|
||||
|
||||
def correct(self) -> List[str]:
|
||||
pass
|
||||
@@ -81,6 +91,8 @@ class QueryRewrite:
|
||||
|
||||
if response.startswith("queries:"):
|
||||
response = response[len("queries:") :]
|
||||
if response.startswith("queries:"):
|
||||
response = response[len("queries:") :]
|
||||
|
||||
queries = response.split(",")
|
||||
if len(queries) == 1:
|
||||
@@ -90,6 +102,10 @@ class QueryRewrite:
|
||||
if len(queries) == 1:
|
||||
queries = response.split("?")
|
||||
for k in queries:
|
||||
if k.startswith("queries:"):
|
||||
k = k[len("queries:") :]
|
||||
if k.startswith("queries:"):
|
||||
k = response[len("queries:") :]
|
||||
rk = k
|
||||
if lowercase:
|
||||
rk = rk.lower()
|
||||
|
Reference in New Issue
Block a user