feat(RAG):add rag operators and rag awel examples (#1061)

Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-13 16:14:48 +08:00
committed by GitHub
parent 99ea6ac1a4
commit a035433170
29 changed files with 1010 additions and 102 deletions

View File

@@ -2,14 +2,14 @@ from typing import List, Optional
from dbgpt.core import LLMClient, ModelMessage, ModelRequest, ModelMessageRoleType
REWRITE_PROMPT_TEMPLATE_EN = """
Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'\n":
"original query:: {original_query}\n"
"queries:\n"
Based on the given context {context}, Generate {nums} search queries related to: {original_query}, Provide following comma-separated format: 'queries: <queries>'":
"original query:{original_query}\n"
"queries:"
"""
REWRITE_PROMPT_TEMPLATE_ZH = """请根据原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries<queries>'
"original_query{original_query}\n"
"queries\n"
REWRITE_PROMPT_TEMPLATE_ZH = """请根据上下文{context}, 将原问题优化生成{nums}个相关的搜索查询,这些查询应与原始查询相似并且是人们可能会提出的可回答的搜索问题。请勿使用任何示例中提到的内容,确保所有生成的查询均独立于示例,仅基于提供的原始查询。请按照以下逗号分隔的格式提供: 'queries:<queries>'
"original_query:{original_query}\n"
"queries:"
"""
@@ -29,6 +29,7 @@ class QueryRewrite:
- query: (str), user query
- model_name: (str), llm model name
- llm_client: (Optional[LLMClient])
- language: (Optional[str]), language
"""
self._model_name = model_name
self._llm_client = llm_client
@@ -39,17 +40,22 @@ class QueryRewrite:
else REWRITE_PROMPT_TEMPLATE_ZH
)
async def rewrite(self, origin_query: str, nums: Optional[int] = 1) -> List[str]:
async def rewrite(
self, origin_query: str, context: Optional[str], nums: Optional[int] = 1
) -> List[str]:
"""query rewrite
Args:
origin_query: str original query
context: Optional[str] context
nums: Optional[int] rewrite nums
Returns:
queries: List[str]
"""
from dbgpt.util.chat_util import run_async_tasks
prompt = self._prompt_template.format(original_query=origin_query, nums=nums)
prompt = self._prompt_template.format(
context=context, original_query=origin_query, nums=nums
)
messages = [ModelMessage(role=ModelMessageRoleType.SYSTEM, content=prompt)]
request = ModelRequest(model=self._model_name, messages=messages)
tasks = [self._llm_client.generate(request)]
@@ -61,8 +67,12 @@ class QueryRewrite:
queries,
)
)
print("rewrite queries:", queries)
return self._parse_llm_output(output=queries[0])
if len(queries) == 0:
print("llm generate no rewrite queries.")
return queries
new_queries = self._parse_llm_output(output=queries[0])[0:nums]
print(f"rewrite queries: {new_queries}")
return new_queries
def correct(self) -> List[str]:
pass
@@ -81,6 +91,8 @@ class QueryRewrite:
if response.startswith("queries:"):
response = response[len("queries:") :]
if response.startswith("queries"):
response = response[len("queries") :]
queries = response.split(",")
if len(queries) == 1:
@@ -90,6 +102,10 @@ class QueryRewrite:
if len(queries) == 1:
queries = response.split("")
for k in queries:
if k.startswith("queries:"):
k = k[len("queries:") :]
if k.startswith("queries"):
k = response[len("queries") :]
rk = k
if lowercase:
rk = rk.lower()