mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat(awel): Modify AWEL http trigger route function (#817)
This commit is contained in:
parent
2ce77519de
commit
64165b23cf
@ -1,5 +1,7 @@
|
|||||||
"""AWEL: Simple chat dag example
|
"""AWEL: Simple chat dag example
|
||||||
|
|
||||||
|
DB-GPT will automatically load and execute the current file after startup.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""AWEL: Simple dag example
|
"""AWEL: Simple dag example
|
||||||
|
|
||||||
|
DB-GPT will automatically load and execute the current file after startup.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""AWEL: Simple rag example
|
"""AWEL: Simple rag example
|
||||||
|
|
||||||
|
DB-GPT will automatically load and execute the current file after startup.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: shell
|
.. code-block:: shell
|
||||||
@ -49,6 +51,7 @@ with DAG("simple_rag_example") as dag:
|
|||||||
"/examples/simple_rag", methods="POST", request_body=ConversationVo
|
"/examples/simple_rag", methods="POST", request_body=ConversationVo
|
||||||
)
|
)
|
||||||
req_parse_task = RequestParseOperator()
|
req_parse_task = RequestParseOperator()
|
||||||
|
# TODO should register prompt template first
|
||||||
prompt_task = PromptManagerOperator()
|
prompt_task = PromptManagerOperator()
|
||||||
history_storage_task = ChatHistoryStorageOperator()
|
history_storage_task = ChatHistoryStorageOperator()
|
||||||
history_task = ChatHistoryOperator()
|
history_task = ChatHistoryOperator()
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from ..operator.base import BaseOperator
|
|
||||||
from ..operator.common_operator import TriggerOperator
|
from ..operator.common_operator import TriggerOperator
|
||||||
from ..dag.base import DAGContext
|
|
||||||
from ..task.base import TaskOutput
|
|
||||||
|
|
||||||
|
|
||||||
class Trigger(TriggerOperator, ABC):
|
class Trigger(TriggerOperator, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def trigger(self, end_operator: "BaseOperator") -> None:
|
async def trigger(self) -> None:
|
||||||
"""Trigger the workflow or a specific operation in the workflow."""
|
"""Trigger the workflow or a specific operation in the workflow."""
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .base import Trigger
|
from .base import Trigger
|
||||||
|
from ..dag.base import DAG
|
||||||
from ..operator.base import BaseOperator
|
from ..operator.base import BaseOperator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -50,46 +51,33 @@ class HttpTrigger(Trigger):
|
|||||||
|
|
||||||
def mount_to_router(self, router: "APIRouter") -> None:
|
def mount_to_router(self, router: "APIRouter") -> None:
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
|
|
||||||
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
methods = self._methods if isinstance(self._methods, list) else [self._methods]
|
||||||
|
|
||||||
def create_route_function(name):
|
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
|
||||||
async def _request_body_dependency(request: Request):
|
async def _request_body_dependency(request: Request):
|
||||||
return await _parse_request_body(request, self._req_body)
|
return await _parse_request_body(request, self._req_body)
|
||||||
|
|
||||||
async def route_function(body: Any = Depends(_request_body_dependency)):
|
async def route_function(body=Depends(_request_body_dependency)):
|
||||||
end_node = self.dag.leaf_nodes
|
return await _trigger_dag(
|
||||||
if len(end_node) != 1:
|
body,
|
||||||
raise ValueError("HttpTrigger just support one leaf node in dag")
|
self.dag,
|
||||||
end_node = end_node[0]
|
self._streaming_response,
|
||||||
if not self._streaming_response:
|
self._response_headers,
|
||||||
return await end_node.call(call_data={"data": body})
|
self._response_media_type,
|
||||||
else:
|
)
|
||||||
headers = self._response_headers
|
|
||||||
media_type = (
|
|
||||||
self._response_media_type
|
|
||||||
if self._response_media_type
|
|
||||||
else "text/event-stream"
|
|
||||||
)
|
|
||||||
if not headers:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "text/event-stream",
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"Transfer-Encoding": "chunked",
|
|
||||||
}
|
|
||||||
return StreamingResponse(
|
|
||||||
end_node.call_stream(call_data={"data": body}),
|
|
||||||
headers=headers,
|
|
||||||
media_type=media_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
route_function.__name__ = name
|
route_function.__name__ = name
|
||||||
return route_function
|
return route_function
|
||||||
|
|
||||||
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
|
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
|
||||||
dynamic_route_function = create_route_function(function_name)
|
request_model = (
|
||||||
|
self._req_body
|
||||||
|
if isinstance(self._req_body, type)
|
||||||
|
and issubclass(self._req_body, BaseModel)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
dynamic_route_function = create_route_function(function_name, request_model)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
|
||||||
)
|
)
|
||||||
@ -115,3 +103,35 @@ async def _parse_request_body(
|
|||||||
return request_body_cls(**request.query_params)
|
return request_body_cls(**request.query_params)
|
||||||
else:
|
else:
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
async def _trigger_dag(
|
||||||
|
body: Any,
|
||||||
|
dag: DAG,
|
||||||
|
streaming_response: Optional[bool] = False,
|
||||||
|
response_headers: Optional[Dict[str, str]] = None,
|
||||||
|
response_media_type: Optional[str] = None,
|
||||||
|
) -> Any:
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
end_node = dag.leaf_nodes
|
||||||
|
if len(end_node) != 1:
|
||||||
|
raise ValueError("HttpTrigger just support one leaf node in dag")
|
||||||
|
end_node = end_node[0]
|
||||||
|
if not streaming_response:
|
||||||
|
return await end_node.call(call_data={"data": body})
|
||||||
|
else:
|
||||||
|
headers = response_headers
|
||||||
|
media_type = response_media_type if response_media_type else "text/event-stream"
|
||||||
|
if not headers:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
}
|
||||||
|
return StreamingResponse(
|
||||||
|
end_node.call_stream(call_data={"data": body}),
|
||||||
|
headers=headers,
|
||||||
|
media_type=media_type,
|
||||||
|
)
|
||||||
|
7
setup.py
7
setup.py
@ -421,12 +421,6 @@ def cache_requires():
|
|||||||
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
|
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
|
||||||
|
|
||||||
|
|
||||||
# def chat_scene():
|
|
||||||
# setup_spec.extras["chat"] = [
|
|
||||||
# ""
|
|
||||||
# ]
|
|
||||||
|
|
||||||
|
|
||||||
def default_requires():
|
def default_requires():
|
||||||
"""
|
"""
|
||||||
pip install "db-gpt[default]"
|
pip install "db-gpt[default]"
|
||||||
@ -445,6 +439,7 @@ def default_requires():
|
|||||||
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
|
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
|
||||||
setup_spec.extras["default"] += setup_spec.extras["torch"]
|
setup_spec.extras["default"] += setup_spec.extras["torch"]
|
||||||
setup_spec.extras["default"] += setup_spec.extras["quantization"]
|
setup_spec.extras["default"] += setup_spec.extras["quantization"]
|
||||||
|
setup_spec.extras["default"] += setup_spec.extras["cache"]
|
||||||
|
|
||||||
|
|
||||||
def all_requires():
|
def all_requires():
|
||||||
|
Loading…
Reference in New Issue
Block a user