mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 13:06:03 +00:00
Community: Update and fix ZenGuardTool docs and add ZenguardTool to init files (#23415)
Thank you for contributing to LangChain! - [x] **PR title**: "community: update docs and add tool to init.py" - [x] **PR message**: - **Description:** Fixed some errors and comments in the docs and added our ZenGuardTool and additional classes to init.py for easy access when importing - **Question:** when will you update the langchain-community package in pypi to make our tool available? - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Thank you for review! --------- Co-authored-by: Baur <baur.krykpayev@gmail.com>
This commit is contained in:
104
libs/community/langchain_community/tools/zenguard/tool.py
Normal file
104
libs/community/langchain_community/tools/zenguard/tool.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError, validator
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
class Detector(str, Enum):
|
||||
ALLOWED_TOPICS = "allowed_subjects"
|
||||
BANNED_TOPICS = "banned_subjects"
|
||||
PROMPT_INJECTION = "prompt_injection"
|
||||
KEYWORDS = "keywords"
|
||||
PII = "pii"
|
||||
SECRETS = "secrets"
|
||||
TOXICITY = "toxicity"
|
||||
|
||||
|
||||
class DetectorAPI(str, Enum):
|
||||
ALLOWED_TOPICS = "v1/detect/topics/allowed"
|
||||
BANNED_TOPICS = "v1/detect/topics/banned"
|
||||
PROMPT_INJECTION = "v1/detect/prompt_injection"
|
||||
KEYWORDS = "v1/detect/keywords"
|
||||
PII = "v1/detect/pii"
|
||||
SECRETS = "v1/detect/secrets"
|
||||
TOXICITY = "v1/detect/toxicity"
|
||||
|
||||
|
||||
class ZenGuardInput(BaseModel):
|
||||
prompts: List[str] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
min_length=1,
|
||||
description="Prompt to check",
|
||||
)
|
||||
detectors: List[Detector] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
description="List of detectors by which you want to check the prompt",
|
||||
)
|
||||
in_parallel: bool = Field(
|
||||
default=True,
|
||||
description="Run prompt detection by the detector in parallel or sequentially",
|
||||
)
|
||||
|
||||
|
||||
class ZenGuardTool(BaseTool):
|
||||
name: str = "ZenGuard"
|
||||
description: str = (
|
||||
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
|
||||
)
|
||||
args_schema = ZenGuardInput
|
||||
return_direct = True
|
||||
|
||||
zenguard_api_key: Optional[str] = Field(default=None)
|
||||
|
||||
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
|
||||
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"
|
||||
|
||||
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
|
||||
def set_api_key(cls, v: str) -> str:
|
||||
if v is None:
|
||||
v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME)
|
||||
if v is None:
|
||||
raise ValidationError(
|
||||
"The zenguard_api_key tool option must be set either "
|
||||
"by passing zenguard_api_key to the tool or by setting "
|
||||
f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable"
|
||||
)
|
||||
return v
|
||||
|
||||
def _run(
|
||||
self,
|
||||
prompts: List[str],
|
||||
detectors: List[Detector],
|
||||
in_parallel: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
postfix = None
|
||||
json: Optional[Dict[str, Any]] = None
|
||||
if len(detectors) == 1:
|
||||
postfix = self._convert_detector_to_api(detectors[0])
|
||||
json = {"messages": prompts}
|
||||
else:
|
||||
postfix = "v1/detect"
|
||||
json = {
|
||||
"messages": prompts,
|
||||
"in_parallel": in_parallel,
|
||||
"detectors": detectors,
|
||||
}
|
||||
response = requests.post(
|
||||
self._ZENGUARD_API_URL_ROOT + postfix,
|
||||
json=json,
|
||||
headers={"x-api-key": self.zenguard_api_key},
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except (requests.HTTPError, requests.Timeout) as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def _convert_detector_to_api(self, detector: Detector) -> str:
|
||||
return DetectorAPI[detector.name].value
|
Reference in New Issue
Block a user