experimental[patch]: Fix import errors and docs of experimental prompt injection identifier (#21586)

**Description**

I noticed that code in experimental was not functional due to small
errors, and users were complaining at
https://stackoverflow.com/questions/78272521/hugging-face-prompt-injection-identifier
and https://github.com/langchain-ai/langchain/discussions/19995.

I know folks are busy with `v2.0` now, but I believe this is worthy fix
because it is isolated, minor, backwards compatible, and fixes an error
with current v0.1 code

**Issue**
https://github.com/langchain-ai/langchain/discussions/19995

**Dependencies**
None

**Twitter handle**
https://twitter.com/nelsonauner


cc @asofter

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Nelson Auner 2024-05-21 16:29:25 -07:00 committed by GitHub
parent 5a40413bfd
commit 946a0a38bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 3 deletions

View File

@ -50,7 +50,7 @@
"from transformers import AutoTokenizer, pipeline\n",
"\n",
"# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n",
"model_path = \"laiyer/deberta-v3-base-prompt-injection-v2\"\n",
"model_path = \"protectai/deberta-v3-base-prompt-injection-v2\"\n",
"revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n",
"tokenizer = AutoTokenizer.from_pretrained(\n",
" model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n",

View File

@ -1,7 +1,7 @@
"""Tool for the identification of prompt injection attacks."""
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Any
from langchain.pydantic_v1 import Field, root_validator
from langchain.tools.base import BaseTool
@ -59,7 +59,7 @@ class HuggingFaceInjectionIdentifier(BaseTool):
"Useful for when you need to ensure that prompt is free of injection attacks. "
"Input should be any message from the user."
)
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory)
model: Any = Field(default_factory=_model_default_factory)
"""Model to use for prompt injection detection.
Can be specified as transformers Pipeline or string. String should correspond to the
@ -98,3 +98,6 @@ class HuggingFaceInjectionIdentifier(BaseTool):
raise PromptInjectionException("Prompt injection attack detected", score)
return query
HuggingFaceInjectionIdentifier.update_forward_refs()