feat(core): Support higher-order operators (#1984)

Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
Fangyin Cheng
2024-09-09 10:15:37 +08:00
committed by GitHub
parent f6d5fc4595
commit 65c875db20
62 changed files with 6281 additions and 386 deletions

View File

@@ -36,7 +36,7 @@ def initialize_components(
system_app.register(
DefaultExecutorFactory, max_workers=param.default_thread_pool_size
)
system_app.register(DefaultScheduler)
system_app.register(DefaultScheduler, scheduler_enable=CFG.SCHEDULER_ENABLED)
system_app.register_instance(controller)
system_app.register(ConnectorManager)
@@ -60,6 +60,7 @@ def initialize_components(
_initialize_openapi(system_app)
# Register serve apps
register_serve_apps(system_app, CFG, param.port)
_initialize_operators()
def _initialize_model_cache(system_app: SystemApp, port: int):
@@ -128,3 +129,14 @@ def _initialize_openapi(system_app: SystemApp):
from dbgpt.app.openapi.api_v1.editor.service import EditorService
system_app.register(EditorService)
def _initialize_operators():
from dbgpt.app.operators.converter import StringToInteger
from dbgpt.app.operators.datasource import (
HODatasourceExecutorOperator,
HODatasourceRetrieverOperator,
)
from dbgpt.app.operators.llm import HOLLMOperator, HOStreamingLLMOperator
from dbgpt.app.operators.rag import HOKnowledgeOperator
from dbgpt.serve.agent.resource.datasource import DatasourceResource

View File

@@ -19,12 +19,14 @@ class DefaultScheduler(BaseComponent):
system_app: SystemApp,
scheduler_delay_ms: int = 5000,
scheduler_interval_ms: int = 1000,
scheduler_enable: bool = True,
):
super().__init__(system_app)
self.system_app = system_app
self._scheduler_interval_ms = scheduler_interval_ms
self._scheduler_delay_ms = scheduler_delay_ms
self._stop_event = threading.Event()
self._scheduler_enable = scheduler_enable
def init_app(self, system_app: SystemApp):
self.system_app = system_app
@@ -39,7 +41,7 @@ class DefaultScheduler(BaseComponent):
def _scheduler(self):
time.sleep(self._scheduler_delay_ms / 1000)
while not self._stop_event.is_set():
while self._scheduler_enable and not self._stop_event.is_set():
try:
schedule.run_pending()
except Exception as e:

View File

@@ -0,0 +1,4 @@
"""Operators package.
This package contains all higher-order operators that are used to build workflows.
"""

View File

@@ -0,0 +1,186 @@
"""Type Converter Operators."""
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
Parameter,
ViewMetadata,
)
from dbgpt.util.i18n_utils import _
_INPUTS_STRING = IOField.build_from(
_("String"),
"string",
str,
description=_("The string to be converted to other types."),
)
_INPUTS_INTEGER = IOField.build_from(
_("Integer"),
"integer",
int,
description=_("The integer to be converted to other types."),
)
_INPUTS_FLOAT = IOField.build_from(
_("Float"),
"float",
float,
description=_("The float to be converted to other types."),
)
_INPUTS_BOOLEAN = IOField.build_from(
_("Boolean"),
"boolean",
bool,
description=_("The boolean to be converted to other types."),
)
_OUTPUTS_STRING = IOField.build_from(
_("String"),
"string",
str,
description=_("The string converted from other types."),
)
_OUTPUTS_INTEGER = IOField.build_from(
_("Integer"),
"integer",
int,
description=_("The integer converted from other types."),
)
_OUTPUTS_FLOAT = IOField.build_from(
_("Float"),
"float",
float,
description=_("The float converted from other types."),
)
_OUTPUTS_BOOLEAN = IOField.build_from(
_("Boolean"),
"boolean",
bool,
description=_("The boolean converted from other types."),
)
class StringToInteger(MapOperator[str, int]):
"""Converts a string to an integer."""
metadata = ViewMetadata(
label=_("String to Integer"),
name="default_converter_string_to_integer",
description=_("Converts a string to an integer."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_STRING],
outputs=[_OUTPUTS_INTEGER],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new StringToInteger operator."""
super().__init__(map_function=lambda x: int(x), **kwargs)
class StringToFloat(MapOperator[str, float]):
"""Converts a string to a float."""
metadata = ViewMetadata(
label=_("String to Float"),
name="default_converter_string_to_float",
description=_("Converts a string to a float."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_STRING],
outputs=[_OUTPUTS_FLOAT],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new StringToFloat operator."""
super().__init__(map_function=lambda x: float(x), **kwargs)
class StringToBoolean(MapOperator[str, bool]):
"""Converts a string to a boolean."""
metadata = ViewMetadata(
label=_("String to Boolean"),
name="default_converter_string_to_boolean",
description=_("Converts a string to a boolean, true: 'true', '1', 'y'"),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[
Parameter.build_from(
_("True Values"),
"true_values",
str,
optional=True,
default="true,1,y",
description=_("Comma-separated values that should be treated as True."),
)
],
inputs=[_INPUTS_STRING],
outputs=[_OUTPUTS_BOOLEAN],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, true_values: str = "true,1,y", **kwargs):
"""Create a new StringToBoolean operator."""
true_values_list = true_values.split(",")
true_values_list = [x.strip().lower() for x in true_values_list]
super().__init__(map_function=lambda x: x.lower() in true_values_list, **kwargs)
class IntegerToString(MapOperator[int, str]):
"""Converts an integer to a string."""
metadata = ViewMetadata(
label=_("Integer to String"),
name="default_converter_integer_to_string",
description=_("Converts an integer to a string."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_INTEGER],
outputs=[_OUTPUTS_STRING],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new IntegerToString operator."""
super().__init__(map_function=lambda x: str(x), **kwargs)
class FloatToString(MapOperator[float, str]):
"""Converts a float to a string."""
metadata = ViewMetadata(
label=_("Float to String"),
name="default_converter_float_to_string",
description=_("Converts a float to a string."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_FLOAT],
outputs=[_OUTPUTS_STRING],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new FloatToString operator."""
super().__init__(map_function=lambda x: str(x), **kwargs)
class BooleanToString(MapOperator[bool, str]):
"""Converts a boolean to a string."""
metadata = ViewMetadata(
label=_("Boolean to String"),
name="default_converter_boolean_to_string",
description=_("Converts a boolean to a string."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_BOOLEAN],
outputs=[_OUTPUTS_STRING],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new BooleanToString operator."""
super().__init__(map_function=lambda x: str(x), **kwargs)

View File

@@ -0,0 +1,363 @@
import json
import logging
from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBResource
from dbgpt.core import Chunk
from dbgpt.core.awel import DAGContext, MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.core.operators import BaseLLM
from dbgpt.util.i18n_utils import _
from dbgpt.vis.tags.vis_chart import default_chart_type_prompt
from .llm import HOContextBody
logger = logging.getLogger(__name__)
CFG = Config()
_DEFAULT_CHART_TYPE = default_chart_type_prompt()
_DEFAULT_TEMPLATE_EN = """You are a database expert.
Please answer the user's question based on the database selected by the user and some \
of the available table structure definitions of the database.
Database name:
{db_name}
Table structure definition:
{table_info}
Constraint:
1.Please understand the user's intention based on the user's question, and use the \
given table structure definition to create a grammatically correct {dialect} sql. \
If sql is not required, answer the user's question directly..
2.Always limit the query to a maximum of {max_num_results} results unless the user \
specifies in the question the specific number of rows of data he wishes to obtain.
3.You can only use the tables provided in the table structure information to \
generate sql. If you cannot generate sql based on the provided table structure, \
please say: "The table structure information provided is not enough to generate \
sql queries." It is prohibited to fabricate information at will.
4.Please be careful not to mistake the relationship between tables and columns \
when generating SQL.
5.Please check the correctness of the SQL and ensure that the query performance is \
optimized under correct conditions.
6.Please choose the best one from the display methods given below for data \
rendering, and put the type name into the name parameter value that returns the \
required format. If you cannot find the most suitable one, use 'Table' as the \
display method. , the available data display methods are as follows: {display_type}
User Question:
{user_input}
Please think step by step and respond according to the following JSON format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads.
"""
_DEFAULT_TEMPLATE_ZH = """你是一个数据库专家.
请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题.
数据库名:
{db_name}
表结构定义:
{table_info}
约束:
1. 请根据用户问题理解用户意图,使用给出表结构定义创建一个语法正确的 {dialect} sql如果不需要 \
sql则直接回答用户问题。
2. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {max_num_results} \
个结果。
3. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql ,请说:\
“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
4. 请注意生成SQL时不要弄错表和列的关系
5. 请检查SQL的正确性并保证正确的情况下优化查询性能
6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染将类型名称放入返回要求格式的name参数值种\
,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type}
用户问题:
{user_input}
请一步步思考并按照以下JSON格式回复
{response}
确保返回正确的json并且可以被Python json.loads方法解析.
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
_DEFAULT_RESPONSE = json.dumps(
{
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
"display_type": "Data display method",
},
ensure_ascii=False,
indent=4,
)
_PARAMETER_DATASOURCE = Parameter.build_from(
_("Datasource"),
"datasource",
type=DBResource,
description=_("The datasource to retrieve the context"),
)
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
_("Prompt Template"),
"prompt_template",
type=str,
optional=True,
default=_DEFAULT_TEMPLATE,
description=_("The prompt template to build a database prompt"),
ui=ui.DefaultUITextArea(),
)
_PARAMETER_DISPLAY_TYPE = Parameter.build_from(
_("Display Type"),
"display_type",
type=str,
optional=True,
default=_DEFAULT_CHART_TYPE,
description=_("The display type for the data"),
ui=ui.DefaultUITextArea(),
)
_PARAMETER_MAX_NUM_RESULTS = Parameter.build_from(
_("Max Number of Results"),
"max_num_results",
type=int,
optional=True,
default=50,
description=_("The maximum number of results to return"),
)
_PARAMETER_RESPONSE_FORMAT = Parameter.build_from(
_("Response Format"),
"response_format",
type=str,
optional=True,
default=_DEFAULT_RESPONSE,
description=_("The response format, default is a JSON format"),
ui=ui.DefaultUITextArea(),
)
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
_("Context Key"),
"context_key",
type=str,
optional=True,
default="context",
description=_("The key of the context, it will be used in building the prompt"),
)
_INPUTS_QUESTION = IOField.build_from(
_("User question"),
"query",
str,
description=_("The user question to retrieve table schemas from the datasource"),
)
_OUTPUTS_CONTEXT = IOField.build_from(
_("Retrieved context"),
"context",
HOContextBody,
description=_("The retrieved context from the datasource"),
)
_INPUTS_SQL_DICT = IOField.build_from(
_("SQL dict"),
"sql_dict",
dict,
description=_("The SQL to be executed wrapped in a dictionary, generated by LLM"),
)
_OUTPUTS_SQL_RESULT = IOField.build_from(
_("SQL result"),
"sql_result",
str,
description=_("The result of the SQL execution"),
)
_INPUTS_SQL_DICT_LIST = IOField.build_from(
_("SQL dict list"),
"sql_dict_list",
dict,
description=_(
"The SQL list to be executed wrapped in a dictionary, generated by LLM"
),
is_list=True,
)
class GPTVisMixin:
async def save_view_message(self, dag_ctx: DAGContext, view: str):
"""Save the view message."""
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
"""Retrieve the table schemas from the datasource."""
_share_data_key = "__datasource_retriever_chunks__"
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
async def map(self, context: HOContextBody) -> List[Chunk]:
schema_info = await self.current_dag_context.get_from_share_data(
HODatasourceRetrieverOperator._share_data_key
)
if isinstance(schema_info, list):
chunks = [Chunk(content=table_info) for table_info in schema_info]
else:
chunks = [Chunk(content=schema_info)]
return chunks
metadata = ViewMetadata(
label=_("Datasource Retriever Operator"),
name="higher_order_datasource_retriever_operator",
description=_("Retrieve the table schemas from the datasource."),
category=OperatorCategory.DATABASE,
parameters=[
_PARAMETER_DATASOURCE.new(),
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_DISPLAY_TYPE.new(),
_PARAMETER_MAX_NUM_RESULTS.new(),
_PARAMETER_RESPONSE_FORMAT.new(),
_PARAMETER_CONTEXT_KEY.new(),
],
inputs=[_INPUTS_QUESTION.new()],
outputs=[
_OUTPUTS_CONTEXT.new(),
IOField.build_from(
_("Retrieved schema chunks"),
"chunks",
Chunk,
is_list=True,
description=_("The retrieved schema chunks from the datasource"),
mappers=[ChunkMapper],
),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(
self,
datasource: DBResource,
prompt_template: str = _DEFAULT_TEMPLATE,
display_type: str = _DEFAULT_CHART_TYPE,
max_num_results: int = 50,
response_format: str = _DEFAULT_RESPONSE,
context_key: Optional[str] = "context",
**kwargs,
):
"""Initialize the operator."""
super().__init__(**kwargs)
self._datasource = datasource
self._prompt_template = prompt_template
self._display_type = display_type
self._max_num_results = max_num_results
self._response_format = response_format
self._context_key = context_key
async def map(self, question: str) -> HOContextBody:
"""Retrieve the context from the datasource."""
db_name = self._datasource._db_name
dialect = self._datasource.dialect
schema_info = await self.blocking_func_to_async(
self._datasource.get_schema_link,
db=db_name,
question=question,
)
await self.current_dag_context.save_to_share_data(
self._share_data_key, schema_info
)
context = self._prompt_template.format(
db_name=db_name,
table_info=schema_info,
dialect=dialect,
max_num_results=self._max_num_results,
display_type=self._display_type,
user_input=question,
response=self._response_format,
)
return HOContextBody(
context_key=self._context_key,
context=context,
)
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
"""Execute the context from the datasource."""
metadata = ViewMetadata(
label=_("Datasource Executor Operator"),
name="higher_order_datasource_executor_operator",
description=_("Execute the context from the datasource."),
category=OperatorCategory.DATABASE,
parameters=[_PARAMETER_DATASOURCE.new()],
inputs=[_INPUTS_SQL_DICT.new()],
outputs=[_OUTPUTS_SQL_RESULT.new()],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, datasource: DBResource, **kwargs):
"""Initialize the operator."""
MapOperator.__init__(self, **kwargs)
self._datasource = datasource
async def map(self, sql_dict: dict) -> str:
"""Execute the context from the datasource."""
from dbgpt.vis.tags.vis_chart import VisChart
if not isinstance(sql_dict, dict):
raise ValueError(
"The input value of datasource executor should be a dictionary."
)
vis = VisChart()
sql = sql_dict.get("sql")
if not sql:
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
data_df = await self._datasource.query_to_df(sql)
view = await vis.display(chart=sql_dict, data_df=data_df)
await self.save_view_message(self.current_dag_context, view)
return view
class HODatasourceDashboardOperator(GPTVisMixin, MapOperator[dict, str]):
"""Execute the context from the datasource."""
metadata = ViewMetadata(
label=_("Datasource Dashboard Operator"),
name="higher_order_datasource_dashboard_operator",
description=_("Execute the context from the datasource."),
category=OperatorCategory.DATABASE,
parameters=[_PARAMETER_DATASOURCE.new()],
inputs=[_INPUTS_SQL_DICT_LIST.new()],
outputs=[_OUTPUTS_SQL_RESULT.new()],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, datasource: DBResource, **kwargs):
"""Initialize the operator."""
MapOperator.__init__(self, **kwargs)
self._datasource = datasource
async def map(self, sql_dict_list: List[dict]) -> str:
"""Execute the context from the datasource."""
from dbgpt.vis.tags.vis_dashboard import VisDashboard
if not isinstance(sql_dict_list, list):
raise ValueError(
"The input value of datasource executor should be a list of dictionaries."
)
vis = VisDashboard()
chart_params = []
for chart_item in sql_dict_list:
chart_dict = {k: v for k, v in chart_item.items()}
sql = chart_item.get("sql")
try:
data_df = await self._datasource.query_to_df(sql)
chart_dict["data"] = data_df
except Exception as e:
logger.warning(f"Sql execute failed{str(e)}")
chart_dict["err_msg"] = str(e)
chart_params.append(chart_dict)
view = await vis.display(charts=chart_params)
await self.save_view_message(self.current_dag_context, view)
return view

453
dbgpt/app/operators/llm.py Normal file
View File

@@ -0,0 +1,453 @@
from typing import List, Literal, Optional, Tuple, Union, cast
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
BaseMessage,
ChatPromptTemplate,
LLMClient,
ModelOutput,
ModelRequest,
StorageConversation,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
CommonLLMHttpRequestBody,
DAGContext,
DefaultInputContext,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
TaskOutput,
)
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OptionValue,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.core.interface.operators.message_operator import (
BaseConversationOperator,
BufferedConversationMapperOperator,
TokenBufferedConversationMapperOperator,
)
from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import root_tracer
class HOContextBody(BaseModel):
"""Higher-order context body."""
context_key: str = Field(
"context",
description=_("The context key can be used as the key for formatting prompt."),
)
context: Union[str, List[str]] = Field(
...,
description=_("The context."),
)
class BaseHOLLMOperator(
BaseConversationOperator,
JoinOperator[ModelRequest],
LLMOperator,
StreamingLLMOperator,
):
"""Higher-order model request builder operator."""
def __init__(
self,
prompt_template: ChatPromptTemplate,
model: str = None,
llm_client: Optional[LLMClient] = None,
history_merge_mode: Literal["none", "window", "token"] = "window",
user_message_key: str = "user_input",
history_key: Optional[str] = None,
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
max_token_limit: int = 2048,
**kwargs,
):
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
LLMOperator.__init__(self, llm_client=llm_client, **kwargs)
StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
# User must select a history merge mode
self._history_merge_mode = history_merge_mode
self._user_message_key = user_message_key
self._has_history = history_merge_mode != "none"
self._prompt_template = prompt_template
self._model = model
self._history_key = history_key
self._str_history = False
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
self._max_token_limit = max_token_limit
self._sub_compose_dag: Optional[DAG] = None
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
conv_serve = ConversationServe.get_instance(self.system_app)
self._storage = conv_serve.conv_storage
self._message_storage = conv_serve.message_storage
_: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx)
dag_ctx.current_task_context.set_task_input(
DefaultInputContext([dag_ctx.current_task_context])
)
if dag_ctx.streaming_call:
task_output = await StreamingLLMOperator._do_run(self, dag_ctx)
else:
task_output = await LLMOperator._do_run(self, dag_ctx)
return task_output
async def after_dag_end(self, event_loop_task_id: int):
model_output: Optional[
ModelOutput
] = await self.current_dag_context.get_from_share_data(
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT
)
model_output_view: Optional[
str
] = await self.current_dag_context.get_from_share_data(
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW
)
storage_conv = await self.get_storage_conversation()
end_current_round: bool = False
if model_output and storage_conv:
# Save model output message to storage
storage_conv.add_ai_message(model_output.text)
end_current_round = True
if model_output_view and storage_conv:
# Save model output view to storage
storage_conv.add_view_message(model_output_view)
end_current_round = True
if end_current_round:
# End current conversation round and flush to storage
storage_conv.end_current_round()
async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
dynamic_inputs = []
for arg in args:
if isinstance(arg, HOContextBody):
dynamic_inputs.append(arg)
# Load and store chat history, default use InMemoryStorage.
storage_conv, history_messages = await self.blocking_func_to_async(
self._build_storage, req
)
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
user_input = (
req.messages[-1] if isinstance(req.messages, list) else req.messages
)
prompt_dict = {
self._user_message_key: user_input,
}
for dynamic_input in dynamic_inputs:
if dynamic_input.context_key in prompt_dict:
raise ValueError(
f"Duplicate context key '{dynamic_input.context_key}' in upstream "
f"operators."
)
prompt_dict[dynamic_input.context_key] = dynamic_input.context
call_data = {
"messages": history_messages,
"prompt_dict": prompt_dict,
}
end_node: BaseOperator = cast(BaseOperator, self.sub_compose_dag.leaf_nodes[0])
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
model_request = ModelRequest.build_request(
model=req.model,
messages=messages,
context=req.context,
temperature=req.temperature,
max_new_tokens=req.max_new_tokens,
span_id=root_tracer.get_current_span_id(),
echo=False,
)
if storage_conv:
# Start new round
storage_conv.start_new_round()
storage_conv.add_user_message(user_input)
return model_request
@property
def sub_compose_dag(self) -> DAG:
if not self._sub_compose_dag:
self._sub_compose_dag = self._build_conversation_composer_dag()
return self._sub_compose_dag
def _build_storage(
self, req: CommonLLMHttpRequestBody
) -> Tuple[StorageConversation, List[BaseMessage]]:
# Create a new storage conversation, this will load the conversation from
# storage, so we must do this async
storage_conv: StorageConversation = StorageConversation(
conv_uid=req.conv_uid,
chat_mode=req.chat_mode,
user_name=req.user_name,
sys_code=req.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
param_type="",
param_value=req.chat_param,
)
# Get history messages from storage
history_messages: List[BaseMessage] = storage_conv.get_history_message(
include_system_message=False
)
return storage_conv, history_messages
def _build_conversation_composer_dag(self) -> DAG:
default_dag_variables = self.dag._default_dag_variables if self.dag else None
with DAG(
"dbgpt_awel_app_chat_history_prompt_composer",
default_dag_variables=default_dag_variables,
) as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# History transform task
if self._history_merge_mode == "token":
history_transform_task = TokenBufferedConversationMapperOperator(
model=self._model,
llm_client=self.llm_client,
max_token_limit=self._max_token_limit,
)
else:
history_transform_task = BufferedConversationMapperOperator(
keep_start_rounds=self._keep_start_rounds,
keep_end_rounds=self._keep_end_rounds,
)
if self._history_key:
history_key = self._history_key
else:
placeholders = self._prompt_template.get_placeholders()
if not placeholders or len(placeholders) != 1:
raise ValueError(
"The prompt template must have exactly one placeholder if "
"history_key is not provided."
)
history_key = placeholders[0]
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template,
history_key=history_key,
check_storage=False,
save_to_storage=False,
str_history=self._str_history,
)
# Build composer dag
(
input_task
>> MapOperator(lambda x: x["messages"])
>> history_transform_task
>> history_prompt_build_task
)
(
input_task
>> MapOperator(lambda x: x["prompt_dict"])
>> history_prompt_build_task
)
return composer_dag
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
_("Prompt Template"),
"prompt_template",
ChatPromptTemplate,
description=_("The prompt template for the conversation."),
)
_PARAMETER_MODEL = Parameter.build_from(
_("Model Name"),
"model",
str,
optional=True,
default=None,
description=_("The model name."),
)
_PARAMETER_LLM_CLIENT = Parameter.build_from(
_("LLM Client"),
"llm_client",
LLMClient,
optional=True,
default=None,
description=_(
"The LLM Client, how to connect to the LLM model, if not provided, it will use"
" the default client deployed by DB-GPT."
),
)
_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from(
_("History Message Merge Mode"),
"history_merge_mode",
str,
optional=True,
default="none",
options=[
OptionValue(label="No History", name="none", value="none"),
OptionValue(label="Message Window", name="window", value="window"),
OptionValue(label="Token Length", name="token", value="token"),
],
description=_(
"The history merge mode, supports 'none', 'window' and 'token'."
" 'none': no history merge, 'window': merge by conversation window, 'token': "
"merge by token length."
),
ui=ui.UISelect(),
)
_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from(
_("User Message Key"),
"user_message_key",
str,
optional=True,
default="user_input",
description=_(
"The key of the user message in your prompt, default is 'user_input'."
),
)
_PARAMETER_HISTORY_KEY = Parameter.build_from(
_("History Key"),
"history_key",
str,
optional=True,
default=None,
description=_(
"The chat history key, with chat history message pass to prompt template, "
"if not provided, it will parse the prompt template to get the key."
),
)
_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from(
_("Keep Start Rounds"),
"keep_start_rounds",
int,
optional=True,
default=None,
description=_("The start rounds to keep in the chat history."),
)
_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from(
_("Keep End Rounds"),
"keep_end_rounds",
int,
optional=True,
default=None,
description=_("The end rounds to keep in the chat history."),
)
_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from(
_("Max Token Limit"),
"max_token_limit",
int,
optional=True,
default=2048,
description=_("The max token limit to keep in the chat history."),
)
_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from(
_("Common LLM Request Body"),
"common_llm_request_body",
CommonLLMHttpRequestBody,
_("The common LLM request body."),
)
_INPUTS_EXTRA_CONTEXT = IOField.build_from(
_("Extra Context"),
"extra_context",
HOContextBody,
_(
"Extra context for building prompt(Knowledge context, database "
"schema, etc), you can add multiple context."
),
dynamic=True,
)
_OUTPUTS_MODEL_OUTPUT = IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output."),
)
_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from(
_("Streaming Model Output"),
"streaming_model_output",
ModelOutput,
is_list=True,
description=_("The streaming model output."),
)
class HOLLMOperator(BaseHOLLMOperator):
metadata = ViewMetadata(
label=_("LLM Operator"),
name="higher_order_llm_operator",
category=OperatorCategory.LLM,
description=_(
"High-level LLM operator, supports multi-round conversation "
"(conversation window, token length and no multi-round)."
),
parameters=[
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_MODEL.new(),
_PARAMETER_LLM_CLIENT.new(),
_PARAMETER_HISTORY_MERGE_MODE.new(),
_PARAMETER_USER_MESSAGE_KEY.new(),
_PARAMETER_HISTORY_KEY.new(),
_PARAMETER_KEEP_START_ROUNDS.new(),
_PARAMETER_KEEP_END_ROUNDS.new(),
_PARAMETER_MAX_TOKEN_LIMIT.new(),
],
inputs=[
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
_INPUTS_EXTRA_CONTEXT.new(),
],
outputs=[
_OUTPUTS_MODEL_OUTPUT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
class HOStreamingLLMOperator(BaseHOLLMOperator):
metadata = ViewMetadata(
label=_("Streaming LLM Operator"),
name="higher_order_streaming_llm_operator",
category=OperatorCategory.LLM,
description=_(
"High-level streaming LLM operator, supports multi-round conversation "
"(conversation window, token length and no multi-round)."
),
parameters=[
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_MODEL.new(),
_PARAMETER_LLM_CLIENT.new(),
_PARAMETER_HISTORY_MERGE_MODE.new(),
_PARAMETER_USER_MESSAGE_KEY.new(),
_PARAMETER_HISTORY_KEY.new(),
_PARAMETER_KEEP_START_ROUNDS.new(),
_PARAMETER_KEEP_END_ROUNDS.new(),
_PARAMETER_MAX_TOKEN_LIMIT.new(),
],
inputs=[
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
_INPUTS_EXTRA_CONTEXT.new(),
],
outputs=[
_OUTPUTS_STREAMING_MODEL_OUTPUT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)

210
dbgpt/app/operators/rag.py Normal file
View File

@@ -0,0 +1,210 @@
from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
FunctionDynamicOptions,
IOField,
OperatorCategory,
OptionValue,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util.i18n_utils import _
from .llm import HOContextBody
CFG = Config()
def _load_space_name() -> List[OptionValue]:
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity())
return [
OptionValue(label=space.name, name=space.name, value=space.name)
for space in spaces
]
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
_("Context Key"),
"context",
type=str,
optional=True,
default="context",
description=_("The key of the context, it will be used in building the prompt"),
)
_PARAMETER_TOP_K = Parameter.build_from(
_("Top K"),
"top_k",
type=int,
optional=True,
default=5,
description=_("The number of chunks to retrieve"),
)
_PARAMETER_SCORE_THRESHOLD = Parameter.build_from(
_("Minimum Match Score"),
"score_threshold",
type=float,
optional=True,
default=0.3,
description=_(
_(
"The minimum match score for the retrieved chunks, it will be dropped if "
"the match score is less than the threshold"
)
),
ui=ui.UISlider(attr=ui.UISlider.UIAttribute(min=0.0, max=1.0, step=0.1)),
)
_PARAMETER_RE_RANKER_ENABLED = Parameter.build_from(
_("Reranker Enabled"),
"reranker_enabled",
type=bool,
optional=True,
default=None,
description=_("Whether to enable the reranker"),
)
_PARAMETER_RE_RANKER_TOP_K = Parameter.build_from(
_("Reranker Top K"),
"reranker_top_k",
type=int,
optional=True,
default=3,
description=_("The top k for the reranker"),
)
_INPUTS_QUESTION = IOField.build_from(
_("User question"),
"query",
str,
description=_("The user question to retrieve the knowledge"),
)
_OUTPUTS_CONTEXT = IOField.build_from(
_("Retrieved context"),
"context",
HOContextBody,
description=_("The retrieved context from the knowledge space"),
)
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
_share_data_key = "_higher_order_knowledge_operator_retriever_chunks"
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
async def map(self, context: HOContextBody) -> List[Chunk]:
chunks = await self.current_dag_context.get_from_share_data(
HOKnowledgeOperator._share_data_key
)
return chunks
metadata = ViewMetadata(
label=_("Knowledge Operator"),
name="higher_order_knowledge_operator",
category=OperatorCategory.RAG,
description=_(
_(
"Knowledge Operator, retrieve your knowledge(documents) from knowledge"
" space"
)
),
parameters=[
Parameter.build_from(
_("Knowledge Space Name"),
"knowledge_space",
type=str,
options=FunctionDynamicOptions(func=_load_space_name),
description=_("The name of the knowledge space"),
),
_PARAMETER_CONTEXT_KEY.new(),
_PARAMETER_TOP_K.new(),
_PARAMETER_SCORE_THRESHOLD.new(),
_PARAMETER_RE_RANKER_ENABLED.new(),
_PARAMETER_RE_RANKER_TOP_K.new(),
],
inputs=[
_INPUTS_QUESTION.new(),
],
outputs=[
_OUTPUTS_CONTEXT.new(),
IOField.build_from(
_("Chunks"),
"chunks",
Chunk,
is_list=True,
description=_("The retrieved chunks from the knowledge space"),
mappers=[ChunkMapper],
),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(
self,
knowledge_space: str,
context_key: Optional[str] = "context",
top_k: Optional[int] = None,
score_threshold: Optional[float] = None,
reranker_enabled: Optional[bool] = None,
reranker_top_k: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self._knowledge_space = knowledge_space
self._context_key = context_key
self._top_k = top_k
self._score_threshold = score_threshold
self._reranker_enabled = reranker_enabled
self._reranker_top_k = reranker_top_k
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.serve.rag.models.models import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
spaces = KnowledgeSpaceDao().get_knowledge_space(
KnowledgeSpaceEntity(name=knowledge_space)
)
if len(spaces) != 1:
raise Exception(f"invalid space name: {knowledge_space}")
space = spaces[0]
reranker: Optional[RerankEmbeddingsRanker] = None
if CFG.RERANK_MODEL and self._reranker_enabled:
reranker_top_k = (
self._reranker_top_k
if self._reranker_top_k is not None
else CFG.RERANK_TOP_K
)
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=reranker_top_k)
if self._top_k < reranker_top_k or self._top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
self._top_k = max(reranker_top_k, 20)
self._space_retriever = KnowledgeSpaceRetriever(
space_id=space.id,
top_k=self._top_k,
rerank=reranker,
)
async def map(self, query: str) -> HOContextBody:
chunks = await self._space_retriever.aretrieve_with_scores(
query, self._score_threshold
)
await self.current_dag_context.save_to_share_data(self._share_data_key, chunks)
return HOContextBody(
context_key=self._context_key,
context=[chunk.content for chunk in chunks],
)