mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
feat: Support intent detection (#1588)
This commit is contained in:
@@ -18,6 +18,7 @@ from .operators.base import BaseOperator, WorkflowRunner
|
||||
from .operators.common_operator import (
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
BranchTaskType,
|
||||
InputOperator,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
@@ -80,6 +81,7 @@ __all__ = [
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
"BranchFunc",
|
||||
"BranchTaskType",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"is_empty_data",
|
||||
|
@@ -3,6 +3,7 @@
|
||||
DAGLoader will load DAGs from dag_dirs or other sources.
|
||||
Now only support load DAGs from local files.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@@ -98,7 +99,7 @@ def _load_modules_from_file(
|
||||
return parse(mod_name, filepath)
|
||||
|
||||
|
||||
def _process_modules(mods) -> List[DAG]:
|
||||
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
|
||||
top_level_dags = (
|
||||
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
|
||||
)
|
||||
@@ -106,7 +107,10 @@ def _process_modules(mods) -> List[DAG]:
|
||||
for dag, mod in top_level_dags:
|
||||
try:
|
||||
# TODO validate dag params
|
||||
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
|
||||
if show_log:
|
||||
logger.info(
|
||||
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
|
||||
)
|
||||
found_dags.append(dag)
|
||||
except Exception:
|
||||
msg = traceback.format_exc()
|
||||
|
@@ -6,9 +6,13 @@ from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
WithJsonSchema,
|
||||
field_validator,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
@@ -255,9 +259,27 @@ class FlowCategory(str, Enum):
|
||||
raise ValueError(f"Invalid flow category value: {value}")
|
||||
|
||||
|
||||
_DAGModel = Annotated[
|
||||
DAG,
|
||||
WithJsonSchema(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_name": {"type": "string", "description": "Dummy task name"}
|
||||
},
|
||||
"description": "DAG model, not used in the serialization.",
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class FlowPanel(BaseModel):
|
||||
"""Flow panel."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, json_encoders={DAG: lambda v: None}
|
||||
)
|
||||
|
||||
uid: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Flow panel uid",
|
||||
@@ -277,7 +299,8 @@ class FlowPanel(BaseModel):
|
||||
description="Flow category",
|
||||
examples=[FlowCategory.COMMON, FlowCategory.CHAT_AGENT],
|
||||
)
|
||||
flow_data: FlowData = Field(..., description="Flow data")
|
||||
flow_data: Optional[FlowData] = Field(None, description="Flow data")
|
||||
flow_dag: Optional[_DAGModel] = Field(None, description="Flow DAG", exclude=True)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="Flow panel description",
|
||||
@@ -305,6 +328,11 @@ class FlowPanel(BaseModel):
|
||||
description="Version of the flow panel",
|
||||
examples=["0.1.0", "0.2.0"],
|
||||
)
|
||||
define_type: Optional[str] = Field(
|
||||
"json",
|
||||
description="Define type of the flow panel",
|
||||
examples=["json", "python"],
|
||||
)
|
||||
editable: bool = Field(
|
||||
True,
|
||||
description="Whether the flow panel is editable",
|
||||
@@ -344,7 +372,7 @@ class FlowPanel(BaseModel):
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self)
|
||||
return model_to_dict(self, exclude={"flow_dag"})
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
@@ -356,7 +384,9 @@ class FlowFactory:
|
||||
|
||||
def build(self, flow_panel: FlowPanel) -> DAG:
|
||||
"""Build the flow."""
|
||||
flow_data = flow_panel.flow_data
|
||||
if not flow_panel.flow_data:
|
||||
raise ValueError("Flow data is required.")
|
||||
flow_data = cast(FlowData, flow_panel.flow_data)
|
||||
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource: Dict[str, ResourceMetadata] = {}
|
||||
@@ -610,7 +640,10 @@ class FlowFactory:
|
||||
"""
|
||||
from dbgpt.util.module_utils import import_from_string
|
||||
|
||||
flow_data = flow_panel.flow_data
|
||||
if not flow_panel.flow_data:
|
||||
return
|
||||
|
||||
flow_data = cast(FlowData, flow_panel.flow_data)
|
||||
for node in flow_data.nodes:
|
||||
if node.data.is_operator:
|
||||
node_data = cast(ViewMetadata, node.data)
|
||||
@@ -709,6 +742,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
Args:
|
||||
flow_panel (FlowPanel): The flow panel to fill.
|
||||
"""
|
||||
if not flow_panel.flow_data:
|
||||
return
|
||||
for node in flow_panel.flow_data.nodes:
|
||||
try:
|
||||
parameters_map = {}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Base classes for operators that can be executed within a workflow."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
@@ -265,7 +266,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
|
||||
task_output = out_ctx.current_task_context.task_output
|
||||
if task_output.is_stream:
|
||||
return out_ctx.current_task_context.task_output.output_stream
|
||||
else:
|
||||
|
||||
async def _gen():
|
||||
yield task_output.output
|
||||
|
||||
return _gen()
|
||||
|
||||
def _blocking_call_stream(
|
||||
self,
|
||||
|
@@ -1,4 +1,5 @@
|
||||
"""Common operators of AWEL."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
@@ -171,6 +172,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
|
||||
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
|
||||
# Function that return the task name
|
||||
BranchTaskType = Union[str, Callable[[IN], str], Callable[[IN], Awaitable[str]]]
|
||||
|
||||
|
||||
class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
@@ -187,7 +190,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
|
||||
branches: Optional[Dict[BranchFunc[IN], BranchTaskType]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a BranchDAGNode with a branching function.
|
||||
@@ -208,6 +211,10 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
if not value.node_name:
|
||||
raise ValueError("branch node name must be set")
|
||||
branches[branch_function] = value.node_name
|
||||
elif callable(value):
|
||||
raise ValueError(
|
||||
"BranchTaskType must be str or BaseOperator on init"
|
||||
)
|
||||
self._branches = branches
|
||||
|
||||
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
|
||||
@@ -234,14 +241,31 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
branches = await self.branches()
|
||||
|
||||
branch_func_tasks = []
|
||||
branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
branch_name_tasks = []
|
||||
# branch_nodes: List[Union[BaseOperator, str]] = []
|
||||
for func, node_name in branches.items():
|
||||
branch_nodes.append(node_name)
|
||||
# branch_nodes.append(node_name)
|
||||
branch_func_tasks.append(
|
||||
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
|
||||
)
|
||||
if callable(node_name):
|
||||
|
||||
async def map_node_name(func) -> str:
|
||||
input_context = await curr_task_ctx.task_input.map(func)
|
||||
task_name = input_context.parent_outputs[0].task_output.output
|
||||
return task_name
|
||||
|
||||
branch_name_tasks.append(map_node_name(node_name))
|
||||
|
||||
else:
|
||||
|
||||
async def _tmp_map_node_name(task_name: str) -> str:
|
||||
return task_name
|
||||
|
||||
branch_name_tasks.append(_tmp_map_node_name(node_name))
|
||||
|
||||
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
|
||||
branch_nodes: List[str] = await asyncio.gather(*branch_name_tasks)
|
||||
parent_output = task_input.parent_outputs[0].task_output
|
||||
curr_task_ctx.set_task_output(parent_output)
|
||||
skip_node_names = []
|
||||
@@ -258,7 +282,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
|
||||
return parent_output
|
||||
|
||||
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
|
||||
async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
|
||||
"""Return branch logic based on input data."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -298,16 +298,24 @@ class ModelMessage(BaseModel):
|
||||
return str_msg
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["ModelMessage"]) -> str:
|
||||
def messages_to_string(
|
||||
messages: List["ModelMessage"],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
) -> str:
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages
|
||||
human_prefix (str): The human prefix
|
||||
ai_prefix (str): The ai prefix
|
||||
system_prefix (str): The system prefix
|
||||
|
||||
Returns:
|
||||
str: The str messages
|
||||
"""
|
||||
return _messages_to_str(messages)
|
||||
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
@@ -1211,9 +1219,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
additional_kwargs=(
|
||||
ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {}
|
||||
),
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@@ -6,7 +6,7 @@ import dataclasses
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
|
||||
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
|
||||
@@ -19,6 +19,8 @@ from dbgpt.core.interface.storage import (
|
||||
)
|
||||
from dbgpt.util.formatting import formatter, no_strict_formatter
|
||||
|
||||
T = TypeVar("T", bound="BasePromptTemplate")
|
||||
|
||||
|
||||
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
@@ -34,9 +36,9 @@ def _jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
_DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": lambda is_strict: formatter.format
|
||||
if is_strict
|
||||
else no_strict_formatter.format,
|
||||
"f-string": lambda is_strict: (
|
||||
formatter.format if is_strict else no_strict_formatter.format
|
||||
),
|
||||
"jinja2": lambda is_strict: _jinja2_formatter,
|
||||
}
|
||||
|
||||
@@ -88,8 +90,8 @@ class PromptTemplate(BasePromptTemplate):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> BasePromptTemplate:
|
||||
cls: Type[T], template: str, template_format: str = "f-string", **kwargs: Any
|
||||
) -> T:
|
||||
"""Create a prompt template from a template string."""
|
||||
input_variables = get_template_vars(template, template_format)
|
||||
return cls(
|
||||
@@ -116,14 +118,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls,
|
||||
cls: Type[T],
|
||||
template: str,
|
||||
template_format: str = "f-string",
|
||||
response_format: Optional[str] = None,
|
||||
response_key: str = "response",
|
||||
template_is_strict: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatPromptTemplate:
|
||||
) -> T:
|
||||
"""Create a prompt template from a template string."""
|
||||
prompt = PromptTemplate.from_template(
|
||||
template,
|
||||
|
Reference in New Issue
Block a user