experimental[patch]: Fix import errors and docs of experimental prompt injection identifier (#21586)

**Description**

I noticed that code in experimental was not functional due to small
errors, and users were complaining at
https://stackoverflow.com/questions/78272521/hugging-face-prompt-injection-identifier
and https://github.com/langchain-ai/langchain/discussions/19995.

I know folks are busy with `v2.0` now, but I believe this is worthy fix
because it is isolated, minor, backwards compatible, and fixes an error
with current v0.1 code

**Issue**
https://github.com/langchain-ai/langchain/discussions/19995

**Dependencies**
None

**Twitter handle**
https://twitter.com/nelsonauner


cc @asofter

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Nelson Auner 2024-05-21 16:29:25 -07:00 committed by GitHub
parent 5a40413bfd
commit 946a0a38bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 3 deletions

View File

@ -50,7 +50,7 @@
"from transformers import AutoTokenizer, pipeline\n", "from transformers import AutoTokenizer, pipeline\n",
"\n", "\n",
"# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n", "# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n",
"model_path = \"laiyer/deberta-v3-base-prompt-injection-v2\"\n", "model_path = \"protectai/deberta-v3-base-prompt-injection-v2\"\n",
"revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n", "revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n",
"tokenizer = AutoTokenizer.from_pretrained(\n", "tokenizer = AutoTokenizer.from_pretrained(\n",
" model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n", " model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n",

View File

@ -1,7 +1,7 @@
"""Tool for the identification of prompt injection attacks.""" """Tool for the identification of prompt injection attacks."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Any
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import Field, root_validator
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
@ -59,7 +59,7 @@ class HuggingFaceInjectionIdentifier(BaseTool):
"Useful for when you need to ensure that prompt is free of injection attacks. " "Useful for when you need to ensure that prompt is free of injection attacks. "
"Input should be any message from the user." "Input should be any message from the user."
) )
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory) model: Any = Field(default_factory=_model_default_factory)
"""Model to use for prompt injection detection. """Model to use for prompt injection detection.
Can be specified as transformers Pipeline or string. String should correspond to the Can be specified as transformers Pipeline or string. String should correspond to the
@ -98,3 +98,6 @@ class HuggingFaceInjectionIdentifier(BaseTool):
raise PromptInjectionException("Prompt injection attack detected", score) raise PromptInjectionException("Prompt injection attack detected", score)
return query return query
HuggingFaceInjectionIdentifier.update_forward_refs()