mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 23:28:35 +00:00
feat(awel): Modify AWEL http trigger route function (#817)
This commit is contained in:
parent
2ce77519de
commit
64165b23cf
@ -1,5 +1,7 @@
|
||||
"""AWEL: Simple chat dag example
|
||||
|
||||
DB-GPT will automatically load and execute the current file after startup.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""AWEL: Simple dag example
|
||||
|
||||
DB-GPT will automatically load and execute the current file after startup.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""AWEL: Simple rag example
|
||||
|
||||
DB-GPT will automatically load and execute the current file after startup.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: shell
|
||||
@ -49,6 +51,7 @@ with DAG("simple_rag_example") as dag:
|
||||
"/examples/simple_rag", methods="POST", request_body=ConversationVo
|
||||
)
|
||||
req_parse_task = RequestParseOperator()
|
||||
# TODO should register prompt template first
|
||||
prompt_task = PromptManagerOperator()
|
||||
history_storage_task = ChatHistoryStorageOperator()
|
||||
history_task = ChatHistoryOperator()
|
||||
|
@ -1,15 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..operator.base import BaseOperator
|
||||
from ..operator.common_operator import TriggerOperator
|
||||
from ..dag.base import DAGContext
|
||||
from ..task.base import TaskOutput
|
||||
|
||||
|
||||
class Trigger(TriggerOperator, ABC):
|
||||
@abstractmethod
|
||||
async def trigger(self, end_operator: "BaseOperator") -> None:
|
||||
async def trigger(self) -> None:
|
||||
"""Trigger the workflow or a specific operation in the workflow."""
|
||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from .base import Trigger
|
||||
from ..dag.base import DAG
|
||||
from ..operator.base import BaseOperator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -50,46 +51,33 @@ class HttpTrigger(Trigger):
|
||||
|
||||
def mount_to_router(self, router: "APIRouter") -> None:
|
||||
from fastapi import Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||
|
||||
def create_route_function(name):
|
||||
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||
async def _request_body_dependency(request: Request):
|
||||
return await _parse_request_body(request, self._req_body)
|
||||
|
||||
async def route_function(body: Any = Depends(_request_body_dependency)):
|
||||
end_node = self.dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
if not self._streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
headers = self._response_headers
|
||||
media_type = (
|
||||
self._response_media_type
|
||||
if self._response_media_type
|
||||
else "text/event-stream"
|
||||
)
|
||||
if not headers:
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
||||
async def route_function(body=Depends(_request_body_dependency)):
|
||||
return await _trigger_dag(
|
||||
body,
|
||||
self.dag,
|
||||
self._streaming_response,
|
||||
self._response_headers,
|
||||
self._response_media_type,
|
||||
)
|
||||
|
||||
route_function.__name__ = name
|
||||
return route_function
|
||||
|
||||
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
|
||||
dynamic_route_function = create_route_function(function_name)
|
||||
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
|
||||
request_model = (
|
||||
self._req_body
|
||||
if isinstance(self._req_body, type)
|
||||
and issubclass(self._req_body, BaseModel)
|
||||
else None
|
||||
)
|
||||
dynamic_route_function = create_route_function(function_name, request_model)
|
||||
logger.info(
|
||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||
)
|
||||
@ -115,3 +103,35 @@ async def _parse_request_body(
|
||||
return request_body_cls(**request.query_params)
|
||||
else:
|
||||
return request
|
||||
|
||||
|
||||
async def _trigger_dag(
|
||||
body: Any,
|
||||
dag: DAG,
|
||||
streaming_response: Optional[bool] = False,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
response_media_type: Optional[str] = None,
|
||||
) -> Any:
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
end_node = dag.leaf_nodes
|
||||
if len(end_node) != 1:
|
||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||
end_node = end_node[0]
|
||||
if not streaming_response:
|
||||
return await end_node.call(call_data={"data": body})
|
||||
else:
|
||||
headers = response_headers
|
||||
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||
if not headers:
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
return StreamingResponse(
|
||||
end_node.call_stream(call_data={"data": body}),
|
||||
headers=headers,
|
||||
media_type=media_type,
|
||||
)
|
||||
|
7
setup.py
7
setup.py
@ -421,12 +421,6 @@ def cache_requires():
|
||||
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
|
||||
|
||||
|
||||
# def chat_scene():
|
||||
# setup_spec.extras["chat"] = [
|
||||
# ""
|
||||
# ]
|
||||
|
||||
|
||||
def default_requires():
|
||||
"""
|
||||
pip install "db-gpt[default]"
|
||||
@ -445,6 +439,7 @@ def default_requires():
|
||||
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["torch"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["quantization"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["cache"]
|
||||
|
||||
|
||||
def all_requires():
|
||||
|
Loading…
Reference in New Issue
Block a user