feat: Support intent detection (#1588)

This commit is contained in:
Fangyin Cheng
2024-05-30 18:51:57 +08:00
committed by GitHub
parent 73d175a127
commit a88af6f87d
22 changed files with 881 additions and 54 deletions

View File

@@ -18,6 +18,7 @@ from .operators.base import BaseOperator, WorkflowRunner
from .operators.common_operator import (
BranchFunc,
BranchOperator,
BranchTaskType,
InputOperator,
JoinOperator,
MapOperator,
@@ -80,6 +81,7 @@ __all__ = [
"BranchOperator",
"InputOperator",
"BranchFunc",
"BranchTaskType",
"WorkflowRunner",
"TaskState",
"is_empty_data",

View File

@@ -3,6 +3,7 @@
DAGLoader will load DAGs from dag_dirs or other sources.
Now only support load DAGs from local files.
"""
import hashlib
import logging
import os
@@ -98,7 +99,7 @@ def _load_modules_from_file(
return parse(mod_name, filepath)
def _process_modules(mods) -> List[DAG]:
def _process_modules(mods, show_log: bool = True) -> List[DAG]:
top_level_dags = (
(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)
)
@@ -106,7 +107,10 @@ def _process_modules(mods) -> List[DAG]:
for dag, mod in top_level_dags:
try:
# TODO validate dag params
logger.info(f"Found dag {dag} from mod {mod} and model file {mod.__file__}")
if show_log:
logger.info(
f"Found dag {dag} from mod {mod} and model file {mod.__file__}"
)
found_dags.append(dag)
except Exception:
msg = traceback.format_exc()

View File

@@ -6,9 +6,13 @@ from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing_extensions import Annotated
from dbgpt._private.pydantic import (
BaseModel,
ConfigDict,
Field,
WithJsonSchema,
field_validator,
model_to_dict,
model_validator,
@@ -255,9 +259,27 @@ class FlowCategory(str, Enum):
raise ValueError(f"Invalid flow category value: {value}")
_DAGModel = Annotated[
DAG,
WithJsonSchema(
{
"type": "object",
"properties": {
"task_name": {"type": "string", "description": "Dummy task name"}
},
"description": "DAG model, not used in the serialization.",
}
),
]
class FlowPanel(BaseModel):
"""Flow panel."""
model_config = ConfigDict(
arbitrary_types_allowed=True, json_encoders={DAG: lambda v: None}
)
uid: str = Field(
default_factory=lambda: str(uuid.uuid4()),
description="Flow panel uid",
@@ -277,7 +299,8 @@ class FlowPanel(BaseModel):
description="Flow category",
examples=[FlowCategory.COMMON, FlowCategory.CHAT_AGENT],
)
flow_data: FlowData = Field(..., description="Flow data")
flow_data: Optional[FlowData] = Field(None, description="Flow data")
flow_dag: Optional[_DAGModel] = Field(None, description="Flow DAG", exclude=True)
description: Optional[str] = Field(
None,
description="Flow panel description",
@@ -305,6 +328,11 @@ class FlowPanel(BaseModel):
description="Version of the flow panel",
examples=["0.1.0", "0.2.0"],
)
define_type: Optional[str] = Field(
"json",
description="Define type of the flow panel",
examples=["json", "python"],
)
editable: bool = Field(
True,
description="Whether the flow panel is editable",
@@ -344,7 +372,7 @@ class FlowPanel(BaseModel):
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self)
return model_to_dict(self, exclude={"flow_dag"})
class FlowFactory:
@@ -356,7 +384,9 @@ class FlowFactory:
def build(self, flow_panel: FlowPanel) -> DAG:
"""Build the flow."""
flow_data = flow_panel.flow_data
if not flow_panel.flow_data:
raise ValueError("Flow data is required.")
flow_data = cast(FlowData, flow_panel.flow_data)
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
key_to_resource: Dict[str, ResourceMetadata] = {}
@@ -610,7 +640,10 @@ class FlowFactory:
"""
from dbgpt.util.module_utils import import_from_string
flow_data = flow_panel.flow_data
if not flow_panel.flow_data:
return
flow_data = cast(FlowData, flow_panel.flow_data)
for node in flow_data.nodes:
if node.data.is_operator:
node_data = cast(ViewMetadata, node.data)
@@ -709,6 +742,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
Args:
flow_panel (FlowPanel): The flow panel to fill.
"""
if not flow_panel.flow_data:
return
for node in flow_panel.flow_data.nodes:
try:
parameters_map = {}

View File

@@ -1,4 +1,5 @@
"""Base classes for operators that can be executed within a workflow."""
import asyncio
import functools
from abc import ABC, ABCMeta, abstractmethod
@@ -265,7 +266,16 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
)
return out_ctx.current_task_context.task_output.output_stream
task_output = out_ctx.current_task_context.task_output
if task_output.is_stream:
return out_ctx.current_task_context.task_output.output_stream
else:
async def _gen():
yield task_output.output
return _gen()
def _blocking_call_stream(
self,

View File

@@ -1,4 +1,5 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Union
@@ -171,6 +172,8 @@ class MapOperator(BaseOperator, Generic[IN, OUT]):
BranchFunc = Union[Callable[[IN], bool], Callable[[IN], Awaitable[bool]]]
# Function that return the task name
BranchTaskType = Union[str, Callable[[IN], str], Callable[[IN], Awaitable[str]]]
class BranchOperator(BaseOperator, Generic[IN, OUT]):
@@ -187,7 +190,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
def __init__(
self,
branches: Optional[Dict[BranchFunc[IN], Union[BaseOperator, str]]] = None,
branches: Optional[Dict[BranchFunc[IN], BranchTaskType]] = None,
**kwargs,
):
"""Create a BranchDAGNode with a branching function.
@@ -208,6 +211,10 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
if not value.node_name:
raise ValueError("branch node name must be set")
branches[branch_function] = value.node_name
elif callable(value):
raise ValueError(
"BranchTaskType must be str or BaseOperator on init"
)
self._branches = branches
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
@@ -234,14 +241,31 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
branches = await self.branches()
branch_func_tasks = []
branch_nodes: List[Union[BaseOperator, str]] = []
branch_name_tasks = []
# branch_nodes: List[Union[BaseOperator, str]] = []
for func, node_name in branches.items():
branch_nodes.append(node_name)
# branch_nodes.append(node_name)
branch_func_tasks.append(
curr_task_ctx.task_input.predicate_map(func, failed_value=None)
)
if callable(node_name):
async def map_node_name(func) -> str:
input_context = await curr_task_ctx.task_input.map(func)
task_name = input_context.parent_outputs[0].task_output.output
return task_name
branch_name_tasks.append(map_node_name(node_name))
else:
async def _tmp_map_node_name(task_name: str) -> str:
return task_name
branch_name_tasks.append(_tmp_map_node_name(node_name))
branch_input_ctxs: List[InputContext] = await asyncio.gather(*branch_func_tasks)
branch_nodes: List[str] = await asyncio.gather(*branch_name_tasks)
parent_output = task_input.parent_outputs[0].task_output
curr_task_ctx.set_task_output(parent_output)
skip_node_names = []
@@ -258,7 +282,7 @@ class BranchOperator(BaseOperator, Generic[IN, OUT]):
curr_task_ctx.update_metadata("skip_node_names", skip_node_names)
return parent_output
async def branches(self) -> Dict[BranchFunc[IN], Union[BaseOperator, str]]:
async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
"""Return branch logic based on input data."""
raise NotImplementedError

View File

@@ -298,16 +298,24 @@ class ModelMessage(BaseModel):
return str_msg
@staticmethod
def messages_to_string(messages: List["ModelMessage"]) -> str:
def messages_to_string(
messages: List["ModelMessage"],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
) -> str:
"""Convert messages to str.
Args:
messages (List[ModelMessage]): The messages
human_prefix (str): The human prefix
ai_prefix (str): The ai prefix
system_prefix (str): The system prefix
Returns:
str: The str messages
"""
return _messages_to_str(messages)
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
_SingleRoundMessage = List[BaseMessage]
@@ -1211,9 +1219,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
content=ai_message.content,
index=ai_message.index,
round_index=ai_message.round_index,
additional_kwargs=ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {},
additional_kwargs=(
ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {}
),
)
current_round.append(view_message)
return sum(messages_by_round, [])

View File

@@ -6,7 +6,7 @@ import dataclasses
import json
from abc import ABC, abstractmethod
from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
@@ -19,6 +19,8 @@ from dbgpt.core.interface.storage import (
)
from dbgpt.util.formatting import formatter, no_strict_formatter
T = TypeVar("T", bound="BasePromptTemplate")
def _jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
@@ -34,9 +36,9 @@ def _jinja2_formatter(template: str, **kwargs: Any) -> str:
_DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": lambda is_strict: formatter.format
if is_strict
else no_strict_formatter.format,
"f-string": lambda is_strict: (
formatter.format if is_strict else no_strict_formatter.format
),
"jinja2": lambda is_strict: _jinja2_formatter,
}
@@ -88,8 +90,8 @@ class PromptTemplate(BasePromptTemplate):
@classmethod
def from_template(
cls, template: str, template_format: str = "f-string", **kwargs: Any
) -> BasePromptTemplate:
cls: Type[T], template: str, template_format: str = "f-string", **kwargs: Any
) -> T:
"""Create a prompt template from a template string."""
input_variables = get_template_vars(template, template_format)
return cls(
@@ -116,14 +118,14 @@ class BaseChatPromptTemplate(BaseModel, ABC):
@classmethod
def from_template(
cls,
cls: Type[T],
template: str,
template_format: str = "f-string",
response_format: Optional[str] = None,
response_key: str = "response",
template_is_strict: bool = True,
**kwargs: Any,
) -> BaseChatPromptTemplate:
) -> T:
"""Create a prompt template from a template string."""
prompt = PromptTemplate.from_template(
template,