chore: Merge latest code

This commit is contained in:
Fangyin Cheng 2024-08-30 15:00:14 +08:00
parent 471689ba20
commit c67b50052d
26 changed files with 1643 additions and 68 deletions

View File

@ -60,6 +60,7 @@ def initialize_components(
_initialize_openapi(system_app) _initialize_openapi(system_app)
# Register serve apps # Register serve apps
register_serve_apps(system_app, CFG, param.port) register_serve_apps(system_app, CFG, param.port)
_initialize_operators()
def _initialize_model_cache(system_app: SystemApp, port: int): def _initialize_model_cache(system_app: SystemApp, port: int):
@ -128,3 +129,14 @@ def _initialize_openapi(system_app: SystemApp):
from dbgpt.app.openapi.api_v1.editor.service import EditorService from dbgpt.app.openapi.api_v1.editor.service import EditorService
system_app.register(EditorService) system_app.register(EditorService)
def _initialize_operators():
from dbgpt.app.operators.converter import StringToInteger
from dbgpt.app.operators.datasource import (
HODatasourceExecutorOperator,
HODatasourceRetrieverOperator,
)
from dbgpt.app.operators.llm import HOLLMOperator, HOStreamingLLMOperator
from dbgpt.app.operators.rag import HOKnowledgeOperator
from dbgpt.serve.agent.resource.datasource import DatasourceResource

View File

@ -0,0 +1,4 @@
"""Operators package.
This package contains all higher-order operators that are used to build workflows.
"""

View File

@ -0,0 +1,186 @@
"""Type Converter Operators."""
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
Parameter,
ViewMetadata,
)
from dbgpt.util.i18n_utils import _
_INPUTS_STRING = IOField.build_from(
_("String"),
"string",
str,
description=_("The string to be converted to other types."),
)
_INPUTS_INTEGER = IOField.build_from(
_("Integer"),
"integer",
int,
description=_("The integer to be converted to other types."),
)
_INPUTS_FLOAT = IOField.build_from(
_("Float"),
"float",
float,
description=_("The float to be converted to other types."),
)
_INPUTS_BOOLEAN = IOField.build_from(
_("Boolean"),
"boolean",
bool,
description=_("The boolean to be converted to other types."),
)
_OUTPUTS_STRING = IOField.build_from(
_("String"),
"string",
str,
description=_("The string converted from other types."),
)
_OUTPUTS_INTEGER = IOField.build_from(
_("Integer"),
"integer",
int,
description=_("The integer converted from other types."),
)
_OUTPUTS_FLOAT = IOField.build_from(
_("Float"),
"float",
float,
description=_("The float converted from other types."),
)
_OUTPUTS_BOOLEAN = IOField.build_from(
_("Boolean"),
"boolean",
bool,
description=_("The boolean converted from other types."),
)
class StringToInteger(MapOperator[str, int]):
"""Converts a string to an integer."""
metadata = ViewMetadata(
label=_("String to Integer"),
name="default_converter_string_to_integer",
description=_("Converts a string to an integer."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_STRING],
outputs=[_OUTPUTS_INTEGER],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new StringToInteger operator."""
super().__init__(map_function=lambda x: int(x), **kwargs)
class StringToFloat(MapOperator[str, float]):
"""Converts a string to a float."""
metadata = ViewMetadata(
label=_("String to Float"),
name="default_converter_string_to_float",
description=_("Converts a string to a float."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_STRING],
outputs=[_OUTPUTS_FLOAT],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new StringToFloat operator."""
super().__init__(map_function=lambda x: float(x), **kwargs)
class StringToBoolean(MapOperator[str, bool]):
"""Converts a string to a boolean."""
metadata = ViewMetadata(
label=_("String to Boolean"),
name="default_converter_string_to_boolean",
description=_("Converts a string to a boolean, true: 'true', '1', 'y'"),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[
Parameter.build_from(
_("True Values"),
"true_values",
str,
optional=True,
default="true,1,y",
description=_("Comma-separated values that should be treated as True."),
)
],
inputs=[_INPUTS_STRING],
outputs=[_OUTPUTS_BOOLEAN],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, true_values: str = "true,1,y", **kwargs):
"""Create a new StringToBoolean operator."""
true_values_list = true_values.split(",")
true_values_list = [x.strip().lower() for x in true_values_list]
super().__init__(map_function=lambda x: x.lower() in true_values_list, **kwargs)
class IntegerToString(MapOperator[int, str]):
"""Converts an integer to a string."""
metadata = ViewMetadata(
label=_("Integer to String"),
name="default_converter_integer_to_string",
description=_("Converts an integer to a string."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_INTEGER],
outputs=[_OUTPUTS_STRING],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new IntegerToString operator."""
super().__init__(map_function=lambda x: str(x), **kwargs)
class FloatToString(MapOperator[float, str]):
"""Converts a float to a string."""
metadata = ViewMetadata(
label=_("Float to String"),
name="default_converter_float_to_string",
description=_("Converts a float to a string."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_FLOAT],
outputs=[_OUTPUTS_STRING],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new FloatToString operator."""
super().__init__(map_function=lambda x: str(x), **kwargs)
class BooleanToString(MapOperator[bool, str]):
"""Converts a boolean to a string."""
metadata = ViewMetadata(
label=_("Boolean to String"),
name="default_converter_boolean_to_string",
description=_("Converts a boolean to a string."),
category=OperatorCategory.TYPE_CONVERTER,
parameters=[],
inputs=[_INPUTS_BOOLEAN],
outputs=[_OUTPUTS_STRING],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
"""Create a new BooleanToString operator."""
super().__init__(map_function=lambda x: str(x), **kwargs)

View File

@ -0,0 +1,336 @@
import json
import logging
from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBResource
from dbgpt.core.awel import DAGContext, MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.core.operators import BaseLLM
from dbgpt.util.i18n_utils import _
from dbgpt.vis.tags.vis_chart import default_chart_type_prompt
from .llm import HOContextBody
logger = logging.getLogger(__name__)
CFG = Config()
_DEFAULT_CHART_TYPE = default_chart_type_prompt()
_DEFAULT_TEMPLATE_EN = """You are a database expert.
Please answer the user's question based on the database selected by the user and some \
of the available table structure definitions of the database.
Database name:
{db_name}
Table structure definition:
{table_info}
Constraint:
1.Please understand the user's intention based on the user's question, and use the \
given table structure definition to create a grammatically correct {dialect} sql. \
If sql is not required, answer the user's question directly..
2.Always limit the query to a maximum of {max_num_results} results unless the user \
specifies in the question the specific number of rows of data he wishes to obtain.
3.You can only use the tables provided in the table structure information to \
generate sql. If you cannot generate sql based on the provided table structure, \
please say: "The table structure information provided is not enough to generate \
sql queries." It is prohibited to fabricate information at will.
4.Please be careful not to mistake the relationship between tables and columns \
when generating SQL.
5.Please check the correctness of the SQL and ensure that the query performance is \
optimized under correct conditions.
6.Please choose the best one from the display methods given below for data \
rendering, and put the type name into the name parameter value that returns the \
required format. If you cannot find the most suitable one, use 'Table' as the \
display method. , the available data display methods are as follows: {display_type}
User Question:
{user_input}
Please think step by step and respond according to the following JSON format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads.
"""
_DEFAULT_TEMPLATE_ZH = """你是一个数据库专家.
请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题.
数据库名:
{db_name}
表结构定义:
{table_info}
约束:
1. 请根据用户问题理解用户意图使用给出表结构定义创建一个语法正确的 {dialect} sql如果不需要 \
sql则直接回答用户问题
2. 除非用户在问题中指定了他希望获得的具体数据行数否则始终将查询限制为最多 {max_num_results} \
个结果
3. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说\
提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
4. 请注意生成SQL时不要弄错表和列的关系
5. 请检查SQL的正确性并保证正确的情况下优化查询性能
6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染将类型名称放入返回要求格式的name参数值种\
如果找不到最合适的则使用'Table'作为展示方式可用数据展示方式如下: {display_type}
用户问题:
{user_input}
请一步步思考并按照以下JSON格式回复
{response}
确保返回正确的json并且可以被Python json.loads方法解析.
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
_DEFAULT_RESPONSE = json.dumps(
{
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
"display_type": "Data display method",
},
ensure_ascii=False,
indent=4,
)
_PARAMETER_DATASOURCE = Parameter.build_from(
_("Datasource"),
"datasource",
type=DBResource,
description=_("The datasource to retrieve the context"),
)
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
_("Prompt Template"),
"prompt_template",
type=str,
optional=True,
default=_DEFAULT_TEMPLATE,
description=_("The prompt template to build a database prompt"),
ui=ui.DefaultUITextArea(),
)
_PARAMETER_DISPLAY_TYPE = Parameter.build_from(
_("Display Type"),
"display_type",
type=str,
optional=True,
default=_DEFAULT_CHART_TYPE,
description=_("The display type for the data"),
ui=ui.DefaultUITextArea(),
)
_PARAMETER_MAX_NUM_RESULTS = Parameter.build_from(
_("Max Number of Results"),
"max_num_results",
type=int,
optional=True,
default=50,
description=_("The maximum number of results to return"),
)
_PARAMETER_RESPONSE_FORMAT = Parameter.build_from(
_("Response Format"),
"response_format",
type=str,
optional=True,
default=_DEFAULT_RESPONSE,
description=_("The response format, default is a JSON format"),
ui=ui.DefaultUITextArea(),
)
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
_("Context Key"),
"context_key",
type=str,
optional=True,
default="context",
description=_("The key of the context, it will be used in building the prompt"),
)
_INPUTS_QUESTION = IOField.build_from(
_("User question"),
"query",
str,
description=_("The user question to retrieve table schemas from the datasource"),
)
_OUTPUTS_CONTEXT = IOField.build_from(
_("Retrieved context"),
"context",
HOContextBody,
description=_("The retrieved context from the datasource"),
)
_INPUTS_SQL_DICT = IOField.build_from(
_("SQL dict"),
"sql_dict",
dict,
description=_("The SQL to be executed wrapped in a dictionary, generated by LLM"),
)
_OUTPUTS_SQL_RESULT = IOField.build_from(
_("SQL result"),
"sql_result",
str,
description=_("The result of the SQL execution"),
)
_INPUTS_SQL_DICT_LIST = IOField.build_from(
_("SQL dict list"),
"sql_dict_list",
dict,
description=_(
"The SQL list to be executed wrapped in a dictionary, generated by LLM"
),
is_list=True,
)
class GPTVisMixin:
async def save_view_message(self, dag_ctx: DAGContext, view: str):
"""Save the view message."""
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
"""Retrieve the table schemas from the datasource."""
metadata = ViewMetadata(
label=_("Datasource Retriever Operator"),
name="higher_order_datasource_retriever_operator",
description=_("Retrieve the table schemas from the datasource."),
category=OperatorCategory.DATABASE,
parameters=[
_PARAMETER_DATASOURCE.new(),
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_DISPLAY_TYPE.new(),
_PARAMETER_MAX_NUM_RESULTS.new(),
_PARAMETER_RESPONSE_FORMAT.new(),
_PARAMETER_CONTEXT_KEY.new(),
],
inputs=[_INPUTS_QUESTION.new()],
outputs=[_OUTPUTS_CONTEXT.new()],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(
self,
datasource: DBResource,
prompt_template: str = _DEFAULT_TEMPLATE,
display_type: str = _DEFAULT_CHART_TYPE,
max_num_results: int = 50,
response_format: str = _DEFAULT_RESPONSE,
context_key: Optional[str] = "context",
**kwargs,
):
"""Initialize the operator."""
super().__init__(**kwargs)
self._datasource = datasource
self._prompt_template = prompt_template
self._display_type = display_type
self._max_num_results = max_num_results
self._response_format = response_format
self._context_key = context_key
async def map(self, question: str) -> HOContextBody:
"""Retrieve the context from the datasource."""
db_name = self._datasource._db_name
dialect = self._datasource.dialect
schema_info = await self.blocking_func_to_async(
self._datasource.get_schema_link,
db=db_name,
question=question,
)
context = self._prompt_template.format(
db_name=db_name,
table_info=schema_info,
dialect=dialect,
max_num_results=self._max_num_results,
display_type=self._display_type,
user_input=question,
response=self._response_format,
)
return HOContextBody(
context_key=self._context_key,
context=context,
)
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
"""Execute the context from the datasource."""
metadata = ViewMetadata(
label=_("Datasource Executor Operator"),
name="higher_order_datasource_executor_operator",
description=_("Execute the context from the datasource."),
category=OperatorCategory.DATABASE,
parameters=[_PARAMETER_DATASOURCE.new()],
inputs=[_INPUTS_SQL_DICT.new()],
outputs=[_OUTPUTS_SQL_RESULT.new()],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, datasource: DBResource, **kwargs):
"""Initialize the operator."""
MapOperator.__init__(self, **kwargs)
self._datasource = datasource
async def map(self, sql_dict: dict) -> str:
"""Execute the context from the datasource."""
from dbgpt.vis.tags.vis_chart import VisChart
if not isinstance(sql_dict, dict):
raise ValueError(
"The input value of datasource executor should be a dictionary."
)
vis = VisChart()
sql = sql_dict.get("sql")
if not sql:
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
data_df = await self._datasource.query_to_df(sql)
view = await vis.display(chart=sql_dict, data_df=data_df)
await self.save_view_message(self.current_dag_context, view)
return view
class HODatasourceDashboardOperator(GPTVisMixin, MapOperator[dict, str]):
"""Execute the context from the datasource."""
metadata = ViewMetadata(
label=_("Datasource Dashboard Operator"),
name="higher_order_datasource_dashboard_operator",
description=_("Execute the context from the datasource."),
category=OperatorCategory.DATABASE,
parameters=[_PARAMETER_DATASOURCE.new()],
inputs=[_INPUTS_SQL_DICT_LIST.new()],
outputs=[_OUTPUTS_SQL_RESULT.new()],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, datasource: DBResource, **kwargs):
"""Initialize the operator."""
MapOperator.__init__(self, **kwargs)
self._datasource = datasource
async def map(self, sql_dict_list: List[dict]) -> str:
"""Execute the context from the datasource."""
from dbgpt.vis.tags.vis_dashboard import VisDashboard
if not isinstance(sql_dict_list, list):
raise ValueError(
"The input value of datasource executor should be a list of dictionaries."
)
vis = VisDashboard()
chart_params = []
for chart_item in sql_dict_list:
chart_dict = {k: v for k, v in chart_item.items()}
sql = chart_item.get("sql")
try:
data_df = await self._datasource.query_to_df(sql)
chart_dict["data"] = data_df
except Exception as e:
logger.warning(f"Sql execute failed{str(e)}")
chart_dict["err_msg"] = str(e)
chart_params.append(chart_dict)
view = await vis.display(charts=chart_params)
await self.save_view_message(self.current_dag_context, view)
return view

443
dbgpt/app/operators/llm.py Normal file
View File

@ -0,0 +1,443 @@
from typing import List, Literal, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
BaseMessage,
ChatPromptTemplate,
LLMClient,
ModelOutput,
ModelRequest,
StorageConversation,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
CommonLLMHttpRequestBody,
DAGContext,
DefaultInputContext,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
TaskOutput,
)
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OptionValue,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.core.interface.operators.message_operator import (
BaseConversationOperator,
BufferedConversationMapperOperator,
TokenBufferedConversationMapperOperator,
)
from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import root_tracer
class HOContextBody(BaseModel):
"""Higher-order context body."""
context_key: str = Field(
"context",
description=_("The context key can be used as the key for formatting prompt."),
)
context: Union[str, List[str]] = Field(
...,
description=_("The context."),
)
class BaseHOLLMOperator(
BaseConversationOperator,
JoinOperator[ModelRequest],
LLMOperator,
StreamingLLMOperator,
):
"""Higher-order model request builder operator."""
def __init__(
self,
prompt_template: ChatPromptTemplate,
model: str = None,
llm_client: Optional[LLMClient] = None,
history_merge_mode: Literal["none", "window", "token"] = "window",
user_message_key: str = "user_input",
history_key: Optional[str] = None,
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
max_token_limit: int = 2048,
**kwargs,
):
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
LLMOperator.__init__(self, llm_client=llm_client, **kwargs)
StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
# User must select a history merge mode
self._history_merge_mode = history_merge_mode
self._user_message_key = user_message_key
self._has_history = history_merge_mode != "none"
self._prompt_template = prompt_template
self._model = model
self._history_key = history_key
self._str_history = False
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
self._max_token_limit = max_token_limit
self._sub_compose_dag = self._build_conversation_composer_dag()
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
conv_serve = ConversationServe.get_instance(self.system_app)
self._storage = conv_serve.conv_storage
self._message_storage = conv_serve.message_storage
_: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx)
dag_ctx.current_task_context.set_task_input(
DefaultInputContext([dag_ctx.current_task_context])
)
if dag_ctx.streaming_call:
task_output = await StreamingLLMOperator._do_run(self, dag_ctx)
else:
task_output = await LLMOperator._do_run(self, dag_ctx)
return task_output
async def after_dag_end(self, event_loop_task_id: int):
model_output: Optional[
ModelOutput
] = await self.current_dag_context.get_from_share_data(
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT
)
model_output_view: Optional[
str
] = await self.current_dag_context.get_from_share_data(
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW
)
storage_conv = await self.get_storage_conversation()
end_current_round: bool = False
if model_output and storage_conv:
# Save model output message to storage
storage_conv.add_ai_message(model_output.text)
end_current_round = True
if model_output_view and storage_conv:
# Save model output view to storage
storage_conv.add_view_message(model_output_view)
end_current_round = True
if end_current_round:
# End current conversation round and flush to storage
storage_conv.end_current_round()
async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
dynamic_inputs = []
for arg in args:
if isinstance(arg, HOContextBody):
dynamic_inputs.append(arg)
# Load and store chat history, default use InMemoryStorage.
storage_conv, history_messages = await self.blocking_func_to_async(
self._build_storage, req
)
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
user_input = (
req.messages[-1] if isinstance(req.messages, list) else req.messages
)
prompt_dict = {
self._user_message_key: user_input,
}
for dynamic_input in dynamic_inputs:
if dynamic_input.context_key in prompt_dict:
raise ValueError(
f"Duplicate context key '{dynamic_input.context_key}' in upstream "
f"operators."
)
prompt_dict[dynamic_input.context_key] = dynamic_input.context
call_data = {
"messages": history_messages,
"prompt_dict": prompt_dict,
}
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
model_request = ModelRequest.build_request(
model=req.model,
messages=messages,
context=req.context,
temperature=req.temperature,
max_new_tokens=req.max_new_tokens,
span_id=root_tracer.get_current_span_id(),
echo=False,
)
if storage_conv:
# Start new round
storage_conv.start_new_round()
storage_conv.add_user_message(user_input)
return model_request
def _build_storage(
self, req: CommonLLMHttpRequestBody
) -> Tuple[StorageConversation, List[BaseMessage]]:
# Create a new storage conversation, this will load the conversation from
# storage, so we must do this async
storage_conv: StorageConversation = StorageConversation(
conv_uid=req.conv_uid,
chat_mode=req.chat_mode,
user_name=req.user_name,
sys_code=req.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
param_type="",
param_value=req.chat_param,
)
# Get history messages from storage
history_messages: List[BaseMessage] = storage_conv.get_history_message(
include_system_message=False
)
return storage_conv, history_messages
def _build_conversation_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# History transform task
if self._history_merge_mode == "token":
history_transform_task = TokenBufferedConversationMapperOperator(
model=self._model,
llm_client=self.llm_client,
max_token_limit=self._max_token_limit,
)
else:
history_transform_task = BufferedConversationMapperOperator(
keep_start_rounds=self._keep_start_rounds,
keep_end_rounds=self._keep_end_rounds,
)
if self._history_key:
history_key = self._history_key
else:
placeholders = self._prompt_template.get_placeholders()
if not placeholders or len(placeholders) != 1:
raise ValueError(
"The prompt template must have exactly one placeholder if "
"history_key is not provided."
)
history_key = placeholders[0]
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template,
history_key=history_key,
check_storage=False,
save_to_storage=False,
str_history=self._str_history,
)
# Build composer dag
(
input_task
>> MapOperator(lambda x: x["messages"])
>> history_transform_task
>> history_prompt_build_task
)
(
input_task
>> MapOperator(lambda x: x["prompt_dict"])
>> history_prompt_build_task
)
return composer_dag
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
_("Prompt Template"),
"prompt_template",
ChatPromptTemplate,
description=_("The prompt template for the conversation."),
)
_PARAMETER_MODEL = Parameter.build_from(
_("Model Name"),
"model",
str,
optional=True,
default=None,
description=_("The model name."),
)
_PARAMETER_LLM_CLIENT = Parameter.build_from(
_("LLM Client"),
"llm_client",
LLMClient,
optional=True,
default=None,
description=_(
"The LLM Client, how to connect to the LLM model, if not provided, it will use"
" the default client deployed by DB-GPT."
),
)
_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from(
_("History Message Merge Mode"),
"history_merge_mode",
str,
optional=True,
default="none",
options=[
OptionValue(label="No History", name="none", value="none"),
OptionValue(label="Message Window", name="window", value="window"),
OptionValue(label="Token Length", name="token", value="token"),
],
description=_(
"The history merge mode, supports 'none', 'window' and 'token'."
" 'none': no history merge, 'window': merge by conversation window, 'token': "
"merge by token length."
),
ui=ui.UISelect(),
)
_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from(
_("User Message Key"),
"user_message_key",
str,
optional=True,
default="user_input",
description=_(
"The key of the user message in your prompt, default is 'user_input'."
),
)
_PARAMETER_HISTORY_KEY = Parameter.build_from(
_("History Key"),
"history_key",
str,
optional=True,
default=None,
description=_(
"The chat history key, with chat history message pass to prompt template, "
"if not provided, it will parse the prompt template to get the key."
),
)
_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from(
_("Keep Start Rounds"),
"keep_start_rounds",
int,
optional=True,
default=None,
description=_("The start rounds to keep in the chat history."),
)
_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from(
_("Keep End Rounds"),
"keep_end_rounds",
int,
optional=True,
default=None,
description=_("The end rounds to keep in the chat history."),
)
_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from(
_("Max Token Limit"),
"max_token_limit",
int,
optional=True,
default=2048,
description=_("The max token limit to keep in the chat history."),
)
_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from(
_("Common LLM Request Body"),
"common_llm_request_body",
CommonLLMHttpRequestBody,
_("The common LLM request body."),
)
_INPUTS_EXTRA_CONTEXT = IOField.build_from(
_("Extra Context"),
"extra_context",
HOContextBody,
_(
"Extra context for building prompt(Knowledge context, database "
"schema, etc), you can add multiple context."
),
dynamic=True,
)
_OUTPUTS_MODEL_OUTPUT = IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output."),
)
_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from(
_("Streaming Model Output"),
"streaming_model_output",
ModelOutput,
is_list=True,
description=_("The streaming model output."),
)
class HOLLMOperator(BaseHOLLMOperator):
metadata = ViewMetadata(
label=_("LLM Operator"),
name="higher_order_llm_operator",
category=OperatorCategory.LLM,
description=_(
"High-level LLM operator, supports multi-round conversation "
"(conversation window, token length and no multi-round)."
),
parameters=[
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_MODEL.new(),
_PARAMETER_LLM_CLIENT.new(),
_PARAMETER_HISTORY_MERGE_MODE.new(),
_PARAMETER_USER_MESSAGE_KEY.new(),
_PARAMETER_HISTORY_KEY.new(),
_PARAMETER_KEEP_START_ROUNDS.new(),
_PARAMETER_KEEP_END_ROUNDS.new(),
_PARAMETER_MAX_TOKEN_LIMIT.new(),
],
inputs=[
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
_INPUTS_EXTRA_CONTEXT.new(),
],
outputs=[
_OUTPUTS_MODEL_OUTPUT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
class HOStreamingLLMOperator(BaseHOLLMOperator):
metadata = ViewMetadata(
label=_("Streaming LLM Operator"),
name="higher_order_streaming_llm_operator",
category=OperatorCategory.LLM,
description=_(
"High-level streaming LLM operator, supports multi-round conversation "
"(conversation window, token length and no multi-round)."
),
parameters=[
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_MODEL.new(),
_PARAMETER_LLM_CLIENT.new(),
_PARAMETER_HISTORY_MERGE_MODE.new(),
_PARAMETER_USER_MESSAGE_KEY.new(),
_PARAMETER_HISTORY_KEY.new(),
_PARAMETER_KEEP_START_ROUNDS.new(),
_PARAMETER_KEEP_END_ROUNDS.new(),
_PARAMETER_MAX_TOKEN_LIMIT.new(),
],
inputs=[
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
_INPUTS_EXTRA_CONTEXT.new(),
],
outputs=[
_OUTPUTS_STREAMING_MODEL_OUTPUT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)

191
dbgpt/app/operators/rag.py Normal file
View File

@ -0,0 +1,191 @@
from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
FunctionDynamicOptions,
IOField,
OperatorCategory,
OptionValue,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util.i18n_utils import _
from .llm import HOContextBody
CFG = Config()
def _load_space_name() -> List[OptionValue]:
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity())
return [
OptionValue(label=space.name, name=space.name, value=space.name)
for space in spaces
]
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
_("Context Key"),
"context",
type=str,
optional=True,
default="context",
description=_("The key of the context, it will be used in building the prompt"),
)
_PARAMETER_TOP_K = Parameter.build_from(
_("Top K"),
"top_k",
type=int,
optional=True,
default=5,
description=_("The number of chunks to retrieve"),
)
_PARAMETER_SCORE_THRESHOLD = Parameter.build_from(
_("Minimum Match Score"),
"score_threshold",
type=float,
optional=True,
default=0.3,
description=_(
_(
"The minimum match score for the retrieved chunks, it will be dropped if "
"the match score is less than the threshold"
)
),
ui=ui.UISlider(attr=ui.UISlider.UIAttribute(min=0.0, max=1.0, step=0.1)),
)
_PARAMETER_RE_RANKER_ENABLED = Parameter.build_from(
_("Reranker Enabled"),
"reranker_enabled",
type=bool,
optional=True,
default=None,
description=_("Whether to enable the reranker"),
)
_PARAMETER_RE_RANKER_TOP_K = Parameter.build_from(
_("Reranker Top K"),
"reranker_top_k",
type=int,
optional=True,
default=3,
description=_("The top k for the reranker"),
)
_INPUTS_QUESTION = IOField.build_from(
_("User question"),
"query",
str,
description=_("The user question to retrieve the knowledge"),
)
_OUTPUTS_CONTEXT = IOField.build_from(
_("Retrieved context"),
"context",
HOContextBody,
description=_("The retrieved context from the knowledge space"),
)
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
metadata = ViewMetadata(
label=_("Knowledge Operator"),
name="higher_order_knowledge_operator",
category=OperatorCategory.RAG,
description=_(
_(
"Knowledge Operator, retrieve your knowledge(documents) from knowledge"
" space"
)
),
parameters=[
Parameter.build_from(
_("Knowledge Space Name"),
"knowledge_space",
type=str,
options=FunctionDynamicOptions(func=_load_space_name),
description=_("The name of the knowledge space"),
),
_PARAMETER_CONTEXT_KEY.new(),
_PARAMETER_TOP_K.new(),
_PARAMETER_SCORE_THRESHOLD.new(),
_PARAMETER_RE_RANKER_ENABLED.new(),
_PARAMETER_RE_RANKER_TOP_K.new(),
],
inputs=[
_INPUTS_QUESTION.new(),
],
outputs=[
_OUTPUTS_CONTEXT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(
self,
knowledge_space: str,
context_key: Optional[str] = "context",
top_k: Optional[int] = None,
score_threshold: Optional[float] = None,
reranker_enabled: Optional[bool] = None,
reranker_top_k: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self._knowledge_space = knowledge_space
self._context_key = context_key
self._top_k = top_k
self._score_threshold = score_threshold
self._reranker_enabled = reranker_enabled
self._reranker_top_k = reranker_top_k
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.serve.rag.models.models import (
KnowledgeSpaceDao,
KnowledgeSpaceEntity,
)
spaces = KnowledgeSpaceDao().get_knowledge_space(
KnowledgeSpaceEntity(name=knowledge_space)
)
if len(spaces) != 1:
raise Exception(f"invalid space name: {knowledge_space}")
space = spaces[0]
reranker: Optional[RerankEmbeddingsRanker] = None
if CFG.RERANK_MODEL and self._reranker_enabled:
reranker_top_k = (
self._reranker_top_k
if self._reranker_top_k is not None
else CFG.RERANK_TOP_K
)
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=reranker_top_k)
if self._top_k < reranker_top_k or self._top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
self._top_k = max(reranker_top_k, 20)
self._space_retriever = KnowledgeSpaceRetriever(
space_id=space.id,
top_k=self._top_k,
rerank=reranker,
)
async def map(self, query: str) -> HOContextBody:
chunks = await self._space_retriever.aretrieve_with_scores(
query, self._score_threshold
)
return HOContextBody(
context_key=self._context_key,
context=[chunk.content for chunk in chunks],
)

View File

@ -10,6 +10,7 @@ from ..util.parameter_util import ( # noqa: F401
VariablesDynamicOptions, VariablesDynamicOptions,
) )
from .base import ( # noqa: F401 from .base import ( # noqa: F401
TAGS_ORDER_HIGH,
IOField, IOField,
OperatorCategory, OperatorCategory,
OperatorType, OperatorType,
@ -33,6 +34,7 @@ __ALL__ = [
"ResourceCategory", "ResourceCategory",
"ResourceType", "ResourceType",
"OperatorType", "OperatorType",
"TAGS_ORDER_HIGH",
"IOField", "IOField",
"BaseDynamicOptions", "BaseDynamicOptions",
"FunctionDynamicOptions", "FunctionDynamicOptions",

View File

@ -40,6 +40,9 @@ _BASIC_TYPES = [str, int, float, bool, dict, list, set]
T = TypeVar("T", bound="ViewMixin") T = TypeVar("T", bound="ViewMixin")
TM = TypeVar("TM", bound="TypeMetadata") TM = TypeVar("TM", bound="TypeMetadata")
TAGS_ORDER_HIGH = "higher-order"
TAGS_ORDER_FIRST = "first-order"
def _get_type_name(type_: Type[Any]) -> str: def _get_type_name(type_: Type[Any]) -> str:
"""Get the type name of the type. """Get the type name of the type.
@ -143,6 +146,8 @@ _OPERATOR_CATEGORY_DETAIL = {
"agent": _CategoryDetail("Agent", "The agent operator"), "agent": _CategoryDetail("Agent", "The agent operator"),
"rag": _CategoryDetail("RAG", "The RAG operator"), "rag": _CategoryDetail("RAG", "The RAG operator"),
"experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"), "experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"),
"database": _CategoryDetail("Database", "Interact with the database"),
"type_converter": _CategoryDetail("Type Converter", "Convert the type"),
"example": _CategoryDetail("Example", "Example operator"), "example": _CategoryDetail("Example", "Example operator"),
} }
@ -159,6 +164,8 @@ class OperatorCategory(str, Enum):
AGENT = "agent" AGENT = "agent"
RAG = "rag" RAG = "rag"
EXPERIMENTAL = "experimental" EXPERIMENTAL = "experimental"
DATABASE = "database"
TYPE_CONVERTER = "type_converter"
EXAMPLE = "example" EXAMPLE = "example"
def label(self) -> str: def label(self) -> str:
@ -202,6 +209,7 @@ _RESOURCE_CATEGORY_DETAIL = {
"embeddings": _CategoryDetail("Embeddings", "The embeddings resource"), "embeddings": _CategoryDetail("Embeddings", "The embeddings resource"),
"rag": _CategoryDetail("RAG", "The resource"), "rag": _CategoryDetail("RAG", "The resource"),
"vector_store": _CategoryDetail("Vector Store", "The vector store resource"), "vector_store": _CategoryDetail("Vector Store", "The vector store resource"),
"database": _CategoryDetail("Database", "Interact with the database"),
"example": _CategoryDetail("Example", "The example resource"), "example": _CategoryDetail("Example", "The example resource"),
} }
@ -219,6 +227,7 @@ class ResourceCategory(str, Enum):
EMBEDDINGS = "embeddings" EMBEDDINGS = "embeddings"
RAG = "rag" RAG = "rag"
VECTOR_STORE = "vector_store" VECTOR_STORE = "vector_store"
DATABASE = "database"
EXAMPLE = "example" EXAMPLE = "example"
def label(self) -> str: def label(self) -> str:
@ -372,32 +381,41 @@ class Parameter(TypeMetadata, Serializable):
"value": values.get("value"), "value": values.get("value"),
"default": values.get("default"), "default": values.get("default"),
} }
is_list = values.get("is_list") or False
if type_cls: if type_cls:
for k, v in to_handle_values.items(): for k, v in to_handle_values.items():
if v: if v:
handled_v = cls._covert_to_real_type(type_cls, v) handled_v = cls._covert_to_real_type(type_cls, v, is_list)
values[k] = handled_v values[k] = handled_v
return values return values
@classmethod @classmethod
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any: def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any:
if type_cls and v is not None: def _parse_single_value(vv: Any) -> Any:
typed_value: Any = v typed_value: Any = vv
try: try:
# Try to convert the value to the type. # Try to convert the value to the type.
if type_cls == "builtins.str": if type_cls == "builtins.str":
typed_value = str(v) typed_value = str(vv)
elif type_cls == "builtins.int": elif type_cls == "builtins.int":
typed_value = int(v) typed_value = int(vv)
elif type_cls == "builtins.float": elif type_cls == "builtins.float":
typed_value = float(v) typed_value = float(vv)
elif type_cls == "builtins.bool": elif type_cls == "builtins.bool":
if str(v).lower() in ["false", "0", "", "no", "off"]: if str(vv).lower() in ["false", "0", "", "no", "off"]:
return False return False
typed_value = bool(v) typed_value = bool(vv)
return typed_value return typed_value
except ValueError: except ValueError:
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}") raise ValidationError(f"Value '{vv}' is not valid for type {type_cls}")
if type_cls and v is not None:
if not is_list:
_parse_single_value(v)
else:
if not isinstance(v, list):
raise ValidationError(f"Value '{v}' is not a list.")
return [_parse_single_value(vv) for vv in v]
return v return v
def get_typed_value(self) -> Any: def get_typed_value(self) -> Any:
@ -413,11 +431,11 @@ class Parameter(TypeMetadata, Serializable):
if is_variables and self.value is not None and isinstance(self.value, str): if is_variables and self.value is not None and isinstance(self.value, str):
return VariablesPlaceHolder(self.name, self.value) return VariablesPlaceHolder(self.name, self.value)
else: else:
return self._covert_to_real_type(self.type_cls, self.value) return self._covert_to_real_type(self.type_cls, self.value, self.is_list)
def get_typed_default(self) -> Any: def get_typed_default(self) -> Any:
"""Get the typed default.""" """Get the typed default."""
return self._covert_to_real_type(self.type_cls, self.default) return self._covert_to_real_type(self.type_cls, self.default, self.is_list)
@classmethod @classmethod
def build_from( def build_from(
@ -499,7 +517,10 @@ class Parameter(TypeMetadata, Serializable):
values = self.options.option_values() values = self.options.option_values()
dict_value["options"] = [value.to_dict() for value in values] dict_value["options"] = [value.to_dict() for value in values]
else: else:
dict_value["options"] = [value.to_dict() for value in self.options] dict_value["options"] = [
value.to_dict() if not isinstance(value, dict) else value
for value in self.options
]
if self.ui: if self.ui:
dict_value["ui"] = self.ui.to_dict() dict_value["ui"] = self.ui.to_dict()
@ -594,6 +615,17 @@ class Parameter(TypeMetadata, Serializable):
value = view_value value = view_value
return {self.name: value} return {self.name: value}
def new(self: TM) -> TM:
"""Copy the metadata."""
new_obj = self.__class__(
**self.model_dump(exclude_defaults=True, exclude={"ui", "options"})
)
if self.ui:
new_obj.ui = self.ui
if self.options:
new_obj.options = self.options
return new_obj
class BaseResource(Serializable, BaseModel): class BaseResource(Serializable, BaseModel):
"""The base resource.""" """The base resource."""
@ -644,6 +676,17 @@ class IOField(Resource):
description="Whether current field is list", description="Whether current field is list",
examples=[True, False], examples=[True, False],
) )
dynamic: bool = Field(
default=False,
description="Whether current field is dynamic",
examples=[True, False],
)
dynamic_minimum: int = Field(
default=0,
description="The minimum count of the dynamic field, only valid when dynamic is"
" True",
examples=[0, 1, 2],
)
@classmethod @classmethod
def build_from( def build_from(
@ -653,6 +696,8 @@ class IOField(Resource):
type: Type, type: Type,
description: Optional[str] = None, description: Optional[str] = None,
is_list: bool = False, is_list: bool = False,
dynamic: bool = False,
dynamic_minimum: int = 0,
): ):
"""Build the resource from the type.""" """Build the resource from the type."""
type_name = type.__qualname__ type_name = type.__qualname__
@ -664,8 +709,22 @@ class IOField(Resource):
type_cls=type_cls, type_cls=type_cls,
is_list=is_list, is_list=is_list,
description=description or label, description=description or label,
dynamic=dynamic,
dynamic_minimum=dynamic_minimum,
) )
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "dynamic" not in values:
values["dynamic"] = False
if "dynamic_minimum" not in values:
values["dynamic_minimum"] = 0
return values
class BaseMetadata(BaseResource): class BaseMetadata(BaseResource):
"""The base metadata.""" """The base metadata."""
@ -808,9 +867,40 @@ class BaseMetadata(BaseResource):
split_ids = self.id.split("_") split_ids = self.id.split("_")
return "_".join(split_ids[:-1]) return "_".join(split_ids[:-1])
def _parse_ui_size(self) -> Optional[str]:
"""Parse the ui size."""
if not self.parameters:
return None
parameters_size = set()
for parameter in self.parameters:
if parameter.ui and parameter.ui.size:
parameters_size.add(parameter.ui.size)
for size in ["large", "middle", "small"]:
if size in parameters_size:
return size
return None
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""Convert current metadata to json dict.""" """Convert current metadata to json dict."""
from .ui import _size_to_order
dict_value = model_to_dict(self, exclude={"parameters"}) dict_value = model_to_dict(self, exclude={"parameters"})
tags = dict_value.get("tags")
if not tags:
tags = {"ui_version": "flow2.0"}
elif isinstance(tags, dict) and "ui_version" not in tags:
tags["ui_version"] = "flow2.0"
parsed_ui_size = self._parse_ui_size()
if parsed_ui_size:
exist_size = tags.get("ui_size")
if not exist_size or _size_to_order(parsed_ui_size) > _size_to_order(
exist_size
):
# Use the higher order size as current size.
tags["ui_size"] = parsed_ui_size
dict_value["tags"] = tags
dict_value["parameters"] = [ dict_value["parameters"] = [
parameter.to_dict() for parameter in self.parameters parameter.to_dict() for parameter in self.parameters
] ]

View File

@ -97,6 +97,12 @@ class FlowNodeData(BaseModel):
return ResourceMetadata(**value) return ResourceMetadata(**value)
raise ValueError("Unable to infer the type for `data`") raise ValueError("Unable to infer the type for `data`")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
dict_value = model_to_dict(self, exclude={"data"})
dict_value["data"] = self.data.to_dict()
return dict_value
class FlowEdgeData(BaseModel): class FlowEdgeData(BaseModel):
"""Edge data in a flow.""" """Edge data in a flow."""
@ -166,6 +172,12 @@ class FlowData(BaseModel):
edges: List[FlowEdgeData] = Field(..., description="Edges in the flow") edges: List[FlowEdgeData] = Field(..., description="Edges in the flow")
viewport: FlowPositionData = Field(..., description="Viewport of the flow") viewport: FlowPositionData = Field(..., description="Viewport of the flow")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
dict_value = model_to_dict(self, exclude={"nodes"})
dict_value["nodes"] = [n.to_dict() for n in self.nodes]
return dict_value
class _VariablesRequestBase(BaseModel): class _VariablesRequestBase(BaseModel):
key: str = Field( key: str = Field(
@ -518,9 +530,24 @@ class FlowPanel(BaseModel):
values["name"] = name values["name"] = name
return values return values
def model_dump(self, **kwargs):
"""Override the model dump method."""
exclude = kwargs.get("exclude", set())
if "flow_dag" not in exclude:
exclude.add("flow_dag")
if "flow_data" not in exclude:
exclude.add("flow_data")
kwargs["exclude"] = exclude
common_dict = super().model_dump(**kwargs)
if self.flow_dag:
common_dict["flow_dag"] = None
if self.flow_data:
common_dict["flow_data"] = self.flow_data.to_dict()
return common_dict
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""Convert to dict.""" """Convert to dict."""
return model_to_dict(self, exclude={"flow_dag"}) return model_to_dict(self, exclude={"flow_dag", "flow_data"})
def get_variables_dict(self) -> List[Dict[str, Any]]: def get_variables_dict(self) -> List[Dict[str, Any]]:
"""Get the variables dict.""" """Get the variables dict."""
@ -568,6 +595,11 @@ class FlowFactory:
key_to_resource_nodes[key] = node key_to_resource_nodes[key] = node
key_to_resource[key] = node.data key_to_resource[key] = node.data
if not key_to_operator_nodes and not key_to_resource_nodes:
raise FlowMetadataException(
"No operator or resource nodes found in the flow."
)
for edge in flow_data.edges: for edge in flow_data.edges:
source_key = edge.source source_key = edge.source
target_key = edge.target target_key = edge.target
@ -943,11 +975,17 @@ def fill_flow_panel(flow_panel: FlowPanel):
new_param = input_parameters[i.name] new_param = input_parameters[i.name]
i.label = new_param.label i.label = new_param.label
i.description = new_param.description i.description = new_param.description
i.dynamic = new_param.dynamic
i.is_list = new_param.is_list
i.dynamic_minimum = new_param.dynamic_minimum
for i in node.data.outputs: for i in node.data.outputs:
if i.name in output_parameters: if i.name in output_parameters:
new_param = output_parameters[i.name] new_param = output_parameters[i.name]
i.label = new_param.label i.label = new_param.label
i.description = new_param.description i.description = new_param.description
i.dynamic = new_param.dynamic
i.is_list = new_param.is_list
i.dynamic_minimum = new_param.dynamic_minimum
else: else:
data = cast(ResourceMetadata, node.data) data = cast(ResourceMetadata, node.data)
key = data.get_origin_id() key = data.get_origin_id()
@ -972,6 +1010,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
param.options = new_param.get_dict_options() # type: ignore param.options = new_param.get_dict_options() # type: ignore
param.default = new_param.default param.default = new_param.default
param.placeholder = new_param.placeholder param.placeholder = new_param.placeholder
param.alias = new_param.alias
param.ui = new_param.ui
except (FlowException, ValueError) as e: except (FlowException, ValueError) as e:
logger.warning(f"Unable to fill the flow panel: {e}") logger.warning(f"Unable to fill the flow panel: {e}")

View File

@ -2,7 +2,7 @@
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict from dbgpt._private.pydantic import BaseModel, Field, model_to_dict, model_validator
from dbgpt.core.interface.serialization import Serializable from dbgpt.core.interface.serialization import Serializable
from .exceptions import FlowUIComponentException from .exceptions import FlowUIComponentException
@ -25,6 +25,16 @@ _UI_TYPE = Literal[
"code_editor", "code_editor",
] ]
_UI_SIZE_TYPE = Literal["large", "middle", "small"]
_SIZE_ORDER = {"large": 6, "middle": 4, "small": 2}
def _size_to_order(size: str) -> int:
"""Convert size to order."""
if size not in _SIZE_ORDER:
return -1
return _SIZE_ORDER[size]
class RefreshableMixin(BaseModel): class RefreshableMixin(BaseModel):
"""Refreshable mixin.""" """Refreshable mixin."""
@ -81,6 +91,10 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel):
) )
ui_type: _UI_TYPE = Field(..., description="UI component type") ui_type: _UI_TYPE = Field(..., description="UI component type")
size: Optional[_UI_SIZE_TYPE] = Field(
None,
description="The size of the component(small, middle, large)",
)
attr: Optional[UIAttribute] = Field( attr: Optional[UIAttribute] = Field(
None, None,
@ -266,6 +280,27 @@ class UITextArea(PanelEditorMixin, UIInput):
description="The attributes of the component", description="The attributes of the component",
) )
@model_validator(mode="after")
def check_size(self) -> "UITextArea":
"""Check the size.
Automatically set the size to large if the max_rows is greater than 10.
"""
attr = self.attr
auto_size = attr.auto_size if attr else None
if not attr or not auto_size or isinstance(auto_size, bool):
return self
max_rows = (
auto_size.max_rows
if isinstance(auto_size, self.UIAttribute.AutoSize)
else None
)
size = self.size
if not size and max_rows and max_rows > 10:
# Automatically set the size to large if the max_rows is greater than 10
self.size = "large"
return self
class UIAutoComplete(UIInput): class UIAutoComplete(UIInput):
"""Auto complete component.""" """Auto complete component."""
@ -450,7 +485,7 @@ class DefaultUITextArea(UITextArea):
attr: Optional[UITextArea.UIAttribute] = Field( attr: Optional[UITextArea.UIAttribute] = Field(
default_factory=lambda: UITextArea.UIAttribute( default_factory=lambda: UITextArea.UIAttribute(
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40) auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=20)
), ),
description="The attributes of the component", description="The attributes of the component",
) )

View File

@ -29,6 +29,7 @@ from dbgpt.util.tracer import root_tracer
from ..dag.base import DAG from ..dag.base import DAG
from ..flow import ( from ..flow import (
TAGS_ORDER_HIGH,
IOField, IOField,
OperatorCategory, OperatorCategory,
OperatorType, OperatorType,
@ -965,6 +966,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
_PARAMETER_MEDIA_TYPE.new(), _PARAMETER_MEDIA_TYPE.new(),
_PARAMETER_STATUS_CODE.new(), _PARAMETER_STATUS_CODE.new(),
], ],
tags={"order": TAGS_ORDER_HIGH},
) )
def __init__( def __init__(
@ -1203,6 +1205,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
"User input parsed operator, parse the user input from request body and " "User input parsed operator, parse the user input from request body and "
"return as a string" "return as a string"
), ),
tags={"order": TAGS_ORDER_HIGH},
) )
def __init__(self, key: str = "user_input", **kwargs): def __init__(self, key: str = "user_input", **kwargs):

View File

@ -195,6 +195,9 @@ class ModelRequest:
temperature: Optional[float] = None temperature: Optional[float] = None
"""The temperature of the model inference.""" """The temperature of the model inference."""
top_p: Optional[float] = None
"""The top p of the model inference."""
max_new_tokens: Optional[int] = None max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate.""" """The maximum number of tokens to generate."""

View File

@ -317,6 +317,25 @@ class ModelMessage(BaseModel):
""" """
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix) return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
@staticmethod
def parse_user_message(messages: List[ModelMessage]) -> str:
"""Parse user message from messages.
Args:
messages (List[ModelMessage]): The all messages in the conversation.
Returns:
str: The user message
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
return lass_user_message
_SingleRoundMessage = List[BaseMessage] _SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]] _MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
content=ai_message.content, content=ai_message.content,
index=ai_message.index, index=ai_message.index,
round_index=ai_message.round_index, round_index=ai_message.round_index,
additional_kwargs=ai_message.additional_kwargs.copy() additional_kwargs=(
if ai_message.additional_kwargs ai_message.additional_kwargs.copy()
else {}, if ai_message.additional_kwargs
else {}
),
) )
current_round.append(view_message) current_round.append(view_message)
return sum(messages_by_round, []) return sum(messages_by_round, [])

View File

@ -246,10 +246,16 @@ class BaseLLM:
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name" SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output" SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view"
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
):
"""Create a new LLM operator.""" """Create a new LLM operator."""
self._llm_client = llm_client self._llm_client = llm_client
self._save_model_output = save_model_output
@property @property
def llm_client(self) -> LLMClient: def llm_client(self) -> LLMClient:
@ -262,9 +268,10 @@ class BaseLLM:
self, current_dag_context: DAGContext, model_output: ModelOutput self, current_dag_context: DAGContext, model_output: ModelOutput
) -> None: ) -> None:
"""Save the model output to the share data.""" """Save the model output to the share data."""
await current_dag_context.save_to_share_data( if self._save_model_output:
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output await current_dag_context.save_to_share_data(
) self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC): class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
This operator will generate a no streaming response. This operator will generate a no streaming response.
""" """
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
"""Create a new LLM operator.""" """Create a new LLM operator."""
super().__init__(llm_client=llm_client) super().__init__(llm_client=llm_client, save_model_output=save_model_output)
MapOperator.__init__(self, **kwargs) MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput: async def map(self, request: ModelRequest) -> ModelOutput:
@ -309,13 +321,18 @@ class BaseStreamingLLMOperator(
This operator will generate streaming response. This operator will generate streaming response.
""" """
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
"""Create a streaming operator for a LLM. """Create a streaming operator for a LLM.
Args: Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None. llm_client (LLMClient, optional): The LLM client. Defaults to None.
""" """
super().__init__(llm_client=llm_client) super().__init__(llm_client=llm_client, save_model_output=save_model_output)
BaseOperator.__init__(self, **kwargs) BaseOperator.__init__(self, **kwargs)
async def streamify( # type: ignore async def streamify( # type: ignore

View File

@ -4,14 +4,10 @@ from abc import ABC
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import model_validator from dbgpt._private.pydantic import model_validator
from dbgpt.core import ( from dbgpt.core import ModelMessage, ModelOutput, StorageConversation
ModelMessage,
ModelMessageRoleType,
ModelOutput,
StorageConversation,
)
from dbgpt.core.awel import JoinOperator, MapOperator from dbgpt.core.awel import JoinOperator, MapOperator
from dbgpt.core.awel.flow import ( from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField, IOField,
OperatorCategory, OperatorCategory,
OperatorType, OperatorType,
@ -42,6 +38,7 @@ from dbgpt.util.i18n_utils import _
name="common_chat_prompt_template", name="common_chat_prompt_template",
category=ResourceCategory.PROMPT, category=ResourceCategory.PROMPT,
description=_("The operator to build the prompt with static prompt."), description=_("The operator to build the prompt with static prompt."),
tags={"order": TAGS_ORDER_HIGH},
parameters=[ parameters=[
Parameter.build_from( Parameter.build_from(
label=_("System Message"), label=_("System Message"),
@ -101,9 +98,10 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
class BasePromptBuilderOperator(BaseConversationOperator, ABC): class BasePromptBuilderOperator(BaseConversationOperator, ABC):
"""The base prompt builder operator.""" """The base prompt builder operator."""
def __init__(self, check_storage: bool, **kwargs): def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs):
"""Create a new prompt builder operator.""" """Create a new prompt builder operator."""
super().__init__(check_storage=check_storage, **kwargs) super().__init__(check_storage=check_storage, **kwargs)
self._save_to_storage = save_to_storage
async def format_prompt( async def format_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any] self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
@ -122,8 +120,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables} pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
messages = prompt.format_messages(**pass_kwargs) messages = prompt.format_messages(**pass_kwargs)
model_messages = ModelMessage.from_base_messages(messages) model_messages = ModelMessage.from_base_messages(messages)
# Start new round conversation, and save user message to storage if self._save_to_storage:
await self.start_new_round_conv(model_messages) # Start new round conversation, and save user message to storage
await self.start_new_round_conv(model_messages)
return model_messages return model_messages
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None: async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
@ -132,13 +131,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
Args: Args:
messages (List[ModelMessage]): The messages. messages (List[ModelMessage]): The messages.
""" """
lass_user_message = None lass_user_message = ModelMessage.parse_user_message(messages)
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
storage_conv: Optional[ storage_conv: Optional[
StorageConversation StorageConversation
] = await self.get_storage_conversation() ] = await self.get_storage_conversation()
@ -150,6 +143,8 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
async def after_dag_end(self, event_loop_task_id: int): async def after_dag_end(self, event_loop_task_id: int):
"""Execute after the DAG finished.""" """Execute after the DAG finished."""
if not self._save_to_storage:
return
# Save the storage conversation to storage after the whole DAG finished # Save the storage conversation to storage after the whole DAG finished
storage_conv: Optional[ storage_conv: Optional[
StorageConversation StorageConversation
@ -422,7 +417,7 @@ class HistoryPromptBuilderOperator(
self._prompt = prompt self._prompt = prompt
self._history_key = history_key self._history_key = history_key
self._str_history = str_history self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage) BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type @rearrange_args_by_type
@ -455,7 +450,7 @@ class HistoryDynamicPromptBuilderOperator(
"""Create a new history dynamic prompt builder operator.""" """Create a new history dynamic prompt builder operator."""
self._history_key = history_key self._history_key = history_key
self._str_history = str_history self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage) BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type @rearrange_args_by_type

View File

@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
from dbgpt.core import ModelOutput from dbgpt.core import ModelOutput
from dbgpt.core.awel import MapOperator from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
ViewMetadata,
)
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
T = TypeVar("T") T = TypeVar("T")
@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
if self.current_dag_context.streaming_call: if self.current_dag_context.streaming_call:
return self.parse_model_stream_resp_ex(input_value, 0) return self.parse_model_stream_resp_ex(input_value, 0)
else: else:
return self.parse_model_nostream_resp(input_value, "###") return self.parse_model_nostream_resp(input_value, "#####################")
def _parse_model_response(response: ResponseTye): def _parse_model_response(response: ResponseTye):
@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
class SQLOutputParser(BaseOutputParser): class SQLOutputParser(BaseOutputParser):
"""Parse the SQL output of an LLM call.""" """Parse the SQL output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL Output Parser"),
name="default_sql_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_("Parse the SQL output of an LLM call."),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("Dict SQL Output"),
"dict",
dict,
description=_("The dict output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs): def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL output parser.""" """Create a new SQL output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs) super().__init__(is_stream_out=is_stream_out, **kwargs)
@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
model_out_text = super().parse_model_nostream_resp(response, sep) model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text) clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True) return json.loads(clean_str, strict=True)
class SQLListOutputParser(BaseOutputParser):
"""Parse the SQL list output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL List Output Parser"),
name="default_sql_list_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_(
"Parse the SQL list output of an LLM call, mostly used for dashboard."
),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("List SQL Output"),
"list",
dict,
is_list=True,
description=_("The list output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL list output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
from dbgpt.util.json_utils import find_json_objects
model_out_text = super().parse_model_nostream_resp(response, sep)
json_objects = find_json_objects(model_out_text)
json_count = len(json_objects)
if json_count < 1:
raise ValueError("Unable to obtain valid output.")
parsed_json_list = json_objects[0]
if not isinstance(parsed_json_list, list):
if isinstance(parsed_json_list, dict):
return [parsed_json_list]
else:
raise ValueError("Invalid output format.")
return parsed_json_list

View File

@ -254,6 +254,18 @@ class ChatPromptTemplate(BasePromptTemplate):
values["input_variables"] = sorted(input_variables) values["input_variables"] = sorted(input_variables)
return values return values
def get_placeholders(self) -> List[str]:
"""Get all placeholders in the prompt template.
Returns:
List[str]: The placeholders.
"""
placeholders = set()
for message in self.messages:
if isinstance(message, MessagesPlaceholder):
placeholders.add(message.variable_name)
return sorted(placeholders)
@dataclasses.dataclass @dataclasses.dataclass
class PromptTemplateIdentifier(ResourceIdentifier): class PromptTemplateIdentifier(ResourceIdentifier):

View File

@ -42,13 +42,13 @@ class DefaultLLMClient(LLMClient):
Args: Args:
worker_manager (WorkerManager): worker manager instance. worker_manager (WorkerManager): worker manager instance.
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False. auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True.
""" """
def __init__( def __init__(
self, self,
worker_manager: Optional[WorkerManager] = None, worker_manager: Optional[WorkerManager] = None,
auto_convert_message: bool = False, auto_convert_message: bool = True,
): ):
self._worker_manager = worker_manager self._worker_manager = worker_manager
self._auto_covert_message = auto_convert_message self._auto_covert_message = auto_convert_message

View File

@ -24,8 +24,13 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
This class extends BaseOperator by adding LLM capabilities. This class extends BaseOperator by adding LLM capabilities.
""" """
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs): def __init__(
super().__init__(default_client) self,
default_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
super().__init__(default_client, save_model_output=save_model_output)
@property @property
def llm_client(self) -> LLMClient: def llm_client(self) -> LLMClient:
@ -95,8 +100,13 @@ class LLMOperator(MixinLLMOperator, BaseLLMOperator):
], ],
) )
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): def __init__(
super().__init__(llm_client) self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
super().__init__(llm_client, save_model_output=save_model_output)
BaseLLMOperator.__init__(self, llm_client, **kwargs) BaseLLMOperator.__init__(self, llm_client, **kwargs)
@ -144,6 +154,11 @@ class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator):
], ],
) )
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs): def __init__(
super().__init__(llm_client) self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
super().__init__(llm_client, save_model_output=save_model_output)
BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs) BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs)

View File

@ -16,7 +16,13 @@ from typing import (
from dbgpt._private.pydantic import model_to_json from dbgpt._private.pydantic import model_to_json
from dbgpt.core.awel import TransformStreamAbsOperator from dbgpt.core.awel import TransformStreamAbsOperator
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
ViewMetadata,
)
from dbgpt.core.interface.llm import ModelOutput from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.operators import BaseLLM from dbgpt.core.operators import BaseLLM
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
@ -184,6 +190,7 @@ class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
), ),
) )
], ],
tags={"order": TAGS_ORDER_HIGH},
) )
async def transform_stream(self, model_output: AsyncIterator[ModelOutput]): async def transform_stream(self, model_output: AsyncIterator[ModelOutput]):

View File

@ -2,6 +2,7 @@
import logging import logging
import traceback import traceback
from typing import List
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
@ -46,7 +47,7 @@ class DBSummaryClient:
logger.info("db summary embedding success") logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk): def get_db_summary(self, dbname, query, topk) -> List[str]:
"""Get user query related tables info.""" """Get user query related tables info."""
from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig

View File

@ -3,14 +3,41 @@ import logging
from typing import Any, List, Optional, Type, Union, cast from typing import Any, List, Optional, Type, Union, cast
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource from dbgpt.agent.resource.database import (
_DEFAULT_PROMPT_TEMPLATE,
_DEFAULT_PROMPT_TEMPLATE_ZH,
DBParameters,
RDBMSConnectorResource,
)
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
FunctionDynamicOptions,
OptionValue,
Parameter,
ResourceCategory,
register_resource,
)
from dbgpt.util import ParameterDescription from dbgpt.util import ParameterDescription
from dbgpt.util.i18n_utils import _
CFG = Config() CFG = Config()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_datasource() -> List[OptionValue]:
dbs = CFG.local_db_manager.get_db_list()
results = [
OptionValue(
label="[" + db["db_type"] + "]" + db["db_name"],
name=db["db_name"],
value=db["db_name"],
)
for db in dbs
]
return results
@dataclasses.dataclass @dataclasses.dataclass
class DatasourceDBParameters(DBParameters): class DatasourceDBParameters(DBParameters):
"""The DB parameters for the datasource.""" """The DB parameters for the datasource."""
@ -57,6 +84,44 @@ class DatasourceDBParameters(DBParameters):
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields) return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
@register_resource(
_("Datasource Resource"),
"datasource",
category=ResourceCategory.DATABASE,
description=_(
"Connect to a datasource(retrieve table schemas and execute SQL to fetch data)."
),
tags={"order": TAGS_ORDER_HIGH},
parameters=[
Parameter.build_from(
_("Datasource Name"),
"name",
str,
optional=True,
default="datasource",
description=_("The name of the datasource, default is 'datasource'."),
),
Parameter.build_from(
_("DB Name"),
"db_name",
str,
description=_("The name of the database."),
options=FunctionDynamicOptions(func=_load_datasource),
),
Parameter.build_from(
_("Prompt Template"),
"prompt_template",
str,
optional=True,
default=(
_DEFAULT_PROMPT_TEMPLATE_ZH
if CFG.LANGUAGE == "zh"
else _DEFAULT_PROMPT_TEMPLATE
),
description=_("The prompt template to build a database prompt."),
),
],
)
class DatasourceResource(RDBMSConnectorResource): class DatasourceResource(RDBMSConnectorResource):
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs): def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
conn = CFG.local_db_manager.get_connector(db_name) conn = CFG.local_db_manager.get_connector(db_name)

View File

@ -64,6 +64,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource):
"""Knowledge Space retriever resource.""" """Knowledge Space retriever resource."""
def __init__(self, name: str, space_name: str, context: Optional[dict] = None): def __init__(self, name: str, space_name: str, context: Optional[dict] = None):
# TODO: Build the retriever in a thread pool, it will block the event loop
retriever = KnowledgeSpaceRetriever( retriever = KnowledgeSpaceRetriever(
space_id=space_name, space_id=space_name,
top_k=context.get("top_k", None) if context else 4, top_k=context.get("top_k", None) if context else 4,

View File

@ -133,7 +133,10 @@ async def create(
Returns: Returns:
ServerResponse: The response ServerResponse: The response
""" """
return Result.succ(service.create_and_save_dag(request)) res = await blocking_func_to_async(
global_system_app, service.create_and_save_dag, request
)
return Result.succ(res)
@router.put( @router.put(
@ -154,7 +157,10 @@ async def update(
ServerResponse: The response ServerResponse: The response
""" """
try: try:
return Result.succ(service.update_flow(request)) res = await blocking_func_to_async(
global_system_app, service.update_flow, request
)
return Result.succ(res)
except Exception as e: except Exception as e:
return Result.failed(msg=str(e)) return Result.failed(msg=str(e))
@ -176,9 +182,7 @@ async def delete(
@router.get("/flows/{uid}") @router.get("/flows/{uid}")
async def get_flows( async def get_flows(uid: str, service: Service = Depends(get_service)):
uid: str, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Get a Flow entity by uid """Get a Flow entity by uid
Args: Args:
@ -191,7 +195,7 @@ async def get_flows(
flow = service.get({"uid": uid}) flow = service.get({"uid": uid})
if not flow: if not flow:
raise HTTPException(status_code=404, detail=f"Flow {uid} not found") raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
return Result.succ(flow) return Result.succ(flow.model_dump())
@router.get( @router.get(
@ -467,7 +471,10 @@ async def import_flow(
status_code=400, detail=f"invalid file extension {file_extension}" status_code=400, detail=f"invalid file extension {file_extension}"
) )
if save_flow: if save_flow:
return Result.succ(service.create_and_save_dag(flow)) res = await blocking_func_to_async(
global_system_app, service.create_and_save_dag, flow
)
return Result.succ(res)
else: else:
return Result.succ(flow) return Result.succ(flow)

View File

@ -27,7 +27,7 @@ from dbgpt.core.schema.api import (
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
DeltaMessage, DeltaMessage,
) )
from dbgpt.serve.core import BaseService from dbgpt.serve.core import BaseService, blocking_func_to_async
from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import QUERY_SPEC from dbgpt.storage.metadata._base_dao import QUERY_SPEC
from dbgpt.util.dbgpts.loader import DBGPTsLoader from dbgpt.util.dbgpts.loader import DBGPTsLoader
@ -590,7 +590,11 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
""" """
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
dag = self._flow_factory.build(request.flow) dag = await blocking_func_to_async(
self._system_app,
self._flow_factory.build,
request.flow,
)
leaf_nodes = dag.leaf_nodes leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1: if len(leaf_nodes) != 1:
raise ValueError("Chat Flow just support one leaf node in dag") raise ValueError("Chat Flow just support one leaf node in dag")

View File

@ -223,7 +223,7 @@ class KnowledgeSpacePromptBuilderOperator(
self._prompt = prompt self._prompt = prompt
self._history_key = history_key self._history_key = history_key
self._str_history = str_history self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage) BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_context, **kwargs) JoinOperator.__init__(self, combine_function=self.merge_context, **kwargs)
@rearrange_args_by_type @rearrange_args_by_type