mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 11:30:37 +00:00
experimental[patch] Update prompt injection model (#13930)
- **Description:** Existing model used for Prompt Injection is quite outdated but we fine-tuned and open-source a new model based on the same model deberta-v3-base from Microsoft - [laiyer/deberta-v3-base-prompt-injection](https://huggingface.co/laiyer/deberta-v3-base-prompt-injection). It supports more up-to-date injections and less prone to false-positives. - **Dependencies:** No - **Tag maintainer:** - - **Twitter handle:** @alex_yaremchuk --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e6ebde9688
commit
c0277d06e8
@@ -1,16 +1,18 @@
|
||||
"""Tool for the identification of prompt injection attacks."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Pipeline
|
||||
|
||||
|
||||
def _model_default_factory() -> Pipeline:
|
||||
def _model_default_factory(
|
||||
model_name: str = "deepset/deberta-v3-base-injection"
|
||||
) -> Pipeline:
|
||||
try:
|
||||
from transformers import pipeline
|
||||
except ImportError as e:
|
||||
@@ -18,11 +20,11 @@ def _model_default_factory() -> Pipeline:
|
||||
"Cannot import transformers, please install with "
|
||||
"`pip install transformers`."
|
||||
) from e
|
||||
return pipeline("text-classification", model="deepset/deberta-v3-base-injection")
|
||||
return pipeline("text-classification", model=model_name)
|
||||
|
||||
|
||||
class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
"""Tool that uses deberta-v3-base-injection to detect prompt injection attacks."""
|
||||
"""Tool that uses HF model to detect prompt injection attacks."""
|
||||
|
||||
name: str = "hugging_face_injection_identifier"
|
||||
description: str = (
|
||||
@@ -30,7 +32,19 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
"Useful for when you need to ensure that prompt is free of injection attacks. "
|
||||
"Input should be any message from the user."
|
||||
)
|
||||
model: Pipeline = Field(default_factory=_model_default_factory)
|
||||
model: Any = Field(default_factory=_model_default_factory)
|
||||
"""Model to use for prompt injection detection.
|
||||
|
||||
Can be specified as transformers Pipeline or string. String should correspond to the
|
||||
model name of a text-classification transformers model. Defaults to
|
||||
``deepset/deberta-v3-base-injection`` model.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
if isinstance(values.get("model"), str):
|
||||
values["model"] = _model_default_factory(model_name=values["model"])
|
||||
return values
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
|
Reference in New Issue
Block a user