diff --git a/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py b/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py index c587a280834..c12c8ca48b0 100644 --- a/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py +++ b/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py @@ -1,21 +1,28 @@ """Tool for the identification of prompt injection attacks.""" +from __future__ import annotations -from enum import Enum +from typing import TYPE_CHECKING +from langchain.pydantic_v1 import Field from langchain.tools.base import BaseTool -from transformers import Pipeline, pipeline + +if TYPE_CHECKING: + from transformers import Pipeline -class PromptInjectionModelOutput(str, Enum): - """Output of the prompt injection model.""" - - LEGIT = "LEGIT" - INJECTION = "INJECTION" +def _model_default_factory() -> Pipeline: + try: + from transformers import pipeline + except ImportError as e: + raise ImportError( + "Cannot import transformers, please install with " + "`pip install transformers`." + ) from e + return pipeline("text-classification", model="deepset/deberta-v3-base-injection") class HuggingFaceInjectionIdentifier(BaseTool): - """Tool that uses deberta-v3-base-injection model - to identify prompt injection attacks.""" + """Tool that uses deberta-v3-base-injection to detect prompt injection attacks.""" name: str = "hugging_face_injection_identifier" description: str = ( @@ -23,21 +30,12 @@ class HuggingFaceInjectionIdentifier(BaseTool): "Useful for when you need to ensure that prompt is free of injection attacks. " "Input should be any message from the user." ) - - model: Pipeline = pipeline( - "text-classification", model="deepset/deberta-v3-base-injection" - ) - - def _classify_user_input(self, query: str) -> bool: - result = self.model(query) - result = sorted(result, key=lambda x: x["score"], reverse=True) - if result[0]["label"] == PromptInjectionModelOutput.INJECTION: - return False - return True + model: Pipeline = Field(default_factory=_model_default_factory) def _run(self, query: str) -> str: """Use the tool.""" - is_query_safe = self._classify_user_input(query) - if not is_query_safe: + result = self.model(query) + result = sorted(result, key=lambda x: x["score"], reverse=True) + if result[0]["label"] == "INJECTION": raise ValueError("Prompt injection attack detected") return query