mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
Merge b6d4ec279e
into db2e94348f
This commit is contained in:
commit
cbcd9612b3
@ -1,8 +1,10 @@
|
|||||||
"""Indicator Agent action."""
|
"""Indicator Agent action."""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field
|
from dbgpt._private.pydantic import BaseModel, Field
|
||||||
from dbgpt.vis.tags.vis_api_response import VisApiResponse
|
from dbgpt.vis.tags.vis_api_response import VisApiResponse
|
||||||
@ -46,6 +48,15 @@ class IndicatorAction(Action[IndicatorInput]):
|
|||||||
"""Init indicator action."""
|
"""Init indicator action."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._render_protocol = VisApiResponse()
|
self._render_protocol = VisApiResponse()
|
||||||
|
self._blocked_hosts = {
|
||||||
|
"169.254.169.254",
|
||||||
|
"metadata.google.internal",
|
||||||
|
"metadata.goog",
|
||||||
|
"localhost",
|
||||||
|
"127.0.0.1",
|
||||||
|
"::1",
|
||||||
|
}
|
||||||
|
self._allowed_methods = {"GET", "POST"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def resource_need(self) -> Optional[ResourceType]:
|
def resource_need(self) -> Optional[ResourceType]:
|
||||||
@ -81,6 +92,42 @@ class IndicatorAction(Action[IndicatorInput]):
|
|||||||
Make sure the response is correct json and can be parsed by Python json.loads.
|
Make sure the response is correct json and can be parsed by Python json.loads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _validate_request(self, url: str, method: str) -> Optional[str]:
|
||||||
|
"""Validate URL and method to prevent SSRF attacks."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
|
if parsed.scheme not in {"http", "https"}:
|
||||||
|
return f"Scheme '{parsed.scheme}' not allowed"
|
||||||
|
|
||||||
|
if parsed.hostname in self._blocked_hosts:
|
||||||
|
return f"Hostname '{parsed.hostname}' is blocked"
|
||||||
|
|
||||||
|
if parsed.hostname:
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(parsed.hostname)
|
||||||
|
if (
|
||||||
|
ip.is_private
|
||||||
|
or ip.is_loopback
|
||||||
|
or ip.is_link_local
|
||||||
|
or ip.is_multicast
|
||||||
|
):
|
||||||
|
return f"Private/internal IP '{parsed.hostname}' is blocked"
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if any(
|
||||||
|
p in parsed.hostname.lower() for p in ["internal", "local", "intranet"]
|
||||||
|
):
|
||||||
|
return f"Hostname pattern '{parsed.hostname}' is blocked"
|
||||||
|
|
||||||
|
if method.upper() not in self._allowed_methods:
|
||||||
|
return f"Method '{method}' not allowed"
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
return f"URL validation error: {e}"
|
||||||
|
|
||||||
def build_headers(self):
|
def build_headers(self):
|
||||||
"""Build headers."""
|
"""Build headers."""
|
||||||
return None
|
return None
|
||||||
@ -95,12 +142,8 @@ class IndicatorAction(Action[IndicatorInput]):
|
|||||||
) -> ActionOutput:
|
) -> ActionOutput:
|
||||||
"""Perform the action."""
|
"""Perform the action."""
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
|
||||||
f"_input_convert: {type(self).__name__} ai_message: {ai_message}"
|
|
||||||
)
|
|
||||||
param: IndicatorInput = self._input_convert(ai_message, IndicatorInput)
|
param: IndicatorInput = self._input_convert(ai_message, IndicatorInput)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(str(e))
|
logger.exception(str(e))
|
||||||
@ -109,61 +152,72 @@ class IndicatorAction(Action[IndicatorInput]):
|
|||||||
content="The requested correctly structured answer could not be found.",
|
content="The requested correctly structured answer could not be found.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if error := self._validate_request(param.api, param.method):
|
||||||
|
logger.warning(f"Blocked request: {error}")
|
||||||
|
return ActionOutput(
|
||||||
|
is_exe_success=False, content=f"Request blocked: {error}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status = Status.RUNNING.value
|
if param.method.lower() == "get":
|
||||||
response_success = True
|
response = requests.get(
|
||||||
response_text = ""
|
param.api,
|
||||||
err_msg = None
|
params=param.args,
|
||||||
try:
|
headers=self.build_headers(),
|
||||||
if param.method.lower() == "get":
|
timeout=10,
|
||||||
response = requests.get(
|
allow_redirects=False,
|
||||||
param.api, params=param.args, headers=self.build_headers()
|
)
|
||||||
)
|
elif param.method.lower() == "post":
|
||||||
elif param.method.lower() == "post":
|
response = requests.post(
|
||||||
response = requests.post(
|
param.api,
|
||||||
param.api, json=param.args, headers=self.build_headers()
|
json=param.args,
|
||||||
)
|
headers=self.build_headers(),
|
||||||
else:
|
timeout=10,
|
||||||
response = requests.request(
|
allow_redirects=False,
|
||||||
param.method.lower(),
|
)
|
||||||
param.api,
|
else:
|
||||||
data=param.args,
|
return ActionOutput(
|
||||||
headers=self.build_headers(),
|
is_exe_success=False,
|
||||||
)
|
content=f"Method '{param.method}' not supported",
|
||||||
response_text = response.text
|
)
|
||||||
logger.info(f"API:{param.api}\nResult:{response_text}")
|
|
||||||
# If the request returns an error status code, an HTTPError exception
|
response.raise_for_status()
|
||||||
# is thrown
|
logger.info(f"API:{param.api}\nResult:{response.text}")
|
||||||
response.raise_for_status()
|
|
||||||
status = Status.COMPLETE.value
|
|
||||||
except HTTPError as http_err:
|
|
||||||
print(f"HTTP error occurred: {http_err}")
|
|
||||||
except Exception as e:
|
|
||||||
response_success = False
|
|
||||||
logger.exception(f"API [{param.indicator_name}] excute Failed!")
|
|
||||||
status = Status.FAILED.value
|
|
||||||
err_msg = f"API [{param.api}] request Failed!{str(e)}"
|
|
||||||
|
|
||||||
plugin_param = {
|
plugin_param = {
|
||||||
"name": param.indicator_name,
|
"name": param.indicator_name,
|
||||||
"args": param.args,
|
"args": param.args,
|
||||||
"status": status,
|
"status": Status.COMPLETE.value,
|
||||||
"logo": None,
|
"logo": None,
|
||||||
"result": response_text,
|
"result": response.text,
|
||||||
"err_msg": err_msg,
|
"err_msg": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
view = (
|
view = (
|
||||||
await self.render_protocol.display(content=plugin_param)
|
await self.render_protocol.display(content=plugin_param)
|
||||||
if self.render_protocol
|
if self.render_protocol
|
||||||
else response_text
|
else response.text
|
||||||
)
|
)
|
||||||
|
|
||||||
return ActionOutput(
|
return ActionOutput(is_exe_success=True, content=response.text, view=view)
|
||||||
is_exe_success=response_success, content=response_text, view=view
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Indicator Action Run Failed!")
|
logger.exception(f"API [{param.indicator_name}] failed: {e}")
|
||||||
return ActionOutput(
|
error_msg = f"API request failed: {str(e)}"
|
||||||
is_exe_success=False, content=f"Indicator action run failed!{str(e)}"
|
|
||||||
|
plugin_param = {
|
||||||
|
"name": param.indicator_name,
|
||||||
|
"args": param.args,
|
||||||
|
"status": Status.FAILED.value,
|
||||||
|
"logo": None,
|
||||||
|
"result": "",
|
||||||
|
"err_msg": error_msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
view = (
|
||||||
|
await self.render_protocol.display(content=plugin_param)
|
||||||
|
if self.render_protocol
|
||||||
|
else error_msg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return ActionOutput(is_exe_success=False, content=error_msg, view=view)
|
||||||
|
Loading…
Reference in New Issue
Block a user