mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
@@ -36,7 +36,7 @@ def initialize_components(
|
||||
system_app.register(
|
||||
DefaultExecutorFactory, max_workers=param.default_thread_pool_size
|
||||
)
|
||||
system_app.register(DefaultScheduler)
|
||||
system_app.register(DefaultScheduler, scheduler_enable=CFG.SCHEDULER_ENABLED)
|
||||
system_app.register_instance(controller)
|
||||
system_app.register(ConnectorManager)
|
||||
|
||||
@@ -60,6 +60,7 @@ def initialize_components(
|
||||
_initialize_openapi(system_app)
|
||||
# Register serve apps
|
||||
register_serve_apps(system_app, CFG, param.port)
|
||||
_initialize_operators()
|
||||
|
||||
|
||||
def _initialize_model_cache(system_app: SystemApp, port: int):
|
||||
@@ -128,3 +129,14 @@ def _initialize_openapi(system_app: SystemApp):
|
||||
from dbgpt.app.openapi.api_v1.editor.service import EditorService
|
||||
|
||||
system_app.register(EditorService)
|
||||
|
||||
|
||||
def _initialize_operators():
|
||||
from dbgpt.app.operators.converter import StringToInteger
|
||||
from dbgpt.app.operators.datasource import (
|
||||
HODatasourceExecutorOperator,
|
||||
HODatasourceRetrieverOperator,
|
||||
)
|
||||
from dbgpt.app.operators.llm import HOLLMOperator, HOStreamingLLMOperator
|
||||
from dbgpt.app.operators.rag import HOKnowledgeOperator
|
||||
from dbgpt.serve.agent.resource.datasource import DatasourceResource
|
||||
|
@@ -19,12 +19,14 @@ class DefaultScheduler(BaseComponent):
|
||||
system_app: SystemApp,
|
||||
scheduler_delay_ms: int = 5000,
|
||||
scheduler_interval_ms: int = 1000,
|
||||
scheduler_enable: bool = True,
|
||||
):
|
||||
super().__init__(system_app)
|
||||
self.system_app = system_app
|
||||
self._scheduler_interval_ms = scheduler_interval_ms
|
||||
self._scheduler_delay_ms = scheduler_delay_ms
|
||||
self._stop_event = threading.Event()
|
||||
self._scheduler_enable = scheduler_enable
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
@@ -39,7 +41,7 @@ class DefaultScheduler(BaseComponent):
|
||||
|
||||
def _scheduler(self):
|
||||
time.sleep(self._scheduler_delay_ms / 1000)
|
||||
while not self._stop_event.is_set():
|
||||
while self._scheduler_enable and not self._stop_event.is_set():
|
||||
try:
|
||||
schedule.run_pending()
|
||||
except Exception as e:
|
||||
|
4
dbgpt/app/operators/__init__.py
Normal file
4
dbgpt/app/operators/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Operators package.
|
||||
|
||||
This package contains all higher-order operators that are used to build workflows.
|
||||
"""
|
186
dbgpt/app/operators/converter.py
Normal file
186
dbgpt/app/operators/converter.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Type Converter Operators."""
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
_INPUTS_STRING = IOField.build_from(
|
||||
_("String"),
|
||||
"string",
|
||||
str,
|
||||
description=_("The string to be converted to other types."),
|
||||
)
|
||||
_INPUTS_INTEGER = IOField.build_from(
|
||||
_("Integer"),
|
||||
"integer",
|
||||
int,
|
||||
description=_("The integer to be converted to other types."),
|
||||
)
|
||||
_INPUTS_FLOAT = IOField.build_from(
|
||||
_("Float"),
|
||||
"float",
|
||||
float,
|
||||
description=_("The float to be converted to other types."),
|
||||
)
|
||||
_INPUTS_BOOLEAN = IOField.build_from(
|
||||
_("Boolean"),
|
||||
"boolean",
|
||||
bool,
|
||||
description=_("The boolean to be converted to other types."),
|
||||
)
|
||||
|
||||
_OUTPUTS_STRING = IOField.build_from(
|
||||
_("String"),
|
||||
"string",
|
||||
str,
|
||||
description=_("The string converted from other types."),
|
||||
)
|
||||
_OUTPUTS_INTEGER = IOField.build_from(
|
||||
_("Integer"),
|
||||
"integer",
|
||||
int,
|
||||
description=_("The integer converted from other types."),
|
||||
)
|
||||
_OUTPUTS_FLOAT = IOField.build_from(
|
||||
_("Float"),
|
||||
"float",
|
||||
float,
|
||||
description=_("The float converted from other types."),
|
||||
)
|
||||
_OUTPUTS_BOOLEAN = IOField.build_from(
|
||||
_("Boolean"),
|
||||
"boolean",
|
||||
bool,
|
||||
description=_("The boolean converted from other types."),
|
||||
)
|
||||
|
||||
|
||||
class StringToInteger(MapOperator[str, int]):
|
||||
"""Converts a string to an integer."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String to Integer"),
|
||||
name="default_converter_string_to_integer",
|
||||
description=_("Converts a string to an integer."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_STRING],
|
||||
outputs=[_OUTPUTS_INTEGER],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new StringToInteger operator."""
|
||||
super().__init__(map_function=lambda x: int(x), **kwargs)
|
||||
|
||||
|
||||
class StringToFloat(MapOperator[str, float]):
|
||||
"""Converts a string to a float."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String to Float"),
|
||||
name="default_converter_string_to_float",
|
||||
description=_("Converts a string to a float."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_STRING],
|
||||
outputs=[_OUTPUTS_FLOAT],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new StringToFloat operator."""
|
||||
super().__init__(map_function=lambda x: float(x), **kwargs)
|
||||
|
||||
|
||||
class StringToBoolean(MapOperator[str, bool]):
|
||||
"""Converts a string to a boolean."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String to Boolean"),
|
||||
name="default_converter_string_to_boolean",
|
||||
description=_("Converts a string to a boolean, true: 'true', '1', 'y'"),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("True Values"),
|
||||
"true_values",
|
||||
str,
|
||||
optional=True,
|
||||
default="true,1,y",
|
||||
description=_("Comma-separated values that should be treated as True."),
|
||||
)
|
||||
],
|
||||
inputs=[_INPUTS_STRING],
|
||||
outputs=[_OUTPUTS_BOOLEAN],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, true_values: str = "true,1,y", **kwargs):
|
||||
"""Create a new StringToBoolean operator."""
|
||||
true_values_list = true_values.split(",")
|
||||
true_values_list = [x.strip().lower() for x in true_values_list]
|
||||
super().__init__(map_function=lambda x: x.lower() in true_values_list, **kwargs)
|
||||
|
||||
|
||||
class IntegerToString(MapOperator[int, str]):
|
||||
"""Converts an integer to a string."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Integer to String"),
|
||||
name="default_converter_integer_to_string",
|
||||
description=_("Converts an integer to a string."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_INTEGER],
|
||||
outputs=[_OUTPUTS_STRING],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new IntegerToString operator."""
|
||||
super().__init__(map_function=lambda x: str(x), **kwargs)
|
||||
|
||||
|
||||
class FloatToString(MapOperator[float, str]):
|
||||
"""Converts a float to a string."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Float to String"),
|
||||
name="default_converter_float_to_string",
|
||||
description=_("Converts a float to a string."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_FLOAT],
|
||||
outputs=[_OUTPUTS_STRING],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new FloatToString operator."""
|
||||
super().__init__(map_function=lambda x: str(x), **kwargs)
|
||||
|
||||
|
||||
class BooleanToString(MapOperator[bool, str]):
|
||||
"""Converts a boolean to a string."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Boolean to String"),
|
||||
name="default_converter_boolean_to_string",
|
||||
description=_("Converts a boolean to a string."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_BOOLEAN],
|
||||
outputs=[_OUTPUTS_STRING],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new BooleanToString operator."""
|
||||
super().__init__(map_function=lambda x: str(x), **kwargs)
|
363
dbgpt/app/operators/datasource.py
Normal file
363
dbgpt/app/operators/datasource.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBResource
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAGContext, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.operators import BaseLLM
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.vis.tags.vis_chart import default_chart_type_prompt
|
||||
|
||||
from .llm import HOContextBody
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_CHART_TYPE = default_chart_type_prompt()
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """You are a database expert.
|
||||
Please answer the user's question based on the database selected by the user and some \
|
||||
of the available table structure definitions of the database.
|
||||
Database name:
|
||||
{db_name}
|
||||
Table structure definition:
|
||||
{table_info}
|
||||
|
||||
Constraint:
|
||||
1.Please understand the user's intention based on the user's question, and use the \
|
||||
given table structure definition to create a grammatically correct {dialect} sql. \
|
||||
If sql is not required, answer the user's question directly..
|
||||
2.Always limit the query to a maximum of {max_num_results} results unless the user \
|
||||
specifies in the question the specific number of rows of data he wishes to obtain.
|
||||
3.You can only use the tables provided in the table structure information to \
|
||||
generate sql. If you cannot generate sql based on the provided table structure, \
|
||||
please say: "The table structure information provided is not enough to generate \
|
||||
sql queries." It is prohibited to fabricate information at will.
|
||||
4.Please be careful not to mistake the relationship between tables and columns \
|
||||
when generating SQL.
|
||||
5.Please check the correctness of the SQL and ensure that the query performance is \
|
||||
optimized under correct conditions.
|
||||
6.Please choose the best one from the display methods given below for data \
|
||||
rendering, and put the type name into the name parameter value that returns the \
|
||||
required format. If you cannot find the most suitable one, use 'Table' as the \
|
||||
display method. , the available data display methods are as follows: {display_type}
|
||||
|
||||
User Question:
|
||||
{user_input}
|
||||
Please think step by step and respond according to the following JSON format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """你是一个数据库专家.
|
||||
请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题.
|
||||
数据库名:
|
||||
{db_name}
|
||||
表结构定义:
|
||||
{table_info}
|
||||
|
||||
约束:
|
||||
1. 请根据用户问题理解用户意图,使用给出表结构定义创建一个语法正确的 {dialect} sql,如果不需要 \
|
||||
sql,则直接回答用户问题。
|
||||
2. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {max_num_results} \
|
||||
个结果。
|
||||
3. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:\
|
||||
“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||
4. 请注意生成SQL时不要弄错表和列的关系
|
||||
5. 请检查SQL的正确性,并保证正确的情况下优化查询性能
|
||||
6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种\
|
||||
,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type}
|
||||
用户问题:
|
||||
{user_input}
|
||||
请一步步思考并按照以下JSON格式回复:
|
||||
{response}
|
||||
确保返回正确的json并且可以被Python json.loads方法解析.
|
||||
"""
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
_DEFAULT_RESPONSE = json.dumps(
|
||||
{
|
||||
"thoughts": "thoughts summary to say to user",
|
||||
"sql": "SQL Query to run",
|
||||
"display_type": "Data display method",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
_PARAMETER_DATASOURCE = Parameter.build_from(
|
||||
_("Datasource"),
|
||||
"datasource",
|
||||
type=DBResource,
|
||||
description=_("The datasource to retrieve the context"),
|
||||
)
|
||||
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=_DEFAULT_TEMPLATE,
|
||||
description=_("The prompt template to build a database prompt"),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
)
|
||||
_PARAMETER_DISPLAY_TYPE = Parameter.build_from(
|
||||
_("Display Type"),
|
||||
"display_type",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=_DEFAULT_CHART_TYPE,
|
||||
description=_("The display type for the data"),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
)
|
||||
_PARAMETER_MAX_NUM_RESULTS = Parameter.build_from(
|
||||
_("Max Number of Results"),
|
||||
"max_num_results",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=50,
|
||||
description=_("The maximum number of results to return"),
|
||||
)
|
||||
_PARAMETER_RESPONSE_FORMAT = Parameter.build_from(
|
||||
_("Response Format"),
|
||||
"response_format",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=_DEFAULT_RESPONSE,
|
||||
description=_("The response format, default is a JSON format"),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
)
|
||||
|
||||
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
|
||||
_("Context Key"),
|
||||
"context_key",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="context",
|
||||
description=_("The key of the context, it will be used in building the prompt"),
|
||||
)
|
||||
_INPUTS_QUESTION = IOField.build_from(
|
||||
_("User question"),
|
||||
"query",
|
||||
str,
|
||||
description=_("The user question to retrieve table schemas from the datasource"),
|
||||
)
|
||||
_OUTPUTS_CONTEXT = IOField.build_from(
|
||||
_("Retrieved context"),
|
||||
"context",
|
||||
HOContextBody,
|
||||
description=_("The retrieved context from the datasource"),
|
||||
)
|
||||
|
||||
_INPUTS_SQL_DICT = IOField.build_from(
|
||||
_("SQL dict"),
|
||||
"sql_dict",
|
||||
dict,
|
||||
description=_("The SQL to be executed wrapped in a dictionary, generated by LLM"),
|
||||
)
|
||||
_OUTPUTS_SQL_RESULT = IOField.build_from(
|
||||
_("SQL result"),
|
||||
"sql_result",
|
||||
str,
|
||||
description=_("The result of the SQL execution"),
|
||||
)
|
||||
|
||||
_INPUTS_SQL_DICT_LIST = IOField.build_from(
|
||||
_("SQL dict list"),
|
||||
"sql_dict_list",
|
||||
dict,
|
||||
description=_(
|
||||
"The SQL list to be executed wrapped in a dictionary, generated by LLM"
|
||||
),
|
||||
is_list=True,
|
||||
)
|
||||
|
||||
|
||||
class GPTVisMixin:
|
||||
async def save_view_message(self, dag_ctx: DAGContext, view: str):
|
||||
"""Save the view message."""
|
||||
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
|
||||
|
||||
|
||||
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
"""Retrieve the table schemas from the datasource."""
|
||||
|
||||
_share_data_key = "__datasource_retriever_chunks__"
|
||||
|
||||
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
|
||||
async def map(self, context: HOContextBody) -> List[Chunk]:
|
||||
schema_info = await self.current_dag_context.get_from_share_data(
|
||||
HODatasourceRetrieverOperator._share_data_key
|
||||
)
|
||||
if isinstance(schema_info, list):
|
||||
chunks = [Chunk(content=table_info) for table_info in schema_info]
|
||||
else:
|
||||
chunks = [Chunk(content=schema_info)]
|
||||
return chunks
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Retriever Operator"),
|
||||
name="higher_order_datasource_retriever_operator",
|
||||
description=_("Retrieve the table schemas from the datasource."),
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[
|
||||
_PARAMETER_DATASOURCE.new(),
|
||||
_PARAMETER_PROMPT_TEMPLATE.new(),
|
||||
_PARAMETER_DISPLAY_TYPE.new(),
|
||||
_PARAMETER_MAX_NUM_RESULTS.new(),
|
||||
_PARAMETER_RESPONSE_FORMAT.new(),
|
||||
_PARAMETER_CONTEXT_KEY.new(),
|
||||
],
|
||||
inputs=[_INPUTS_QUESTION.new()],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
IOField.build_from(
|
||||
_("Retrieved schema chunks"),
|
||||
"chunks",
|
||||
Chunk,
|
||||
is_list=True,
|
||||
description=_("The retrieved schema chunks from the datasource"),
|
||||
mappers=[ChunkMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasource: DBResource,
|
||||
prompt_template: str = _DEFAULT_TEMPLATE,
|
||||
display_type: str = _DEFAULT_CHART_TYPE,
|
||||
max_num_results: int = 50,
|
||||
response_format: str = _DEFAULT_RESPONSE,
|
||||
context_key: Optional[str] = "context",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._datasource = datasource
|
||||
self._prompt_template = prompt_template
|
||||
self._display_type = display_type
|
||||
self._max_num_results = max_num_results
|
||||
self._response_format = response_format
|
||||
self._context_key = context_key
|
||||
|
||||
async def map(self, question: str) -> HOContextBody:
|
||||
"""Retrieve the context from the datasource."""
|
||||
db_name = self._datasource._db_name
|
||||
dialect = self._datasource.dialect
|
||||
schema_info = await self.blocking_func_to_async(
|
||||
self._datasource.get_schema_link,
|
||||
db=db_name,
|
||||
question=question,
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self._share_data_key, schema_info
|
||||
)
|
||||
context = self._prompt_template.format(
|
||||
db_name=db_name,
|
||||
table_info=schema_info,
|
||||
dialect=dialect,
|
||||
max_num_results=self._max_num_results,
|
||||
display_type=self._display_type,
|
||||
user_input=question,
|
||||
response=self._response_format,
|
||||
)
|
||||
|
||||
return HOContextBody(
|
||||
context_key=self._context_key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
"""Execute the context from the datasource."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Executor Operator"),
|
||||
name="higher_order_datasource_executor_operator",
|
||||
description=_("Execute the context from the datasource."),
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[_PARAMETER_DATASOURCE.new()],
|
||||
inputs=[_INPUTS_SQL_DICT.new()],
|
||||
outputs=[_OUTPUTS_SQL_RESULT.new()],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, datasource: DBResource, **kwargs):
|
||||
"""Initialize the operator."""
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._datasource = datasource
|
||||
|
||||
async def map(self, sql_dict: dict) -> str:
|
||||
"""Execute the context from the datasource."""
|
||||
from dbgpt.vis.tags.vis_chart import VisChart
|
||||
|
||||
if not isinstance(sql_dict, dict):
|
||||
raise ValueError(
|
||||
"The input value of datasource executor should be a dictionary."
|
||||
)
|
||||
vis = VisChart()
|
||||
sql = sql_dict.get("sql")
|
||||
if not sql:
|
||||
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
|
||||
data_df = await self._datasource.query_to_df(sql)
|
||||
view = await vis.display(chart=sql_dict, data_df=data_df)
|
||||
await self.save_view_message(self.current_dag_context, view)
|
||||
return view
|
||||
|
||||
|
||||
class HODatasourceDashboardOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
"""Execute the context from the datasource."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Dashboard Operator"),
|
||||
name="higher_order_datasource_dashboard_operator",
|
||||
description=_("Execute the context from the datasource."),
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[_PARAMETER_DATASOURCE.new()],
|
||||
inputs=[_INPUTS_SQL_DICT_LIST.new()],
|
||||
outputs=[_OUTPUTS_SQL_RESULT.new()],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, datasource: DBResource, **kwargs):
|
||||
"""Initialize the operator."""
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._datasource = datasource
|
||||
|
||||
async def map(self, sql_dict_list: List[dict]) -> str:
|
||||
"""Execute the context from the datasource."""
|
||||
from dbgpt.vis.tags.vis_dashboard import VisDashboard
|
||||
|
||||
if not isinstance(sql_dict_list, list):
|
||||
raise ValueError(
|
||||
"The input value of datasource executor should be a list of dictionaries."
|
||||
)
|
||||
vis = VisDashboard()
|
||||
chart_params = []
|
||||
for chart_item in sql_dict_list:
|
||||
chart_dict = {k: v for k, v in chart_item.items()}
|
||||
sql = chart_item.get("sql")
|
||||
try:
|
||||
data_df = await self._datasource.query_to_df(sql)
|
||||
chart_dict["data"] = data_df
|
||||
except Exception as e:
|
||||
logger.warning(f"Sql execute failed!{str(e)}")
|
||||
chart_dict["err_msg"] = str(e)
|
||||
chart_params.append(chart_dict)
|
||||
view = await vis.display(charts=chart_params)
|
||||
await self.save_view_message(self.current_dag_context, view)
|
||||
return view
|
453
dbgpt/app/operators/llm.py
Normal file
453
dbgpt/app/operators/llm.py
Normal file
@@ -0,0 +1,453 @@
|
||||
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import (
|
||||
BaseMessage,
|
||||
ChatPromptTemplate,
|
||||
LLMClient,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
BaseOperator,
|
||||
CommonLLMHttpRequestBody,
|
||||
DAGContext,
|
||||
DefaultInputContext,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
SimpleCallDataInputSource,
|
||||
TaskOutput,
|
||||
)
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.interface.operators.message_operator import (
|
||||
BaseConversationOperator,
|
||||
BufferedConversationMapperOperator,
|
||||
TokenBufferedConversationMapperOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator
|
||||
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
|
||||
class HOContextBody(BaseModel):
|
||||
"""Higher-order context body."""
|
||||
|
||||
context_key: str = Field(
|
||||
"context",
|
||||
description=_("The context key can be used as the key for formatting prompt."),
|
||||
)
|
||||
context: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description=_("The context."),
|
||||
)
|
||||
|
||||
|
||||
class BaseHOLLMOperator(
|
||||
BaseConversationOperator,
|
||||
JoinOperator[ModelRequest],
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
):
|
||||
"""Higher-order model request builder operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_template: ChatPromptTemplate,
|
||||
model: str = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
history_merge_mode: Literal["none", "window", "token"] = "window",
|
||||
user_message_key: str = "user_input",
|
||||
history_key: Optional[str] = None,
|
||||
keep_start_rounds: Optional[int] = None,
|
||||
keep_end_rounds: Optional[int] = None,
|
||||
max_token_limit: int = 2048,
|
||||
**kwargs,
|
||||
):
|
||||
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
|
||||
LLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
||||
StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
||||
|
||||
# User must select a history merge mode
|
||||
self._history_merge_mode = history_merge_mode
|
||||
self._user_message_key = user_message_key
|
||||
self._has_history = history_merge_mode != "none"
|
||||
self._prompt_template = prompt_template
|
||||
self._model = model
|
||||
self._history_key = history_key
|
||||
self._str_history = False
|
||||
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
|
||||
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
|
||||
self._max_token_limit = max_token_limit
|
||||
self._sub_compose_dag: Optional[DAG] = None
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
|
||||
conv_serve = ConversationServe.get_instance(self.system_app)
|
||||
self._storage = conv_serve.conv_storage
|
||||
self._message_storage = conv_serve.message_storage
|
||||
|
||||
_: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx)
|
||||
dag_ctx.current_task_context.set_task_input(
|
||||
DefaultInputContext([dag_ctx.current_task_context])
|
||||
)
|
||||
if dag_ctx.streaming_call:
|
||||
task_output = await StreamingLLMOperator._do_run(self, dag_ctx)
|
||||
else:
|
||||
task_output = await LLMOperator._do_run(self, dag_ctx)
|
||||
|
||||
return task_output
|
||||
|
||||
async def after_dag_end(self, event_loop_task_id: int):
|
||||
model_output: Optional[
|
||||
ModelOutput
|
||||
] = await self.current_dag_context.get_from_share_data(
|
||||
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT
|
||||
)
|
||||
model_output_view: Optional[
|
||||
str
|
||||
] = await self.current_dag_context.get_from_share_data(
|
||||
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW
|
||||
)
|
||||
storage_conv = await self.get_storage_conversation()
|
||||
end_current_round: bool = False
|
||||
if model_output and storage_conv:
|
||||
# Save model output message to storage
|
||||
storage_conv.add_ai_message(model_output.text)
|
||||
end_current_round = True
|
||||
if model_output_view and storage_conv:
|
||||
# Save model output view to storage
|
||||
storage_conv.add_view_message(model_output_view)
|
||||
end_current_round = True
|
||||
if end_current_round:
|
||||
# End current conversation round and flush to storage
|
||||
storage_conv.end_current_round()
|
||||
|
||||
async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
|
||||
dynamic_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, HOContextBody):
|
||||
dynamic_inputs.append(arg)
|
||||
# Load and store chat history, default use InMemoryStorage.
|
||||
storage_conv, history_messages = await self.blocking_func_to_async(
|
||||
self._build_storage, req
|
||||
)
|
||||
# Save the storage conversation to share data, for the child operators
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||
)
|
||||
|
||||
user_input = (
|
||||
req.messages[-1] if isinstance(req.messages, list) else req.messages
|
||||
)
|
||||
prompt_dict = {
|
||||
self._user_message_key: user_input,
|
||||
}
|
||||
for dynamic_input in dynamic_inputs:
|
||||
if dynamic_input.context_key in prompt_dict:
|
||||
raise ValueError(
|
||||
f"Duplicate context key '{dynamic_input.context_key}' in upstream "
|
||||
f"operators."
|
||||
)
|
||||
prompt_dict[dynamic_input.context_key] = dynamic_input.context
|
||||
|
||||
call_data = {
|
||||
"messages": history_messages,
|
||||
"prompt_dict": prompt_dict,
|
||||
}
|
||||
end_node: BaseOperator = cast(BaseOperator, self.sub_compose_dag.leaf_nodes[0])
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
|
||||
model_request = ModelRequest.build_request(
|
||||
model=req.model,
|
||||
messages=messages,
|
||||
context=req.context,
|
||||
temperature=req.temperature,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
span_id=root_tracer.get_current_span_id(),
|
||||
echo=False,
|
||||
)
|
||||
if storage_conv:
|
||||
# Start new round
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(user_input)
|
||||
return model_request
|
||||
|
||||
@property
|
||||
def sub_compose_dag(self) -> DAG:
|
||||
if not self._sub_compose_dag:
|
||||
self._sub_compose_dag = self._build_conversation_composer_dag()
|
||||
return self._sub_compose_dag
|
||||
|
||||
def _build_storage(
|
||||
self, req: CommonLLMHttpRequestBody
|
||||
) -> Tuple[StorageConversation, List[BaseMessage]]:
|
||||
# Create a new storage conversation, this will load the conversation from
|
||||
# storage, so we must do this async
|
||||
storage_conv: StorageConversation = StorageConversation(
|
||||
conv_uid=req.conv_uid,
|
||||
chat_mode=req.chat_mode,
|
||||
user_name=req.user_name,
|
||||
sys_code=req.sys_code,
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
param_type="",
|
||||
param_value=req.chat_param,
|
||||
)
|
||||
# Get history messages from storage
|
||||
history_messages: List[BaseMessage] = storage_conv.get_history_message(
|
||||
include_system_message=False
|
||||
)
|
||||
|
||||
return storage_conv, history_messages
|
||||
|
||||
def _build_conversation_composer_dag(self) -> DAG:
|
||||
default_dag_variables = self.dag._default_dag_variables if self.dag else None
|
||||
with DAG(
|
||||
"dbgpt_awel_app_chat_history_prompt_composer",
|
||||
default_dag_variables=default_dag_variables,
|
||||
) as composer_dag:
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
# History transform task
|
||||
if self._history_merge_mode == "token":
|
||||
history_transform_task = TokenBufferedConversationMapperOperator(
|
||||
model=self._model,
|
||||
llm_client=self.llm_client,
|
||||
max_token_limit=self._max_token_limit,
|
||||
)
|
||||
else:
|
||||
history_transform_task = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=self._keep_start_rounds,
|
||||
keep_end_rounds=self._keep_end_rounds,
|
||||
)
|
||||
if self._history_key:
|
||||
history_key = self._history_key
|
||||
else:
|
||||
placeholders = self._prompt_template.get_placeholders()
|
||||
if not placeholders or len(placeholders) != 1:
|
||||
raise ValueError(
|
||||
"The prompt template must have exactly one placeholder if "
|
||||
"history_key is not provided."
|
||||
)
|
||||
history_key = placeholders[0]
|
||||
history_prompt_build_task = HistoryPromptBuilderOperator(
|
||||
prompt=self._prompt_template,
|
||||
history_key=history_key,
|
||||
check_storage=False,
|
||||
save_to_storage=False,
|
||||
str_history=self._str_history,
|
||||
)
|
||||
# Build composer dag
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x["messages"])
|
||||
>> history_transform_task
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x["prompt_dict"])
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
|
||||
return composer_dag
|
||||
|
||||
|
||||
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
ChatPromptTemplate,
|
||||
description=_("The prompt template for the conversation."),
|
||||
)
|
||||
_PARAMETER_MODEL = Parameter.build_from(
|
||||
_("Model Name"),
|
||||
"model",
|
||||
str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The model name."),
|
||||
)
|
||||
|
||||
_PARAMETER_LLM_CLIENT = Parameter.build_from(
|
||||
_("LLM Client"),
|
||||
"llm_client",
|
||||
LLMClient,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_(
|
||||
"The LLM Client, how to connect to the LLM model, if not provided, it will use"
|
||||
" the default client deployed by DB-GPT."
|
||||
),
|
||||
)
|
||||
_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from(
|
||||
_("History Message Merge Mode"),
|
||||
"history_merge_mode",
|
||||
str,
|
||||
optional=True,
|
||||
default="none",
|
||||
options=[
|
||||
OptionValue(label="No History", name="none", value="none"),
|
||||
OptionValue(label="Message Window", name="window", value="window"),
|
||||
OptionValue(label="Token Length", name="token", value="token"),
|
||||
],
|
||||
description=_(
|
||||
"The history merge mode, supports 'none', 'window' and 'token'."
|
||||
" 'none': no history merge, 'window': merge by conversation window, 'token': "
|
||||
"merge by token length."
|
||||
),
|
||||
ui=ui.UISelect(),
|
||||
)
|
||||
_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from(
|
||||
_("User Message Key"),
|
||||
"user_message_key",
|
||||
str,
|
||||
optional=True,
|
||||
default="user_input",
|
||||
description=_(
|
||||
"The key of the user message in your prompt, default is 'user_input'."
|
||||
),
|
||||
)
|
||||
_PARAMETER_HISTORY_KEY = Parameter.build_from(
|
||||
_("History Key"),
|
||||
"history_key",
|
||||
str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_(
|
||||
"The chat history key, with chat history message pass to prompt template, "
|
||||
"if not provided, it will parse the prompt template to get the key."
|
||||
),
|
||||
)
|
||||
_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from(
|
||||
_("Keep Start Rounds"),
|
||||
"keep_start_rounds",
|
||||
int,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The start rounds to keep in the chat history."),
|
||||
)
|
||||
_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from(
|
||||
_("Keep End Rounds"),
|
||||
"keep_end_rounds",
|
||||
int,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The end rounds to keep in the chat history."),
|
||||
)
|
||||
_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from(
|
||||
_("Max Token Limit"),
|
||||
"max_token_limit",
|
||||
int,
|
||||
optional=True,
|
||||
default=2048,
|
||||
description=_("The max token limit to keep in the chat history."),
|
||||
)
|
||||
|
||||
_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from(
|
||||
_("Common LLM Request Body"),
|
||||
"common_llm_request_body",
|
||||
CommonLLMHttpRequestBody,
|
||||
_("The common LLM request body."),
|
||||
)
|
||||
_INPUTS_EXTRA_CONTEXT = IOField.build_from(
|
||||
_("Extra Context"),
|
||||
"extra_context",
|
||||
HOContextBody,
|
||||
_(
|
||||
"Extra context for building prompt(Knowledge context, database "
|
||||
"schema, etc), you can add multiple context."
|
||||
),
|
||||
dynamic=True,
|
||||
)
|
||||
_OUTPUTS_MODEL_OUTPUT = IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output."),
|
||||
)
|
||||
_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from(
|
||||
_("Streaming Model Output"),
|
||||
"streaming_model_output",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description=_("The streaming model output."),
|
||||
)
|
||||
|
||||
|
||||
class HOLLMOperator(BaseHOLLMOperator):
|
||||
metadata = ViewMetadata(
|
||||
label=_("LLM Operator"),
|
||||
name="higher_order_llm_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
description=_(
|
||||
"High-level LLM operator, supports multi-round conversation "
|
||||
"(conversation window, token length and no multi-round)."
|
||||
),
|
||||
parameters=[
|
||||
_PARAMETER_PROMPT_TEMPLATE.new(),
|
||||
_PARAMETER_MODEL.new(),
|
||||
_PARAMETER_LLM_CLIENT.new(),
|
||||
_PARAMETER_HISTORY_MERGE_MODE.new(),
|
||||
_PARAMETER_USER_MESSAGE_KEY.new(),
|
||||
_PARAMETER_HISTORY_KEY.new(),
|
||||
_PARAMETER_KEEP_START_ROUNDS.new(),
|
||||
_PARAMETER_KEEP_END_ROUNDS.new(),
|
||||
_PARAMETER_MAX_TOKEN_LIMIT.new(),
|
||||
],
|
||||
inputs=[
|
||||
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
|
||||
_INPUTS_EXTRA_CONTEXT.new(),
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_MODEL_OUTPUT.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class HOStreamingLLMOperator(BaseHOLLMOperator):
|
||||
metadata = ViewMetadata(
|
||||
label=_("Streaming LLM Operator"),
|
||||
name="higher_order_streaming_llm_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
description=_(
|
||||
"High-level streaming LLM operator, supports multi-round conversation "
|
||||
"(conversation window, token length and no multi-round)."
|
||||
),
|
||||
parameters=[
|
||||
_PARAMETER_PROMPT_TEMPLATE.new(),
|
||||
_PARAMETER_MODEL.new(),
|
||||
_PARAMETER_LLM_CLIENT.new(),
|
||||
_PARAMETER_HISTORY_MERGE_MODE.new(),
|
||||
_PARAMETER_USER_MESSAGE_KEY.new(),
|
||||
_PARAMETER_HISTORY_KEY.new(),
|
||||
_PARAMETER_KEEP_START_ROUNDS.new(),
|
||||
_PARAMETER_KEEP_END_ROUNDS.new(),
|
||||
_PARAMETER_MAX_TOKEN_LIMIT.new(),
|
||||
],
|
||||
inputs=[
|
||||
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
|
||||
_INPUTS_EXTRA_CONTEXT.new(),
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_STREAMING_MODEL_OUTPUT.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
210
dbgpt/app/operators/rag.py
Normal file
210
dbgpt/app/operators/rag.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
FunctionDynamicOptions,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .llm import HOContextBody
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _load_space_name() -> List[OptionValue]:
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
|
||||
spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity())
|
||||
return [
|
||||
OptionValue(label=space.name, name=space.name, value=space.name)
|
||||
for space in spaces
|
||||
]
|
||||
|
||||
|
||||
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
|
||||
_("Context Key"),
|
||||
"context",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="context",
|
||||
description=_("The key of the context, it will be used in building the prompt"),
|
||||
)
|
||||
_PARAMETER_TOP_K = Parameter.build_from(
|
||||
_("Top K"),
|
||||
"top_k",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=5,
|
||||
description=_("The number of chunks to retrieve"),
|
||||
)
|
||||
_PARAMETER_SCORE_THRESHOLD = Parameter.build_from(
|
||||
_("Minimum Match Score"),
|
||||
"score_threshold",
|
||||
type=float,
|
||||
optional=True,
|
||||
default=0.3,
|
||||
description=_(
|
||||
_(
|
||||
"The minimum match score for the retrieved chunks, it will be dropped if "
|
||||
"the match score is less than the threshold"
|
||||
)
|
||||
),
|
||||
ui=ui.UISlider(attr=ui.UISlider.UIAttribute(min=0.0, max=1.0, step=0.1)),
|
||||
)
|
||||
|
||||
_PARAMETER_RE_RANKER_ENABLED = Parameter.build_from(
|
||||
_("Reranker Enabled"),
|
||||
"reranker_enabled",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("Whether to enable the reranker"),
|
||||
)
|
||||
_PARAMETER_RE_RANKER_TOP_K = Parameter.build_from(
|
||||
_("Reranker Top K"),
|
||||
"reranker_top_k",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=3,
|
||||
description=_("The top k for the reranker"),
|
||||
)
|
||||
|
||||
_INPUTS_QUESTION = IOField.build_from(
|
||||
_("User question"),
|
||||
"query",
|
||||
str,
|
||||
description=_("The user question to retrieve the knowledge"),
|
||||
)
|
||||
_OUTPUTS_CONTEXT = IOField.build_from(
|
||||
_("Retrieved context"),
|
||||
"context",
|
||||
HOContextBody,
|
||||
description=_("The retrieved context from the knowledge space"),
|
||||
)
|
||||
|
||||
|
||||
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
_share_data_key = "_higher_order_knowledge_operator_retriever_chunks"
|
||||
|
||||
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
|
||||
async def map(self, context: HOContextBody) -> List[Chunk]:
|
||||
chunks = await self.current_dag_context.get_from_share_data(
|
||||
HOKnowledgeOperator._share_data_key
|
||||
)
|
||||
return chunks
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Knowledge Operator"),
|
||||
name="higher_order_knowledge_operator",
|
||||
category=OperatorCategory.RAG,
|
||||
description=_(
|
||||
_(
|
||||
"Knowledge Operator, retrieve your knowledge(documents) from knowledge"
|
||||
" space"
|
||||
)
|
||||
),
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Knowledge Space Name"),
|
||||
"knowledge_space",
|
||||
type=str,
|
||||
options=FunctionDynamicOptions(func=_load_space_name),
|
||||
description=_("The name of the knowledge space"),
|
||||
),
|
||||
_PARAMETER_CONTEXT_KEY.new(),
|
||||
_PARAMETER_TOP_K.new(),
|
||||
_PARAMETER_SCORE_THRESHOLD.new(),
|
||||
_PARAMETER_RE_RANKER_ENABLED.new(),
|
||||
_PARAMETER_RE_RANKER_TOP_K.new(),
|
||||
],
|
||||
inputs=[
|
||||
_INPUTS_QUESTION.new(),
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
IOField.build_from(
|
||||
_("Chunks"),
|
||||
"chunks",
|
||||
Chunk,
|
||||
is_list=True,
|
||||
description=_("The retrieved chunks from the knowledge space"),
|
||||
mappers=[ChunkMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_space: str,
|
||||
context_key: Optional[str] = "context",
|
||||
top_k: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
reranker_enabled: Optional[bool] = None,
|
||||
reranker_top_k: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._knowledge_space = knowledge_space
|
||||
self._context_key = context_key
|
||||
self._top_k = top_k
|
||||
self._score_threshold = score_threshold
|
||||
self._reranker_enabled = reranker_enabled
|
||||
self._reranker_top_k = reranker_top_k
|
||||
|
||||
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.serve.rag.models.models import (
|
||||
KnowledgeSpaceDao,
|
||||
KnowledgeSpaceEntity,
|
||||
)
|
||||
|
||||
spaces = KnowledgeSpaceDao().get_knowledge_space(
|
||||
KnowledgeSpaceEntity(name=knowledge_space)
|
||||
)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"invalid space name: {knowledge_space}")
|
||||
space = spaces[0]
|
||||
|
||||
reranker: Optional[RerankEmbeddingsRanker] = None
|
||||
|
||||
if CFG.RERANK_MODEL and self._reranker_enabled:
|
||||
reranker_top_k = (
|
||||
self._reranker_top_k
|
||||
if self._reranker_top_k is not None
|
||||
else CFG.RERANK_TOP_K
|
||||
)
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(
|
||||
CFG.SYSTEM_APP
|
||||
).create()
|
||||
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=reranker_top_k)
|
||||
if self._top_k < reranker_top_k or self._top_k < 20:
|
||||
# We use reranker, so if the top_k is less than 20,
|
||||
# we need to set it to 20
|
||||
self._top_k = max(reranker_top_k, 20)
|
||||
|
||||
self._space_retriever = KnowledgeSpaceRetriever(
|
||||
space_id=space.id,
|
||||
top_k=self._top_k,
|
||||
rerank=reranker,
|
||||
)
|
||||
|
||||
async def map(self, query: str) -> HOContextBody:
|
||||
chunks = await self._space_retriever.aretrieve_with_scores(
|
||||
query, self._score_threshold
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(self._share_data_key, chunks)
|
||||
return HOContextBody(
|
||||
context_key=self._context_key,
|
||||
context=[chunk.content for chunk in chunks],
|
||||
)
|
Reference in New Issue
Block a user