This commit is contained in:
Gecko Security 2025-07-17 17:58:08 +01:00 committed by GitHub
commit cbcd9612b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,8 +1,10 @@
"""Indicator Agent action.""" """Indicator Agent action."""
import ipaddress
import json import json
import logging import logging
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from dbgpt._private.pydantic import BaseModel, Field from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.vis.tags.vis_api_response import VisApiResponse from dbgpt.vis.tags.vis_api_response import VisApiResponse
@ -46,6 +48,15 @@ class IndicatorAction(Action[IndicatorInput]):
"""Init indicator action.""" """Init indicator action."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._render_protocol = VisApiResponse() self._render_protocol = VisApiResponse()
self._blocked_hosts = {
"169.254.169.254",
"metadata.google.internal",
"metadata.goog",
"localhost",
"127.0.0.1",
"::1",
}
self._allowed_methods = {"GET", "POST"}
@property @property
def resource_need(self) -> Optional[ResourceType]: def resource_need(self) -> Optional[ResourceType]:
@ -81,6 +92,42 @@ class IndicatorAction(Action[IndicatorInput]):
Make sure the response is correct json and can be parsed by Python json.loads. Make sure the response is correct json and can be parsed by Python json.loads.
""" """
def _validate_request(self, url: str, method: str) -> Optional[str]:
"""Validate URL and method to prevent SSRF attacks."""
try:
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
return f"Scheme '{parsed.scheme}' not allowed"
if parsed.hostname in self._blocked_hosts:
return f"Hostname '{parsed.hostname}' is blocked"
if parsed.hostname:
try:
ip = ipaddress.ip_address(parsed.hostname)
if (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
):
return f"Private/internal IP '{parsed.hostname}' is blocked"
except ValueError:
pass
if any(
p in parsed.hostname.lower() for p in ["internal", "local", "intranet"]
):
return f"Hostname pattern '{parsed.hostname}' is blocked"
if method.upper() not in self._allowed_methods:
return f"Method '{method}' not allowed"
return None
except Exception as e:
return f"URL validation error: {e}"
def build_headers(self): def build_headers(self):
"""Build headers.""" """Build headers."""
return None return None
@ -95,12 +142,8 @@ class IndicatorAction(Action[IndicatorInput]):
) -> ActionOutput: ) -> ActionOutput:
"""Perform the action.""" """Perform the action."""
import requests import requests
from requests.exceptions import HTTPError
try: try:
logger.info(
f"_input_convert: {type(self).__name__} ai_message: {ai_message}"
)
param: IndicatorInput = self._input_convert(ai_message, IndicatorInput) param: IndicatorInput = self._input_convert(ai_message, IndicatorInput)
except Exception as e: except Exception as e:
logger.exception(str(e)) logger.exception(str(e))
@ -109,61 +152,72 @@ class IndicatorAction(Action[IndicatorInput]):
content="The requested correctly structured answer could not be found.", content="The requested correctly structured answer could not be found.",
) )
if error := self._validate_request(param.api, param.method):
logger.warning(f"Blocked request: {error}")
return ActionOutput(
is_exe_success=False, content=f"Request blocked: {error}"
)
try: try:
status = Status.RUNNING.value if param.method.lower() == "get":
response_success = True response = requests.get(
response_text = "" param.api,
err_msg = None params=param.args,
try: headers=self.build_headers(),
if param.method.lower() == "get": timeout=10,
response = requests.get( allow_redirects=False,
param.api, params=param.args, headers=self.build_headers() )
) elif param.method.lower() == "post":
elif param.method.lower() == "post": response = requests.post(
response = requests.post( param.api,
param.api, json=param.args, headers=self.build_headers() json=param.args,
) headers=self.build_headers(),
else: timeout=10,
response = requests.request( allow_redirects=False,
param.method.lower(), )
param.api, else:
data=param.args, return ActionOutput(
headers=self.build_headers(), is_exe_success=False,
) content=f"Method '{param.method}' not supported",
response_text = response.text )
logger.info(f"API:{param.api}\nResult:{response_text}")
# If the request returns an error status code, an HTTPError exception response.raise_for_status()
# is thrown logger.info(f"API:{param.api}\nResult:{response.text}")
response.raise_for_status()
status = Status.COMPLETE.value
except HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except Exception as e:
response_success = False
logger.exception(f"API [{param.indicator_name}] excute Failed!")
status = Status.FAILED.value
err_msg = f"API [{param.api}] request Failed!{str(e)}"
plugin_param = { plugin_param = {
"name": param.indicator_name, "name": param.indicator_name,
"args": param.args, "args": param.args,
"status": status, "status": Status.COMPLETE.value,
"logo": None, "logo": None,
"result": response_text, "result": response.text,
"err_msg": err_msg, "err_msg": None,
} }
view = ( view = (
await self.render_protocol.display(content=plugin_param) await self.render_protocol.display(content=plugin_param)
if self.render_protocol if self.render_protocol
else response_text else response.text
) )
return ActionOutput( return ActionOutput(is_exe_success=True, content=response.text, view=view)
is_exe_success=response_success, content=response_text, view=view
)
except Exception as e: except Exception as e:
logger.exception("Indicator Action Run Failed") logger.exception(f"API [{param.indicator_name}] failed: {e}")
return ActionOutput( error_msg = f"API request failed: {str(e)}"
is_exe_success=False, content=f"Indicator action run failed!{str(e)}"
plugin_param = {
"name": param.indicator_name,
"args": param.args,
"status": Status.FAILED.value,
"logo": None,
"result": "",
"err_msg": error_msg,
}
view = (
await self.render_protocol.display(content=plugin_param)
if self.render_protocol
else error_msg
) )
return ActionOutput(is_exe_success=False, content=error_msg, view=view)