mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 07:00:15 +00:00
docs: New AWEL tutorial (#1245)
This commit is contained in:
@@ -30,8 +30,16 @@ from .operators.stream_operator import (
|
||||
UnstreamifyAbsOperator,
|
||||
)
|
||||
from .runner.local_runner import DefaultWorkflowRunner
|
||||
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
|
||||
from .task.base import (
|
||||
InputContext,
|
||||
InputSource,
|
||||
TaskContext,
|
||||
TaskOutput,
|
||||
TaskState,
|
||||
is_empty_data,
|
||||
)
|
||||
from .task.task_impl import (
|
||||
BaseInputSource,
|
||||
DefaultInputContext,
|
||||
DefaultTaskContext,
|
||||
SimpleCallDataInputSource,
|
||||
@@ -40,6 +48,7 @@ from .task.task_impl import (
|
||||
SimpleTaskOutput,
|
||||
_is_async_iterator,
|
||||
)
|
||||
from .trigger.base import Trigger
|
||||
from .trigger.http_trigger import (
|
||||
CommonLLMHttpRequestBody,
|
||||
CommonLLMHTTPRequestContext,
|
||||
@@ -73,12 +82,14 @@ __all__ = [
|
||||
"BranchFunc",
|
||||
"WorkflowRunner",
|
||||
"TaskState",
|
||||
"is_empty_data",
|
||||
"TaskOutput",
|
||||
"TaskContext",
|
||||
"InputContext",
|
||||
"InputSource",
|
||||
"DefaultWorkflowRunner",
|
||||
"SimpleInputSource",
|
||||
"BaseInputSource",
|
||||
"SimpleCallDataInputSource",
|
||||
"DefaultTaskContext",
|
||||
"DefaultInputContext",
|
||||
@@ -87,6 +98,7 @@ __all__ = [
|
||||
"StreamifyAbsOperator",
|
||||
"UnstreamifyAbsOperator",
|
||||
"TransformStreamAbsOperator",
|
||||
"Trigger",
|
||||
"HttpTrigger",
|
||||
"CommonLLMHTTPRequestContext",
|
||||
"CommonLLMHttpResponseBody",
|
||||
@@ -136,9 +148,6 @@ def setup_dev_environment(
|
||||
Defaults to True. If True, the DAG graph will be saved to a file and open
|
||||
it automatically.
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util.utils import setup_logging
|
||||
|
||||
@@ -148,7 +157,13 @@ def setup_dev_environment(
|
||||
logger_filename = "dbgpt_awel_dev.log"
|
||||
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)
|
||||
|
||||
app = FastAPI()
|
||||
start_http = _check_has_http_trigger(dags)
|
||||
if start_http:
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
else:
|
||||
app = None
|
||||
system_app = SystemApp(app)
|
||||
DAGVar.set_current_system_app(system_app)
|
||||
trigger_manager = DefaultTriggerManager()
|
||||
@@ -169,6 +184,24 @@ def setup_dev_environment(
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger, system_app)
|
||||
trigger_manager.after_register()
|
||||
if trigger_manager.keep_running():
|
||||
if start_http and trigger_manager.keep_running() and app:
|
||||
import uvicorn
|
||||
|
||||
# Should keep running
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def _check_has_http_trigger(dags: List[DAG]) -> bool:
|
||||
"""Check whether has http trigger.
|
||||
|
||||
Args:
|
||||
dags (List[DAG]): The dags.
|
||||
|
||||
Returns:
|
||||
bool: Whether has http trigger.
|
||||
"""
|
||||
for dag in dags:
|
||||
for trigger in dag.trigger_nodes:
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
return True
|
||||
return False
|
||||
|
@@ -1,10 +0,0 @@
|
||||
"""Base classes for AWEL."""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Trigger(ABC):
|
||||
"""Base class for trigger."""
|
||||
|
||||
@abstractmethod
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
@@ -274,6 +274,8 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
|
||||
node_id = self._dag._new_node_id()
|
||||
self._node_id: Optional[str] = node_id
|
||||
self._node_name: Optional[str] = node_name
|
||||
if self._dag:
|
||||
self._dag._append_node(self)
|
||||
|
||||
@property
|
||||
def node_id(self) -> str:
|
||||
@@ -421,7 +423,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
|
||||
def __repr__(self):
|
||||
"""Return the representation of current DAGNode."""
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_name and self.node_name:
|
||||
if self.node_id and self.node_name:
|
||||
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
|
||||
if self.node_id:
|
||||
return f"{cls_name}(node_id={self.node_id})"
|
||||
@@ -430,6 +432,19 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
|
||||
else:
|
||||
return f"{cls_name}"
|
||||
|
||||
@property
|
||||
def graph_str(self):
|
||||
"""Return the graph string of current DAGNode."""
|
||||
cls_name = self.__class__.__name__
|
||||
if self.node_id and self.node_name:
|
||||
return f"{self.node_id}({cls_name},{self.node_name})"
|
||||
if self.node_id:
|
||||
return f"{self.node_id}({cls_name})"
|
||||
if self.node_name:
|
||||
return f"{self.node_name}_{cls_name}({cls_name})"
|
||||
else:
|
||||
return f"{cls_name}"
|
||||
|
||||
def __str__(self):
|
||||
"""Return the string of current DAGNode."""
|
||||
return self.__repr__()
|
||||
@@ -798,12 +813,16 @@ def _handle_dag_nodes(
|
||||
_handle_dag_nodes(is_down_to_up, level, node, func)
|
||||
|
||||
|
||||
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
def _visualize_dag(
|
||||
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
|
||||
) -> Optional[str]:
|
||||
"""Visualize the DAG.
|
||||
|
||||
Args:
|
||||
dag (DAG): The DAG to visualize
|
||||
view (bool, optional): Whether view the DAG graph. Defaults to True.
|
||||
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The filename of the DAG graph
|
||||
@@ -815,15 +834,20 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
return None
|
||||
|
||||
dot = Digraph(name=dag.dag_id)
|
||||
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
|
||||
# Record the added edges to avoid adding duplicate edges
|
||||
added_edges = set()
|
||||
|
||||
def add_edges(node: DAGNode):
|
||||
nonlocal mermaid_str
|
||||
if node.downstream:
|
||||
for downstream_node in node.downstream:
|
||||
# Check if the edge has been added
|
||||
if (str(node), str(downstream_node)) not in added_edges:
|
||||
dot.edge(str(node), str(downstream_node))
|
||||
mermaid_str += (
|
||||
f" {node.graph_str} --> {downstream_node.graph_str};\n"
|
||||
)
|
||||
added_edges.add((str(node), str(downstream_node)))
|
||||
add_edges(downstream_node)
|
||||
|
||||
@@ -839,4 +863,13 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
|
||||
|
||||
kwargs["directory"] = LOGDIR
|
||||
|
||||
# Generate Mermaid syntax file if requested
|
||||
if generate_mermaid:
|
||||
mermaid_filename = filename.replace(".gv", ".md")
|
||||
with open(
|
||||
f"{kwargs.get('directory', '')}/{mermaid_filename}", "w"
|
||||
) as mermaid_file:
|
||||
logger.info(f"Writing Mermaid syntax to {mermaid_filename}")
|
||||
mermaid_file.write(mermaid_str)
|
||||
|
||||
return dot.render(filename, view=view, **kwargs)
|
||||
|
@@ -24,7 +24,7 @@ from dbgpt.util.executor_utils import (
|
||||
)
|
||||
|
||||
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
|
||||
from ..task.base import OUT, T, TaskOutput
|
||||
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
@@ -186,7 +186,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
async def call(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
@@ -200,7 +200,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
if call_data:
|
||||
if call_data != EMPTY_DATA:
|
||||
call_data = {"data": call_data}
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, exist_dag_ctx=dag_ctx
|
||||
@@ -209,7 +209,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
def _blocking_call(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> OUT:
|
||||
"""Execute the node and return the output.
|
||||
@@ -232,7 +232,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
async def call_stream(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
dag_ctx: Optional[DAGContext] = None,
|
||||
) -> AsyncIterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
@@ -247,7 +247,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
if call_data:
|
||||
if call_data != EMPTY_DATA:
|
||||
call_data = {"data": call_data}
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
@@ -256,7 +256,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
|
||||
def _blocking_call_stream(
|
||||
self,
|
||||
call_data: Optional[CALL_DATA] = None,
|
||||
call_data: Optional[CALL_DATA] = EMPTY_DATA,
|
||||
loop: Optional[asyncio.BaseEventLoop] = None,
|
||||
) -> Iterator[OUT]:
|
||||
"""Execute the node and return the output as a stream.
|
||||
|
@@ -1,16 +1,7 @@
|
||||
"""Common operators of AWEL."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union
|
||||
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import (
|
||||
@@ -106,7 +97,7 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
|
||||
curr_task_ctx.set_task_output(reduce_output)
|
||||
return reduce_output
|
||||
|
||||
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
|
||||
async def reduce(self, a: IN, b: IN) -> OUT:
|
||||
"""Reduce the input stream to a single value."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -20,13 +20,21 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class _EMPTY_DATA_TYPE:
|
||||
"""A special type to represent empty data."""
|
||||
|
||||
def __init__(self, name: str = "EMPTY_DATA"):
|
||||
self.name = name
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return f"EmptyData({self.name})"
|
||||
|
||||
EMPTY_DATA = _EMPTY_DATA_TYPE()
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||
|
||||
EMPTY_DATA = _EMPTY_DATA_TYPE("EMPTY_DATA")
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE("SKIP_DATA")
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE("PLACEHOLDER_DATA")
|
||||
|
||||
|
||||
def is_empty_data(data: Any):
|
||||
@@ -37,7 +45,7 @@ def is_empty_data(data: Any):
|
||||
|
||||
|
||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN, IN], OUT], Callable[[IN, IN], Awaitable[OUT]]]
|
||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
|
||||
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
|
||||
@@ -341,7 +349,7 @@ class InputContext(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> "InputContext":
|
||||
"""Apply a reducing function to the inputs.
|
||||
|
||||
Args:
|
||||
|
@@ -479,7 +479,8 @@ class DefaultInputContext(InputContext):
|
||||
if apply_type == "map":
|
||||
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
|
||||
elif apply_type == "reduce":
|
||||
result = out.task_output.reduce(func)
|
||||
reduce_func = cast(ReduceFunc, func)
|
||||
result = out.task_output.reduce(reduce_func)
|
||||
elif apply_type == "check_condition":
|
||||
result = out.task_output.check_condition(func)
|
||||
else:
|
||||
@@ -541,14 +542,16 @@ class DefaultInputContext(InputContext):
|
||||
)
|
||||
return DefaultInputContext([single_output])
|
||||
|
||||
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
|
||||
async def reduce(self, reduce_func: ReduceFunc) -> InputContext:
|
||||
"""Apply a reduce function to all parent outputs."""
|
||||
if not self.check_stream():
|
||||
raise ValueError(
|
||||
"The output in all tasks must has same output format of stream to apply"
|
||||
" reduce function"
|
||||
)
|
||||
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
|
||||
new_outputs, results = await self._apply_func(
|
||||
reduce_func, apply_type="reduce" # type: ignore
|
||||
)
|
||||
for i, task_ctx in enumerate(new_outputs):
|
||||
task_ctx = cast(TaskContext, task_ctx)
|
||||
task_ctx.set_task_output(results[i])
|
||||
|
@@ -410,12 +410,20 @@ class HttpTrigger(Trigger):
|
||||
"""
|
||||
return self._register_to_app
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
def mount_to_router(
|
||||
self, router: "APIRouter", global_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
"""Mount the trigger to a router.
|
||||
|
||||
Args:
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
global_prefix (Optional[str], optional): The global prefix of the router.
|
||||
"""
|
||||
path = (
|
||||
join_paths(global_prefix, self._endpoint)
|
||||
if global_prefix
|
||||
else self._endpoint
|
||||
)
|
||||
dynamic_route_function = self._create_route_func()
|
||||
router.api_route(
|
||||
self._endpoint,
|
||||
@@ -425,6 +433,8 @@ class HttpTrigger(Trigger):
|
||||
tags=self._router_tags,
|
||||
)(dynamic_route_function)
|
||||
|
||||
logger.info(f"Mount http trigger success, path: {path}")
|
||||
|
||||
def mount_to_app(self, app: "FastAPI", global_prefix: Optional[str] = None) -> None:
|
||||
"""Mount the trigger to a FastAPI app.
|
||||
|
||||
@@ -455,6 +465,7 @@ class HttpTrigger(Trigger):
|
||||
)
|
||||
app.openapi_schema = None
|
||||
app.middleware_stack = None
|
||||
logger.info(f"Mount http trigger success, path: {path}")
|
||||
|
||||
def remove_from_app(
|
||||
self, app: "FastAPI", global_prefix: Optional[str] = None
|
||||
@@ -583,9 +594,12 @@ class HttpTrigger(Trigger):
|
||||
request_model = self._req_body
|
||||
elif get_origin(self._req_body) == dict and not is_query_method:
|
||||
request_model = self._req_body
|
||||
elif is_query_method:
|
||||
request_model = None
|
||||
else:
|
||||
err_msg = f"Unsupported request body type {self._req_body}"
|
||||
raise AWELHttpError(err_msg)
|
||||
|
||||
dynamic_route_function = create_route_function(function_name, request_model)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), "
|
||||
|
@@ -88,7 +88,7 @@ class HttpTriggerManager(TriggerManager):
|
||||
# Mount to app, support dynamic route.
|
||||
trigger.mount_to_app(app, self._router_prefix)
|
||||
else:
|
||||
trigger.mount_to_router(self._router)
|
||||
trigger.mount_to_router(self._router, self._router_prefix)
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
except Exception as e:
|
||||
self._unregister_route_tables(path, methods)
|
||||
@@ -174,12 +174,14 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
"""Initialize a DefaultTriggerManager."""
|
||||
self.system_app = system_app
|
||||
self.http_trigger = HttpTriggerManager()
|
||||
super().__init__(None)
|
||||
self._http_trigger: Optional[HttpTriggerManager] = None
|
||||
super().__init__()
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize the trigger manager."""
|
||||
self.system_app = system_app
|
||||
if system_app and self.system_app.app:
|
||||
self._http_trigger = HttpTriggerManager()
|
||||
|
||||
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
@@ -187,9 +189,9 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
logger.info(f"Register trigger {trigger}")
|
||||
self.http_trigger.register_trigger(trigger, system_app)
|
||||
# else:
|
||||
# raise ValueError(f"Unsupported trigger: {trigger}")
|
||||
if not self._http_trigger:
|
||||
raise ValueError("Http trigger manager not initialized")
|
||||
self._http_trigger.register_trigger(trigger, system_app)
|
||||
|
||||
def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None:
|
||||
"""Unregister a trigger to current manager."""
|
||||
@@ -197,14 +199,14 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
|
||||
if isinstance(trigger, HttpTrigger):
|
||||
logger.info(f"Unregister trigger {trigger}")
|
||||
self.http_trigger.unregister_trigger(trigger, system_app)
|
||||
# else:
|
||||
# raise ValueError(f"Unsupported trigger: {trigger}")
|
||||
if not self._http_trigger:
|
||||
raise ValueError("Http trigger manager not initialized")
|
||||
self._http_trigger.unregister_trigger(trigger, system_app)
|
||||
|
||||
def after_register(self) -> None:
|
||||
"""After register, init the trigger manager."""
|
||||
if self.system_app:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
if self.system_app and self._http_trigger:
|
||||
self._http_trigger._init_app(self.system_app)
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
@@ -212,4 +214,6 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return self.http_trigger.keep_running()
|
||||
if not self._http_trigger:
|
||||
return False
|
||||
return self._http_trigger.keep_running()
|
||||
|
Reference in New Issue
Block a user