mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
feat: Support HTTP sender (#1383)
This commit is contained in:
parent
df36b947d1
commit
634e62cb6e
@ -82,6 +82,12 @@ def run():
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def net():
|
||||
"""Net tools."""
|
||||
pass
|
||||
|
||||
|
||||
stop_all_func_list = []
|
||||
|
||||
|
||||
@ -100,6 +106,7 @@ cli.add_command(new)
|
||||
cli.add_command(app)
|
||||
cli.add_command(repo)
|
||||
cli.add_command(run)
|
||||
cli.add_command(net)
|
||||
add_command_alias(stop_all, name="all", parent_group=stop)
|
||||
|
||||
try:
|
||||
@ -200,6 +207,13 @@ try:
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt client command line tool failed: {e}")
|
||||
|
||||
try:
|
||||
from dbgpt.util.network._cli import start_forward
|
||||
|
||||
add_command_alias(start_forward, name="forward", parent_group=net)
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt net command line tool failed: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
|
@ -108,6 +108,7 @@ class _CategoryDetail:
|
||||
|
||||
_OPERATOR_CATEGORY_DETAIL = {
|
||||
"trigger": _CategoryDetail("Trigger", "Trigger your AWEL flow"),
|
||||
"sender": _CategoryDetail("Sender", "Send the data to the target"),
|
||||
"llm": _CategoryDetail("LLM", "Invoke LLM model"),
|
||||
"conversion": _CategoryDetail("Conversion", "Handle the conversion"),
|
||||
"output_parser": _CategoryDetail("Output Parser", "Parse the output of LLM model"),
|
||||
@ -121,6 +122,7 @@ class OperatorCategory(str, Enum):
|
||||
"""The category of the operator."""
|
||||
|
||||
TRIGGER = "trigger"
|
||||
SENDER = "sender"
|
||||
LLM = "llm"
|
||||
CONVERSION = "conversion"
|
||||
OUTPUT_PARSER = "output_parser"
|
||||
|
@ -20,6 +20,7 @@ from .base import (
|
||||
from .exceptions import (
|
||||
FlowClassMetadataException,
|
||||
FlowDAGMetadataException,
|
||||
FlowException,
|
||||
FlowMetadataException,
|
||||
)
|
||||
|
||||
@ -720,5 +721,5 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
param.default = new_param.default
|
||||
param.placeholder = new_param.placeholder
|
||||
|
||||
except ValueError as e:
|
||||
except (FlowException, ValueError) as e:
|
||||
logger.warning(f"Unable to fill the flow panel: {e}")
|
||||
|
@ -3,13 +3,14 @@
|
||||
Supports more trigger types, such as RequestHttpTrigger.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from ..flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from ..flow import IOField, OperatorCategory, OperatorType, Parameter, ViewMetadata
|
||||
from ..operators.common_operator import MapOperator
|
||||
from .http_trigger import (
|
||||
_PARAMETER_ENDPOINT,
|
||||
_PARAMETER_MEDIA_TYPE,
|
||||
@ -82,3 +83,122 @@ class RequestHttpTrigger(HttpTrigger):
|
||||
register_to_app=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DictHTTPSender(MapOperator[Dict, Dict]):
|
||||
"""HTTP Sender operator for AWEL."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("HTTP Sender"),
|
||||
name="awel_dict_http_sender",
|
||||
category=OperatorCategory.SENDER,
|
||||
description=_("Send a HTTP request to a specified endpoint"),
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Request Body"),
|
||||
"request_body",
|
||||
dict,
|
||||
description=_("The request body to send"),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Response Body"),
|
||||
"response_body",
|
||||
dict,
|
||||
description=_("The response body of the HTTP request"),
|
||||
)
|
||||
],
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
_("HTTP Address"),
|
||||
_("address"),
|
||||
type=str,
|
||||
description=_("The address to send the HTTP request to"),
|
||||
),
|
||||
_PARAMETER_METHODS_ALL.new(),
|
||||
_PARAMETER_STATUS_CODE.new(),
|
||||
Parameter.build_from(
|
||||
_("Timeout"),
|
||||
"timeout",
|
||||
type=int,
|
||||
optional=True,
|
||||
default=60,
|
||||
description=_("The timeout of the HTTP request in seconds"),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Token"),
|
||||
"token",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The token to use for the HTTP request"),
|
||||
),
|
||||
Parameter.build_from(
|
||||
_("Cookies"),
|
||||
"cookies",
|
||||
type=str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description=_("The cookies to use for the HTTP request"),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str,
|
||||
methods: Optional[str] = "GET",
|
||||
status_code: Optional[int] = 200,
|
||||
timeout: Optional[int] = 60,
|
||||
token: Optional[str] = None,
|
||||
cookies: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a HTTPSender."""
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"aiohttp is required for HTTPSender, please install it with "
|
||||
"`pip install aiohttp`"
|
||||
)
|
||||
self._address = address
|
||||
self._methods = methods
|
||||
self._status_code = status_code
|
||||
self._timeout = timeout
|
||||
self._token = token
|
||||
self._cookies = cookies
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, request_body: Dict) -> Dict:
|
||||
"""Send the request body to the specified address."""
|
||||
import aiohttp
|
||||
|
||||
if self._methods in ["POST", "PUT"]:
|
||||
req_kwargs = {"json": request_body}
|
||||
else:
|
||||
req_kwargs = {"params": request_body}
|
||||
method = self._methods or "GET"
|
||||
|
||||
headers = {}
|
||||
if self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers,
|
||||
cookies=self._cookies,
|
||||
timeout=aiohttp.ClientTimeout(total=self._timeout),
|
||||
) as session:
|
||||
async with session.request(
|
||||
method,
|
||||
self._address,
|
||||
raise_for_status=False,
|
||||
**req_kwargs,
|
||||
) as response:
|
||||
status_code = response.status
|
||||
if status_code != self._status_code:
|
||||
raise ValueError(
|
||||
f"HTTP request failed with status code {status_code}"
|
||||
)
|
||||
response_body = await response.json()
|
||||
return response_body
|
||||
|
@ -1036,7 +1036,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str,
|
||||
keys = self._key.split(".")
|
||||
for k in keys:
|
||||
dict_value = dict_value[k]
|
||||
if isinstance(dict_value, dict):
|
||||
if not isinstance(dict_value, dict):
|
||||
raise ValueError(
|
||||
f"Prefix key {self._key} is not a valid key of the request body"
|
||||
)
|
||||
|
0
dbgpt/util/network/__init__.py
Normal file
0
dbgpt/util/network/__init__.py
Normal file
289
dbgpt/util/network/_cli.py
Normal file
289
dbgpt/util/network/_cli.py
Normal file
@ -0,0 +1,289 @@
|
||||
import os
|
||||
import socket
|
||||
import ssl as py_ssl
|
||||
import threading
|
||||
|
||||
import click
|
||||
|
||||
from ..console import CliLogger
|
||||
|
||||
logger = CliLogger()
|
||||
|
||||
|
||||
def forward_data(source, destination):
|
||||
"""Forward data from source to destination."""
|
||||
try:
|
||||
while True:
|
||||
data = source.recv(4096)
|
||||
if b"" == data:
|
||||
destination.sendall(data)
|
||||
break
|
||||
if not data:
|
||||
break # no more data or connection closed
|
||||
destination.sendall(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forwarding data: {e}")
|
||||
|
||||
|
||||
def handle_client(
|
||||
client_socket,
|
||||
remote_host: str,
|
||||
remote_port: int,
|
||||
is_ssl: bool = False,
|
||||
http_proxy=None,
|
||||
):
|
||||
"""Handle client connection.
|
||||
|
||||
Create a connection to the remote host and port, and forward data between the
|
||||
client and the remote host.
|
||||
|
||||
Close the client socket and remote socket when all forwarding threads are done.
|
||||
"""
|
||||
# remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
if http_proxy:
|
||||
proxy_host, proxy_port = http_proxy
|
||||
remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
remote_socket.connect((proxy_host, proxy_port))
|
||||
client_ip = client_socket.getpeername()[0]
|
||||
scheme = "https" if is_ssl else "http"
|
||||
connect_request = (
|
||||
f"CONNECT {remote_host}:{remote_port} HTTP/1.1\r\n"
|
||||
f"Host: {remote_host}\r\n"
|
||||
f"Connection: keep-alive\r\n"
|
||||
f"X-Real-IP: {client_ip}\r\n"
|
||||
f"X-Forwarded-For: {client_ip}\r\n"
|
||||
f"X-Forwarded-Proto: {scheme}\r\n\r\n"
|
||||
)
|
||||
logger.info(f"Sending connect request: {connect_request}")
|
||||
remote_socket.sendall(connect_request.encode())
|
||||
|
||||
response = b""
|
||||
while True:
|
||||
part = remote_socket.recv(4096)
|
||||
response += part
|
||||
if b"\r\n\r\n" in part:
|
||||
break
|
||||
|
||||
if b"200 Connection established" not in response:
|
||||
logger.error("Failed to establish connection through proxy")
|
||||
return
|
||||
|
||||
else:
|
||||
remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
remote_socket.connect((remote_host, remote_port))
|
||||
|
||||
if is_ssl:
|
||||
# context = py_ssl.create_default_context(py_ssl.Purpose.CLIENT_AUTH)
|
||||
context = py_ssl.create_default_context(py_ssl.Purpose.SERVER_AUTH)
|
||||
# ssl_target_socket = py_ssl.wrap_socket(remote_socket)
|
||||
ssl_target_socket = context.wrap_socket(
|
||||
remote_socket, server_hostname=remote_host
|
||||
)
|
||||
else:
|
||||
ssl_target_socket = remote_socket
|
||||
try:
|
||||
# ssl_target_socket.connect((remote_host, remote_port))
|
||||
|
||||
# Forward data from client to server
|
||||
client_to_server = threading.Thread(
|
||||
target=forward_data, args=(client_socket, ssl_target_socket)
|
||||
)
|
||||
client_to_server.start()
|
||||
|
||||
# Forward data from server to client
|
||||
server_to_client = threading.Thread(
|
||||
target=forward_data, args=(ssl_target_socket, client_socket)
|
||||
)
|
||||
server_to_client.start()
|
||||
|
||||
client_to_server.join()
|
||||
server_to_client.join()
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling client connection: {e}")
|
||||
finally:
|
||||
# close the client and server sockets
|
||||
client_socket.close()
|
||||
ssl_target_socket.close()
|
||||
|
||||
|
||||
@click.command(name="forward")
|
||||
@click.option("--local-port", required=True, type=int, help="Local port to listen on.")
|
||||
@click.option(
|
||||
"--remote-host", required=True, type=str, help="Remote host to forward to."
|
||||
)
|
||||
@click.option(
|
||||
"--remote-port", required=True, type=int, help="Remote port to forward to."
|
||||
)
|
||||
@click.option(
|
||||
"--ssl",
|
||||
is_flag=True,
|
||||
help="Whether to use SSL for the connection to the remote host.",
|
||||
)
|
||||
@click.option(
|
||||
"--tcp",
|
||||
is_flag=True,
|
||||
help="Whether to forward TCP traffic. "
|
||||
"Default is HTTP. TCP has higher performance but not support proxies now.",
|
||||
)
|
||||
@click.option("--timeout", type=int, default=120, help="Timeout for the connection.")
|
||||
@click.option(
|
||||
"--proxies",
|
||||
type=str,
|
||||
help="HTTP proxy to use for forwarding requests. e.g. http://127.0.0.1:7890, "
|
||||
"if not specified, try to read from environment variable http_proxy and "
|
||||
"https_proxy.",
|
||||
)
|
||||
def start_forward(
|
||||
local_port,
|
||||
remote_host,
|
||||
remote_port,
|
||||
ssl: bool,
|
||||
tcp: bool,
|
||||
timeout: int,
|
||||
proxies: str | None = None,
|
||||
):
|
||||
"""Start a TCP/HTTP proxy server that forwards traffic from a local port to a remote
|
||||
host and port, just for debugging purposes, please don't use it in production
|
||||
environment.
|
||||
"""
|
||||
|
||||
"""
|
||||
Example:
|
||||
1. Forward HTTP traffic:
|
||||
|
||||
```
|
||||
dbgpt net forward --local-port 5010 \
|
||||
--remote-host api.openai.com \
|
||||
--remote-port 443 \
|
||||
--ssl \
|
||||
--proxies http://127.0.0.1:7890 \
|
||||
--timeout 30
|
||||
```
|
||||
Then you can set your environment variable `OPENAI_API_BASE` to
|
||||
`http://127.0.0.1:5010/v1`
|
||||
"""
|
||||
if not tcp:
|
||||
_start_http_forward(local_port, remote_host, remote_port, ssl, timeout, proxies)
|
||||
else:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
|
||||
server.bind(("0.0.0.0", local_port))
|
||||
server.listen(5)
|
||||
logger.info(
|
||||
f"[*] Listening on 0.0.0.0:{local_port}, forwarding to "
|
||||
f"{remote_host}:{remote_port}"
|
||||
)
|
||||
# http_proxy = ("127.0.0.1", 7890)
|
||||
proxies = (
|
||||
proxies or os.environ.get("http_proxy") or os.environ.get("https_proxy")
|
||||
)
|
||||
if proxies:
|
||||
# proxies = "http://127.0.0.1:7890"
|
||||
if proxies.startswith("http://") or proxies.startswith("https://"):
|
||||
proxies = proxies.split("//")[1]
|
||||
http_proxy = proxies.split(":")[0], int(proxies.split(":")[1])
|
||||
|
||||
while True:
|
||||
client_socket, addr = server.accept()
|
||||
logger.info(f"[*] Accepted connection from: {addr[0]}:{addr[1]}")
|
||||
client_thread = threading.Thread(
|
||||
target=handle_client,
|
||||
args=(client_socket, remote_host, remote_port, ssl, http_proxy),
|
||||
)
|
||||
client_thread.start()
|
||||
|
||||
|
||||
def _start_http_forward(
|
||||
local_port, remote_host, remote_port, ssl: bool, timeout, proxies: str | None = None
|
||||
):
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.middleware("http")
|
||||
async def forward_http_request(request: Request, call_next):
|
||||
"""Forward HTTP request to remote host."""
|
||||
nonlocal proxies
|
||||
req_body = await request.body()
|
||||
scheme = request.scope.get("scheme")
|
||||
path = request.scope.get("path")
|
||||
headers = dict(request.headers)
|
||||
# Remove needless headers
|
||||
stream_response = False
|
||||
if request.method in ["POST", "PUT"]:
|
||||
try:
|
||||
import json
|
||||
|
||||
stream_config = json.loads(req_body.decode("utf-8"))
|
||||
stream_response = stream_config.get("stream", False)
|
||||
except Exception:
|
||||
pass
|
||||
headers.pop("host", None)
|
||||
if not proxies:
|
||||
proxies = os.environ.get("http_proxy") or os.environ.get("https_proxy")
|
||||
if proxies:
|
||||
client_req = {
|
||||
"proxies": {
|
||||
"http://": proxies,
|
||||
"https://": proxies,
|
||||
}
|
||||
}
|
||||
else:
|
||||
client_req = {}
|
||||
if timeout:
|
||||
client_req["timeout"] = timeout
|
||||
|
||||
client = httpx.AsyncClient(**client_req)
|
||||
# async with httpx.AsyncClient(**client_req) as client:
|
||||
proxy_url = f"{remote_host}:{remote_port}"
|
||||
if ssl:
|
||||
scheme = "https"
|
||||
new_url = (
|
||||
proxy_url if "://" in proxy_url else (scheme + "://" + proxy_url + path)
|
||||
)
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
url=new_url,
|
||||
cookies=request.cookies,
|
||||
content=req_body,
|
||||
headers=headers,
|
||||
params=request.query_params,
|
||||
)
|
||||
has_connection = False
|
||||
try:
|
||||
logger.info(f"Forwarding request to {new_url}")
|
||||
res = await client.send(req, stream=stream_response)
|
||||
has_connection = True
|
||||
if stream_response:
|
||||
res_headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(client.aclose)
|
||||
return StreamingResponse(
|
||||
res.aiter_raw(),
|
||||
headers=res_headers,
|
||||
media_type=res.headers.get("content-type"),
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
content=res.content,
|
||||
status_code=res.status_code,
|
||||
headers=dict(res.headers),
|
||||
)
|
||||
except httpx.ConnectTimeout:
|
||||
return Response(
|
||||
content=f"Connection to remote server timeout", status_code=500
|
||||
)
|
||||
except Exception as e:
|
||||
return Response(content=str(e), status_code=500)
|
||||
finally:
|
||||
if has_connection and not stream_response:
|
||||
await client.aclose()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=local_port)
|
Loading…
Reference in New Issue
Block a user