mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
Merge b6d4ec279e
into db2e94348f
This commit is contained in:
commit
cbcd9612b3
@ -1,8 +1,10 @@
|
||||
"""Indicator Agent action."""
|
||||
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.vis.tags.vis_api_response import VisApiResponse
|
||||
@ -46,6 +48,15 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
"""Init indicator action."""
|
||||
super().__init__(**kwargs)
|
||||
self._render_protocol = VisApiResponse()
|
||||
self._blocked_hosts = {
|
||||
"169.254.169.254",
|
||||
"metadata.google.internal",
|
||||
"metadata.goog",
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
"::1",
|
||||
}
|
||||
self._allowed_methods = {"GET", "POST"}
|
||||
|
||||
@property
|
||||
def resource_need(self) -> Optional[ResourceType]:
|
||||
@ -81,6 +92,42 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
Make sure the response is correct json and can be parsed by Python json.loads.
|
||||
"""
|
||||
|
||||
def _validate_request(self, url: str, method: str) -> Optional[str]:
|
||||
"""Validate URL and method to prevent SSRF attacks."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.scheme not in {"http", "https"}:
|
||||
return f"Scheme '{parsed.scheme}' not allowed"
|
||||
|
||||
if parsed.hostname in self._blocked_hosts:
|
||||
return f"Hostname '{parsed.hostname}' is blocked"
|
||||
|
||||
if parsed.hostname:
|
||||
try:
|
||||
ip = ipaddress.ip_address(parsed.hostname)
|
||||
if (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
):
|
||||
return f"Private/internal IP '{parsed.hostname}' is blocked"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if any(
|
||||
p in parsed.hostname.lower() for p in ["internal", "local", "intranet"]
|
||||
):
|
||||
return f"Hostname pattern '{parsed.hostname}' is blocked"
|
||||
|
||||
if method.upper() not in self._allowed_methods:
|
||||
return f"Method '{method}' not allowed"
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
return f"URL validation error: {e}"
|
||||
|
||||
def build_headers(self):
|
||||
"""Build headers."""
|
||||
return None
|
||||
@ -95,12 +142,8 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
) -> ActionOutput:
|
||||
"""Perform the action."""
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"_input_convert: {type(self).__name__} ai_message: {ai_message}"
|
||||
)
|
||||
param: IndicatorInput = self._input_convert(ai_message, IndicatorInput)
|
||||
except Exception as e:
|
||||
logger.exception(str(e))
|
||||
@ -109,61 +152,72 @@ class IndicatorAction(Action[IndicatorInput]):
|
||||
content="The requested correctly structured answer could not be found.",
|
||||
)
|
||||
|
||||
try:
|
||||
status = Status.RUNNING.value
|
||||
response_success = True
|
||||
response_text = ""
|
||||
err_msg = None
|
||||
if error := self._validate_request(param.api, param.method):
|
||||
logger.warning(f"Blocked request: {error}")
|
||||
return ActionOutput(
|
||||
is_exe_success=False, content=f"Request blocked: {error}"
|
||||
)
|
||||
|
||||
try:
|
||||
if param.method.lower() == "get":
|
||||
response = requests.get(
|
||||
param.api, params=param.args, headers=self.build_headers()
|
||||
param.api,
|
||||
params=param.args,
|
||||
headers=self.build_headers(),
|
||||
timeout=10,
|
||||
allow_redirects=False,
|
||||
)
|
||||
elif param.method.lower() == "post":
|
||||
response = requests.post(
|
||||
param.api, json=param.args, headers=self.build_headers()
|
||||
param.api,
|
||||
json=param.args,
|
||||
headers=self.build_headers(),
|
||||
timeout=10,
|
||||
allow_redirects=False,
|
||||
)
|
||||
else:
|
||||
response = requests.request(
|
||||
param.method.lower(),
|
||||
param.api,
|
||||
data=param.args,
|
||||
headers=self.build_headers(),
|
||||
return ActionOutput(
|
||||
is_exe_success=False,
|
||||
content=f"Method '{param.method}' not supported",
|
||||
)
|
||||
response_text = response.text
|
||||
logger.info(f"API:{param.api}\nResult:{response_text}")
|
||||
# If the request returns an error status code, an HTTPError exception
|
||||
# is thrown
|
||||
|
||||
response.raise_for_status()
|
||||
status = Status.COMPLETE.value
|
||||
except HTTPError as http_err:
|
||||
print(f"HTTP error occurred: {http_err}")
|
||||
except Exception as e:
|
||||
response_success = False
|
||||
logger.exception(f"API [{param.indicator_name}] excute Failed!")
|
||||
status = Status.FAILED.value
|
||||
err_msg = f"API [{param.api}] request Failed!{str(e)}"
|
||||
logger.info(f"API:{param.api}\nResult:{response.text}")
|
||||
|
||||
plugin_param = {
|
||||
"name": param.indicator_name,
|
||||
"args": param.args,
|
||||
"status": status,
|
||||
"status": Status.COMPLETE.value,
|
||||
"logo": None,
|
||||
"result": response_text,
|
||||
"err_msg": err_msg,
|
||||
"result": response.text,
|
||||
"err_msg": None,
|
||||
}
|
||||
|
||||
view = (
|
||||
await self.render_protocol.display(content=plugin_param)
|
||||
if self.render_protocol
|
||||
else response_text
|
||||
else response.text
|
||||
)
|
||||
|
||||
return ActionOutput(
|
||||
is_exe_success=response_success, content=response_text, view=view
|
||||
)
|
||||
return ActionOutput(is_exe_success=True, content=response.text, view=view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Indicator Action Run Failed!")
|
||||
return ActionOutput(
|
||||
is_exe_success=False, content=f"Indicator action run failed!{str(e)}"
|
||||
logger.exception(f"API [{param.indicator_name}] failed: {e}")
|
||||
error_msg = f"API request failed: {str(e)}"
|
||||
|
||||
plugin_param = {
|
||||
"name": param.indicator_name,
|
||||
"args": param.args,
|
||||
"status": Status.FAILED.value,
|
||||
"logo": None,
|
||||
"result": "",
|
||||
"err_msg": error_msg,
|
||||
}
|
||||
|
||||
view = (
|
||||
await self.render_protocol.display(content=plugin_param)
|
||||
if self.render_protocol
|
||||
else error_msg
|
||||
)
|
||||
|
||||
return ActionOutput(is_exe_success=False, content=error_msg, view=view)
|
||||
|
Loading…
Reference in New Issue
Block a user