diff --git a/docs/docs/guides/productionization/safety/hugging_face_prompt_injection.ipynb b/docs/docs/guides/productionization/safety/hugging_face_prompt_injection.ipynb index f4b306f4553..2bafd7a95b7 100644 --- a/docs/docs/guides/productionization/safety/hugging_face_prompt_injection.ipynb +++ b/docs/docs/guides/productionization/safety/hugging_face_prompt_injection.ipynb @@ -50,7 +50,7 @@ "from transformers import AutoTokenizer, pipeline\n", "\n", "# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n", - "model_path = \"laiyer/deberta-v3-base-prompt-injection-v2\"\n", + "model_path = \"protectai/deberta-v3-base-prompt-injection-v2\"\n", "revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n", "tokenizer = AutoTokenizer.from_pretrained(\n", " model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n", diff --git a/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py b/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py index ecd5f7f6822..519ee54c85f 100644 --- a/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py +++ b/libs/experimental/langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py @@ -1,7 +1,7 @@ """Tool for the identification of prompt injection attacks.""" from __future__ import annotations -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any from langchain.pydantic_v1 import Field, root_validator from langchain.tools.base import BaseTool @@ -59,7 +59,7 @@ class HuggingFaceInjectionIdentifier(BaseTool): "Useful for when you need to ensure that prompt is free of injection attacks. " "Input should be any message from the user." ) - model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory) + model: Any = Field(default_factory=_model_default_factory) """Model to use for prompt injection detection. Can be specified as transformers Pipeline or string. String should correspond to the @@ -98,3 +98,6 @@ class HuggingFaceInjectionIdentifier(BaseTool): raise PromptInjectionException("Prompt injection attack detected", score) return query + + +HuggingFaceInjectionIdentifier.update_forward_refs()