feat(awel): Modify AWEL http trigger route function (#817)

This commit is contained in:
FangYin Cheng 2023-11-22 09:55:49 +08:00 committed by GitHub
parent 2ce77519de
commit 64165b23cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 41 deletions

View File

@ -1,5 +1,7 @@
"""AWEL: Simple chat dag example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell

View File

@ -1,5 +1,7 @@
"""AWEL: Simple dag example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell

View File

@ -1,5 +1,7 @@
"""AWEL: Simple rag example
DB-GPT will automatically load and execute the current file after startup.
Example:
.. code-block:: shell
@ -49,6 +51,7 @@ with DAG("simple_rag_example") as dag:
"/examples/simple_rag", methods="POST", request_body=ConversationVo
)
req_parse_task = RequestParseOperator()
# TODO should register prompt template first
prompt_task = PromptManagerOperator()
history_storage_task = ChatHistoryStorageOperator()
history_task = ChatHistoryOperator()

View File

@ -1,15 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from ..operator.base import BaseOperator
from ..operator.common_operator import TriggerOperator
from ..dag.base import DAGContext
from ..task.base import TaskOutput
class Trigger(TriggerOperator, ABC):
@abstractmethod
async def trigger(self, end_operator: "BaseOperator") -> None:
async def trigger(self) -> None:
"""Trigger the workflow or a specific operation in the workflow."""

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel
import logging
from .base import Trigger
from ..dag.base import DAG
from ..operator.base import BaseOperator
if TYPE_CHECKING:
@ -50,46 +51,33 @@ class HttpTrigger(Trigger):
def mount_to_router(self, router: "APIRouter") -> None:
from fastapi import Depends
from fastapi.responses import StreamingResponse
methods = self._methods if isinstance(self._methods, list) else [self._methods]
def create_route_function(name):
def create_route_function(name, req_body_cls: Optional[Type[BaseModel]]):
async def _request_body_dependency(request: Request):
return await _parse_request_body(request, self._req_body)
async def route_function(body: Any = Depends(_request_body_dependency)):
end_node = self.dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not self._streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = self._response_headers
media_type = (
self._response_media_type
if self._response_media_type
else "text/event-stream"
)
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
async def route_function(body=Depends(_request_body_dependency)):
return await _trigger_dag(
body,
self.dag,
self._streaming_response,
self._response_headers,
self._response_media_type,
)
route_function.__name__ = name
return route_function
function_name = f"dynamic_route_{self._endpoint.replace('/', '_')}"
dynamic_route_function = create_route_function(function_name)
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
request_model = (
self._req_body
if isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
else None
)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), endpoint: {self._endpoint}, methods: {methods}"
)
@ -115,3 +103,35 @@ async def _parse_request_body(
return request_body_cls(**request.query_params)
else:
return request
async def _trigger_dag(
body: Any,
dag: DAG,
streaming_response: Optional[bool] = False,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi.responses import StreamingResponse
end_node = dag.leaf_nodes
if len(end_node) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = end_node[0]
if not streaming_response:
return await end_node.call(call_data={"data": body})
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
return StreamingResponse(
end_node.call_stream(call_data={"data": body}),
headers=headers,
media_type=media_type,
)

View File

@ -421,12 +421,6 @@ def cache_requires():
setup_spec.extras["cache"] = ["rocksdict", "msgpack"]
# def chat_scene():
# setup_spec.extras["chat"] = [
# ""
# ]
def default_requires():
"""
pip install "db-gpt[default]"
@ -445,6 +439,7 @@ def default_requires():
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
setup_spec.extras["default"] += setup_spec.extras["torch"]
setup_spec.extras["default"] += setup_spec.extras["quantization"]
setup_spec.extras["default"] += setup_spec.extras["cache"]
def all_requires():