DB-GPT/dbgpt/rag/operators/rewrite.py
2024-03-25 20:15:39 +08:00

102 lines
3.1 KiB
Python

"""The rewrite operator."""
from typing import Any, List, Optional
from dbgpt.core import LLMClient
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.i18n_utils import _
class QueryRewriteOperator(MapOperator[dict, Any]):
"""The Rewrite Operator."""
metadata = ViewMetadata(
label=_("Query Rewrite Operator"),
name="query_rewrite_operator",
category=OperatorCategory.RAG,
description=_("Query rewrite operator."),
inputs=[
IOField.build_from(
_("Query context"), "query_context", dict, _("query context")
)
],
outputs=[
IOField.build_from(
_("Rewritten queries"),
"queries",
str,
is_list=True,
description=_("Rewritten queries"),
)
],
parameters=[
Parameter.build_from(
_("LLM Client"),
"llm_client",
LLMClient,
description=_("The LLM Client."),
),
Parameter.build_from(
label=_("Model name"),
name="model_name",
type=str,
optional=True,
default="gpt-3.5-turbo",
description=_("LLM model name."),
),
Parameter.build_from(
label=_("Prompt language"),
name="language",
type=str,
optional=True,
default="en",
description=_("Prompt language."),
),
Parameter.build_from(
label=_("Number of results"),
name="nums",
type=int,
optional=True,
default=5,
description=_("rewrite query number."),
),
],
documentation_url="https://github.com/openai/openai-python",
)
def __init__(
self,
llm_client: LLMClient,
model_name: str = "gpt-3.5-turbo",
language: Optional[str] = "en",
nums: Optional[int] = 1,
**kwargs
):
"""Init the query rewrite operator.
Args:
llm_client (Optional[LLMClient]): The LLM client.
model_name (Optional[str]): The model name.
language (Optional[str]): The prompt language.
nums (Optional[int]): The number of the rewrite results.
"""
super().__init__(**kwargs)
self._nums = nums
self._rewrite = QueryRewrite(
llm_client=llm_client,
model_name=model_name,
language=language,
)
async def map(self, query_context: dict) -> List[str]:
"""Rewrite the query."""
query = query_context.get("query")
context = query_context.get("context")
if not query:
raise ValueError("query is required")
return await self._rewrite.rewrite(
origin_query=query, context=context, nums=self._nums
)