mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
feat: Support intent detection (#1588)
This commit is contained in:
@@ -20,6 +20,7 @@ else:
|
||||
PositiveInt,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
WithJsonSchema,
|
||||
field_validator,
|
||||
model_validator,
|
||||
root_validator,
|
||||
|
@@ -257,13 +257,6 @@ class BaseChat(ABC):
|
||||
def stream_call_reinforce_fn(self, text):
|
||||
return text
|
||||
|
||||
async def check_iterator_end(iterator):
|
||||
try:
|
||||
await asyncio.anext(iterator)
|
||||
return False # 迭代器还有下一个元素
|
||||
except StopAsyncIteration:
|
||||
return True # 迭代器已经执行结束
|
||||
|
||||
def _get_span_metadata(self, payload: Dict) -> Dict:
|
||||
metadata = {k: v for k, v in payload.items()}
|
||||
del metadata["prompt"]
|
||||
|
@@ -18,6 +18,7 @@ from .operators.base import BaseOperator, WorkflowRunner
|
||||
from .operators.common_operator import (
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
BranchTaskType,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
@@ -80,6 +81,7 @@ __all__ = [
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
"BranchTaskType",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"is_empty_data",
|
||||
|
@@ -3,6 +3,7 @@
|
||||
DAGLoader will load DAGs from dag_dirs or other sources.
|
||||
Now only support load DAGs from local files.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@@ -98,7 +99,7 @@ def _load_modules_from_file(
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
@@ -106,7 +107,10 @@ def _process_modules(mods) -> List[DAG]:
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
if show_log:
|
||||
logger.info(
|
||||
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
|
||||
)
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
|
@@ -6,9 +6,13 @@ from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
WithJsonSchema,
|
||||
field_validator,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
@@ -255,9 +259,27 @@ class FlowCategory(str, Enum):
|
||||
raise ValueError(f"Invalid flow category value: {value}")
|
||||
|
||||
|
||||
_DAGModel = Annotated[
|
||||
DAG,
|
||||
WithJsonSchema(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_name": {"type": "string", "description": "Dummy task name"}
|
||||
},
|
||||
"description": "DAG model, not used in the serialization.",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class FlowPanel(BaseModel):
|
||||
"""Flow panel."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, json_encoders={DAG: lambda v: None}
|
||||
)
|
||||
|
||||
uid: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Flow panel uid",
|
||||
@@ -277,7 +299,8 @@ class FlowPanel(BaseModel):
|
||||
description="Flow category",
|
||||
examples=[FlowCategory.COMMON, FlowCategory.CHAT_AGENT],
|
||||
)
|
||||
flow_data: FlowData = Field(..., description="Flow data")
|
||||
flow_data: Optional[FlowData] = Field(None, description="Flow data")
|
||||
flow_dag: Optional[_DAGModel] = Field(None, description="Flow DAG", exclude=True)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="Flow panel description",
|
||||
@@ -305,6 +328,11 @@ class FlowPanel(BaseModel):
|
||||
description="Version of the flow panel",
|
||||
examples=["0.1.0", "0.2.0"],
|
||||
)
|
||||
define_type: Optional[str] = Field(
|
||||
"json",
|
||||
description="Define type of the flow panel",
|
||||
examples=["json", "python"],
|
||||
)
|
||||
editable: bool = Field(
|
||||
True,
|
||||
description="Whether the flow panel is editable",
|
||||
@@ -344,7 +372,7 @@ class FlowPanel(BaseModel):
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self)
|
||||
return model_to_dict(self, exclude={"flow_dag"})
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
@@ -356,7 +384,9 @@ class FlowFactory:
|
||||
|
||||
def build(self, flow_panel: FlowPanel) -> DAG:
|
||||
"""Build the flow."""
|
||||
flow_data = flow_panel.flow_data
|
||||
if not flow_panel.flow_data:
|
||||
raise ValueError("Flow data is required.")
|
||||
flow_data = cast(FlowData, flow_panel.flow_data)
|
||||
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource: Dict[str, ResourceMetadata] = {}
|
||||
@@ -610,7 +640,10 @@ class FlowFactory:
|
||||
"""
|
||||
from dbgpt.util.module_utils import import_from_string
|
||||
|
||||
flow_data = flow_panel.flow_data
|
||||
if not flow_panel.flow_data:
|
||||
return
|
||||
|
||||
flow_data = cast(FlowData, flow_panel.flow_data)
|
||||
for node in flow_data.nodes:
|
||||
if node.data.is_operator:
|
||||
node_data = cast(ViewMetadata, node.data)
|
||||
@@ -709,6 +742,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
Args:
|
||||
flow_panel (FlowPanel): The flow panel to fill.
|
||||
"""
|
||||
if not flow_panel.flow_data:
|
||||
return
|
||||
for node in flow_panel.flow_data.nodes:
|
||||
try:
|
||||
parameters_map = {}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Base classes for operators that can be executed within a workflow."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
@@ -265,7 +266,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
task_output = out_ctx.current_task_context.task_output
|
||||
if task_output.is_stream:
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
else:
|
||||
|
||||
async def _gen():
|
||||
yield task_output.output
|
||||
|
||||
return _gen()
|
||||
|
||||
def _blocking_call_stream(
|
||||
self,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Common operators of AWEL."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
@@ -171,6 +172,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
# Function that return the task name
|
||||
BranchTaskType = Union[str, Callable[[IN], str], Callable[[IN], Awaitable[str]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
@@ -187,7 +190,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
branches: Optional[Dict[BranchFunc[IN], BranchTaskType]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a BranchDAGNode with a branching function.
|
||||
@@ -208,6 +211,10 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not value.node_name:
|
||||
raise ValueError("branch node name must be set")
|
||||
branches[branch_function] = value.node_name
|
||||
elif callable(value):
|
||||
raise ValueError(
|
||||
"BranchTaskType must be str or BaseOperator on init"
|
||||
)
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
@@ -234,14 +241,31 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
branch_name_tasks = []
|
||||
# branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
# branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
if callable(node_name):
|
||||
|
||||
async def map_node_name(func) -> str:
|
||||
input_context = await curr_task_ctx.task_input.map(func)
|
||||
task_name = input_context.parent_outputs[0].task_output.output
|
||||
return task_name
|
||||
|
||||
branch_name_tasks.append(map_node_name(node_name))
|
||||
|
||||
else:
|
||||
|
||||
async def _tmp_map_node_name(task_name: str) -> str:
|
||||
return task_name
|
||||
|
||||
branch_name_tasks.append(_tmp_map_node_name(node_name))
|
||||
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
branch_nodes: List[str] = await asyncio.gather(*branch_name_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
skip_node_names = []
|
||||
@@ -258,7 +282,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
|
||||
"""Return branch logic based on input data."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -298,16 +298,24 @@ class ModelMessage(BaseModel):
|
||||
return str_msg
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["ModelMessage"]) -> str:
|
||||
def messages_to_string(
|
||||
messages: List["ModelMessage"],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
) -> str:
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages
|
||||
human_prefix (str): The human prefix
|
||||
ai_prefix (str): The ai prefix
|
||||
system_prefix (str): The system prefix
|
||||
|
||||
Returns:
|
||||
str: The str messages
|
||||
"""
|
||||
return _messages_to_str(messages)
|
||||
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
@@ -1211,9 +1219,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
additional_kwargs=(
|
||||
ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {}
|
||||
),
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@@ -6,7 +6,7 @@ import dataclasses
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
@@ -19,6 +19,8 @@ from dbgpt.core.interface.storage import (
|
||||
)
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
|
||||
T = TypeVar("T", bound="BasePromptTemplate")
|
||||
|
||||
|
||||
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
@@ -34,9 +36,9 @@ def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
_DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": lambda is_strict: formatter.format
|
||||
if is_strict
|
||||
else no_strict_formatter.format,
|
||||
"f-string": lambda is_strict: (
|
||||
formatter.format if is_strict else no_strict_formatter.format
|
||||
),
|
||||
"jinja2": lambda is_strict: _jinja2_formatter,
|
||||
}
|
||||
|
||||
@@ -88,8 +90,8 @@ class PromptTemplate(BasePromptTemplate):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
cls: Type[T], template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> T:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
@@ -116,14 +118,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls,
|
||||
cls: Type[T],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
response_format: Optional[str] = None,
|
||||
response_key: str = "response",
|
||||
template_is_strict: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatPromptTemplate:
|
||||
) -> T:
|
||||
"""Create a prompt template from a template string."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
template,
|
||||
|
4
dbgpt/experimental/__init__.py
Normal file
4
dbgpt/experimental/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Experimental features for DB-GPT.
|
||||
|
||||
Warning: These features are experimental and may change or be removed in the future.
|
||||
"""
|
1
dbgpt/experimental/intent/__init__.py
Normal file
1
dbgpt/experimental/intent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Intent detection module."""
|
186
dbgpt/experimental/intent/base.py
Normal file
186
dbgpt/experimental/intent/base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Base class for intent detection."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import (
|
||||
BaseOutputParser,
|
||||
LLMClient,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
PromptTemplate,
|
||||
)
|
||||
|
||||
_DEFAULT_PROMPT = """Please select the most matching intent from the intent definitions below based on the user's question,
|
||||
and return the complete intent information according to the requirements and output format.
|
||||
1. Strictly follow the given intent definition for output; do not create intents or slot attributes on your own. If an intent has no defined slots, the output should not include slots either.
|
||||
2. Extract slot attribute values from the user's input and historical dialogue information according to the intent definition. If the corresponding target information for the slot attribute cannot be obtained, the slot value should be empty.
|
||||
3. When extracting slot values, ensure to only obtain the effective value part. Do not include auxiliary descriptions or modifiers. Ensure that all slot attributes defined in the intent are output, regardless of whether values are obtained. If no values are found, output the slot name with an empty value.
|
||||
4. Ensure that if the user's question does not provide the content defined in the intent slots, the slot values must be empty. Do not fill slots with invalid information such as 'user did not provide'.
|
||||
5. If the information extracted from the user's question does not fully correspond to the matched intent slots, generate a new question to ask the user, prompting them to provide the missing slot data.
|
||||
|
||||
{response}
|
||||
|
||||
You can refer to the following examples:
|
||||
{example}
|
||||
|
||||
The known intent information is defined as follows:
|
||||
{intent_definitions}
|
||||
|
||||
Here are the known historical dialogue messages. If they are not relevant to the user's question, they can be ignored(Some times you can extract useful intent and slot information from the historical dialogue messages).
|
||||
{history}
|
||||
|
||||
User question: {user_input}
|
||||
""" # noqa
|
||||
|
||||
_DEFAULT_PROMPT_ZH = """从下面的意图定义中选择一个和用户问题最匹配的意图,并根据要求和输出格式返回意图完整信息。
|
||||
1. 严格根给出的意图定义输出,不要自行生成意图和槽位属性,意图没有定义槽位则输出也不应该包含槽位。
|
||||
2. 从用户输入和历史对话信息中提取意图定义中槽位属性的值,如果无法获取到槽位属性对应的目标信息,则槽位值输出空。
|
||||
3. 槽位值提取时请注意只获取有效值部分,不要填入辅助描述或定语确保意图定义的槽位属性不管是否获取到值,都要输出全部定义给出的槽位属性,没有找到值的输出槽位名和空值。
|
||||
4. 请确保如果用户问题中未提供意图槽位定义的内容,则槽位值必须为空,不要在槽位里填‘用户未提供’这类无效信息。
|
||||
5. 如果用户问题内容提取的信息和匹配到的意图槽位无法完全对应,则生成新的问题向用户提问,提示用户补充缺少的槽位数据。
|
||||
|
||||
{response}
|
||||
|
||||
可以参考下面的例子:
|
||||
{example}
|
||||
|
||||
已知的意图信息定义如下:
|
||||
{intent_definitions}
|
||||
|
||||
以下是已知的历史对话消息,如果和用户问题无关可以忽略(有时可以从历史对话消息中提取有用的意图和槽位信息)。
|
||||
{history}
|
||||
|
||||
用户问题:{user_input}
|
||||
""" # noqa
|
||||
|
||||
|
||||
class IntentDetectionResponse(BaseModel):
|
||||
"""Response schema for intent detection."""
|
||||
|
||||
intent: str = Field(
|
||||
...,
|
||||
description="The intent of user question.",
|
||||
)
|
||||
thought: str = Field(
|
||||
...,
|
||||
description="Logic and rationale for selecting the current application.",
|
||||
)
|
||||
task_name: str = Field(
|
||||
...,
|
||||
description="The task name of the intent.",
|
||||
)
|
||||
slots: Optional[dict] = Field(
|
||||
None,
|
||||
description="The slots of user question.",
|
||||
)
|
||||
user_input: str = Field(
|
||||
...,
|
||||
description="Instructions generated based on intent and slot.",
|
||||
)
|
||||
ask_user: Optional[str] = Field(
|
||||
None,
|
||||
description="Questions to users.",
|
||||
)
|
||||
|
||||
def has_empty_slot(self):
|
||||
"""Check if the response has empty slot."""
|
||||
if self.slots:
|
||||
for key, value in self.slots.items():
|
||||
if not value or len(value) <= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def to_response_format(cls) -> str:
|
||||
"""Get the response format."""
|
||||
schema_dict = {
|
||||
"intent": "[Intent placeholder]",
|
||||
"thought": "Your reasoning idea here.",
|
||||
"task_name": "[Task name of the intent]",
|
||||
"slots": {
|
||||
"Slot attribute 1 in the intention definition": "[Slot value 1]",
|
||||
"Slot attribute 2 in the intention definition": "[Slot value 2]",
|
||||
},
|
||||
"ask_user": "If you want the user to supplement the slot data, the problem"
|
||||
" is raised to the user, please use the same language as the user.",
|
||||
"user_input": "Complete instructions generated according to the intention "
|
||||
"and slot, please use the same language as the user.",
|
||||
}
|
||||
# How to integration the streaming json
|
||||
schema_str = json.dumps(schema_dict, indent=2, ensure_ascii=False)
|
||||
response_format = (
|
||||
f"Please output in the following JSON format: \n{schema_str}"
|
||||
f"\nMake sure the response is correct json and can be parsed by Python "
|
||||
f"json.loads."
|
||||
)
|
||||
return response_format
|
||||
|
||||
|
||||
class BaseIntentDetection(ABC):
|
||||
"""Base class for intent detection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_definitions: str,
|
||||
prompt_template: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
examples: Optional[str] = None,
|
||||
):
|
||||
"""Create a new intent detection instance."""
|
||||
self._intent_definitions = intent_definitions
|
||||
self._prompt_template = prompt_template
|
||||
self._response_format = response_format
|
||||
self._examples = examples
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def llm_client(self) -> LLMClient:
|
||||
"""Get the LLM client."""
|
||||
|
||||
@property
|
||||
def response_schema(self) -> Type[IntentDetectionResponse]:
|
||||
"""Return the response schema."""
|
||||
return IntentDetectionResponse
|
||||
|
||||
async def detect_intent(
|
||||
self,
|
||||
messages: List[ModelMessage],
|
||||
model: Optional[str] = None,
|
||||
language: str = "en",
|
||||
) -> IntentDetectionResponse:
|
||||
"""Detect intent from messages."""
|
||||
default_prompt = _DEFAULT_PROMPT if language == "en" else _DEFAULT_PROMPT_ZH
|
||||
|
||||
models = await self.llm_client.models()
|
||||
if not models:
|
||||
raise Exception("No models available.")
|
||||
model = model or models[0].model
|
||||
history_messages = ModelMessage.messages_to_string(
|
||||
messages[:-1], human_prefix="user", ai_prefix="assistant"
|
||||
)
|
||||
|
||||
prompt_template = self._prompt_template or default_prompt
|
||||
|
||||
template: PromptTemplate = PromptTemplate.from_template(prompt_template)
|
||||
response_schema = self.response_schema
|
||||
response_format = self._response_format or response_schema.to_response_format()
|
||||
formatted_message = template.format(
|
||||
response=response_format,
|
||||
example=self._examples,
|
||||
intent_definitions=self._intent_definitions,
|
||||
history=history_messages,
|
||||
user_input=messages[-1].content,
|
||||
)
|
||||
model_messages = ModelMessage.build_human_message(formatted_message)
|
||||
model_request = ModelRequest.build_request(model, messages=[model_messages])
|
||||
model_output = await self.llm_client.generate(model_request)
|
||||
output_parser = BaseOutputParser()
|
||||
str_out = output_parser.parse_model_nostream_resp(
|
||||
model_output, "#########################"
|
||||
)
|
||||
json_out = output_parser.parse_prompt_response(str_out)
|
||||
dict_out = json.loads(json_out)
|
||||
return response_schema.model_validate(dict_out)
|
92
dbgpt/experimental/intent/operators.py
Normal file
92
dbgpt/experimental/intent/operators.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Operators for intent detection."""
|
||||
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from dbgpt.core import ModelMessage, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core.awel import BranchFunc, BranchOperator, BranchTaskType, MapOperator
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from .base import BaseIntentDetection, IntentDetectionResponse
|
||||
|
||||
|
||||
class IntentDetectionOperator(
|
||||
MixinLLMOperator, BaseIntentDetection, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
"""The intent detection operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_definitions: str,
|
||||
prompt_template: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
examples: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create the intent detection operator."""
|
||||
MixinLLMOperator.__init__(self)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
BaseIntentDetection.__init__(
|
||||
self,
|
||||
intent_definitions=intent_definitions,
|
||||
prompt_template=prompt_template,
|
||||
response_format=response_format,
|
||||
examples=examples,
|
||||
)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Detect the intent.
|
||||
|
||||
Merge the intent detection result into the context.
|
||||
"""
|
||||
language = "en"
|
||||
if self.system_app:
|
||||
language = self.system_app.config.get_current_lang()
|
||||
messages = self.parse_messages(input_value)
|
||||
ic = await self.detect_intent(
|
||||
messages,
|
||||
input_value.model,
|
||||
language=language,
|
||||
)
|
||||
if not input_value.context:
|
||||
input_value.context = ModelRequestContext()
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
input_value.context.extra["intent_detection"] = ic
|
||||
return input_value
|
||||
|
||||
def parse_messages(self, request: ModelRequest) -> List[ModelMessage]:
|
||||
"""Parse the messages from the request."""
|
||||
return request.get_messages()
|
||||
|
||||
|
||||
class IntentDetectionBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
"""The intent detection branch operator."""
|
||||
|
||||
def __init__(self, end_task_name: str, **kwargs):
|
||||
"""Create the intent detection branch operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._end_task_name = end_task_name
|
||||
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], BranchTaskType]:
|
||||
"""Branch the intent detection result to different tasks."""
|
||||
download_task_names = set(task.node_name for task in self.downstream) # noqa
|
||||
branch_func_map = {}
|
||||
for task_name in download_task_names:
|
||||
|
||||
def check(r: ModelRequest, outer_task_name=task_name):
|
||||
if not r.context or not r.context.extra:
|
||||
return False
|
||||
ic_result = r.context.extra.get("intent_detection")
|
||||
if not ic_result:
|
||||
return False
|
||||
ic: IntentDetectionResponse = cast(IntentDetectionResponse, ic_result)
|
||||
if ic.has_empty_slot():
|
||||
return self._end_task_name == outer_task_name
|
||||
else:
|
||||
return outer_task_name == ic.task_name
|
||||
|
||||
branch_func_map[check] = task_name
|
||||
|
||||
return branch_func_map # type: ignore
|
@@ -1,6 +1,7 @@
|
||||
"""This is an auto-generated model file
|
||||
You can define your own models and DAOs here
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Union
|
||||
@@ -33,6 +34,12 @@ class ServeEntity(Model):
|
||||
source = Column(String(64), nullable=True, comment="Flow source")
|
||||
source_url = Column(String(512), nullable=True, comment="Flow source url")
|
||||
version = Column(String(32), nullable=True, comment="Flow version")
|
||||
define_type = Column(
|
||||
String(32),
|
||||
default="json",
|
||||
nullable=True,
|
||||
comment="Flow define type(json or python)",
|
||||
)
|
||||
editable = Column(
|
||||
Integer, nullable=True, comment="Editable, 0: editable, 1: not editable"
|
||||
)
|
||||
@@ -103,6 +110,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"source": request_dict.get("source"),
|
||||
"source_url": request_dict.get("source_url"),
|
||||
"version": request_dict.get("version"),
|
||||
"define_type": request_dict.get("define_type"),
|
||||
"editable": ServeEntity.parse_editable(request_dict.get("editable")),
|
||||
"description": request_dict.get("description"),
|
||||
"user_name": request_dict.get("user_name"),
|
||||
@@ -133,6 +141,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
source=entity.source,
|
||||
source_url=entity.source_url,
|
||||
version=entity.version,
|
||||
define_type=entity.define_type,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
description=entity.description,
|
||||
user_name=entity.user_name,
|
||||
@@ -165,6 +174,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
source_url=entity.source_url,
|
||||
version=entity.version,
|
||||
editable=ServeEntity.to_bool_editable(entity.editable),
|
||||
define_type=entity.define_type,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created_str,
|
||||
@@ -203,6 +213,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
if update_request.version:
|
||||
entry.version = update_request.version
|
||||
entry.editable = ServeEntity.parse_editable(update_request.editable)
|
||||
if update_request.define_type:
|
||||
entry.define_type = update_request.define_type
|
||||
if update_request.user_name:
|
||||
entry.user_name = update_request.user_name
|
||||
if update_request.sys_code:
|
||||
|
@@ -138,7 +138,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
"""
|
||||
try:
|
||||
# Build DAG from request
|
||||
dag = self._flow_factory.build(request)
|
||||
if request.define_type == "json":
|
||||
dag = self._flow_factory.build(request)
|
||||
else:
|
||||
dag = request.flow_dag
|
||||
request.dag_id = dag.dag_id
|
||||
# Save DAG to storage
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
@@ -149,7 +152,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
request.dag_id = ""
|
||||
return self.dao.create(request)
|
||||
else:
|
||||
raise e
|
||||
raise ValueError(
|
||||
f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}"
|
||||
) from e
|
||||
res = self.dao.create(request)
|
||||
|
||||
state = request.state
|
||||
@@ -193,6 +198,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
entities = self.dao.get_list({})
|
||||
for entity in entities:
|
||||
try:
|
||||
if entity.define_type != "json":
|
||||
continue
|
||||
dag = self._flow_factory.build(entity)
|
||||
if entity.state in [State.DEPLOYED, State.RUNNING] or (
|
||||
entity.version == "0.1.0" and entity.state == State.INITIALIZING
|
||||
@@ -213,7 +220,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
try:
|
||||
self._flow_factory.pre_load_requirements(flow)
|
||||
if flow.define_type == "json":
|
||||
self._flow_factory.pre_load_requirements(flow)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Pre load requirements for DAG({flow.name}) from "
|
||||
@@ -225,6 +233,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
flows = self.dbgpts_loader.get_flows()
|
||||
for flow in flows:
|
||||
try:
|
||||
if flow.define_type == "python" and flow.flow_dag is None:
|
||||
continue
|
||||
# Set state to DEPLOYED
|
||||
flow.state = State.DEPLOYED
|
||||
exist_inst = self.get({"name": flow.name})
|
||||
@@ -260,7 +270,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
new_state = request.state
|
||||
try:
|
||||
# Try to build the dag from the request
|
||||
dag = self._flow_factory.build(request)
|
||||
if request.define_type == "json":
|
||||
dag = self._flow_factory.build(request)
|
||||
else:
|
||||
dag = request.flow_dag
|
||||
request.flow_category = self._parse_flow_category(dag)
|
||||
except Exception as e:
|
||||
if save_failed_flow:
|
||||
@@ -295,6 +308,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Flow detail {request.uid} not found"
|
||||
)
|
||||
update_obj.flow_dag = request.flow_dag
|
||||
return self.create_and_save_dag(update_obj)
|
||||
except Exception as e:
|
||||
if old_data and old_data.state == State.RUNNING:
|
||||
|
@@ -10,6 +10,7 @@ import tomlkit
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
from dbgpt.util.dbgpts.base import (
|
||||
DBGPTS_METADATA_FILE,
|
||||
@@ -77,7 +78,7 @@ class BasePackage(BaseModel):
|
||||
values: Dict[str, Any],
|
||||
expected_cls: Type[T],
|
||||
predicates: Optional[List[Callable[..., bool]]] = None,
|
||||
) -> Tuple[List[Type[T]], List[Any]]:
|
||||
) -> Tuple[List[Type[T]], List[Any], List[Any]]:
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
from dbgpt.core.awel.dag.loader import _load_modules_from_file
|
||||
@@ -101,7 +102,7 @@ class BasePackage(BaseModel):
|
||||
for c in list_cls:
|
||||
if issubclass(c, expected_cls):
|
||||
module_cls.append(c)
|
||||
return module_cls, all_predicate_results
|
||||
return module_cls, all_predicate_results, mods
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
@@ -113,6 +114,24 @@ class FlowPackage(BasePackage):
|
||||
) -> "FlowPackage":
|
||||
if values["definition_type"] == "json":
|
||||
return FlowJsonPackage.build_from(values, ext_dict)
|
||||
return FlowPythonPackage.build_from(values, ext_dict)
|
||||
|
||||
|
||||
class FlowPythonPackage(FlowPackage):
|
||||
dag: DAG = Field(..., description="The DAG of the package")
|
||||
|
||||
@classmethod
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.core.awel.dag.loader import _process_modules
|
||||
|
||||
_, _, mods = cls.load_module_class(values, DAG)
|
||||
|
||||
dags = _process_modules(mods, show_log=False)
|
||||
if not dags:
|
||||
raise ValueError("No DAGs found in the package")
|
||||
if len(dags) > 1:
|
||||
raise ValueError("Only support one DAG in the package")
|
||||
values["dag"] = dags[0]
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@@ -144,7 +163,7 @@ class OperatorPackage(BasePackage):
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.core.awel import BaseOperator
|
||||
|
||||
values["operators"], _ = cls.load_module_class(values, BaseOperator)
|
||||
values["operators"], _, _ = cls.load_module_class(values, BaseOperator)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@@ -159,7 +178,7 @@ class AgentPackage(BasePackage):
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
from dbgpt.agent import ConversableAgent
|
||||
|
||||
values["agents"], _ = cls.load_module_class(values, ConversableAgent)
|
||||
values["agents"], _, _ = cls.load_module_class(values, ConversableAgent)
|
||||
return cls(**values)
|
||||
|
||||
|
||||
@@ -190,7 +209,7 @@ class ResourcePackage(BasePackage):
|
||||
else:
|
||||
return False
|
||||
|
||||
_, predicted_cls = cls.load_module_class(values, Resource, [_predicate])
|
||||
_, predicted_cls, _ = cls.load_module_class(values, Resource, [_predicate])
|
||||
resource_instances = []
|
||||
resources = []
|
||||
for o in predicted_cls:
|
||||
@@ -353,7 +372,7 @@ class DBGPTsLoader(BaseComponent):
|
||||
for package in self._packages.values():
|
||||
if package.package_type != "flow":
|
||||
continue
|
||||
package = cast(FlowJsonPackage, package)
|
||||
package = cast(FlowPackage, package)
|
||||
dict_value = {
|
||||
"name": package.name,
|
||||
"label": package.label,
|
||||
@@ -361,8 +380,24 @@ class DBGPTsLoader(BaseComponent):
|
||||
"editable": False,
|
||||
"description": package.description,
|
||||
"source": package.repo,
|
||||
"flow_data": package.read_definition_json(),
|
||||
"define_type": "json",
|
||||
}
|
||||
if isinstance(package, FlowJsonPackage):
|
||||
dict_value["flow_data"] = package.read_definition_json()
|
||||
elif isinstance(package, FlowPythonPackage):
|
||||
dict_value["flow_data"] = {
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
"viewport": {
|
||||
"x": 213,
|
||||
"y": 269,
|
||||
"zoom": 0,
|
||||
},
|
||||
}
|
||||
dict_value["flow_dag"] = package.dag
|
||||
dict_value["define_type"] = "python"
|
||||
else:
|
||||
raise ValueError(f"Unsupported package type: {package}")
|
||||
panels.append(FlowPanel(**dict_value))
|
||||
return panels
|
||||
|
||||
|
@@ -97,9 +97,7 @@ def _create_flow_template(
|
||||
if definition_type == "json":
|
||||
_write_flow_define_json_file(working_directory, name, mod_name)
|
||||
else:
|
||||
raise click.ClickException(
|
||||
f"Unsupported definition type: {definition_type} for dbgpts type: {dbgpts_type}"
|
||||
)
|
||||
_write_flow_define_python_file(working_directory, name, mod_name)
|
||||
|
||||
|
||||
def _create_operator_template(
|
||||
@@ -222,6 +220,16 @@ def _write_flow_define_json_file(working_directory: str, name: str, mod_name: st
|
||||
print("Please write your flow json to the file: ", def_file)
|
||||
|
||||
|
||||
def _write_flow_define_python_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the flow define python file"""
|
||||
|
||||
init_file = Path(working_directory) / name / mod_name / "__init__.py"
|
||||
content = ""
|
||||
|
||||
with open(init_file, "w") as f:
|
||||
f.write(f'"""{name} flow package"""\n{content}')
|
||||
|
||||
|
||||
def _write_operator_init_file(working_directory: str, name: str, mod_name: str):
|
||||
"""Write the operator __init__.py file"""
|
||||
|
||||
|
Reference in New Issue
Block a user