feat: Support HTTP sender (#1383)

This commit is contained in:
Fangyin Cheng 2024-04-08 16:13:49 +08:00 committed by GitHub
parent df36b947d1
commit 634e62cb6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 430 additions and 4 deletions

View File

@ -82,6 +82,12 @@ def run():
pass
@click.group()
def net():
"""Net tools."""
pass
stop_all_func_list = []
@ -100,6 +106,7 @@ cli.add_command(new)
cli.add_command(app)
cli.add_command(repo)
cli.add_command(run)
cli.add_command(net)
add_command_alias(stop_all, name="all", parent_group=stop)
try:
@ -200,6 +207,13 @@ try:
except ImportError as e:
logging.warning(f"Integrating dbgpt client command line tool failed: {e}")
try:
from dbgpt.util.network._cli import start_forward
add_command_alias(start_forward, name="forward", parent_group=net)
except ImportError as e:
logging.warning(f"Integrating dbgpt net command line tool failed: {e}")
def main():
return cli()

View File

@ -108,6 +108,7 @@ class _CategoryDetail:
_OPERATOR_CATEGORY_DETAIL = {
"trigger": _CategoryDetail("Trigger", "Trigger your AWEL flow"),
"sender": _CategoryDetail("Sender", "Send the data to the target"),
"llm": _CategoryDetail("LLM", "Invoke LLM model"),
"conversion": _CategoryDetail("Conversion", "Handle the conversion"),
"output_parser": _CategoryDetail("Output Parser", "Parse the output of LLM model"),
@ -121,6 +122,7 @@ class OperatorCategory(str, Enum):
"""The category of the operator."""
TRIGGER = "trigger"
SENDER = "sender"
LLM = "llm"
CONVERSION = "conversion"
OUTPUT_PARSER = "output_parser"

View File

@ -20,6 +20,7 @@ from .base import (
from .exceptions import (
FlowClassMetadataException,
FlowDAGMetadataException,
FlowException,
FlowMetadataException,
)
@ -720,5 +721,5 @@ def fill_flow_panel(flow_panel: FlowPanel):
param.default = new_param.default
param.placeholder = new_param.placeholder
except ValueError as e:
except (FlowException, ValueError) as e:
logger.warning(f"Unable to fill the flow panel: {e}")

View File

@ -3,13 +3,14 @@
Supports more trigger types, such as RequestHttpTrigger.
"""
from enum import Enum
from typing import List, Optional, Type, Union
from typing import Dict, List, Optional, Type, Union
from starlette.requests import Request
from dbgpt.util.i18n_utils import _
from ..flow import IOField, OperatorCategory, OperatorType, ViewMetadata
from ..flow import IOField, OperatorCategory, OperatorType, Parameter, ViewMetadata
from ..operators.common_operator import MapOperator
from .http_trigger import (
_PARAMETER_ENDPOINT,
_PARAMETER_MEDIA_TYPE,
@ -82,3 +83,122 @@ class RequestHttpTrigger(HttpTrigger):
register_to_app=True,
**kwargs,
)
class DictHTTPSender(MapOperator[Dict, Dict]):
"""HTTP Sender operator for AWEL."""
metadata = ViewMetadata(
label=_("HTTP Sender"),
name="awel_dict_http_sender",
category=OperatorCategory.SENDER,
description=_("Send a HTTP request to a specified endpoint"),
inputs=[
IOField.build_from(
_("Request Body"),
"request_body",
dict,
description=_("The request body to send"),
)
],
outputs=[
IOField.build_from(
_("Response Body"),
"response_body",
dict,
description=_("The response body of the HTTP request"),
)
],
parameters=[
Parameter.build_from(
_("HTTP Address"),
_("address"),
type=str,
description=_("The address to send the HTTP request to"),
),
_PARAMETER_METHODS_ALL.new(),
_PARAMETER_STATUS_CODE.new(),
Parameter.build_from(
_("Timeout"),
"timeout",
type=int,
optional=True,
default=60,
description=_("The timeout of the HTTP request in seconds"),
),
Parameter.build_from(
_("Token"),
"token",
type=str,
optional=True,
default=None,
description=_("The token to use for the HTTP request"),
),
Parameter.build_from(
_("Cookies"),
"cookies",
type=str,
optional=True,
default=None,
description=_("The cookies to use for the HTTP request"),
),
],
)
def __init__(
self,
address: str,
methods: Optional[str] = "GET",
status_code: Optional[int] = 200,
timeout: Optional[int] = 60,
token: Optional[str] = None,
cookies: Optional[Dict[str, str]] = None,
**kwargs,
):
"""Initialize a HTTPSender."""
try:
import aiohttp # noqa: F401
except ImportError:
raise ImportError(
"aiohttp is required for HTTPSender, please install it with "
"`pip install aiohttp`"
)
self._address = address
self._methods = methods
self._status_code = status_code
self._timeout = timeout
self._token = token
self._cookies = cookies
super().__init__(**kwargs)
async def map(self, request_body: Dict) -> Dict:
"""Send the request body to the specified address."""
import aiohttp
if self._methods in ["POST", "PUT"]:
req_kwargs = {"json": request_body}
else:
req_kwargs = {"params": request_body}
method = self._methods or "GET"
headers = {}
if self._token:
headers["Authorization"] = f"Bearer {self._token}"
async with aiohttp.ClientSession(
headers=headers,
cookies=self._cookies,
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as session:
async with session.request(
method,
self._address,
raise_for_status=False,
**req_kwargs,
) as response:
status_code = response.status
if status_code != self._status_code:
raise ValueError(
f"HTTP request failed with status code {status_code}"
)
response_body = await response.json()
return response_body

View File

@ -1036,7 +1036,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str,
keys = self._key.split(".")
for k in keys:
dict_value = dict_value[k]
if isinstance(dict_value, dict):
if not isinstance(dict_value, dict):
raise ValueError(
f"Prefix key {self._key} is not a valid key of the request body"
)

View File

289
dbgpt/util/network/_cli.py Normal file
View File

@ -0,0 +1,289 @@
import os
import socket
import ssl as py_ssl
import threading
import click
from ..console import CliLogger
logger = CliLogger()
def forward_data(source, destination):
"""Forward data from source to destination."""
try:
while True:
data = source.recv(4096)
if b"" == data:
destination.sendall(data)
break
if not data:
break # no more data or connection closed
destination.sendall(data)
except Exception as e:
logger.error(f"Error forwarding data: {e}")
def handle_client(
client_socket,
remote_host: str,
remote_port: int,
is_ssl: bool = False,
http_proxy=None,
):
"""Handle client connection.
Create a connection to the remote host and port, and forward data between the
client and the remote host.
Close the client socket and remote socket when all forwarding threads are done.
"""
# remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if http_proxy:
proxy_host, proxy_port = http_proxy
remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
remote_socket.connect((proxy_host, proxy_port))
client_ip = client_socket.getpeername()[0]
scheme = "https" if is_ssl else "http"
connect_request = (
f"CONNECT {remote_host}:{remote_port} HTTP/1.1\r\n"
f"Host: {remote_host}\r\n"
f"Connection: keep-alive\r\n"
f"X-Real-IP: {client_ip}\r\n"
f"X-Forwarded-For: {client_ip}\r\n"
f"X-Forwarded-Proto: {scheme}\r\n\r\n"
)
logger.info(f"Sending connect request: {connect_request}")
remote_socket.sendall(connect_request.encode())
response = b""
while True:
part = remote_socket.recv(4096)
response += part
if b"\r\n\r\n" in part:
break
if b"200 Connection established" not in response:
logger.error("Failed to establish connection through proxy")
return
else:
remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
remote_socket.connect((remote_host, remote_port))
if is_ssl:
# context = py_ssl.create_default_context(py_ssl.Purpose.CLIENT_AUTH)
context = py_ssl.create_default_context(py_ssl.Purpose.SERVER_AUTH)
# ssl_target_socket = py_ssl.wrap_socket(remote_socket)
ssl_target_socket = context.wrap_socket(
remote_socket, server_hostname=remote_host
)
else:
ssl_target_socket = remote_socket
try:
# ssl_target_socket.connect((remote_host, remote_port))
# Forward data from client to server
client_to_server = threading.Thread(
target=forward_data, args=(client_socket, ssl_target_socket)
)
client_to_server.start()
# Forward data from server to client
server_to_client = threading.Thread(
target=forward_data, args=(ssl_target_socket, client_socket)
)
server_to_client.start()
client_to_server.join()
server_to_client.join()
except Exception as e:
logger.error(f"Error handling client connection: {e}")
finally:
# close the client and server sockets
client_socket.close()
ssl_target_socket.close()
@click.command(name="forward")
@click.option("--local-port", required=True, type=int, help="Local port to listen on.")
@click.option(
"--remote-host", required=True, type=str, help="Remote host to forward to."
)
@click.option(
"--remote-port", required=True, type=int, help="Remote port to forward to."
)
@click.option(
"--ssl",
is_flag=True,
help="Whether to use SSL for the connection to the remote host.",
)
@click.option(
"--tcp",
is_flag=True,
help="Whether to forward TCP traffic. "
"Default is HTTP. TCP has higher performance but not support proxies now.",
)
@click.option("--timeout", type=int, default=120, help="Timeout for the connection.")
@click.option(
"--proxies",
type=str,
help="HTTP proxy to use for forwarding requests. e.g. http://127.0.0.1:7890, "
"if not specified, try to read from environment variable http_proxy and "
"https_proxy.",
)
def start_forward(
local_port,
remote_host,
remote_port,
ssl: bool,
tcp: bool,
timeout: int,
proxies: str | None = None,
):
"""Start a TCP/HTTP proxy server that forwards traffic from a local port to a remote
host and port, just for debugging purposes, please don't use it in production
environment.
"""
"""
Example:
1. Forward HTTP traffic:
```
dbgpt net forward --local-port 5010 \
--remote-host api.openai.com \
--remote-port 443 \
--ssl \
--proxies http://127.0.0.1:7890 \
--timeout 30
```
Then you can set your environment variable `OPENAI_API_BASE` to
`http://127.0.0.1:5010/v1`
"""
if not tcp:
_start_http_forward(local_port, remote_host, remote_port, ssl, timeout, proxies)
else:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
server.bind(("0.0.0.0", local_port))
server.listen(5)
logger.info(
f"[*] Listening on 0.0.0.0:{local_port}, forwarding to "
f"{remote_host}:{remote_port}"
)
# http_proxy = ("127.0.0.1", 7890)
proxies = (
proxies or os.environ.get("http_proxy") or os.environ.get("https_proxy")
)
if proxies:
# proxies = "http://127.0.0.1:7890"
if proxies.startswith("http://") or proxies.startswith("https://"):
proxies = proxies.split("//")[1]
http_proxy = proxies.split(":")[0], int(proxies.split(":")[1])
while True:
client_socket, addr = server.accept()
logger.info(f"[*] Accepted connection from: {addr[0]}:{addr[1]}")
client_thread = threading.Thread(
target=handle_client,
args=(client_socket, remote_host, remote_port, ssl, http_proxy),
)
client_thread.start()
def _start_http_forward(
local_port, remote_host, remote_port, ssl: bool, timeout, proxies: str | None = None
):
import httpx
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request, Response
from fastapi.responses import StreamingResponse
app = FastAPI()
@app.middleware("http")
async def forward_http_request(request: Request, call_next):
"""Forward HTTP request to remote host."""
nonlocal proxies
req_body = await request.body()
scheme = request.scope.get("scheme")
path = request.scope.get("path")
headers = dict(request.headers)
# Remove needless headers
stream_response = False
if request.method in ["POST", "PUT"]:
try:
import json
stream_config = json.loads(req_body.decode("utf-8"))
stream_response = stream_config.get("stream", False)
except Exception:
pass
headers.pop("host", None)
if not proxies:
proxies = os.environ.get("http_proxy") or os.environ.get("https_proxy")
if proxies:
client_req = {
"proxies": {
"http://": proxies,
"https://": proxies,
}
}
else:
client_req = {}
if timeout:
client_req["timeout"] = timeout
client = httpx.AsyncClient(**client_req)
# async with httpx.AsyncClient(**client_req) as client:
proxy_url = f"{remote_host}:{remote_port}"
if ssl:
scheme = "https"
new_url = (
proxy_url if "://" in proxy_url else (scheme + "://" + proxy_url + path)
)
req = client.build_request(
method=request.method,
url=new_url,
cookies=request.cookies,
content=req_body,
headers=headers,
params=request.query_params,
)
has_connection = False
try:
logger.info(f"Forwarding request to {new_url}")
res = await client.send(req, stream=stream_response)
has_connection = True
if stream_response:
res_headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
background_tasks = BackgroundTasks()
background_tasks.add_task(client.aclose)
return StreamingResponse(
res.aiter_raw(),
headers=res_headers,
media_type=res.headers.get("content-type"),
)
else:
return Response(
content=res.content,
status_code=res.status_code,
headers=dict(res.headers),
)
except httpx.ConnectTimeout:
return Response(
content=f"Connection to remote server timeout", status_code=500
)
except Exception as e:
return Response(content=str(e), status_code=500)
finally:
if has_connection and not stream_response:
await client.aclose()
uvicorn.run(app, host="0.0.0.0", port=local_port)