mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
292 lines
9.6 KiB
Python
292 lines
9.6 KiB
Python
import os
|
|
import socket
|
|
import ssl as py_ssl
|
|
import threading
|
|
|
|
import click
|
|
|
|
from ..console import CliLogger
|
|
|
|
logger = CliLogger()
|
|
|
|
|
|
def forward_data(source, destination):
|
|
"""Forward data from source to destination."""
|
|
try:
|
|
while True:
|
|
data = source.recv(4096)
|
|
if b"" == data:
|
|
destination.sendall(data)
|
|
break
|
|
if not data:
|
|
break # no more data or connection closed
|
|
destination.sendall(data)
|
|
except Exception as e:
|
|
logger.error(f"Error forwarding data: {e}")
|
|
|
|
|
|
def handle_client(
|
|
client_socket,
|
|
remote_host: str,
|
|
remote_port: int,
|
|
is_ssl: bool = False,
|
|
http_proxy=None,
|
|
):
|
|
"""Handle client connection.
|
|
|
|
Create a connection to the remote host and port, and forward data between the
|
|
client and the remote host.
|
|
|
|
Close the client socket and remote socket when all forwarding threads are done.
|
|
"""
|
|
# remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
if http_proxy:
|
|
proxy_host, proxy_port = http_proxy
|
|
remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
remote_socket.connect((proxy_host, proxy_port))
|
|
client_ip = client_socket.getpeername()[0]
|
|
scheme = "https" if is_ssl else "http"
|
|
connect_request = (
|
|
f"CONNECT {remote_host}:{remote_port} HTTP/1.1\r\n"
|
|
f"Host: {remote_host}\r\n"
|
|
f"Connection: keep-alive\r\n"
|
|
f"X-Real-IP: {client_ip}\r\n"
|
|
f"X-Forwarded-For: {client_ip}\r\n"
|
|
f"X-Forwarded-Proto: {scheme}\r\n\r\n"
|
|
)
|
|
logger.info(f"Sending connect request: {connect_request}")
|
|
remote_socket.sendall(connect_request.encode())
|
|
|
|
response = b""
|
|
while True:
|
|
part = remote_socket.recv(4096)
|
|
response += part
|
|
if b"\r\n\r\n" in part:
|
|
break
|
|
|
|
if b"200 Connection established" not in response:
|
|
logger.error("Failed to establish connection through proxy")
|
|
return
|
|
|
|
else:
|
|
remote_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
remote_socket.connect((remote_host, remote_port))
|
|
|
|
if is_ssl:
|
|
# context = py_ssl.create_default_context(py_ssl.Purpose.CLIENT_AUTH)
|
|
context = py_ssl.create_default_context(py_ssl.Purpose.SERVER_AUTH)
|
|
# ssl_target_socket = py_ssl.wrap_socket(remote_socket)
|
|
ssl_target_socket = context.wrap_socket(
|
|
remote_socket, server_hostname=remote_host
|
|
)
|
|
else:
|
|
ssl_target_socket = remote_socket
|
|
try:
|
|
# ssl_target_socket.connect((remote_host, remote_port))
|
|
|
|
# Forward data from client to server
|
|
client_to_server = threading.Thread(
|
|
target=forward_data, args=(client_socket, ssl_target_socket)
|
|
)
|
|
client_to_server.start()
|
|
|
|
# Forward data from server to client
|
|
server_to_client = threading.Thread(
|
|
target=forward_data, args=(ssl_target_socket, client_socket)
|
|
)
|
|
server_to_client.start()
|
|
|
|
client_to_server.join()
|
|
server_to_client.join()
|
|
except Exception as e:
|
|
logger.error(f"Error handling client connection: {e}")
|
|
finally:
|
|
# close the client and server sockets
|
|
client_socket.close()
|
|
ssl_target_socket.close()
|
|
|
|
|
|
@click.command(name="forward")
|
|
@click.option("--local-port", required=True, type=int, help="Local port to listen on.")
|
|
@click.option(
|
|
"--remote-host", required=True, type=str, help="Remote host to forward to."
|
|
)
|
|
@click.option(
|
|
"--remote-port", required=True, type=int, help="Remote port to forward to."
|
|
)
|
|
@click.option(
|
|
"--ssl",
|
|
is_flag=True,
|
|
help="Whether to use SSL for the connection to the remote host.",
|
|
)
|
|
@click.option(
|
|
"--tcp",
|
|
is_flag=True,
|
|
help="Whether to forward TCP traffic. "
|
|
"Default is HTTP. TCP has higher performance but not support proxies now.",
|
|
)
|
|
@click.option("--timeout", type=int, default=120, help="Timeout for the connection.")
|
|
@click.option(
|
|
"--proxies",
|
|
type=str,
|
|
help="HTTP proxy to use for forwarding requests. e.g. http://127.0.0.1:7890, "
|
|
"if not specified, try to read from environment variable http_proxy and "
|
|
"https_proxy.",
|
|
)
|
|
def start_forward(
|
|
local_port,
|
|
remote_host,
|
|
remote_port,
|
|
ssl: bool,
|
|
tcp: bool,
|
|
timeout: int,
|
|
proxies: str | None = None,
|
|
):
|
|
"""Start a TCP/HTTP proxy server that forwards traffic from a local port to a remote
|
|
host and port, just for debugging purposes, please don't use it in production
|
|
environment.
|
|
"""
|
|
|
|
"""
|
|
Example:
|
|
1. Forward HTTP traffic:
|
|
|
|
```
|
|
dbgpt net forward --local-port 5010 \
|
|
--remote-host api.openai.com \
|
|
--remote-port 443 \
|
|
--ssl \
|
|
--proxies http://127.0.0.1:7890 \
|
|
--timeout 30
|
|
```
|
|
Then you can set your environment variable `OPENAI_API_BASE` to
|
|
`http://127.0.0.1:5010/v1`
|
|
"""
|
|
if not tcp:
|
|
_start_http_forward(local_port, remote_host, remote_port, ssl, timeout, proxies)
|
|
else:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
|
|
server.bind(("0.0.0.0", local_port))
|
|
server.listen(5)
|
|
logger.info(
|
|
f"[*] Listening on 0.0.0.0:{local_port}, forwarding to "
|
|
f"{remote_host}:{remote_port}"
|
|
)
|
|
# http_proxy = ("127.0.0.1", 7890)
|
|
proxies = (
|
|
proxies or os.environ.get("http_proxy") or os.environ.get("https_proxy")
|
|
)
|
|
if proxies:
|
|
# proxies = "http://127.0.0.1:7890"
|
|
if proxies.startswith("http://") or proxies.startswith("https://"):
|
|
proxies = proxies.split("//")[1]
|
|
http_proxy = proxies.split(":")[0], int(proxies.split(":")[1])
|
|
|
|
while True:
|
|
client_socket, addr = server.accept()
|
|
logger.info(f"[*] Accepted connection from: {addr[0]}:{addr[1]}")
|
|
client_thread = threading.Thread(
|
|
target=handle_client,
|
|
args=(client_socket, remote_host, remote_port, ssl, http_proxy),
|
|
)
|
|
client_thread.start()
|
|
|
|
|
|
def _start_http_forward(
|
|
local_port, remote_host, remote_port, ssl: bool, timeout, proxies: str | None = None
|
|
):
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import BackgroundTasks, Request, Response
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from dbgpt.util.fastapi import create_app
|
|
|
|
app = create_app()
|
|
|
|
@app.middleware("http")
|
|
async def forward_http_request(request: Request, call_next):
|
|
"""Forward HTTP request to remote host."""
|
|
nonlocal proxies
|
|
req_body = await request.body()
|
|
scheme = request.scope.get("scheme")
|
|
path = request.scope.get("path")
|
|
headers = dict(request.headers)
|
|
# Remove needless headers
|
|
stream_response = False
|
|
if request.method in ["POST", "PUT"]:
|
|
try:
|
|
import json
|
|
|
|
stream_config = json.loads(req_body.decode("utf-8"))
|
|
stream_response = stream_config.get("stream", False)
|
|
except Exception:
|
|
pass
|
|
headers.pop("host", None)
|
|
if not proxies:
|
|
proxies = os.environ.get("http_proxy") or os.environ.get("https_proxy")
|
|
if proxies:
|
|
client_req = {
|
|
"proxies": {
|
|
"http://": proxies,
|
|
"https://": proxies,
|
|
}
|
|
}
|
|
else:
|
|
client_req = {}
|
|
if timeout:
|
|
client_req["timeout"] = timeout
|
|
|
|
client = httpx.AsyncClient(**client_req)
|
|
# async with httpx.AsyncClient(**client_req) as client:
|
|
proxy_url = f"{remote_host}:{remote_port}"
|
|
if ssl:
|
|
scheme = "https"
|
|
new_url = (
|
|
proxy_url if "://" in proxy_url else (scheme + "://" + proxy_url + path)
|
|
)
|
|
req = client.build_request(
|
|
method=request.method,
|
|
url=new_url,
|
|
cookies=request.cookies,
|
|
content=req_body,
|
|
headers=headers,
|
|
params=request.query_params,
|
|
)
|
|
has_connection = False
|
|
try:
|
|
logger.info(f"Forwarding request to {new_url}")
|
|
res = await client.send(req, stream=stream_response)
|
|
has_connection = True
|
|
if stream_response:
|
|
res_headers = {
|
|
"Content-Type": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"Transfer-Encoding": "chunked",
|
|
}
|
|
background_tasks = BackgroundTasks()
|
|
background_tasks.add_task(client.aclose)
|
|
return StreamingResponse(
|
|
res.aiter_raw(),
|
|
headers=res_headers,
|
|
media_type=res.headers.get("content-type"),
|
|
)
|
|
else:
|
|
return Response(
|
|
content=res.content,
|
|
status_code=res.status_code,
|
|
headers=dict(res.headers),
|
|
)
|
|
except httpx.ConnectTimeout:
|
|
return Response(
|
|
content=f"Connection to remote server timeout", status_code=500
|
|
)
|
|
except Exception as e:
|
|
return Response(content=str(e), status_code=500)
|
|
finally:
|
|
if has_connection and not stream_response:
|
|
await client.aclose()
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=local_port)
|