mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 12:01:54 +00:00
experimental[patch]: Fix import errors and docs of experimental prompt injection identifier (#21586)
**Description** I noticed that code in experimental was not functional due to small errors, and users were complaining at https://stackoverflow.com/questions/78272521/hugging-face-prompt-injection-identifier and https://github.com/langchain-ai/langchain/discussions/19995. I know folks are busy with `v2.0` now, but I believe this is worthy fix because it is isolated, minor, backwards compatible, and fixes an error with current v0.1 code **Issue** https://github.com/langchain-ai/langchain/discussions/19995 **Dependencies** None **Twitter handle** https://twitter.com/nelsonauner cc @asofter --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -50,7 +50,7 @@
|
|||||||
"from transformers import AutoTokenizer, pipeline\n",
|
"from transformers import AutoTokenizer, pipeline\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n",
|
"# Using https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2\n",
|
||||||
"model_path = \"laiyer/deberta-v3-base-prompt-injection-v2\"\n",
|
"model_path = \"protectai/deberta-v3-base-prompt-injection-v2\"\n",
|
||||||
"revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n",
|
"revision = None # We recommend specifiying the revision to avoid breaking changes or supply chain attacks\n",
|
||||||
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
"tokenizer = AutoTokenizer.from_pretrained(\n",
|
||||||
" model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n",
|
" model_path, revision=revision, model_input_names=[\"input_ids\", \"attention_mask\"]\n",
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Tool for the identification of prompt injection attacks."""
|
"""Tool for the identification of prompt injection attacks."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain.pydantic_v1 import Field, root_validator
|
from langchain.pydantic_v1 import Field, root_validator
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
@@ -59,7 +59,7 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
|||||||
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
||||||
"Input should be any message from the user."
|
"Input should be any message from the user."
|
||||||
)
|
)
|
||||||
model: Union[Pipeline, str, None] = Field(default_factory=_model_default_factory)
|
model: Any = Field(default_factory=_model_default_factory)
|
||||||
"""Model to use for prompt injection detection.
|
"""Model to use for prompt injection detection.
|
||||||
|
|
||||||
Can be specified as transformers Pipeline or string. String should correspond to the
|
Can be specified as transformers Pipeline or string. String should correspond to the
|
||||||
@@ -98,3 +98,6 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
|||||||
raise PromptInjectionException("Prompt injection attack detected", score)
|
raise PromptInjectionException("Prompt injection attack detected", score)
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
HuggingFaceInjectionIdentifier.update_forward_refs()
|
||||||
|
Reference in New Issue
Block a user