mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
experimental[patch]: Fix import errors and docs of experimental prompt injection identifier (#21586)
**Description** I noticed that code in experimental was not functional due to small errors, and users were complaining at https://stackoverflow.com/questions/78272521/hugging-face-prompt-injection-identifier and https://github.com/langchain-ai/langchain/discussions/19995. I know folks are busy with `v2.0` now, but I believe this is worthy fix because it is isolated, minor, backwards compatible, and fixes an error with current v0.1 code **Issue** https://github.com/langchain-ai/langchain/discussions/19995 **Dependencies** None **Twitter handle** https://twitter.com/nelsonauner cc @asofter --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
5a40413bfd
commit
946a0a38bd
@ -50,7 +50,7 @@
|
||||
"from transformers import AutoTokenizer, pipeline\n",
|
||||
"\n",
|
||||
"# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n",
|
||||
"model_path = \"laiyer/deberta-v3-base-prompt-injection-v2\"\n",
|
||||
"model_path = \"protectai/deberta-v3-base-prompt-injection-v2\"\n",
|
||||
"revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
||||
" model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n",
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Tool for the identification of prompt injection attacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
@ -59,7 +59,7 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
||||
"Input should be any message from the user."
|
||||
)
|
||||
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory)
|
||||
model: Any = Field(default_factory=_model_default_factory)
|
||||
"""Model to use for prompt injection detection.
|
||||
|
||||
Can be specified as transformers Pipeline or string. String should correspond to the
|
||||
@ -98,3 +98,6 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
raise PromptInjectionException("Prompt injection attack detected", score)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
HuggingFaceInjectionIdentifier.update_forward_refs()
|
||||
|
Loading…
Reference in New Issue
Block a user