diff --git a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/indicator_action.py b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/indicator_action.py index 9156e3ff8..09f6752aa 100644 --- a/packages/dbgpt-core/src/dbgpt/agent/expand/actions/indicator_action.py +++ b/packages/dbgpt-core/src/dbgpt/agent/expand/actions/indicator_action.py @@ -1,8 +1,10 @@ """Indicator Agent action.""" +import ipaddress import json import logging from typing import Optional +from urllib.parse import urlparse from dbgpt._private.pydantic import BaseModel, Field from dbgpt.vis.tags.vis_api_response import VisApiResponse @@ -46,6 +48,15 @@ class IndicatorAction(Action[IndicatorInput]): """Init indicator action.""" super().__init__(**kwargs) self._render_protocol = VisApiResponse() + self._blocked_hosts = { + "169.254.169.254", + "metadata.google.internal", + "metadata.goog", + "localhost", + "127.0.0.1", + "::1", + } + self._allowed_methods = {"GET", "POST"} @property def resource_need(self) -> Optional[ResourceType]: @@ -81,6 +92,42 @@ class IndicatorAction(Action[IndicatorInput]): Make sure the response is correct json and can be parsed by Python json.loads. """ + def _validate_request(self, url: str, method: str) -> Optional[str]: + """Validate URL and method to prevent SSRF attacks.""" + try: + parsed = urlparse(url) + + if parsed.scheme not in {"http", "https"}: + return f"Scheme '{parsed.scheme}' not allowed" + + if parsed.hostname in self._blocked_hosts: + return f"Hostname '{parsed.hostname}' is blocked" + + if parsed.hostname: + try: + ip = ipaddress.ip_address(parsed.hostname) + if ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + ): + return f"Private/internal IP '{parsed.hostname}' is blocked" + except ValueError: + pass + + if any( + p in parsed.hostname.lower() for p in ["internal", "local", "intranet"] + ): + return f"Hostname pattern '{parsed.hostname}' is blocked" + + if method.upper() not in self._allowed_methods: + return f"Method '{method}' not allowed" + + return None + except Exception as e: + return f"URL validation error: {e}" + def build_headers(self): """Build headers.""" return None @@ -95,12 +142,8 @@ class IndicatorAction(Action[IndicatorInput]): ) -> ActionOutput: """Perform the action.""" import requests - from requests.exceptions import HTTPError try: - logger.info( - f"_input_convert: {type(self).__name__} ai_message: {ai_message}" - ) param: IndicatorInput = self._input_convert(ai_message, IndicatorInput) except Exception as e: logger.exception(str(e)) @@ -109,61 +152,72 @@ class IndicatorAction(Action[IndicatorInput]): content="The requested correctly structured answer could not be found.", ) + if error := self._validate_request(param.api, param.method): + logger.warning(f"Blocked request: {error}") + return ActionOutput( + is_exe_success=False, content=f"Request blocked: {error}" + ) + try: - status = Status.RUNNING.value - response_success = True - response_text = "" - err_msg = None - try: - if param.method.lower() == "get": - response = requests.get( - param.api, params=param.args, headers=self.build_headers() - ) - elif param.method.lower() == "post": - response = requests.post( - param.api, json=param.args, headers=self.build_headers() - ) - else: - response = requests.request( - param.method.lower(), - param.api, - data=param.args, - headers=self.build_headers(), - ) - response_text = response.text - logger.info(f"API:{param.api}\nResult:{response_text}") - # If the request returns an error status code, an HTTPError exception - # is thrown - response.raise_for_status() - status = Status.COMPLETE.value - except HTTPError as http_err: - print(f"HTTP error occurred: {http_err}") - except Exception as e: - response_success = False - logger.exception(f"API [{param.indicator_name}] excute Failed!") - status = Status.FAILED.value - err_msg = f"API [{param.api}] request Failed!{str(e)}" + if param.method.lower() == "get": + response = requests.get( + param.api, + params=param.args, + headers=self.build_headers(), + timeout=10, + allow_redirects=False, + ) + elif param.method.lower() == "post": + response = requests.post( + param.api, + json=param.args, + headers=self.build_headers(), + timeout=10, + allow_redirects=False, + ) + else: + return ActionOutput( + is_exe_success=False, + content=f"Method '{param.method}' not supported", + ) + + response.raise_for_status() + logger.info(f"API:{param.api}\nResult:{response.text}") plugin_param = { "name": param.indicator_name, "args": param.args, - "status": status, + "status": Status.COMPLETE.value, "logo": None, - "result": response_text, - "err_msg": err_msg, + "result": response.text, + "err_msg": None, } view = ( await self.render_protocol.display(content=plugin_param) if self.render_protocol - else response_text + else response.text ) - return ActionOutput( - is_exe_success=response_success, content=response_text, view=view - ) + return ActionOutput(is_exe_success=True, content=response.text, view=view) + except Exception as e: - logger.exception("Indicator Action Run Failed!") - return ActionOutput( - is_exe_success=False, content=f"Indicator action run failed!{str(e)}" + logger.exception(f"API [{param.indicator_name}] failed: {e}") + error_msg = f"API request failed: {str(e)}" + + plugin_param = { + "name": param.indicator_name, + "args": param.args, + "status": Status.FAILED.value, + "logo": None, + "result": "", + "err_msg": error_msg, + } + + view = ( + await self.render_protocol.display(content=plugin_param) + if self.render_protocol + else error_msg ) + + return ActionOutput(is_exe_success=False, content=error_msg, view=view)