mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 03:50:42 +00:00
refactor: Refactor for core SDK (#1092)
This commit is contained in:
@@ -22,6 +22,7 @@ from .operator.common_operator import (
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
ReduceStreamOperator,
|
||||
TriggerOperator,
|
||||
)
|
||||
from .operator.stream_operator import (
|
||||
StreamifyAbsOperator,
|
||||
@@ -50,6 +51,7 @@ __all__ = [
|
||||
"BaseOperator",
|
||||
"JoinOperator",
|
||||
"ReduceStreamOperator",
|
||||
"TriggerOperator",
|
||||
"MapOperator",
|
||||
"BranchOperator",
|
||||
"InputOperator",
|
||||
@@ -150,4 +152,6 @@ def setup_dev_environment(
|
||||
for trigger in dag.trigger_nodes:
|
||||
trigger_manager.register_trigger(trigger)
|
||||
trigger_manager.after_register()
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
if trigger_manager.keep_running():
|
||||
# Should keep running
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
@@ -28,7 +28,7 @@ from ..task.base import OUT, T, TaskOutput
|
||||
|
||||
F = TypeVar("F", bound=FunctionType)
|
||||
|
||||
CALL_DATA = Union[Dict, Dict[str, Dict]]
|
||||
CALL_DATA = Union[Dict[str, Any], Any]
|
||||
|
||||
|
||||
class WorkflowRunner(ABC, Generic[T]):
|
||||
@@ -197,6 +197,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
OUT: The output of the node after execution.
|
||||
"""
|
||||
if call_data:
|
||||
call_data = {"data": call_data}
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
@@ -242,6 +244,8 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
||||
Returns:
|
||||
AsyncIterator[OUT]: An asynchronous iterator over the output stream.
|
||||
"""
|
||||
if call_data:
|
||||
call_data = {"data": call_data}
|
||||
out_ctx = await self._runner.execute_workflow(
|
||||
self, call_data, streaming_call=True, exist_dag_ctx=dag_ctx
|
||||
)
|
||||
|
@@ -28,6 +28,14 @@ EMPTY_DATA = _EMPTY_DATA_TYPE()
|
||||
SKIP_DATA = _EMPTY_DATA_TYPE()
|
||||
PLACEHOLDER_DATA = _EMPTY_DATA_TYPE()
|
||||
|
||||
|
||||
def is_empty_data(data: Any):
|
||||
"""Check if the data is empty."""
|
||||
if isinstance(data, _EMPTY_DATA_TYPE):
|
||||
return data in (EMPTY_DATA, SKIP_DATA)
|
||||
return False
|
||||
|
||||
|
||||
MapFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
ReduceFunc = Union[Callable[[IN], OUT], Callable[[IN], Awaitable[OUT]]]
|
||||
StreamFunc = Callable[[IN], Awaitable[AsyncIterator[OUT]]]
|
||||
|
@@ -24,7 +24,6 @@ from .base import (
|
||||
EMPTY_DATA,
|
||||
OUT,
|
||||
PLACEHOLDER_DATA,
|
||||
SKIP_DATA,
|
||||
InputContext,
|
||||
InputSource,
|
||||
MapFunc,
|
||||
@@ -37,6 +36,7 @@ from .base import (
|
||||
TaskState,
|
||||
TransformFunc,
|
||||
UnStreamFunc,
|
||||
is_empty_data,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -99,7 +99,7 @@ class SimpleTaskOutput(TaskOutput[T], Generic[T]):
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
return is_empty_data(self._data)
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
@@ -171,7 +171,7 @@ class SimpleStreamTaskOutput(TaskOutput[T], Generic[T]):
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the output data is empty."""
|
||||
return self._data == EMPTY_DATA or self._data == SKIP_DATA
|
||||
return is_empty_data(self._data)
|
||||
|
||||
@property
|
||||
def is_none(self) -> bool:
|
||||
@@ -330,7 +330,7 @@ class SimpleCallDataInputSource(BaseInputSource):
|
||||
"""
|
||||
call_data = task_ctx.call_data
|
||||
data = call_data.get("data", EMPTY_DATA) if call_data else EMPTY_DATA
|
||||
if data == EMPTY_DATA:
|
||||
if is_empty_data(data):
|
||||
raise ValueError("No call data for current SimpleCallDataInputSource")
|
||||
return data
|
||||
|
||||
|
@@ -1,12 +1,8 @@
|
||||
"""Http trigger for AWEL."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
|
||||
from ..dag.base import DAG
|
||||
@@ -15,9 +11,10 @@ from .base import Trigger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import APIRouter
|
||||
from starlette.requests import Request
|
||||
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
RequestBody = Union[Type[Request], Type[BaseModel], Type[str]]
|
||||
StreamingPredictFunc = Callable[[Union[Request, BaseModel]], bool]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,9 +29,9 @@ class HttpTrigger(Trigger):
|
||||
self,
|
||||
endpoint: str,
|
||||
methods: Optional[Union[str, List[str]]] = "GET",
|
||||
request_body: Optional[RequestBody] = None,
|
||||
request_body: Optional["RequestBody"] = None,
|
||||
streaming_response: bool = False,
|
||||
streaming_predict_func: Optional[StreamingPredictFunc] = None,
|
||||
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
|
||||
response_model: Optional[Type] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
@@ -69,6 +66,7 @@ class HttpTrigger(Trigger):
|
||||
router (APIRouter): The router to mount the trigger.
|
||||
"""
|
||||
from fastapi import Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
methods = [self._methods] if isinstance(self._methods, str) else self._methods
|
||||
|
||||
@@ -114,8 +112,10 @@ class HttpTrigger(Trigger):
|
||||
|
||||
|
||||
async def _parse_request_body(
|
||||
request: Request, request_body_cls: Optional[RequestBody]
|
||||
request: "Request", request_body_cls: Optional["RequestBody"]
|
||||
):
|
||||
from starlette.requests import Request
|
||||
|
||||
if not request_body_cls:
|
||||
return None
|
||||
if request_body_cls == Request:
|
||||
@@ -152,7 +152,7 @@ async def _trigger_dag(
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
return await end_node.call(call_data=body)
|
||||
else:
|
||||
headers = response_headers
|
||||
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||
@@ -163,7 +163,7 @@ async def _trigger_dag(
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
generator = await end_node.call_stream(call_data={"data": body})
|
||||
generator = await end_node.call_stream(call_data=body)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(dag._after_dag_end)
|
||||
return StreamingResponse(
|
||||
|
@@ -24,6 +24,14 @@ class TriggerManager(ABC):
|
||||
def register_trigger(self, trigger: Any) -> None:
|
||||
"""Register a trigger to current manager."""
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class HttpTriggerManager(TriggerManager):
|
||||
"""Http trigger manager.
|
||||
@@ -64,6 +72,8 @@ class HttpTriggerManager(TriggerManager):
|
||||
self._trigger_map[trigger_id] = trigger
|
||||
|
||||
def _init_app(self, system_app: SystemApp):
|
||||
if not self.keep_running():
|
||||
return
|
||||
logger.info(
|
||||
f"Include router {self._router} to prefix path {self._router_prefix}"
|
||||
)
|
||||
@@ -72,6 +82,14 @@ class HttpTriggerManager(TriggerManager):
|
||||
raise RuntimeError("System app not initialized")
|
||||
app.include_router(self._router, prefix=self._router_prefix, tags=["AWEL"])
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return len(self._trigger_map) > 0
|
||||
|
||||
|
||||
class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
"""Default trigger manager for AWEL.
|
||||
@@ -105,3 +123,11 @@ class DefaultTriggerManager(TriggerManager, BaseComponent):
|
||||
"""After register, init the trigger manager."""
|
||||
if self.system_app:
|
||||
self.http_trigger._init_app(self.system_app)
|
||||
|
||||
def keep_running(self) -> bool:
|
||||
"""Whether keep running.
|
||||
|
||||
Returns:
|
||||
bool: Whether keep running, True means keep running, False means stop.
|
||||
"""
|
||||
return self.http_trigger.keep_running()
|
||||
|
@@ -70,7 +70,7 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
end_node: BaseOperator = cast(BaseOperator, self._sub_compose_dag.leaf_nodes[0])
|
||||
# Sub dag, use the same dag context in the parent dag
|
||||
return await end_node.call(
|
||||
call_data={"data": input_value}, dag_ctx=self.current_dag_context
|
||||
call_data=input_value, dag_ctx=self.current_dag_context
|
||||
)
|
||||
|
||||
def _build_composer_dag(self) -> DAG:
|
||||
|
@@ -150,7 +150,7 @@ class PromptBuilderOperator(
|
||||
)
|
||||
)
|
||||
|
||||
single_input = {"data": {"dialect": "mysql"}}
|
||||
single_input = {"dialect": "mysql"}
|
||||
single_expected_messages = [
|
||||
ModelMessage(
|
||||
content="Please write a mysql SQL count the length of a field",
|
||||
|
Reference in New Issue
Block a user