docs: New AWEL tutorial (#1245)

This commit is contained in:
Fangyin Cheng
2024-03-04 17:06:42 +08:00
committed by GitHub
parent 7a38edcaed
commit 3c93fe589a
42 changed files with 15325 additions and 8779 deletions

View File

@@ -30,8 +30,16 @@ from .operators.stream_operator import (
UnstreamifyAbsOperator,
)
from .runner.local_runner import DefaultWorkflowRunner
from .task.base import InputContext, InputSource, TaskContext, TaskOutput, TaskState
from .task.base import (
InputContext,
InputSource,
TaskContext,
TaskOutput,
TaskState,
is_empty_data,
)
from .task.task_impl import (
BaseInputSource,
DefaultInputContext,
DefaultTaskContext,
SimpleCallDataInputSource,
@@ -40,6 +48,7 @@ from .task.task_impl import (
SimpleTaskOutput,
_is_async_iterator,
)
from .trigger.base import Trigger
from .trigger.http_trigger import (
CommonLLMHttpRequestBody,
CommonLLMHTTPRequestContext,
@@ -73,12 +82,14 @@ __all__ = [
"BranchFunc",
"WorkflowRunner",
"TaskState",
"is_empty_data",
"TaskOutput",
"TaskContext",
"InputContext",
"InputSource",
"DefaultWorkflowRunner",
"SimpleInputSource",
"BaseInputSource",
"SimpleCallDataInputSource",
"DefaultTaskContext",
"DefaultInputContext",
@@ -87,6 +98,7 @@ __all__ = [
"StreamifyAbsOperator",
"UnstreamifyAbsOperator",
"TransformStreamAbsOperator",
"Trigger",
"HttpTrigger",
"CommonLLMHTTPRequestContext",
"CommonLLMHttpResponseBody",
@@ -136,9 +148,6 @@ def setup_dev_environment(
Defaults to True. If True, the DAG graph will be saved to a file and open
it automatically.
"""
import uvicorn
from fastapi import FastAPI
from dbgpt.component import SystemApp
from dbgpt.util.utils import setup_logging
@@ -148,7 +157,13 @@ def setup_dev_environment(
logger_filename = "dbgpt_awel_dev.log"
setup_logging("dbgpt", logging_level=logging_level, logger_filename=logger_filename)
app = FastAPI()
start_http = _check_has_http_trigger(dags)
if start_http:
from fastapi import FastAPI
app = FastAPI()
else:
app = None
system_app = SystemApp(app)
DAGVar.set_current_system_app(system_app)
trigger_manager = DefaultTriggerManager()
@@ -169,6 +184,24 @@ def setup_dev_environment(
for trigger in dag.trigger_nodes:
trigger_manager.register_trigger(trigger, system_app)
trigger_manager.after_register()
if trigger_manager.keep_running():
if start_http and trigger_manager.keep_running() and app:
import uvicorn
# Should keep running
uvicorn.run(app, host=host, port=port)
def _check_has_http_trigger(dags: List[DAG]) -> bool:
"""Check whether has http trigger.
Args:
dags (List[DAG]): The dags.
Returns:
bool: Whether has http trigger.
"""
for dag in dags:
for trigger in dag.trigger_nodes:
if isinstance(trigger, HttpTrigger):
return True
return False

View File

@@ -1,10 +0,0 @@
"""Base classes for AWEL."""
from abc import ABC, abstractmethod
class Trigger(ABC):
"""Base class for trigger."""
@abstractmethod
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@@ -274,6 +274,8 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
node_id = self._dag._new_node_id()
self._node_id: Optional[str] = node_id
self._node_name: Optional[str] = node_name
if self._dag:
self._dag._append_node(self)
@property
def node_id(self) -> str:
@@ -421,7 +423,7 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
def __repr__(self):
"""Return the representation of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_name and self.node_name:
if self.node_id and self.node_name:
return f"{cls_name}(node_id={self.node_id}, node_name={self.node_name})"
if self.node_id:
return f"{cls_name}(node_id={self.node_id})"
@@ -430,6 +432,19 @@ class DAGNode(DAGLifecycle, DependencyMixin, ViewMixin, ABC):
else:
return f"{cls_name}"
@property
def graph_str(self):
"""Return the graph string of current DAGNode."""
cls_name = self.__class__.__name__
if self.node_id and self.node_name:
return f"{self.node_id}({cls_name},{self.node_name})"
if self.node_id:
return f"{self.node_id}({cls_name})"
if self.node_name:
return f"{self.node_name}_{cls_name}({cls_name})"
else:
return f"{cls_name}"
def __str__(self):
"""Return the string of current DAGNode."""
return self.__repr__()
@@ -798,12 +813,16 @@ def _handle_dag_nodes(
_handle_dag_nodes(is_down_to_up, level, node, func)
def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
def _visualize_dag(
dag: DAG, view: bool = True, generate_mermaid: bool = True, **kwargs
) -> Optional[str]:
"""Visualize the DAG.
Args:
dag (DAG): The DAG to visualize
view (bool, optional): Whether view the DAG graph. Defaults to True.
generate_mermaid (bool, optional): Whether to generate a Mermaid syntax file.
Defaults to True.
Returns:
Optional[str]: The filename of the DAG graph
@@ -815,15 +834,20 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
return None
dot = Digraph(name=dag.dag_id)
mermaid_str = "graph TD;\n" # Initialize Mermaid graph definition
# Record the added edges to avoid adding duplicate edges
added_edges = set()
def add_edges(node: DAGNode):
nonlocal mermaid_str
if node.downstream:
for downstream_node in node.downstream:
# Check if the edge has been added
if (str(node), str(downstream_node)) not in added_edges:
dot.edge(str(node), str(downstream_node))
mermaid_str += (
f" {node.graph_str} --> {downstream_node.graph_str};\n"
)
added_edges.add((str(node), str(downstream_node)))
add_edges(downstream_node)
@@ -839,4 +863,13 @@ def _visualize_dag(dag: DAG, view: bool = True, **kwargs) -> Optional[str]:
kwargs["directory"] = LOGDIR
# Generate Mermaid syntax file if requested
if generate_mermaid:
mermaid_filename = filename.replace(".gv", ".md")
with open(
f"{kwargs.get('directory', '')}/{mermaid_filename}", "w"
) as mermaid_file:
logger.info(f"Writing Mermaid syntax to {mermaid_filename}")
mermaid_file.write(mermaid_str)
return dot.render(filename, view=view, **kwargs)

View File

@@ -24,7 +24,7 @@ from dbgpt.util.executor_utils import (
)
from ..dag.base import DAG, DAGContext, DAGNode, DAGVar
from ..task.base import OUT, T, TaskOutput
from ..task.base import EMPTY_DATA, OUT, T, TaskOutput
F = TypeVar("F", bound=FunctionType)
@@ -186,7 +186,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
async def call(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
) -> OUT:
"""Execute the node and return the output.
@@ -200,7 +200,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
OUT: The output of the node after execution.
"""
if call_data:
if call_data != EMPTY_DATA:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, exist_dag_ctx=dag_ctx
@@ -209,7 +209,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def _blocking_call(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> OUT:
"""Execute the node and return the output.
@@ -232,7 +232,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
async def call_stream(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
dag_ctx: Optional[DAGContext] = None,
) -> AsyncIterator[OUT]:
"""Execute the node and return the output as a stream.
@@ -247,7 +247,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
Returns:
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
"""
if call_data:
if call_data != EMPTY_DATA:
call_data = {"data": call_data}
out_ctx = await self._runner.execute_workflow(
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
@@ -256,7 +256,7 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
def _blocking_call_stream(
self,
call_data: Optional[CALL_DATA] = None,
call_data: Optional[CALL_DATA] = EMPTY_DATA,
loop: Optional[asyncio.BaseEventLoop] = None,
) -> Iterator[OUT]:
"""Execute the node and return the output as a stream.

View File

@@ -1,16 +1,7 @@
"""Common operators of AWEL."""
import asyncio
import logging
from typing import (
AsyncIterator,
Awaitable,
Callable,
Dict,
Generic,
List,
Optional,
Union,
)
from typing import Awaitable, Callable, Dict, Generic, List, Optional, Union
from ..dag.base import DAGContext
from ..task.base import (
@@ -106,7 +97,7 @@ class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
curr_task_ctx.set_task_output(reduce_output)
return reduce_output
async def reduce(self, input_value: AsyncIterator[IN]) -> OUT:
async def reduce(self, a: IN, b: IN) -> OUT:
"""Reduce the input stream to a single value."""
raise NotImplementedError

View File

@@ -20,13 +20,21 @@ T = TypeVar("T")
class _EMPTY_DATA_TYPE:
"""A special type to represent empty data."""
def __init__(self, name: str = "EMPTY_DATA"):
self.name = name
def __bool__(self):
return False
def __str__(self):
return f"EmptyData({self.name})"
EMPTY_DATA = _EMPTY_DATA_TYPE()
SKIP_DATA = _EMPTY_DATA_TYPE()
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
EMPTY_DATA = _EMPTY_DATA_TYPE("EMPTY_DATA")
SKIP_DATA = _EMPTY_DATA_TYPE("SKIP_DATA")
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE("PLACEHOLDER_DATA")
def is_empty_data(data: Any):
@@ -37,7 +45,7 @@ def is_empty_data(data: Any):
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
ReduceFunc = Union[Callable[[IN, IN], OUT], Callable[[IN, IN], Awaitable[OUT]]]
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
UnStreamFunc = Callable[[AsyncIterator[IN]], OUT]
TransformFunc = Callable[[AsyncIterator[IN]], Awaitable[AsyncIterator[OUT]]]
@@ -341,7 +349,7 @@ class InputContext(ABC):
"""
@abstractmethod
async def reduce(self, reduce_func: Callable[[Any], Any]) -> "InputContext":
async def reduce(self, reduce_func: ReduceFunc) -> "InputContext":
"""Apply a reducing function to the inputs.
Args:

View File

@@ -479,7 +479,8 @@ class DefaultInputContext(InputContext):
if apply_type == "map":
result: Coroutine[Any, Any, TaskOutput[Any]] = out.task_output.map(func)
elif apply_type == "reduce":
result = out.task_output.reduce(func)
reduce_func = cast(ReduceFunc, func)
result = out.task_output.reduce(reduce_func)
elif apply_type == "check_condition":
result = out.task_output.check_condition(func)
else:
@@ -541,14 +542,16 @@ class DefaultInputContext(InputContext):
)
return DefaultInputContext([single_output])
async def reduce(self, reduce_func: Callable[[Any], Any]) -> InputContext:
async def reduce(self, reduce_func: ReduceFunc) -> InputContext:
"""Apply a reduce function to all parent outputs."""
if not self.check_stream():
raise ValueError(
"The output in all tasks must has same output format of stream to apply"
" reduce function"
)
new_outputs, results = await self._apply_func(reduce_func, apply_type="reduce")
new_outputs, results = await self._apply_func(
reduce_func, apply_type="reduce" # type: ignore
)
for i, task_ctx in enumerate(new_outputs):
task_ctx = cast(TaskContext, task_ctx)
task_ctx.set_task_output(results[i])

View File

@@ -410,12 +410,20 @@ class HttpTrigger(Trigger):
"""
return self._register_to_app
def mount_to_router(self, router: "APIRouter") -> None:
def mount_to_router(
self, router: "APIRouter", global_prefix: Optional[str] = None
) -> None:
"""Mount the trigger to a router.
Args:
router (APIRouter): The router to mount the trigger.
global_prefix (Optional[str], optional): The global prefix of the router.
"""
path = (
join_paths(global_prefix, self._endpoint)
if global_prefix
else self._endpoint
)
dynamic_route_function = self._create_route_func()
router.api_route(
self._endpoint,
@@ -425,6 +433,8 @@ class HttpTrigger(Trigger):
tags=self._router_tags,
)(dynamic_route_function)
logger.info(f"Mount http trigger success, path: {path}")
def mount_to_app(self, app: "FastAPI", global_prefix: Optional[str] = None) -> None:
"""Mount the trigger to a FastAPI app.
@@ -455,6 +465,7 @@ class HttpTrigger(Trigger):
)
app.openapi_schema = None
app.middleware_stack = None
logger.info(f"Mount http trigger success, path: {path}")
def remove_from_app(
self, app: "FastAPI", global_prefix: Optional[str] = None
@@ -583,9 +594,12 @@ class HttpTrigger(Trigger):
request_model = self._req_body
elif get_origin(self._req_body) == dict and not is_query_method:
request_model = self._req_body
elif is_query_method:
request_model = None
else:
err_msg = f"Unsupported request body type {self._req_body}"
raise AWELHttpError(err_msg)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), "

View File

@@ -88,7 +88,7 @@ class HttpTriggerManager(TriggerManager):
# Mount to app, support dynamic route.
trigger.mount_to_app(app, self._router_prefix)
else:
trigger.mount_to_router(self._router)
trigger.mount_to_router(self._router, self._router_prefix)
self._trigger_map[trigger_id] = trigger
except Exception as e:
self._unregister_route_tables(path, methods)
@@ -174,12 +174,14 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
def __init__(self, system_app: SystemApp | None = None):
"""Initialize a DefaultTriggerManager."""
self.system_app = system_app
self.http_trigger = HttpTriggerManager()
super().__init__(None)
self._http_trigger: Optional[HttpTriggerManager] = None
super().__init__()
def init_app(self, system_app: SystemApp):
"""Initialize the trigger manager."""
self.system_app = system_app
if system_app and self.system_app.app:
self._http_trigger = HttpTriggerManager()
def register_trigger(self, trigger: Any, system_app: SystemApp) -> None:
"""Register a trigger to current manager."""
@@ -187,9 +189,9 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
if isinstance(trigger, HttpTrigger):
logger.info(f"Register trigger {trigger}")
self.http_trigger.register_trigger(trigger, system_app)
# else:
# raise ValueError(f"Unsupported trigger: {trigger}")
if not self._http_trigger:
raise ValueError("Http trigger manager not initialized")
self._http_trigger.register_trigger(trigger, system_app)
def unregister_trigger(self, trigger: Any, system_app: SystemApp) -> None:
"""Unregister a trigger to current manager."""
@@ -197,14 +199,14 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
if isinstance(trigger, HttpTrigger):
logger.info(f"Unregister trigger {trigger}")
self.http_trigger.unregister_trigger(trigger, system_app)
# else:
# raise ValueError(f"Unsupported trigger: {trigger}")
if not self._http_trigger:
raise ValueError("Http trigger manager not initialized")
self._http_trigger.unregister_trigger(trigger, system_app)
def after_register(self) -> None:
"""After register, init the trigger manager."""
if self.system_app:
self.http_trigger._init_app(self.system_app)
if self.system_app and self._http_trigger:
self._http_trigger._init_app(self.system_app)
def keep_running(self) -> bool:
"""Whether keep running.
@@ -212,4 +214,6 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
Returns:
bool: Whether keep running, True means keep running, False means stop.
"""
return self.http_trigger.keep_running()
if not self._http_trigger:
return False
return self._http_trigger.keep_running()