mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
import os
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Type
|
|
|
|
import requests
|
|
from langchain_core.tools import BaseTool
|
|
from pydantic import BaseModel, Field, ValidationError, validator
|
|
|
|
|
|
class Detector(str, Enum):
|
|
ALLOWED_TOPICS = "allowed_subjects"
|
|
BANNED_TOPICS = "banned_subjects"
|
|
PROMPT_INJECTION = "prompt_injection"
|
|
KEYWORDS = "keywords"
|
|
PII = "pii"
|
|
SECRETS = "secrets"
|
|
TOXICITY = "toxicity"
|
|
|
|
|
|
class DetectorAPI(str, Enum):
|
|
ALLOWED_TOPICS = "v1/detect/topics/allowed"
|
|
BANNED_TOPICS = "v1/detect/topics/banned"
|
|
PROMPT_INJECTION = "v1/detect/prompt_injection"
|
|
KEYWORDS = "v1/detect/keywords"
|
|
PII = "v1/detect/pii"
|
|
SECRETS = "v1/detect/secrets"
|
|
TOXICITY = "v1/detect/toxicity"
|
|
|
|
|
|
class ZenGuardInput(BaseModel):
|
|
prompts: List[str] = Field(
|
|
...,
|
|
min_length=1,
|
|
description="Prompt to check",
|
|
)
|
|
detectors: List[Detector] = Field(
|
|
...,
|
|
min_length=1,
|
|
description="List of detectors by which you want to check the prompt",
|
|
)
|
|
in_parallel: bool = Field(
|
|
default=True,
|
|
description="Run prompt detection by the detector in parallel or sequentially",
|
|
)
|
|
|
|
|
|
class ZenGuardTool(BaseTool): # type: ignore[override, override]
|
|
name: str = "ZenGuard"
|
|
description: str = (
|
|
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
|
|
)
|
|
args_schema: Type[BaseModel] = ZenGuardInput
|
|
return_direct: bool = True
|
|
|
|
zenguard_api_key: Optional[str] = Field(default=None)
|
|
|
|
_ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/"
|
|
_ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
|
|
|
|
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
|
|
def set_api_key(cls, v: str) -> str:
|
|
if v is None:
|
|
v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME)
|
|
if v is None:
|
|
raise ValidationError(
|
|
"The zenguard_api_key tool option must be set either "
|
|
"by passing zenguard_api_key to the tool or by setting "
|
|
f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable"
|
|
)
|
|
return v
|
|
|
|
@property
|
|
def _api_key(self) -> str:
|
|
if self.zenguard_api_key is None:
|
|
raise ValueError(
|
|
"API key is required for the ZenGuardTool. "
|
|
"Please provide the API key by either:\n"
|
|
"1. Manually specifying it when initializing the tool: "
|
|
"ZenGuardTool(zenguard_api_key='your_api_key')\n"
|
|
"2. Setting it as an environment variable:"
|
|
f" {self._ZENGUARD_API_KEY_ENV_NAME}"
|
|
)
|
|
return self.zenguard_api_key
|
|
|
|
def _run(
|
|
self,
|
|
prompts: List[str],
|
|
detectors: List[Detector],
|
|
in_parallel: bool = True,
|
|
) -> Dict[str, Any]:
|
|
try:
|
|
postfix = None
|
|
json: Optional[Dict[str, Any]] = None
|
|
if len(detectors) == 1:
|
|
postfix = self._convert_detector_to_api(detectors[0])
|
|
json = {"messages": prompts}
|
|
else:
|
|
postfix = "v1/detect"
|
|
json = {
|
|
"messages": prompts,
|
|
"in_parallel": in_parallel,
|
|
"detectors": detectors,
|
|
}
|
|
response = requests.post(
|
|
self._ZENGUARD_API_URL_ROOT + postfix,
|
|
json=json,
|
|
headers={"x-api-key": self._api_key},
|
|
timeout=5,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except (requests.HTTPError, requests.Timeout) as e:
|
|
return {"error": str(e)}
|
|
|
|
def _convert_detector_to_api(self, detector: Detector) -> str:
|
|
return DetectorAPI[detector.name].value
|