mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 16:57:21 +00:00
chore: Merge latest code
This commit is contained in:
parent
471689ba20
commit
c67b50052d
@ -60,6 +60,7 @@ def initialize_components(
|
||||
_initialize_openapi(system_app)
|
||||
# Register serve apps
|
||||
register_serve_apps(system_app, CFG, param.port)
|
||||
_initialize_operators()
|
||||
|
||||
|
||||
def _initialize_model_cache(system_app: SystemApp, port: int):
|
||||
@ -128,3 +129,14 @@ def _initialize_openapi(system_app: SystemApp):
|
||||
from dbgpt.app.openapi.api_v1.editor.service import EditorService
|
||||
|
||||
system_app.register(EditorService)
|
||||
|
||||
|
||||
def _initialize_operators():
|
||||
from dbgpt.app.operators.converter import StringToInteger
|
||||
from dbgpt.app.operators.datasource import (
|
||||
HODatasourceExecutorOperator,
|
||||
HODatasourceRetrieverOperator,
|
||||
)
|
||||
from dbgpt.app.operators.llm import HOLLMOperator, HOStreamingLLMOperator
|
||||
from dbgpt.app.operators.rag import HOKnowledgeOperator
|
||||
from dbgpt.serve.agent.resource.datasource import DatasourceResource
|
||||
|
4
dbgpt/app/operators/__init__.py
Normal file
4
dbgpt/app/operators/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""Operators package.
|
||||
|
||||
This package contains all higher-order operators that are used to build workflows.
|
||||
"""
|
186
dbgpt/app/operators/converter.py
Normal file
186
dbgpt/app/operators/converter.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""Type Converter Operators."""
|
||||
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
_INPUTS_STRING = IOField.build_from(
|
||||
_("String"),
|
||||
"string",
|
||||
str,
|
||||
description=_("The string to be converted to other types."),
|
||||
)
|
||||
_INPUTS_INTEGER = IOField.build_from(
|
||||
_("Integer"),
|
||||
"integer",
|
||||
int,
|
||||
description=_("The integer to be converted to other types."),
|
||||
)
|
||||
_INPUTS_FLOAT = IOField.build_from(
|
||||
_("Float"),
|
||||
"float",
|
||||
float,
|
||||
description=_("The float to be converted to other types."),
|
||||
)
|
||||
_INPUTS_BOOLEAN = IOField.build_from(
|
||||
_("Boolean"),
|
||||
"boolean",
|
||||
bool,
|
||||
description=_("The boolean to be converted to other types."),
|
||||
)
|
||||
|
||||
_OUTPUTS_STRING = IOField.build_from(
|
||||
_("String"),
|
||||
"string",
|
||||
str,
|
||||
description=_("The string converted from other types."),
|
||||
)
|
||||
_OUTPUTS_INTEGER = IOField.build_from(
|
||||
_("Integer"),
|
||||
"integer",
|
||||
int,
|
||||
description=_("The integer converted from other types."),
|
||||
)
|
||||
_OUTPUTS_FLOAT = IOField.build_from(
|
||||
_("Float"),
|
||||
"float",
|
||||
float,
|
||||
description=_("The float converted from other types."),
|
||||
)
|
||||
_OUTPUTS_BOOLEAN = IOField.build_from(
|
||||
_("Boolean"),
|
||||
"boolean",
|
||||
bool,
|
||||
description=_("The boolean converted from other types."),
|
||||
)
|
||||
|
||||
|
||||
class StringToInteger(MapOperator[str, int]):
|
||||
"""Converts a string to an integer."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String to Integer"),
|
||||
name="default_converter_string_to_integer",
|
||||
description=_("Converts a string to an integer."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_STRING],
|
||||
outputs=[_OUTPUTS_INTEGER],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new StringToInteger operator."""
|
||||
super().__init__(map_function=lambda x: int(x), **kwargs)
|
||||
|
||||
|
||||
class StringToFloat(MapOperator[str, float]):
|
||||
"""Converts a string to a float."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String to Float"),
|
||||
name="default_converter_string_to_float",
|
||||
description=_("Converts a string to a float."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_STRING],
|
||||
outputs=[_OUTPUTS_FLOAT],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new StringToFloat operator."""
|
||||
super().__init__(map_function=lambda x: float(x), **kwargs)
|
||||
|
||||
|
||||
class StringToBoolean(MapOperator[str, bool]):
|
||||
"""Converts a string to a boolean."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("String to Boolean"),
|
||||
name="default_converter_string_to_boolean",
|
||||
description=_("Converts a string to a boolean, true: 'true', '1', 'y'"),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("True Values"),
|
||||
"true_values",
|
||||
str,
|
||||
optional=True,
|
||||
default="true,1,y",
|
||||
description=_("Comma-separated values that should be treated as True."),
|
||||
)
|
||||
],
|
||||
inputs=[_INPUTS_STRING],
|
||||
outputs=[_OUTPUTS_BOOLEAN],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, true_values: str = "true,1,y", **kwargs):
|
||||
"""Create a new StringToBoolean operator."""
|
||||
true_values_list = true_values.split(",")
|
||||
true_values_list = [x.strip().lower() for x in true_values_list]
|
||||
super().__init__(map_function=lambda x: x.lower() in true_values_list, **kwargs)
|
||||
|
||||
|
||||
class IntegerToString(MapOperator[int, str]):
|
||||
"""Converts an integer to a string."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Integer to String"),
|
||||
name="default_converter_integer_to_string",
|
||||
description=_("Converts an integer to a string."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_INTEGER],
|
||||
outputs=[_OUTPUTS_STRING],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new IntegerToString operator."""
|
||||
super().__init__(map_function=lambda x: str(x), **kwargs)
|
||||
|
||||
|
||||
class FloatToString(MapOperator[float, str]):
|
||||
"""Converts a float to a string."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Float to String"),
|
||||
name="default_converter_float_to_string",
|
||||
description=_("Converts a float to a string."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_FLOAT],
|
||||
outputs=[_OUTPUTS_STRING],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new FloatToString operator."""
|
||||
super().__init__(map_function=lambda x: str(x), **kwargs)
|
||||
|
||||
|
||||
class BooleanToString(MapOperator[bool, str]):
|
||||
"""Converts a boolean to a string."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Boolean to String"),
|
||||
name="default_converter_boolean_to_string",
|
||||
description=_("Converts a boolean to a string."),
|
||||
category=OperatorCategory.TYPE_CONVERTER,
|
||||
parameters=[],
|
||||
inputs=[_INPUTS_BOOLEAN],
|
||||
outputs=[_OUTPUTS_STRING],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new BooleanToString operator."""
|
||||
super().__init__(map_function=lambda x: str(x), **kwargs)
|
336
dbgpt/app/operators/datasource.py
Normal file
336
dbgpt/app/operators/datasource.py
Normal file
@ -0,0 +1,336 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBResource
|
||||
from dbgpt.core.awel import DAGContext, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.operators import BaseLLM
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.vis.tags.vis_chart import default_chart_type_prompt
|
||||
|
||||
from .llm import HOContextBody
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_DEFAULT_CHART_TYPE = default_chart_type_prompt()
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """You are a database expert.
|
||||
Please answer the user's question based on the database selected by the user and some \
|
||||
of the available table structure definitions of the database.
|
||||
Database name:
|
||||
{db_name}
|
||||
Table structure definition:
|
||||
{table_info}
|
||||
|
||||
Constraint:
|
||||
1.Please understand the user's intention based on the user's question, and use the \
|
||||
given table structure definition to create a grammatically correct {dialect} sql. \
|
||||
If sql is not required, answer the user's question directly..
|
||||
2.Always limit the query to a maximum of {max_num_results} results unless the user \
|
||||
specifies in the question the specific number of rows of data he wishes to obtain.
|
||||
3.You can only use the tables provided in the table structure information to \
|
||||
generate sql. If you cannot generate sql based on the provided table structure, \
|
||||
please say: "The table structure information provided is not enough to generate \
|
||||
sql queries." It is prohibited to fabricate information at will.
|
||||
4.Please be careful not to mistake the relationship between tables and columns \
|
||||
when generating SQL.
|
||||
5.Please check the correctness of the SQL and ensure that the query performance is \
|
||||
optimized under correct conditions.
|
||||
6.Please choose the best one from the display methods given below for data \
|
||||
rendering, and put the type name into the name parameter value that returns the \
|
||||
required format. If you cannot find the most suitable one, use 'Table' as the \
|
||||
display method. , the available data display methods are as follows: {display_type}
|
||||
|
||||
User Question:
|
||||
{user_input}
|
||||
Please think step by step and respond according to the following JSON format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """你是一个数据库专家.
|
||||
请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题.
|
||||
数据库名:
|
||||
{db_name}
|
||||
表结构定义:
|
||||
{table_info}
|
||||
|
||||
约束:
|
||||
1. 请根据用户问题理解用户意图,使用给出表结构定义创建一个语法正确的 {dialect} sql,如果不需要 \
|
||||
sql,则直接回答用户问题。
|
||||
2. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {max_num_results} \
|
||||
个结果。
|
||||
3. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:\
|
||||
“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||
4. 请注意生成SQL时不要弄错表和列的关系
|
||||
5. 请检查SQL的正确性,并保证正确的情况下优化查询性能
|
||||
6.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种\
|
||||
,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {display_type}
|
||||
用户问题:
|
||||
{user_input}
|
||||
请一步步思考并按照以下JSON格式回复:
|
||||
{response}
|
||||
确保返回正确的json并且可以被Python json.loads方法解析.
|
||||
"""
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
_DEFAULT_RESPONSE = json.dumps(
|
||||
{
|
||||
"thoughts": "thoughts summary to say to user",
|
||||
"sql": "SQL Query to run",
|
||||
"display_type": "Data display method",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
_PARAMETER_DATASOURCE = Parameter.build_from(
|
||||
_("Datasource"),
|
||||
"datasource",
|
||||
type=DBResource,
|
||||
description=_("The datasource to retrieve the context"),
|
||||
)
|
||||
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=_DEFAULT_TEMPLATE,
|
||||
description=_("The prompt template to build a database prompt"),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
)
|
||||
_PARAMETER_DISPLAY_TYPE = Parameter.build_from(
|
||||
_("Display Type"),
|
||||
"display_type",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=_DEFAULT_CHART_TYPE,
|
||||
description=_("The display type for the data"),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
)
|
||||
_PARAMETER_MAX_NUM_RESULTS = Parameter.build_from(
|
||||
_("Max Number of Results"),
|
||||
"max_num_results",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=50,
|
||||
description=_("The maximum number of results to return"),
|
||||
)
|
||||
_PARAMETER_RESPONSE_FORMAT = Parameter.build_from(
|
||||
_("Response Format"),
|
||||
"response_format",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=_DEFAULT_RESPONSE,
|
||||
description=_("The response format, default is a JSON format"),
|
||||
ui=ui.DefaultUITextArea(),
|
||||
)
|
||||
|
||||
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
|
||||
_("Context Key"),
|
||||
"context_key",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="context",
|
||||
description=_("The key of the context, it will be used in building the prompt"),
|
||||
)
|
||||
_INPUTS_QUESTION = IOField.build_from(
|
||||
_("User question"),
|
||||
"query",
|
||||
str,
|
||||
description=_("The user question to retrieve table schemas from the datasource"),
|
||||
)
|
||||
_OUTPUTS_CONTEXT = IOField.build_from(
|
||||
_("Retrieved context"),
|
||||
"context",
|
||||
HOContextBody,
|
||||
description=_("The retrieved context from the datasource"),
|
||||
)
|
||||
|
||||
_INPUTS_SQL_DICT = IOField.build_from(
|
||||
_("SQL dict"),
|
||||
"sql_dict",
|
||||
dict,
|
||||
description=_("The SQL to be executed wrapped in a dictionary, generated by LLM"),
|
||||
)
|
||||
_OUTPUTS_SQL_RESULT = IOField.build_from(
|
||||
_("SQL result"),
|
||||
"sql_result",
|
||||
str,
|
||||
description=_("The result of the SQL execution"),
|
||||
)
|
||||
|
||||
_INPUTS_SQL_DICT_LIST = IOField.build_from(
|
||||
_("SQL dict list"),
|
||||
"sql_dict_list",
|
||||
dict,
|
||||
description=_(
|
||||
"The SQL list to be executed wrapped in a dictionary, generated by LLM"
|
||||
),
|
||||
is_list=True,
|
||||
)
|
||||
|
||||
|
||||
class GPTVisMixin:
|
||||
async def save_view_message(self, dag_ctx: DAGContext, view: str):
|
||||
"""Save the view message."""
|
||||
await dag_ctx.save_to_share_data(BaseLLM.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW, view)
|
||||
|
||||
|
||||
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
"""Retrieve the table schemas from the datasource."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Retriever Operator"),
|
||||
name="higher_order_datasource_retriever_operator",
|
||||
description=_("Retrieve the table schemas from the datasource."),
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[
|
||||
_PARAMETER_DATASOURCE.new(),
|
||||
_PARAMETER_PROMPT_TEMPLATE.new(),
|
||||
_PARAMETER_DISPLAY_TYPE.new(),
|
||||
_PARAMETER_MAX_NUM_RESULTS.new(),
|
||||
_PARAMETER_RESPONSE_FORMAT.new(),
|
||||
_PARAMETER_CONTEXT_KEY.new(),
|
||||
],
|
||||
inputs=[_INPUTS_QUESTION.new()],
|
||||
outputs=[_OUTPUTS_CONTEXT.new()],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasource: DBResource,
|
||||
prompt_template: str = _DEFAULT_TEMPLATE,
|
||||
display_type: str = _DEFAULT_CHART_TYPE,
|
||||
max_num_results: int = 50,
|
||||
response_format: str = _DEFAULT_RESPONSE,
|
||||
context_key: Optional[str] = "context",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._datasource = datasource
|
||||
self._prompt_template = prompt_template
|
||||
self._display_type = display_type
|
||||
self._max_num_results = max_num_results
|
||||
self._response_format = response_format
|
||||
self._context_key = context_key
|
||||
|
||||
async def map(self, question: str) -> HOContextBody:
|
||||
"""Retrieve the context from the datasource."""
|
||||
db_name = self._datasource._db_name
|
||||
dialect = self._datasource.dialect
|
||||
schema_info = await self.blocking_func_to_async(
|
||||
self._datasource.get_schema_link,
|
||||
db=db_name,
|
||||
question=question,
|
||||
)
|
||||
context = self._prompt_template.format(
|
||||
db_name=db_name,
|
||||
table_info=schema_info,
|
||||
dialect=dialect,
|
||||
max_num_results=self._max_num_results,
|
||||
display_type=self._display_type,
|
||||
user_input=question,
|
||||
response=self._response_format,
|
||||
)
|
||||
|
||||
return HOContextBody(
|
||||
context_key=self._context_key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
|
||||
class HODatasourceExecutorOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
"""Execute the context from the datasource."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Executor Operator"),
|
||||
name="higher_order_datasource_executor_operator",
|
||||
description=_("Execute the context from the datasource."),
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[_PARAMETER_DATASOURCE.new()],
|
||||
inputs=[_INPUTS_SQL_DICT.new()],
|
||||
outputs=[_OUTPUTS_SQL_RESULT.new()],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, datasource: DBResource, **kwargs):
|
||||
"""Initialize the operator."""
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._datasource = datasource
|
||||
|
||||
async def map(self, sql_dict: dict) -> str:
|
||||
"""Execute the context from the datasource."""
|
||||
from dbgpt.vis.tags.vis_chart import VisChart
|
||||
|
||||
if not isinstance(sql_dict, dict):
|
||||
raise ValueError(
|
||||
"The input value of datasource executor should be a dictionary."
|
||||
)
|
||||
vis = VisChart()
|
||||
sql = sql_dict.get("sql")
|
||||
if not sql:
|
||||
return sql_dict.get("thoughts", "No SQL found in the input dictionary.")
|
||||
data_df = await self._datasource.query_to_df(sql)
|
||||
view = await vis.display(chart=sql_dict, data_df=data_df)
|
||||
await self.save_view_message(self.current_dag_context, view)
|
||||
return view
|
||||
|
||||
|
||||
class HODatasourceDashboardOperator(GPTVisMixin, MapOperator[dict, str]):
|
||||
"""Execute the context from the datasource."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Dashboard Operator"),
|
||||
name="higher_order_datasource_dashboard_operator",
|
||||
description=_("Execute the context from the datasource."),
|
||||
category=OperatorCategory.DATABASE,
|
||||
parameters=[_PARAMETER_DATASOURCE.new()],
|
||||
inputs=[_INPUTS_SQL_DICT_LIST.new()],
|
||||
outputs=[_OUTPUTS_SQL_RESULT.new()],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, datasource: DBResource, **kwargs):
|
||||
"""Initialize the operator."""
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._datasource = datasource
|
||||
|
||||
async def map(self, sql_dict_list: List[dict]) -> str:
|
||||
"""Execute the context from the datasource."""
|
||||
from dbgpt.vis.tags.vis_dashboard import VisDashboard
|
||||
|
||||
if not isinstance(sql_dict_list, list):
|
||||
raise ValueError(
|
||||
"The input value of datasource executor should be a list of dictionaries."
|
||||
)
|
||||
vis = VisDashboard()
|
||||
chart_params = []
|
||||
for chart_item in sql_dict_list:
|
||||
chart_dict = {k: v for k, v in chart_item.items()}
|
||||
sql = chart_item.get("sql")
|
||||
try:
|
||||
data_df = await self._datasource.query_to_df(sql)
|
||||
chart_dict["data"] = data_df
|
||||
except Exception as e:
|
||||
logger.warning(f"Sql execute failed!{str(e)}")
|
||||
chart_dict["err_msg"] = str(e)
|
||||
chart_params.append(chart_dict)
|
||||
view = await vis.display(charts=chart_params)
|
||||
await self.save_view_message(self.current_dag_context, view)
|
||||
return view
|
443
dbgpt/app/operators/llm.py
Normal file
443
dbgpt/app/operators/llm.py
Normal file
@ -0,0 +1,443 @@
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import (
|
||||
BaseMessage,
|
||||
ChatPromptTemplate,
|
||||
LLMClient,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core.awel import (
|
||||
DAG,
|
||||
BaseOperator,
|
||||
CommonLLMHttpRequestBody,
|
||||
DAGContext,
|
||||
DefaultInputContext,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
SimpleCallDataInputSource,
|
||||
TaskOutput,
|
||||
)
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.core.interface.operators.message_operator import (
|
||||
BaseConversationOperator,
|
||||
BufferedConversationMapperOperator,
|
||||
TokenBufferedConversationMapperOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator
|
||||
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
|
||||
class HOContextBody(BaseModel):
|
||||
"""Higher-order context body."""
|
||||
|
||||
context_key: str = Field(
|
||||
"context",
|
||||
description=_("The context key can be used as the key for formatting prompt."),
|
||||
)
|
||||
context: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description=_("The context."),
|
||||
)
|
||||
|
||||
|
||||
class BaseHOLLMOperator(
|
||||
BaseConversationOperator,
|
||||
JoinOperator[ModelRequest],
|
||||
LLMOperator,
|
||||
StreamingLLMOperator,
|
||||
):
|
||||
"""Higher-order model request builder operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_template: ChatPromptTemplate,
|
||||
model: str = None,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
history_merge_mode: Literal["none", "window", "token"] = "window",
|
||||
user_message_key: str = "user_input",
|
||||
history_key: Optional[str] = None,
|
||||
keep_start_rounds: Optional[int] = None,
|
||||
keep_end_rounds: Optional[int] = None,
|
||||
max_token_limit: int = 2048,
|
||||
**kwargs,
|
||||
):
|
||||
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
|
||||
LLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
||||
StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
||||
|
||||
# User must select a history merge mode
|
||||
self._history_merge_mode = history_merge_mode
|
||||
self._user_message_key = user_message_key
|
||||
self._has_history = history_merge_mode != "none"
|
||||
self._prompt_template = prompt_template
|
||||
self._model = model
|
||||
self._history_key = history_key
|
||||
self._str_history = False
|
||||
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
|
||||
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
|
||||
self._max_token_limit = max_token_limit
|
||||
self._sub_compose_dag = self._build_conversation_composer_dag()
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
|
||||
conv_serve = ConversationServe.get_instance(self.system_app)
|
||||
self._storage = conv_serve.conv_storage
|
||||
self._message_storage = conv_serve.message_storage
|
||||
|
||||
_: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx)
|
||||
dag_ctx.current_task_context.set_task_input(
|
||||
DefaultInputContext([dag_ctx.current_task_context])
|
||||
)
|
||||
if dag_ctx.streaming_call:
|
||||
task_output = await StreamingLLMOperator._do_run(self, dag_ctx)
|
||||
else:
|
||||
task_output = await LLMOperator._do_run(self, dag_ctx)
|
||||
|
||||
return task_output
|
||||
|
||||
async def after_dag_end(self, event_loop_task_id: int):
|
||||
model_output: Optional[
|
||||
ModelOutput
|
||||
] = await self.current_dag_context.get_from_share_data(
|
||||
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT
|
||||
)
|
||||
model_output_view: Optional[
|
||||
str
|
||||
] = await self.current_dag_context.get_from_share_data(
|
||||
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW
|
||||
)
|
||||
storage_conv = await self.get_storage_conversation()
|
||||
end_current_round: bool = False
|
||||
if model_output and storage_conv:
|
||||
# Save model output message to storage
|
||||
storage_conv.add_ai_message(model_output.text)
|
||||
end_current_round = True
|
||||
if model_output_view and storage_conv:
|
||||
# Save model output view to storage
|
||||
storage_conv.add_view_message(model_output_view)
|
||||
end_current_round = True
|
||||
if end_current_round:
|
||||
# End current conversation round and flush to storage
|
||||
storage_conv.end_current_round()
|
||||
|
||||
async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
|
||||
dynamic_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, HOContextBody):
|
||||
dynamic_inputs.append(arg)
|
||||
# Load and store chat history, default use InMemoryStorage.
|
||||
storage_conv, history_messages = await self.blocking_func_to_async(
|
||||
self._build_storage, req
|
||||
)
|
||||
# Save the storage conversation to share data, for the child operators
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
||||
)
|
||||
|
||||
user_input = (
|
||||
req.messages[-1] if isinstance(req.messages, list) else req.messages
|
||||
)
|
||||
prompt_dict = {
|
||||
self._user_message_key: user_input,
|
||||
}
|
||||
for dynamic_input in dynamic_inputs:
|
||||
if dynamic_input.context_key in prompt_dict:
|
||||
raise ValueError(
|
||||
f"Duplicate context key '{dynamic_input.context_key}' in upstream "
|
||||
f"operators."
|
||||
)
|
||||
prompt_dict[dynamic_input.context_key] = dynamic_input.context
|
||||
|
||||
call_data = {
|
||||
"messages": history_messages,
|
||||
"prompt_dict": prompt_dict,
|
||||
}
|
||||
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
|
||||
model_request = ModelRequest.build_request(
|
||||
model=req.model,
|
||||
messages=messages,
|
||||
context=req.context,
|
||||
temperature=req.temperature,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
span_id=root_tracer.get_current_span_id(),
|
||||
echo=False,
|
||||
)
|
||||
if storage_conv:
|
||||
# Start new round
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(user_input)
|
||||
return model_request
|
||||
|
||||
def _build_storage(
|
||||
self, req: CommonLLMHttpRequestBody
|
||||
) -> Tuple[StorageConversation, List[BaseMessage]]:
|
||||
# Create a new storage conversation, this will load the conversation from
|
||||
# storage, so we must do this async
|
||||
storage_conv: StorageConversation = StorageConversation(
|
||||
conv_uid=req.conv_uid,
|
||||
chat_mode=req.chat_mode,
|
||||
user_name=req.user_name,
|
||||
sys_code=req.sys_code,
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
param_type="",
|
||||
param_value=req.chat_param,
|
||||
)
|
||||
# Get history messages from storage
|
||||
history_messages: List[BaseMessage] = storage_conv.get_history_message(
|
||||
include_system_message=False
|
||||
)
|
||||
|
||||
return storage_conv, history_messages
|
||||
|
||||
def _build_conversation_composer_dag(self) -> DAG:
|
||||
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
# History transform task
|
||||
if self._history_merge_mode == "token":
|
||||
history_transform_task = TokenBufferedConversationMapperOperator(
|
||||
model=self._model,
|
||||
llm_client=self.llm_client,
|
||||
max_token_limit=self._max_token_limit,
|
||||
)
|
||||
else:
|
||||
history_transform_task = BufferedConversationMapperOperator(
|
||||
keep_start_rounds=self._keep_start_rounds,
|
||||
keep_end_rounds=self._keep_end_rounds,
|
||||
)
|
||||
if self._history_key:
|
||||
history_key = self._history_key
|
||||
else:
|
||||
placeholders = self._prompt_template.get_placeholders()
|
||||
if not placeholders or len(placeholders) != 1:
|
||||
raise ValueError(
|
||||
"The prompt template must have exactly one placeholder if "
|
||||
"history_key is not provided."
|
||||
)
|
||||
history_key = placeholders[0]
|
||||
history_prompt_build_task = HistoryPromptBuilderOperator(
|
||||
prompt=self._prompt_template,
|
||||
history_key=history_key,
|
||||
check_storage=False,
|
||||
save_to_storage=False,
|
||||
str_history=self._str_history,
|
||||
)
|
||||
# Build composer dag
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x["messages"])
|
||||
>> history_transform_task
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
(
|
||||
input_task
|
||||
>> MapOperator(lambda x: x["prompt_dict"])
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
|
||||
return composer_dag
|
||||
|
||||
|
||||
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
ChatPromptTemplate,
|
||||
description=_("The prompt template for the conversation."),
|
||||
)
|
||||
_PARAMETER_MODEL = Parameter.build_from(
|
||||
_("Model Name"),
|
||||
"model",
|
||||
str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The model name."),
|
||||
)
|
||||
|
||||
_PARAMETER_LLM_CLIENT = Parameter.build_from(
|
||||
_("LLM Client"),
|
||||
"llm_client",
|
||||
LLMClient,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_(
|
||||
"The LLM Client, how to connect to the LLM model, if not provided, it will use"
|
||||
" the default client deployed by DB-GPT."
|
||||
),
|
||||
)
|
||||
_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from(
|
||||
_("History Message Merge Mode"),
|
||||
"history_merge_mode",
|
||||
str,
|
||||
optional=True,
|
||||
default="none",
|
||||
options=[
|
||||
OptionValue(label="No History", name="none", value="none"),
|
||||
OptionValue(label="Message Window", name="window", value="window"),
|
||||
OptionValue(label="Token Length", name="token", value="token"),
|
||||
],
|
||||
description=_(
|
||||
"The history merge mode, supports 'none', 'window' and 'token'."
|
||||
" 'none': no history merge, 'window': merge by conversation window, 'token': "
|
||||
"merge by token length."
|
||||
),
|
||||
ui=ui.UISelect(),
|
||||
)
|
||||
_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from(
|
||||
_("User Message Key"),
|
||||
"user_message_key",
|
||||
str,
|
||||
optional=True,
|
||||
default="user_input",
|
||||
description=_(
|
||||
"The key of the user message in your prompt, default is 'user_input'."
|
||||
),
|
||||
)
|
||||
_PARAMETER_HISTORY_KEY = Parameter.build_from(
|
||||
_("History Key"),
|
||||
"history_key",
|
||||
str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_(
|
||||
"The chat history key, with chat history message pass to prompt template, "
|
||||
"if not provided, it will parse the prompt template to get the key."
|
||||
),
|
||||
)
|
||||
_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from(
|
||||
_("Keep Start Rounds"),
|
||||
"keep_start_rounds",
|
||||
int,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The start rounds to keep in the chat history."),
|
||||
)
|
||||
_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from(
|
||||
_("Keep End Rounds"),
|
||||
"keep_end_rounds",
|
||||
int,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The end rounds to keep in the chat history."),
|
||||
)
|
||||
_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from(
|
||||
_("Max Token Limit"),
|
||||
"max_token_limit",
|
||||
int,
|
||||
optional=True,
|
||||
default=2048,
|
||||
description=_("The max token limit to keep in the chat history."),
|
||||
)
|
||||
|
||||
_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from(
|
||||
_("Common LLM Request Body"),
|
||||
"common_llm_request_body",
|
||||
CommonLLMHttpRequestBody,
|
||||
_("The common LLM request body."),
|
||||
)
|
||||
_INPUTS_EXTRA_CONTEXT = IOField.build_from(
|
||||
_("Extra Context"),
|
||||
"extra_context",
|
||||
HOContextBody,
|
||||
_(
|
||||
"Extra context for building prompt(Knowledge context, database "
|
||||
"schema, etc), you can add multiple context."
|
||||
),
|
||||
dynamic=True,
|
||||
)
|
||||
_OUTPUTS_MODEL_OUTPUT = IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output."),
|
||||
)
|
||||
_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from(
|
||||
_("Streaming Model Output"),
|
||||
"streaming_model_output",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description=_("The streaming model output."),
|
||||
)
|
||||
|
||||
|
||||
class HOLLMOperator(BaseHOLLMOperator):
|
||||
metadata = ViewMetadata(
|
||||
label=_("LLM Operator"),
|
||||
name="higher_order_llm_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
description=_(
|
||||
"High-level LLM operator, supports multi-round conversation "
|
||||
"(conversation window, token length and no multi-round)."
|
||||
),
|
||||
parameters=[
|
||||
_PARAMETER_PROMPT_TEMPLATE.new(),
|
||||
_PARAMETER_MODEL.new(),
|
||||
_PARAMETER_LLM_CLIENT.new(),
|
||||
_PARAMETER_HISTORY_MERGE_MODE.new(),
|
||||
_PARAMETER_USER_MESSAGE_KEY.new(),
|
||||
_PARAMETER_HISTORY_KEY.new(),
|
||||
_PARAMETER_KEEP_START_ROUNDS.new(),
|
||||
_PARAMETER_KEEP_END_ROUNDS.new(),
|
||||
_PARAMETER_MAX_TOKEN_LIMIT.new(),
|
||||
],
|
||||
inputs=[
|
||||
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
|
||||
_INPUTS_EXTRA_CONTEXT.new(),
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_MODEL_OUTPUT.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class HOStreamingLLMOperator(BaseHOLLMOperator):
|
||||
metadata = ViewMetadata(
|
||||
label=_("Streaming LLM Operator"),
|
||||
name="higher_order_streaming_llm_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
description=_(
|
||||
"High-level streaming LLM operator, supports multi-round conversation "
|
||||
"(conversation window, token length and no multi-round)."
|
||||
),
|
||||
parameters=[
|
||||
_PARAMETER_PROMPT_TEMPLATE.new(),
|
||||
_PARAMETER_MODEL.new(),
|
||||
_PARAMETER_LLM_CLIENT.new(),
|
||||
_PARAMETER_HISTORY_MERGE_MODE.new(),
|
||||
_PARAMETER_USER_MESSAGE_KEY.new(),
|
||||
_PARAMETER_HISTORY_KEY.new(),
|
||||
_PARAMETER_KEEP_START_ROUNDS.new(),
|
||||
_PARAMETER_KEEP_END_ROUNDS.new(),
|
||||
_PARAMETER_MAX_TOKEN_LIMIT.new(),
|
||||
],
|
||||
inputs=[
|
||||
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
|
||||
_INPUTS_EXTRA_CONTEXT.new(),
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_STREAMING_MODEL_OUTPUT.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
191
dbgpt/app/operators/rag.py
Normal file
191
dbgpt/app/operators/rag.py
Normal file
@ -0,0 +1,191 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
FunctionDynamicOptions,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
ui,
|
||||
)
|
||||
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from .llm import HOContextBody
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _load_space_name() -> List[OptionValue]:
|
||||
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||
|
||||
spaces = KnowledgeSpaceDao().get_knowledge_space(KnowledgeSpaceEntity())
|
||||
return [
|
||||
OptionValue(label=space.name, name=space.name, value=space.name)
|
||||
for space in spaces
|
||||
]
|
||||
|
||||
|
||||
_PARAMETER_CONTEXT_KEY = Parameter.build_from(
|
||||
_("Context Key"),
|
||||
"context",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="context",
|
||||
description=_("The key of the context, it will be used in building the prompt"),
|
||||
)
|
||||
_PARAMETER_TOP_K = Parameter.build_from(
|
||||
_("Top K"),
|
||||
"top_k",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=5,
|
||||
description=_("The number of chunks to retrieve"),
|
||||
)
|
||||
_PARAMETER_SCORE_THRESHOLD = Parameter.build_from(
|
||||
_("Minimum Match Score"),
|
||||
"score_threshold",
|
||||
type=float,
|
||||
optional=True,
|
||||
default=0.3,
|
||||
description=_(
|
||||
_(
|
||||
"The minimum match score for the retrieved chunks, it will be dropped if "
|
||||
"the match score is less than the threshold"
|
||||
)
|
||||
),
|
||||
ui=ui.UISlider(attr=ui.UISlider.UIAttribute(min=0.0, max=1.0, step=0.1)),
|
||||
)
|
||||
|
||||
_PARAMETER_RE_RANKER_ENABLED = Parameter.build_from(
|
||||
_("Reranker Enabled"),
|
||||
"reranker_enabled",
|
||||
type=bool,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("Whether to enable the reranker"),
|
||||
)
|
||||
_PARAMETER_RE_RANKER_TOP_K = Parameter.build_from(
|
||||
_("Reranker Top K"),
|
||||
"reranker_top_k",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=3,
|
||||
description=_("The top k for the reranker"),
|
||||
)
|
||||
|
||||
_INPUTS_QUESTION = IOField.build_from(
|
||||
_("User question"),
|
||||
"query",
|
||||
str,
|
||||
description=_("The user question to retrieve the knowledge"),
|
||||
)
|
||||
_OUTPUTS_CONTEXT = IOField.build_from(
|
||||
_("Retrieved context"),
|
||||
"context",
|
||||
HOContextBody,
|
||||
description=_("The retrieved context from the knowledge space"),
|
||||
)
|
||||
|
||||
|
||||
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
metadata = ViewMetadata(
|
||||
label=_("Knowledge Operator"),
|
||||
name="higher_order_knowledge_operator",
|
||||
category=OperatorCategory.RAG,
|
||||
description=_(
|
||||
_(
|
||||
"Knowledge Operator, retrieve your knowledge(documents) from knowledge"
|
||||
" space"
|
||||
)
|
||||
),
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Knowledge Space Name"),
|
||||
"knowledge_space",
|
||||
type=str,
|
||||
options=FunctionDynamicOptions(func=_load_space_name),
|
||||
description=_("The name of the knowledge space"),
|
||||
),
|
||||
_PARAMETER_CONTEXT_KEY.new(),
|
||||
_PARAMETER_TOP_K.new(),
|
||||
_PARAMETER_SCORE_THRESHOLD.new(),
|
||||
_PARAMETER_RE_RANKER_ENABLED.new(),
|
||||
_PARAMETER_RE_RANKER_TOP_K.new(),
|
||||
],
|
||||
inputs=[
|
||||
_INPUTS_QUESTION.new(),
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_space: str,
|
||||
context_key: Optional[str] = "context",
|
||||
top_k: Optional[int] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
reranker_enabled: Optional[bool] = None,
|
||||
reranker_top_k: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._knowledge_space = knowledge_space
|
||||
self._context_key = context_key
|
||||
self._top_k = top_k
|
||||
self._score_threshold = score_threshold
|
||||
self._reranker_enabled = reranker_enabled
|
||||
self._reranker_top_k = reranker_top_k
|
||||
|
||||
from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.serve.rag.models.models import (
|
||||
KnowledgeSpaceDao,
|
||||
KnowledgeSpaceEntity,
|
||||
)
|
||||
|
||||
spaces = KnowledgeSpaceDao().get_knowledge_space(
|
||||
KnowledgeSpaceEntity(name=knowledge_space)
|
||||
)
|
||||
if len(spaces) != 1:
|
||||
raise Exception(f"invalid space name: {knowledge_space}")
|
||||
space = spaces[0]
|
||||
|
||||
reranker: Optional[RerankEmbeddingsRanker] = None
|
||||
|
||||
if CFG.RERANK_MODEL and self._reranker_enabled:
|
||||
reranker_top_k = (
|
||||
self._reranker_top_k
|
||||
if self._reranker_top_k is not None
|
||||
else CFG.RERANK_TOP_K
|
||||
)
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(
|
||||
CFG.SYSTEM_APP
|
||||
).create()
|
||||
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=reranker_top_k)
|
||||
if self._top_k < reranker_top_k or self._top_k < 20:
|
||||
# We use reranker, so if the top_k is less than 20,
|
||||
# we need to set it to 20
|
||||
self._top_k = max(reranker_top_k, 20)
|
||||
|
||||
self._space_retriever = KnowledgeSpaceRetriever(
|
||||
space_id=space.id,
|
||||
top_k=self._top_k,
|
||||
rerank=reranker,
|
||||
)
|
||||
|
||||
async def map(self, query: str) -> HOContextBody:
|
||||
chunks = await self._space_retriever.aretrieve_with_scores(
|
||||
query, self._score_threshold
|
||||
)
|
||||
return HOContextBody(
|
||||
context_key=self._context_key,
|
||||
context=[chunk.content for chunk in chunks],
|
||||
)
|
@ -10,6 +10,7 @@ from ..util.parameter_util import ( # noqa: F401
|
||||
VariablesDynamicOptions,
|
||||
)
|
||||
from .base import ( # noqa: F401
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@ -33,6 +34,7 @@ __ALL__ = [
|
||||
"ResourceCategory",
|
||||
"ResourceType",
|
||||
"OperatorType",
|
||||
"TAGS_ORDER_HIGH",
|
||||
"IOField",
|
||||
"BaseDynamicOptions",
|
||||
"FunctionDynamicOptions",
|
||||
|
@ -40,6 +40,9 @@ _BASIC_TYPES = [str, int, float, bool, dict, list, set]
|
||||
T = TypeVar("T", bound="ViewMixin")
|
||||
TM = TypeVar("TM", bound="TypeMetadata")
|
||||
|
||||
TAGS_ORDER_HIGH = "higher-order"
|
||||
TAGS_ORDER_FIRST = "first-order"
|
||||
|
||||
|
||||
def _get_type_name(type_: Type[Any]) -> str:
|
||||
"""Get the type name of the type.
|
||||
@ -143,6 +146,8 @@ _OPERATOR_CATEGORY_DETAIL = {
|
||||
"agent": _CategoryDetail("Agent", "The agent operator"),
|
||||
"rag": _CategoryDetail("RAG", "The RAG operator"),
|
||||
"experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"),
|
||||
"database": _CategoryDetail("Database", "Interact with the database"),
|
||||
"type_converter": _CategoryDetail("Type Converter", "Convert the type"),
|
||||
"example": _CategoryDetail("Example", "Example operator"),
|
||||
}
|
||||
|
||||
@ -159,6 +164,8 @@ class OperatorCategory(str, Enum):
|
||||
AGENT = "agent"
|
||||
RAG = "rag"
|
||||
EXPERIMENTAL = "experimental"
|
||||
DATABASE = "database"
|
||||
TYPE_CONVERTER = "type_converter"
|
||||
EXAMPLE = "example"
|
||||
|
||||
def label(self) -> str:
|
||||
@ -202,6 +209,7 @@ _RESOURCE_CATEGORY_DETAIL = {
|
||||
"embeddings": _CategoryDetail("Embeddings", "The embeddings resource"),
|
||||
"rag": _CategoryDetail("RAG", "The resource"),
|
||||
"vector_store": _CategoryDetail("Vector Store", "The vector store resource"),
|
||||
"database": _CategoryDetail("Database", "Interact with the database"),
|
||||
"example": _CategoryDetail("Example", "The example resource"),
|
||||
}
|
||||
|
||||
@ -219,6 +227,7 @@ class ResourceCategory(str, Enum):
|
||||
EMBEDDINGS = "embeddings"
|
||||
RAG = "rag"
|
||||
VECTOR_STORE = "vector_store"
|
||||
DATABASE = "database"
|
||||
EXAMPLE = "example"
|
||||
|
||||
def label(self) -> str:
|
||||
@ -372,32 +381,41 @@ class Parameter(TypeMetadata, Serializable):
|
||||
"value": values.get("value"),
|
||||
"default": values.get("default"),
|
||||
}
|
||||
is_list = values.get("is_list") or False
|
||||
if type_cls:
|
||||
for k, v in to_handle_values.items():
|
||||
if v:
|
||||
handled_v = cls._covert_to_real_type(type_cls, v)
|
||||
handled_v = cls._covert_to_real_type(type_cls, v, is_list)
|
||||
values[k] = handled_v
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any:
|
||||
if type_cls and v is not None:
|
||||
typed_value: Any = v
|
||||
def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any:
|
||||
def _parse_single_value(vv: Any) -> Any:
|
||||
typed_value: Any = vv
|
||||
try:
|
||||
# Try to convert the value to the type.
|
||||
if type_cls == "builtins.str":
|
||||
typed_value = str(v)
|
||||
typed_value = str(vv)
|
||||
elif type_cls == "builtins.int":
|
||||
typed_value = int(v)
|
||||
typed_value = int(vv)
|
||||
elif type_cls == "builtins.float":
|
||||
typed_value = float(v)
|
||||
typed_value = float(vv)
|
||||
elif type_cls == "builtins.bool":
|
||||
if str(v).lower() in ["false", "0", "", "no", "off"]:
|
||||
if str(vv).lower() in ["false", "0", "", "no", "off"]:
|
||||
return False
|
||||
typed_value = bool(v)
|
||||
typed_value = bool(vv)
|
||||
return typed_value
|
||||
except ValueError:
|
||||
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}")
|
||||
raise ValidationError(f"Value '{vv}' is not valid for type {type_cls}")
|
||||
|
||||
if type_cls and v is not None:
|
||||
if not is_list:
|
||||
_parse_single_value(v)
|
||||
else:
|
||||
if not isinstance(v, list):
|
||||
raise ValidationError(f"Value '{v}' is not a list.")
|
||||
return [_parse_single_value(vv) for vv in v]
|
||||
return v
|
||||
|
||||
def get_typed_value(self) -> Any:
|
||||
@ -413,11 +431,11 @@ class Parameter(TypeMetadata, Serializable):
|
||||
if is_variables and self.value is not None and isinstance(self.value, str):
|
||||
return VariablesPlaceHolder(self.name, self.value)
|
||||
else:
|
||||
return self._covert_to_real_type(self.type_cls, self.value)
|
||||
return self._covert_to_real_type(self.type_cls, self.value, self.is_list)
|
||||
|
||||
def get_typed_default(self) -> Any:
|
||||
"""Get the typed default."""
|
||||
return self._covert_to_real_type(self.type_cls, self.default)
|
||||
return self._covert_to_real_type(self.type_cls, self.default, self.is_list)
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@ -499,7 +517,10 @@ class Parameter(TypeMetadata, Serializable):
|
||||
values = self.options.option_values()
|
||||
dict_value["options"] = [value.to_dict() for value in values]
|
||||
else:
|
||||
dict_value["options"] = [value.to_dict() for value in self.options]
|
||||
dict_value["options"] = [
|
||||
value.to_dict() if not isinstance(value, dict) else value
|
||||
for value in self.options
|
||||
]
|
||||
|
||||
if self.ui:
|
||||
dict_value["ui"] = self.ui.to_dict()
|
||||
@ -594,6 +615,17 @@ class Parameter(TypeMetadata, Serializable):
|
||||
value = view_value
|
||||
return {self.name: value}
|
||||
|
||||
def new(self: TM) -> TM:
|
||||
"""Copy the metadata."""
|
||||
new_obj = self.__class__(
|
||||
**self.model_dump(exclude_defaults=True, exclude={"ui", "options"})
|
||||
)
|
||||
if self.ui:
|
||||
new_obj.ui = self.ui
|
||||
if self.options:
|
||||
new_obj.options = self.options
|
||||
return new_obj
|
||||
|
||||
|
||||
class BaseResource(Serializable, BaseModel):
|
||||
"""The base resource."""
|
||||
@ -644,6 +676,17 @@ class IOField(Resource):
|
||||
description="Whether current field is list",
|
||||
examples=[True, False],
|
||||
)
|
||||
dynamic: bool = Field(
|
||||
default=False,
|
||||
description="Whether current field is dynamic",
|
||||
examples=[True, False],
|
||||
)
|
||||
dynamic_minimum: int = Field(
|
||||
default=0,
|
||||
description="The minimum count of the dynamic field, only valid when dynamic is"
|
||||
" True",
|
||||
examples=[0, 1, 2],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@ -653,6 +696,8 @@ class IOField(Resource):
|
||||
type: Type,
|
||||
description: Optional[str] = None,
|
||||
is_list: bool = False,
|
||||
dynamic: bool = False,
|
||||
dynamic_minimum: int = 0,
|
||||
):
|
||||
"""Build the resource from the type."""
|
||||
type_name = type.__qualname__
|
||||
@ -664,8 +709,22 @@ class IOField(Resource):
|
||||
type_cls=type_cls,
|
||||
is_list=is_list,
|
||||
description=description or label,
|
||||
dynamic=dynamic,
|
||||
dynamic_minimum=dynamic_minimum,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "dynamic" not in values:
|
||||
values["dynamic"] = False
|
||||
if "dynamic_minimum" not in values:
|
||||
values["dynamic_minimum"] = 0
|
||||
return values
|
||||
|
||||
|
||||
class BaseMetadata(BaseResource):
|
||||
"""The base metadata."""
|
||||
@ -808,9 +867,40 @@ class BaseMetadata(BaseResource):
|
||||
split_ids = self.id.split("_")
|
||||
return "_".join(split_ids[:-1])
|
||||
|
||||
def _parse_ui_size(self) -> Optional[str]:
|
||||
"""Parse the ui size."""
|
||||
if not self.parameters:
|
||||
return None
|
||||
parameters_size = set()
|
||||
for parameter in self.parameters:
|
||||
if parameter.ui and parameter.ui.size:
|
||||
parameters_size.add(parameter.ui.size)
|
||||
for size in ["large", "middle", "small"]:
|
||||
if size in parameters_size:
|
||||
return size
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
from .ui import _size_to_order
|
||||
|
||||
dict_value = model_to_dict(self, exclude={"parameters"})
|
||||
tags = dict_value.get("tags")
|
||||
if not tags:
|
||||
tags = {"ui_version": "flow2.0"}
|
||||
elif isinstance(tags, dict) and "ui_version" not in tags:
|
||||
tags["ui_version"] = "flow2.0"
|
||||
|
||||
parsed_ui_size = self._parse_ui_size()
|
||||
if parsed_ui_size:
|
||||
exist_size = tags.get("ui_size")
|
||||
if not exist_size or _size_to_order(parsed_ui_size) > _size_to_order(
|
||||
exist_size
|
||||
):
|
||||
# Use the higher order size as current size.
|
||||
tags["ui_size"] = parsed_ui_size
|
||||
|
||||
dict_value["tags"] = tags
|
||||
dict_value["parameters"] = [
|
||||
parameter.to_dict() for parameter in self.parameters
|
||||
]
|
||||
|
@ -97,6 +97,12 @@ class FlowNodeData(BaseModel):
|
||||
return ResourceMetadata(**value)
|
||||
raise ValueError("Unable to infer the type for `data`")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
dict_value = model_to_dict(self, exclude={"data"})
|
||||
dict_value["data"] = self.data.to_dict()
|
||||
return dict_value
|
||||
|
||||
|
||||
class FlowEdgeData(BaseModel):
|
||||
"""Edge data in a flow."""
|
||||
@ -166,6 +172,12 @@ class FlowData(BaseModel):
|
||||
edges: List[FlowEdgeData] = Field(..., description="Edges in the flow")
|
||||
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
dict_value = model_to_dict(self, exclude={"nodes"})
|
||||
dict_value["nodes"] = [n.to_dict() for n in self.nodes]
|
||||
return dict_value
|
||||
|
||||
|
||||
class _VariablesRequestBase(BaseModel):
|
||||
key: str = Field(
|
||||
@ -518,9 +530,24 @@ class FlowPanel(BaseModel):
|
||||
values["name"] = name
|
||||
return values
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override the model dump method."""
|
||||
exclude = kwargs.get("exclude", set())
|
||||
if "flow_dag" not in exclude:
|
||||
exclude.add("flow_dag")
|
||||
if "flow_data" not in exclude:
|
||||
exclude.add("flow_data")
|
||||
kwargs["exclude"] = exclude
|
||||
common_dict = super().model_dump(**kwargs)
|
||||
if self.flow_dag:
|
||||
common_dict["flow_dag"] = None
|
||||
if self.flow_data:
|
||||
common_dict["flow_data"] = self.flow_data.to_dict()
|
||||
return common_dict
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, exclude={"flow_dag"})
|
||||
return model_to_dict(self, exclude={"flow_dag", "flow_data"})
|
||||
|
||||
def get_variables_dict(self) -> List[Dict[str, Any]]:
|
||||
"""Get the variables dict."""
|
||||
@ -568,6 +595,11 @@ class FlowFactory:
|
||||
key_to_resource_nodes[key] = node
|
||||
key_to_resource[key] = node.data
|
||||
|
||||
if not key_to_operator_nodes and not key_to_resource_nodes:
|
||||
raise FlowMetadataException(
|
||||
"No operator or resource nodes found in the flow."
|
||||
)
|
||||
|
||||
for edge in flow_data.edges:
|
||||
source_key = edge.source
|
||||
target_key = edge.target
|
||||
@ -943,11 +975,17 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
new_param = input_parameters[i.name]
|
||||
i.label = new_param.label
|
||||
i.description = new_param.description
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
for i in node.data.outputs:
|
||||
if i.name in output_parameters:
|
||||
new_param = output_parameters[i.name]
|
||||
i.label = new_param.label
|
||||
i.description = new_param.description
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
else:
|
||||
data = cast(ResourceMetadata, node.data)
|
||||
key = data.get_origin_id()
|
||||
@ -972,6 +1010,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
param.options = new_param.get_dict_options() # type: ignore
|
||||
param.default = new_param.default
|
||||
param.placeholder = new_param.placeholder
|
||||
param.alias = new_param.alias
|
||||
param.ui = new_param.ui
|
||||
|
||||
except (FlowException, ValueError) as e:
|
||||
logger.warning(f"Unable to fill the flow panel: {e}")
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict, model_validator
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
from .exceptions import FlowUIComponentException
|
||||
@ -25,6 +25,16 @@ _UI_TYPE = Literal[
|
||||
"code_editor",
|
||||
]
|
||||
|
||||
_UI_SIZE_TYPE = Literal["large", "middle", "small"]
|
||||
_SIZE_ORDER = {"large": 6, "middle": 4, "small": 2}
|
||||
|
||||
|
||||
def _size_to_order(size: str) -> int:
|
||||
"""Convert size to order."""
|
||||
if size not in _SIZE_ORDER:
|
||||
return -1
|
||||
return _SIZE_ORDER[size]
|
||||
|
||||
|
||||
class RefreshableMixin(BaseModel):
|
||||
"""Refreshable mixin."""
|
||||
@ -81,6 +91,10 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel):
|
||||
)
|
||||
|
||||
ui_type: _UI_TYPE = Field(..., description="UI component type")
|
||||
size: Optional[_UI_SIZE_TYPE] = Field(
|
||||
None,
|
||||
description="The size of the component(small, middle, large)",
|
||||
)
|
||||
|
||||
attr: Optional[UIAttribute] = Field(
|
||||
None,
|
||||
@ -266,6 +280,27 @@ class UITextArea(PanelEditorMixin, UIInput):
|
||||
description="The attributes of the component",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_size(self) -> "UITextArea":
|
||||
"""Check the size.
|
||||
|
||||
Automatically set the size to large if the max_rows is greater than 10.
|
||||
"""
|
||||
attr = self.attr
|
||||
auto_size = attr.auto_size if attr else None
|
||||
if not attr or not auto_size or isinstance(auto_size, bool):
|
||||
return self
|
||||
max_rows = (
|
||||
auto_size.max_rows
|
||||
if isinstance(auto_size, self.UIAttribute.AutoSize)
|
||||
else None
|
||||
)
|
||||
size = self.size
|
||||
if not size and max_rows and max_rows > 10:
|
||||
# Automatically set the size to large if the max_rows is greater than 10
|
||||
self.size = "large"
|
||||
return self
|
||||
|
||||
|
||||
class UIAutoComplete(UIInput):
|
||||
"""Auto complete component."""
|
||||
@ -450,7 +485,7 @@ class DefaultUITextArea(UITextArea):
|
||||
|
||||
attr: Optional[UITextArea.UIAttribute] = Field(
|
||||
default_factory=lambda: UITextArea.UIAttribute(
|
||||
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40)
|
||||
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=20)
|
||||
),
|
||||
description="The attributes of the component",
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..dag.base import DAG
|
||||
from ..flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@ -965,6 +966,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
||||
_PARAMETER_MEDIA_TYPE.new(),
|
||||
_PARAMETER_STATUS_CODE.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -1203,6 +1205,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
"User input parsed operator, parse the user input from request body and "
|
||||
"return as a string"
|
||||
),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, key: str = "user_input", **kwargs):
|
||||
|
@ -195,6 +195,9 @@ class ModelRequest:
|
||||
temperature: Optional[float] = None
|
||||
"""The temperature of the model inference."""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""The top p of the model inference."""
|
||||
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
|
||||
|
@ -317,6 +317,25 @@ class ModelMessage(BaseModel):
|
||||
"""
|
||||
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
|
||||
|
||||
@staticmethod
|
||||
def parse_user_message(messages: List[ModelMessage]) -> str:
|
||||
"""Parse user message from messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The all messages in the conversation.
|
||||
|
||||
Returns:
|
||||
str: The user message
|
||||
"""
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
return lass_user_message
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
|
||||
@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
additional_kwargs=(
|
||||
ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {}
|
||||
),
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@ -246,10 +246,16 @@ class BaseLLM:
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
):
|
||||
"""Create a new LLM operator."""
|
||||
self._llm_client = llm_client
|
||||
self._save_model_output = save_model_output
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
@ -262,9 +268,10 @@ class BaseLLM:
|
||||
self, current_dag_context: DAGContext, model_output: ModelOutput
|
||||
) -> None:
|
||||
"""Save the model output to the share data."""
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
if self._save_model_output:
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
|
||||
|
||||
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new LLM operator."""
|
||||
super().__init__(llm_client=llm_client)
|
||||
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
@ -309,13 +321,18 @@ class BaseStreamingLLMOperator(
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
"""
|
||||
super().__init__(llm_client=llm_client)
|
||||
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
|
||||
BaseOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify( # type: ignore
|
||||
|
@ -4,14 +4,10 @@ from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import model_validator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core import ModelMessage, ModelOutput, StorageConversation
|
||||
from dbgpt.core.awel import JoinOperator, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@ -42,6 +38,7 @@ from dbgpt.util.i18n_utils import _
|
||||
name="common_chat_prompt_template",
|
||||
category=ResourceCategory.PROMPT,
|
||||
description=_("The operator to build the prompt with static prompt."),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label=_("System Message"),
|
||||
@ -101,9 +98,10 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
"""The base prompt builder operator."""
|
||||
|
||||
def __init__(self, check_storage: bool, **kwargs):
|
||||
def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs):
|
||||
"""Create a new prompt builder operator."""
|
||||
super().__init__(check_storage=check_storage, **kwargs)
|
||||
self._save_to_storage = save_to_storage
|
||||
|
||||
async def format_prompt(
|
||||
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||
@ -122,8 +120,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(model_messages)
|
||||
if self._save_to_storage:
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(model_messages)
|
||||
return model_messages
|
||||
|
||||
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
|
||||
@ -132,13 +131,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
"""
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
lass_user_message = ModelMessage.parse_user_message(messages)
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
] = await self.get_storage_conversation()
|
||||
@ -150,6 +143,8 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
|
||||
async def after_dag_end(self, event_loop_task_id: int):
|
||||
"""Execute after the DAG finished."""
|
||||
if not self._save_to_storage:
|
||||
return
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
@ -422,7 +417,7 @@ class HistoryPromptBuilderOperator(
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
@ -455,7 +450,7 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
"""Create a new history dynamic prompt builder operator."""
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
|
@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
if self.current_dag_context.streaming_call:
|
||||
return self.parse_model_stream_resp_ex(input_value, 0)
|
||||
else:
|
||||
return self.parse_model_nostream_resp(input_value, "###")
|
||||
return self.parse_model_nostream_resp(input_value, "#####################")
|
||||
|
||||
|
||||
def _parse_model_response(response: ResponseTye):
|
||||
@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL Output Parser"),
|
||||
name="default_sql_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_("Parse the SQL output of an LLM call."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Dict SQL Output"),
|
||||
"dict",
|
||||
dict,
|
||||
description=_("The dict output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
||||
|
||||
class SQLListOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL list output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL List Output Parser"),
|
||||
name="default_sql_list_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_(
|
||||
"Parse the SQL list output of an LLM call, mostly used for dashboard."
|
||||
),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("List SQL Output"),
|
||||
"list",
|
||||
dict,
|
||||
is_list=True,
|
||||
description=_("The list output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL list output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
json_objects = find_json_objects(model_out_text)
|
||||
json_count = len(json_objects)
|
||||
if json_count < 1:
|
||||
raise ValueError("Unable to obtain valid output.")
|
||||
|
||||
parsed_json_list = json_objects[0]
|
||||
if not isinstance(parsed_json_list, list):
|
||||
if isinstance(parsed_json_list, dict):
|
||||
return [parsed_json_list]
|
||||
else:
|
||||
raise ValueError("Invalid output format.")
|
||||
return parsed_json_list
|
||||
|
@ -254,6 +254,18 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
|
||||
def get_placeholders(self) -> List[str]:
|
||||
"""Get all placeholders in the prompt template.
|
||||
|
||||
Returns:
|
||||
List[str]: The placeholders.
|
||||
"""
|
||||
placeholders = set()
|
||||
for message in self.messages:
|
||||
if isinstance(message, MessagesPlaceholder):
|
||||
placeholders.add(message.variable_name)
|
||||
return sorted(placeholders)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
|
@ -42,13 +42,13 @@ class DefaultLLMClient(LLMClient):
|
||||
|
||||
Args:
|
||||
worker_manager (WorkerManager): worker manager instance.
|
||||
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to False.
|
||||
auto_convert_message (bool, optional): auto convert the message to ModelRequest. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_manager: Optional[WorkerManager] = None,
|
||||
auto_convert_message: bool = False,
|
||||
auto_convert_message: bool = True,
|
||||
):
|
||||
self._worker_manager = worker_manager
|
||||
self._auto_covert_message = auto_convert_message
|
||||
|
@ -24,8 +24,13 @@ class MixinLLMOperator(BaseLLM, BaseOperator, ABC):
|
||||
This class extends BaseOperator by adding LLM capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(default_client)
|
||||
def __init__(
|
||||
self,
|
||||
default_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(default_client, save_model_output=save_model_output)
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
@ -95,8 +100,13 @@ class LLMOperator(MixinLLMOperator, BaseLLMOperator):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(llm_client, save_model_output=save_model_output)
|
||||
BaseLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
@ -144,6 +154,11 @@ class StreamingLLMOperator(MixinLLMOperator, BaseStreamingLLMOperator):
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(llm_client, save_model_output=save_model_output)
|
||||
BaseStreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
@ -16,7 +16,13 @@ from typing import (
|
||||
|
||||
from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.core.awel import TransformStreamAbsOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.llm import ModelOutput
|
||||
from dbgpt.core.operators import BaseLLM
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@ -184,6 +190,7 @@ class OpenAIStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]
|
||||
),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
async def transform_stream(self, model_output: AsyncIterator[ModelOutput]):
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
@ -46,7 +47,7 @@ class DBSummaryClient:
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
def get_db_summary(self, dbname, query, topk) -> List[str]:
|
||||
"""Get user query related tables info."""
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
@ -3,14 +3,41 @@ import logging
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBParameters, RDBMSConnectorResource
|
||||
from dbgpt.agent.resource.database import (
|
||||
_DEFAULT_PROMPT_TEMPLATE,
|
||||
_DEFAULT_PROMPT_TEMPLATE_ZH,
|
||||
DBParameters,
|
||||
RDBMSConnectorResource,
|
||||
)
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
FunctionDynamicOptions,
|
||||
OptionValue,
|
||||
Parameter,
|
||||
ResourceCategory,
|
||||
register_resource,
|
||||
)
|
||||
from dbgpt.util import ParameterDescription
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
CFG = Config()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_datasource() -> List[OptionValue]:
|
||||
dbs = CFG.local_db_manager.get_db_list()
|
||||
results = [
|
||||
OptionValue(
|
||||
label="[" + db["db_type"] + "]" + db["db_name"],
|
||||
name=db["db_name"],
|
||||
value=db["db_name"],
|
||||
)
|
||||
for db in dbs
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DatasourceDBParameters(DBParameters):
|
||||
"""The DB parameters for the datasource."""
|
||||
@ -57,6 +84,44 @@ class DatasourceDBParameters(DBParameters):
|
||||
return super().from_dict(copied_data, ignore_extra_fields=ignore_extra_fields)
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Datasource Resource"),
|
||||
"datasource",
|
||||
category=ResourceCategory.DATABASE,
|
||||
description=_(
|
||||
"Connect to a datasource(retrieve table schemas and execute SQL to fetch data)."
|
||||
),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("Datasource Name"),
|
||||
"name",
|
||||
str,
|
||||
optional=True,
|
||||
default="datasource",
|
||||
description=_("The name of the datasource, default is 'datasource'."),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("DB Name"),
|
||||
"db_name",
|
||||
str,
|
||||
description=_("The name of the database."),
|
||||
options=FunctionDynamicOptions(func=_load_datasource),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Prompt Template"),
|
||||
"prompt_template",
|
||||
str,
|
||||
optional=True,
|
||||
default=(
|
||||
_DEFAULT_PROMPT_TEMPLATE_ZH
|
||||
if CFG.LANGUAGE == "zh"
|
||||
else _DEFAULT_PROMPT_TEMPLATE
|
||||
),
|
||||
description=_("The prompt template to build a database prompt."),
|
||||
),
|
||||
],
|
||||
)
|
||||
class DatasourceResource(RDBMSConnectorResource):
|
||||
def __init__(self, name: str, db_name: Optional[str] = None, **kwargs):
|
||||
conn = CFG.local_db_manager.get_connector(db_name)
|
||||
|
@ -64,6 +64,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource):
|
||||
"""Knowledge Space retriever resource."""
|
||||
|
||||
def __init__(self, name: str, space_name: str, context: Optional[dict] = None):
|
||||
# TODO: Build the retriever in a thread pool, it will block the event loop
|
||||
retriever = KnowledgeSpaceRetriever(
|
||||
space_id=space_name,
|
||||
top_k=context.get("top_k", None) if context else 4,
|
||||
|
@ -133,7 +133,10 @@ async def create(
|
||||
Returns:
|
||||
ServerResponse: The response
|
||||
"""
|
||||
return Result.succ(service.create_and_save_dag(request))
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, service.create_and_save_dag, request
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
@router.put(
|
||||
@ -154,7 +157,10 @@ async def update(
|
||||
ServerResponse: The response
|
||||
"""
|
||||
try:
|
||||
return Result.succ(service.update_flow(request))
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, service.update_flow, request
|
||||
)
|
||||
return Result.succ(res)
|
||||
except Exception as e:
|
||||
return Result.failed(msg=str(e))
|
||||
|
||||
@ -176,9 +182,7 @@ async def delete(
|
||||
|
||||
|
||||
@router.get("/flows/{uid}")
|
||||
async def get_flows(
|
||||
uid: str, service: Service = Depends(get_service)
|
||||
) -> Result[ServerResponse]:
|
||||
async def get_flows(uid: str, service: Service = Depends(get_service)):
|
||||
"""Get a Flow entity by uid
|
||||
|
||||
Args:
|
||||
@ -191,7 +195,7 @@ async def get_flows(
|
||||
flow = service.get({"uid": uid})
|
||||
if not flow:
|
||||
raise HTTPException(status_code=404, detail=f"Flow {uid} not found")
|
||||
return Result.succ(flow)
|
||||
return Result.succ(flow.model_dump())
|
||||
|
||||
|
||||
@router.get(
|
||||
@ -467,7 +471,10 @@ async def import_flow(
|
||||
status_code=400, detail=f"invalid file extension {file_extension}"
|
||||
)
|
||||
if save_flow:
|
||||
return Result.succ(service.create_and_save_dag(flow))
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app, service.create_and_save_dag, flow
|
||||
)
|
||||
return Result.succ(res)
|
||||
else:
|
||||
return Result.succ(flow)
|
||||
|
||||
|
@ -27,7 +27,7 @@ from dbgpt.core.schema.api import (
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
)
|
||||
from dbgpt.serve.core import BaseService
|
||||
from dbgpt.serve.core import BaseService, blocking_func_to_async
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata._base_dao import QUERY_SPEC
|
||||
from dbgpt.util.dbgpts.loader import DBGPTsLoader
|
||||
@ -590,7 +590,11 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata
|
||||
|
||||
dag = self._flow_factory.build(request.flow)
|
||||
dag = await blocking_func_to_async(
|
||||
self._system_app,
|
||||
self._flow_factory.build,
|
||||
request.flow,
|
||||
)
|
||||
leaf_nodes = dag.leaf_nodes
|
||||
if len(leaf_nodes) != 1:
|
||||
raise ValueError("Chat Flow just support one leaf node in dag")
|
||||
|
@ -223,7 +223,7 @@ class KnowledgeSpacePromptBuilderOperator(
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_context, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
|
Loading…
Reference in New Issue
Block a user