Files
DB-GPT/dbgpt/core/awel/trigger/http_trigger.py
2024-08-30 15:02:53 +08:00

1259 lines
41 KiB
Python

"""Http trigger for AWEL."""
import json
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Type,
Union,
cast,
get_origin,
)
from dbgpt._private.pydantic import (
BaseModel,
Field,
field_is_required,
field_outer_type,
model_fields,
model_to_dict,
)
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import root_tracer
from ..dag.base import DAG
from ..flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
OptionValue,
Parameter,
ResourceCategory,
ResourceType,
ViewMetadata,
register_resource,
)
from ..operators.base import BaseOperator
from ..operators.common_operator import MapOperator
from ..util._typing_util import _parse_bool
from ..util.http_util import join_paths
from .base import Trigger, TriggerMetadata
if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI
from starlette.requests import Request
from dbgpt.core.interface.llm import ModelRequestContext
RequestBody = Union[Type[Request], Type[BaseModel], Type[Dict[str, Any]], Type[str]]
CommonRequestType = Union[Request, BaseModel, Dict[str, Any], str, None]
StreamingPredictFunc = Callable[[CommonRequestType], bool]
logger = logging.getLogger(__name__)
ENDPOINT_PLACEHOLDER_DAG_ID = "{dag_id}"
class AWELHttpError(RuntimeError):
"""AWEL Http Error."""
def __init__(self, msg: str, code: Optional[str] = None):
"""Init the AWELHttpError."""
super().__init__(msg)
self.msg = msg
self.code = code
def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
if isinstance(body, BaseModel):
body = model_to_dict(body)
elif isinstance(body, str):
try:
body = json.loads(body)
except Exception:
return False
elif not isinstance(body, dict):
return False
streaming = body.get("streaming") or body.get("stream")
return _parse_bool(streaming)
class HttpTriggerMetadata(TriggerMetadata):
"""Trigger metadata."""
path: str = Field(..., description="The path of the trigger")
methods: List[str] = Field(..., description="The methods of the trigger")
trigger_mode: str = Field(
default="command", description="The mode of the trigger, command or chat"
)
trigger_type: Optional[str] = Field(
default="http", description="The type of the trigger"
)
class BaseHttpBody(BaseModel):
"""Http body.
For http request body or response body.
"""
@classmethod
def get_body_class(cls) -> Type:
"""Get body class.
Returns:
Type: The body class.
"""
return cls
def get_body(self) -> Any:
"""Get the body."""
return self
@classmethod
def streaming_predict_func(cls) -> Optional["StreamingPredictFunc"]:
"""Get the streaming predict function."""
return _default_streaming_predict_func
def streaming_response(self) -> bool:
"""Whether the response is streaming.
Returns:
bool: Whether the response is streaming.
"""
return False
@register_resource(
label=_("Dict Http Body"),
name="dict_http_body",
category=ResourceCategory.HTTP_BODY,
resource_type=ResourceType.CLASS,
description=_("Parse the request body as a dict or response body as a dict"),
)
class DictHttpBody(BaseHttpBody):
"""Dict http body."""
_default_body: Optional[Dict[str, Any]] = None
@classmethod
def get_body_class(cls) -> Type[Dict[str, Any]]:
"""Get body class.
Just return Dict[str, Any] here.
Returns:
Type[Dict[str, Any]]: The body class.
"""
return Dict[str, Any]
def get_body(self) -> Dict[str, Any]:
"""Get the body."""
if self._default_body is None:
raise AWELHttpError("DictHttpBody is not set")
return self._default_body
@register_resource(
label=_("String Http Body"),
name="string_http_body",
category=ResourceCategory.HTTP_BODY,
resource_type=ResourceType.CLASS,
description=_("Parse the request body as a string or response body as string"),
)
class StringHttpBody(BaseHttpBody):
"""String http body."""
_default_body: Optional[str] = None
@classmethod
def get_body_class(cls) -> Type[str]:
"""Get body class.
Just return str here.
Returns:
Type[str]: The body class.
"""
return str
def get_body(self) -> str:
"""Get the body."""
if self._default_body is None:
raise AWELHttpError("StringHttpBody is not set")
return self._default_body
@register_resource(
label=_("Request Http Body"),
name="request_http_body",
category=ResourceCategory.HTTP_BODY,
resource_type=ResourceType.CLASS,
description=_("Parse the request body as a starlette Request"),
)
class RequestHttpBody(BaseHttpBody):
"""Http trigger body."""
_default_body: Optional["Request"] = None
@classmethod
def get_body_class(cls) -> Type["Request"]:
"""Get the request body type.
Just return Request here.
Returns:
Type[Request]: The request body type.
"""
from starlette.requests import Request
return Request
def get_body(self) -> "Request":
"""Get the body."""
if self._default_body is None:
raise AWELHttpError("RequestHttpBody is not set")
return self._default_body
@register_resource(
label=_("Common LLM Http Request Body"),
name="common_llm_http_request_body",
category=ResourceCategory.HTTP_BODY,
resource_type=ResourceType.CLASS,
description=_("Parse the request body as a common LLM http body"),
)
class CommonLLMHttpRequestBody(BaseHttpBody):
"""Common LLM http request body."""
model: str = Field(
..., description="The model name", examples=["gpt-3.5-turbo", "proxyllm"]
)
messages: Union[str, List[str]] = Field(
..., description="User input messages", examples=["Hello", "How are you?"]
)
stream: bool = Field(default=False, description="Whether return stream")
temperature: Optional[float] = Field(
default=None,
description="What sampling temperature to use, between 0 and 2. Higher values "
"like 0.8 will make the output more random, while lower values like 0.2 will "
"make it more focused and deterministic.",
)
max_new_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens that can be generated in the chat "
"completion.",
)
conv_uid: Optional[str] = Field(
default=None, description="The conversation id of the model inference"
)
span_id: Optional[str] = Field(
default=None, description="The span id of the model inference"
)
chat_mode: Optional[str] = Field(
default="chat_normal",
description="The chat mode",
examples=["chat_awel_flow", "chat_normal"],
)
chat_param: Optional[str] = Field(
default=None,
description="The chat param of chat mode",
)
user_name: Optional[str] = Field(
default=None, description="The user name of the model inference"
)
sys_code: Optional[str] = Field(
default=None, description="The system code of the model inference"
)
incremental: bool = Field(
default=True,
description="Used to control whether the content is returned incrementally "
"or in full each time. "
"If this parameter is not provided, the default is full return.",
)
enable_vis: bool = Field(
default=True, description="response content whether to output vis label"
)
extra: Optional[Dict[str, Any]] = Field(
default=None, description="The extra info of the model inference"
)
@property
def context(self) -> "ModelRequestContext":
"""Get the model request context."""
from dbgpt.core.interface.llm import ModelRequestContext
return ModelRequestContext(
stream=self.stream,
user_name=self.user_name,
sys_code=self.sys_code,
conv_uid=self.conv_uid,
span_id=self.span_id,
chat_mode=self.chat_mode,
chat_param=self.chat_param,
extra=self.extra,
)
@register_resource(
label=_("Common LLM Http Response Body"),
name="common_llm_http_response_body",
category=ResourceCategory.HTTP_BODY,
resource_type=ResourceType.CLASS,
description=_("Parse the response body as a common LLM http body"),
)
class CommonLLMHttpResponseBody(BaseHttpBody):
"""Common LLM http response body."""
text: str = Field(
..., description="The generated text", examples=["Hello", "How are you?"]
)
error_code: int = Field(
default=0, description="The error code, 0 means no error", examples=[0, 1]
)
metrics: Optional[Dict[str, Any]] = Field(
default=None,
description="The metrics of the model, like the number of tokens generated",
)
class HttpTrigger(Trigger):
"""Http trigger for AWEL.
Http trigger is used to trigger a DAG by http request.
"""
metadata = ViewMetadata(
label="Http Trigger",
name="http_trigger",
category=OperatorCategory.TRIGGER,
operator_type=OperatorType.INPUT,
description="Trigger your workflow by http request",
inputs=[],
outputs=[],
parameters=[
Parameter.build_from(
"API Endpoint", "endpoint", str, description="The API endpoint"
),
Parameter.build_from(
"Http Methods",
"methods",
str,
optional=True,
default="GET",
description="The methods of the API endpoint",
options=[
OptionValue(label="HTTP Method GET", name="http_get", value="GET"),
OptionValue(label="HTTP Method PUT", name="http_put", value="PUT"),
OptionValue(
label="HTTP Method POST", name="http_post", value="POST"
),
OptionValue(
label="HTTP Method DELETE", name="http_delete", value="DELETE"
),
],
),
Parameter.build_from(
"Http Request Trigger Body",
"http_trigger_body",
BaseHttpBody,
optional=True,
default=None,
description="The request body of the API endpoint",
resource_type=ResourceType.CLASS,
),
Parameter.build_from(
"Streaming Response",
"streaming_response",
bool,
optional=True,
default=False,
description="Whether the response is streaming",
),
Parameter.build_from(
"Http Response Body",
"http_response_body",
BaseHttpBody,
optional=True,
default=None,
description="The response body of the API endpoint",
resource_type=ResourceType.CLASS,
),
Parameter.build_from(
"Response Media Type",
"response_media_type",
str,
optional=True,
default=None,
description="The response media type",
),
Parameter.build_from(
"Http Status Code",
"status_code",
int,
optional=True,
default=200,
description="The http status code",
),
],
)
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "GET",
request_body: Optional["RequestBody"] = None,
http_trigger_body: Optional[Type[BaseHttpBody]] = None,
streaming_response: bool = False,
streaming_predict_func: Optional["StreamingPredictFunc"] = None,
http_response_body: Optional[Type[BaseHttpBody]] = None,
response_model: Optional[Type] = None,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str | Enum]] = None,
register_to_app: bool = False,
**kwargs,
) -> None:
"""Initialize a HttpTrigger."""
super().__init__(**kwargs)
if not endpoint.startswith("/"):
endpoint = "/" + endpoint
if not request_body and http_trigger_body:
request_body = http_trigger_body.get_body_class()
streaming_predict_func = http_trigger_body.streaming_predict_func()
if not response_model and http_response_body:
response_model = http_response_body.get_body_class()
self._endpoint = endpoint
self._methods = [methods] if isinstance(methods, str) else methods
self._req_body = request_body
self._streaming_response = _parse_bool(streaming_response)
self._streaming_predict_func = streaming_predict_func
self._response_model = response_model
self._status_code = status_code
self._router_tags = router_tags
self._response_headers = response_headers
self._response_media_type = response_media_type
self._end_node: Optional[BaseOperator] = None
self._register_to_app = register_to_app
async def trigger(self, **kwargs) -> Any:
"""Trigger the DAG. Not used in HttpTrigger."""
raise NotImplementedError("HttpTrigger does not support trigger directly")
def register_to_app(self) -> bool:
"""Register the trigger to a FastAPI app.
Returns:
bool: Whether register to app, if not register to app, will register to
router.
"""
return self._register_to_app
def mount_to_router(
self, router: "APIRouter", global_prefix: Optional[str] = None
) -> HttpTriggerMetadata:
"""Mount the trigger to a router.
Args:
router (APIRouter): The router to mount the trigger.
global_prefix (Optional[str], optional): The global prefix of the router.
"""
endpoint = self._resolved_endpoint()
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
dynamic_route_function = self._create_route_func()
router.api_route(
endpoint,
methods=self._methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
)(dynamic_route_function)
logger.info(f"Mount http trigger success, path: {path}")
return HttpTriggerMetadata(
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
)
def mount_to_app(
self, app: "FastAPI", global_prefix: Optional[str] = None
) -> HttpTriggerMetadata:
"""Mount the trigger to a FastAPI app.
TODO: The performance of this method is not good, need to be optimized.
Args:
app (FastAPI): The FastAPI app.
global_prefix (Optional[str], optional): The global prefix of the app.
Defaults to None.
"""
from dbgpt.util.fastapi import PriorityAPIRouter
endpoint = self._resolved_endpoint()
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
dynamic_route_function = self._create_route_func()
router = cast(PriorityAPIRouter, app.router)
router.add_api_route(
path,
dynamic_route_function,
methods=self._methods,
response_model=self._response_model,
status_code=self._status_code,
tags=self._router_tags,
priority=10,
)
app.openapi_schema = None
app.middleware_stack = None
logger.info(f"Mount http trigger success, path: {path}")
return HttpTriggerMetadata(
path=path, methods=self._methods, trigger_mode=self._trigger_mode()
)
def remove_from_app(
self, app: "FastAPI", global_prefix: Optional[str] = None
) -> None:
"""Remove the trigger from a FastAPI app.
Args:
app (FastAPI): The FastAPI app.
global_prefix (Optional[str], optional): The global prefix of the app.
Defaults to None.
"""
from fastapi import APIRouter
endpoint = self._resolved_endpoint()
path = join_paths(global_prefix, endpoint) if global_prefix else endpoint
app_router = cast(APIRouter, app.router)
for i, r in enumerate(app_router.routes):
if r.path_format == path: # type: ignore
# TODO, remove with path and methods
del app_router.routes[i]
def _resolved_endpoint(self) -> str:
"""Get the resolved endpoint.
Replace the placeholder {dag_id} with the real dag_id.
"""
endpoint = self._endpoint
if ENDPOINT_PLACEHOLDER_DAG_ID not in endpoint:
return endpoint
if not self.dag:
raise AWELHttpError("DAG is not set")
dag_id = self.dag.dag_id
return endpoint.replace(ENDPOINT_PLACEHOLDER_DAG_ID, dag_id)
def _trigger_mode(self) -> str:
if (
self._req_body
and isinstance(self._req_body, type)
and issubclass(self._req_body, CommonLLMHttpRequestBody)
):
return "chat"
return "command"
async def map(self, input_data: Any) -> Any:
"""Map the input data.
Do some transformation for the input data.
Args:
input_data (Any): The input data from caller.
Returns:
Any: The mapped data.
"""
if not self._req_body or not input_data:
return await super().map(input_data)
if (
isinstance(self._req_body, type)
and issubclass(self._req_body, BaseModel)
and isinstance(input_data, dict)
):
return self._req_body(**input_data)
return await super().map(input_data)
def _create_route_func(self):
from inspect import Parameter, Signature
from typing import get_type_hints
from starlette.requests import Request
is_query_method = (
all(method in ["GET", "DELETE"] for method in self._methods)
if self._methods
else True
)
async def _trigger_dag_func(body: Union[Request, BaseModel, str, None]):
streaming_response = self._streaming_response
if self._streaming_predict_func:
streaming_response = self._streaming_predict_func(body)
elif isinstance(body, BaseHttpBody):
# BaseHttpBody, read streaming flag from body
streaming_response = _default_streaming_predict_func(body)
dag = self.dag
if not dag:
raise AWELHttpError("DAG is not set")
return await _trigger_dag(
body,
dag,
streaming_response,
self._response_headers,
self._response_media_type,
)
def create_route_function(name, req_body_cls: Optional["RequestBody"]):
async def route_function_request(request: Request):
return await _trigger_dag_func(request)
async def route_function_none():
return await _trigger_dag_func(None)
route_function_request.__name__ = name
route_function_none.__name__ = name
if not req_body_cls:
return route_function_none
if req_body_cls == Request:
return route_function_request
if is_query_method:
if req_body_cls == str:
raise AWELHttpError(
f"Query methods {self._methods} not support str type"
)
async def route_function_get(**kwargs):
if req_body_cls == dict or get_origin(req_body_cls) == dict:
body = kwargs
else:
body = req_body_cls(**kwargs)
return await _trigger_dag_func(body)
if isinstance(req_body_cls, type) and issubclass(
req_body_cls, BaseModel
):
fields = model_fields(req_body_cls) # type: ignore
parameters = []
for field_name, field in fields.items():
default_value = (
Parameter.empty
if field_is_required(field)
else field.default
)
parameters.append(
Parameter(
name=field_name,
kind=Parameter.KEYWORD_ONLY,
default=default_value,
annotation=field_outer_type(field),
)
)
elif req_body_cls == Dict[str, Any] or req_body_cls == dict:
raise AWELHttpError(
f"Query methods {self._methods} not support dict type"
)
else:
parameters = []
route_function_get.__signature__ = Signature(parameters) # type: ignore
if isinstance(req_body_cls, type):
route_function_get.__annotations__ = get_type_hints(req_body_cls)
route_function_get.__name__ = name
return route_function_get
else:
async def route_function(body: req_body_cls): # type: ignore
return await _trigger_dag_func(body)
route_function.__name__ = name
return route_function
function_name = f"AWEL_trigger_route_{self._endpoint.replace('/', '_')}"
if isinstance(self._req_body, type) and ( # noqa: SIM114
issubclass(self._req_body, Request)
or issubclass(self._req_body, BaseModel)
or issubclass(self._req_body, dict)
or issubclass(self._req_body, str)
): # noqa: SIM114
request_model = self._req_body
elif get_origin(self._req_body) == dict and not is_query_method:
request_model = self._req_body
elif is_query_method:
request_model = None
else:
err_msg = f"Unsupported request body type {self._req_body}"
raise AWELHttpError(err_msg)
dynamic_route_function = create_route_function(function_name, request_model)
logger.info(
f"mount router function {dynamic_route_function}({function_name}), "
f"endpoint: {self._endpoint}, methods: {self._methods}"
)
return dynamic_route_function
async def _trigger_dag(
body: Any,
dag: DAG,
streaming_response: Optional[bool] = False,
response_headers: Optional[Dict[str, str]] = None,
response_media_type: Optional[str] = None,
) -> Any:
from fastapi import BackgroundTasks
from fastapi.responses import StreamingResponse
span_id = root_tracer._parse_span_id(body)
leaf_nodes = dag.leaf_nodes
if len(leaf_nodes) != 1:
raise ValueError("HttpTrigger just support one leaf node in dag")
end_node = cast(BaseOperator, leaf_nodes[0])
metadata = {
"awel_node_id": end_node.node_id,
"awel_node_name": end_node.node_name,
}
if not streaming_response:
with root_tracer.start_span(
"dbgpt.core.trigger.http.run_dag", span_id, metadata=metadata
):
return await end_node.call(call_data=body)
else:
headers = response_headers
media_type = response_media_type if response_media_type else "text/event-stream"
if not headers:
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
_generator = await end_node.call_stream(call_data=body)
trace_generator = root_tracer.wrapper_async_stream(
_generator, "dbgpt.core.trigger.http.run_dag", span_id, metadata=metadata
)
async def _after_dag_end():
await dag._after_dag_end(end_node.current_event_loop_task_id)
background_tasks = BackgroundTasks()
background_tasks.add_task(_after_dag_end)
return StreamingResponse(
trace_generator,
headers=headers,
media_type=media_type,
background=background_tasks,
)
_PARAMETER_ENDPOINT = Parameter.build_from(
_("API Endpoint"), "endpoint", str, description=_("The API endpoint")
)
_PARAMETER_METHODS_POST_PUT = Parameter.build_from(
_("Http Methods"),
"methods",
str,
optional=True,
default="POST",
description=_("The methods of the API endpoint"),
options=[
OptionValue(label=_("HTTP Method PUT"), name="http_put", value="PUT"),
OptionValue(label=_("HTTP Method POST"), name="http_post", value="POST"),
],
)
_PARAMETER_METHODS_ALL = Parameter.build_from(
_("Http Methods"),
"methods",
str,
optional=True,
default="GET",
description=_("The methods of the API endpoint"),
options=[
OptionValue(label=_("HTTP Method GET"), name="http_get", value="GET"),
OptionValue(label=_("HTTP Method DELETE"), name="http_delete", value="DELETE"),
OptionValue(label=_("HTTP Method PUT"), name="http_put", value="PUT"),
OptionValue(label=_("HTTP Method POST"), name="http_post", value="POST"),
],
)
_PARAMETER_STREAMING_RESPONSE = Parameter.build_from(
_("Streaming Response"),
"streaming_response",
bool,
optional=True,
default=False,
description=_("Whether the response is streaming"),
)
_PARAMETER_RESPONSE_BODY = Parameter.build_from(
_("Http Response Body"),
"http_response_body",
BaseHttpBody,
optional=True,
default=None,
description=_("The response body of the API endpoint"),
resource_type=ResourceType.CLASS,
)
_PARAMETER_MEDIA_TYPE = Parameter.build_from(
_("Response Media Type"),
"response_media_type",
str,
optional=True,
default=None,
description=_("The response media type"),
)
_PARAMETER_STATUS_CODE = Parameter.build_from(
_("Http Status Code"),
"status_code",
int,
optional=True,
default=200,
description=_("The http status code"),
)
class DictHttpTrigger(HttpTrigger):
"""Dict http trigger for AWEL.
Parse the request body as a dict.
"""
metadata = ViewMetadata(
label=_("Dict Http Trigger"),
name="dict_http_trigger",
category=OperatorCategory.TRIGGER,
operator_type=OperatorType.INPUT,
description=_(
"Trigger your workflow by http request, and parse the request body"
" as a dict"
),
inputs=[],
outputs=[
IOField.build_from(
_("Request Body"),
"request_body",
dict,
description=_("The request body of the API endpoint"),
),
],
parameters=[
_PARAMETER_ENDPOINT.new(),
_PARAMETER_METHODS_POST_PUT.new(),
_PARAMETER_STREAMING_RESPONSE.new(),
_PARAMETER_RESPONSE_BODY.new(),
_PARAMETER_MEDIA_TYPE.new(),
_PARAMETER_STATUS_CODE.new(),
],
)
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "POST",
streaming_response: bool = False,
http_response_body: Optional[Type[BaseHttpBody]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str | Enum]] = None,
**kwargs,
):
"""Initialize a DictHttpTrigger."""
if not router_tags:
router_tags = ["AWEL DictHttpTrigger"]
super().__init__(
endpoint,
methods,
streaming_response=streaming_response,
request_body=dict,
http_response_body=http_response_body,
response_media_type=response_media_type,
status_code=status_code,
router_tags=router_tags,
register_to_app=True,
**kwargs,
)
class StringHttpTrigger(HttpTrigger):
"""String http trigger for AWEL."""
metadata = ViewMetadata(
label=_("String Http Trigger"),
name="string_http_trigger",
category=OperatorCategory.TRIGGER,
operator_type=OperatorType.INPUT,
description=_(
"Trigger your workflow by http request, and parse the request body"
" as a string"
),
inputs=[],
outputs=[
IOField.build_from(
_("Request Body"),
"request_body",
str,
description=_(
"The request body of the API endpoint, parse as a json " "string"
),
),
],
parameters=[
_PARAMETER_ENDPOINT.new(),
_PARAMETER_METHODS_POST_PUT.new(),
_PARAMETER_STREAMING_RESPONSE.new(),
_PARAMETER_RESPONSE_BODY.new(),
_PARAMETER_MEDIA_TYPE.new(),
_PARAMETER_STATUS_CODE.new(),
],
)
def __init__(
self,
endpoint: str,
methods: Optional[Union[str, List[str]]] = "POST",
streaming_response: bool = False,
http_response_body: Optional[Type[BaseHttpBody]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str | Enum]] = None,
**kwargs,
):
"""Initialize a StringHttpTrigger."""
if not router_tags:
router_tags = ["AWEL StringHttpTrigger"]
super().__init__(
endpoint,
methods,
streaming_response=streaming_response,
request_body=str,
http_response_body=http_response_body,
response_media_type=response_media_type,
status_code=status_code,
router_tags=router_tags,
register_to_app=True,
**kwargs,
)
class CommonLLMHttpTrigger(HttpTrigger):
"""Common LLM http trigger for AWEL."""
class MessagesOutputMapper(MapOperator[CommonLLMHttpRequestBody, str]):
"""Messages output mapper."""
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
"""Map the request body to messages."""
if isinstance(request_body.messages, str):
return request_body.messages
else:
raise ValueError("Messages to be transformed is not a string")
metadata = ViewMetadata(
label=_("Common LLM Http Trigger"),
name="common_llm_http_trigger",
category=OperatorCategory.TRIGGER,
operator_type=OperatorType.INPUT,
description=_(
"Trigger your workflow by http request, and parse the request body "
"as a common LLM http body"
),
inputs=[],
outputs=[
IOField.build_from(
_("Request Body"),
"request_body",
CommonLLMHttpRequestBody,
description=_(
"The request body of the API endpoint, parse as a common "
"LLM http body"
),
),
IOField.build_from(
_("Request String Messages"),
"request_string_messages",
str,
description=_(
"The request string messages of the API endpoint, parsed from "
"'messages' field of the request body"
),
mappers=[MessagesOutputMapper],
),
],
parameters=[
Parameter.build_from(
_("API Endpoint"),
"endpoint",
str,
optional=True,
default="/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
description=_("The API endpoint"),
),
_PARAMETER_METHODS_POST_PUT.new(),
_PARAMETER_STREAMING_RESPONSE.new(),
_PARAMETER_RESPONSE_BODY.new(),
_PARAMETER_MEDIA_TYPE.new(),
_PARAMETER_STATUS_CODE.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(
self,
endpoint: str = "/example/" + ENDPOINT_PLACEHOLDER_DAG_ID,
methods: Optional[Union[str, List[str]]] = "POST",
streaming_response: bool = False,
http_response_body: Optional[Type[BaseHttpBody]] = None,
response_media_type: Optional[str] = None,
status_code: Optional[int] = 200,
router_tags: Optional[List[str | Enum]] = None,
**kwargs,
):
"""Initialize a CommonLLMHttpTrigger."""
if not router_tags:
router_tags = ["AWEL CommonLLMHttpTrigger"]
super().__init__(
endpoint,
methods,
streaming_response=streaming_response,
request_body=CommonLLMHttpRequestBody,
http_response_body=http_response_body,
response_media_type=response_media_type,
status_code=status_code,
router_tags=router_tags,
register_to_app=True,
**kwargs,
)
@register_resource(
label=_("Example Http Response"),
name="example_http_response",
category=ResourceCategory.HTTP_BODY,
resource_type=ResourceType.CLASS,
description=_("Example Http Request"),
)
class ExampleHttpResponse(BaseHttpBody):
"""Example Http Response.
Just for test.
Register as a resource.
"""
server_res: str = Field(..., description="The server response from Operator")
request_body: Dict[str, Any] = Field(
..., description="The request body from Http request"
)
class ExampleHttpHelloOperator(MapOperator[dict, ExampleHttpResponse]):
"""Example Http Hello Operator.
Just for test.
"""
metadata = ViewMetadata(
label=_("Example Http Hello Operator"),
name="example_http_hello_operator",
category=OperatorCategory.COMMON,
parameters=[],
inputs=[
IOField.build_from(
_("Http Request Body"),
"request_body",
dict,
description=_("The request body of the API endpoint(Dict[str, Any])"),
)
],
outputs=[
IOField.build_from(
_("Response Body"),
"response_body",
ExampleHttpResponse,
description=_("The response body of the API endpoint"),
)
],
description=_("Example Http Hello Operator"),
)
def __int__(self, **kwargs):
"""Initialize a ExampleHttpHelloOperator."""
super().__init__(**kwargs)
async def map(self, request_body: dict) -> ExampleHttpResponse:
"""Map the request body to response body."""
print(f"Receive input value: {request_body}")
name = request_body.get("name")
age = request_body.get("age")
server_res = f"Hello, {name}, your age is {age}"
return ExampleHttpResponse(server_res=server_res, request_body=request_body)
class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str, Any]]):
"""Request body to dict operator."""
metadata = ViewMetadata(
label=_("Request Body To Dict Operator"),
name="request_body_to_dict_operator",
category=OperatorCategory.COMMON,
parameters=[
Parameter.build_from(
_("Prefix Key"),
"prefix_key",
str,
optional=True,
default=None,
description=_(
"The prefix key of the dict, link 'message' or 'extra.info'"
),
)
],
inputs=[
IOField.build_from(
_("Request Body"),
"request_body",
CommonLLMHttpRequestBody,
description=_("The request body of the API endpoint"),
)
],
outputs=[
IOField.build_from(
_("Response Body"),
"response_body",
dict,
description=_("The response body of the API endpoint"),
)
],
description="Request body to dict operator",
)
def __init__(self, prefix_key: Optional[str] = None, **kwargs):
"""Initialize a RequestBodyToDictOperator."""
super().__init__(**kwargs)
self._key = prefix_key
async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]:
"""Map the request body to response body."""
dict_value = model_to_dict(request_body)
if not self._key:
return dict_value
else:
keys = self._key.split(".")
for k in keys:
dict_value = dict_value[k]
if not isinstance(dict_value, dict):
raise ValueError(
f"Prefix key {self._key} is not a valid key of the request body"
)
return dict_value
class UserInputParsedOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str, Any]]):
"""User input parsed operator."""
metadata = ViewMetadata(
label=_("User Input Parsed Operator"),
name="user_input_parsed_operator",
category=OperatorCategory.COMMON,
parameters=[
Parameter.build_from(
_("Key"),
"key",
str,
optional=True,
default="user_input",
description=_("The key of the dict, link 'user_input'"),
)
],
inputs=[
IOField.build_from(
_("Request Body"),
"request_body",
CommonLLMHttpRequestBody,
description=_("The request body of the API endpoint"),
)
],
outputs=[
IOField.build_from(
_("User Input Dict"),
"user_input_dict",
dict,
description=_("The user input dict of the API endpoint"),
)
],
description=_(
"User input parsed operator, parse the user input from request body"
" and return as a dict"
),
)
def __init__(self, key: str = "user_input", **kwargs):
"""Initialize a UserInputParsedOperator."""
self._key = key
super().__init__(**kwargs)
async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]:
"""Map the request body to response body."""
return {self._key: request_body.messages}
class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
"""User input parsed operator."""
metadata = ViewMetadata(
label=_("Request Body Parsed To String Operator"),
name="request_body_to_str__parsed_operator",
category=OperatorCategory.COMMON,
parameters=[
Parameter.build_from(
_("Key"),
"key",
str,
optional=True,
default="messages",
description=_("The key of the dict, link 'user_input'"),
)
],
inputs=[
IOField.build_from(
_("Request Body"),
"request_body",
CommonLLMHttpRequestBody,
description=_("The request body of the API endpoint"),
)
],
outputs=[
IOField.build_from(
_("User Input String"),
"user_input_str",
str,
description=_("The user input dict of the API endpoint"),
)
],
description=_(
"User input parsed operator, parse the user input from request body and "
"return as a string"
),
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, key: str = "user_input", **kwargs):
"""Initialize a UserInputParsedOperator."""
self._key = key
super().__init__(**kwargs)
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
"""Map the request body to response body."""
dict_value = model_to_dict(request_body)
if not self._key or self._key not in dict_value:
raise ValueError(
f"Prefix key {self._key} is not a valid key of the request body"
)
return dict_value[self._key]