From c0277d06e83c9fdc40d27a4bf23d161e152fd78b Mon Sep 17 00:00:00 2001 From: Oleksandr Yaremchuk Date: Mon, 27 Nov 2023 23:56:53 +0100 Subject: [PATCH] experimental[patch] Update prompt injection model (#13930) - **Description:** Existing model used for Prompt Injection is quite outdated but we fine-tuned and open-source a new model based on the same model deberta-v3-base from Microsoft - [laiyer/deberta-v3-base-prompt-injection](https://huggingface.co/laiyer/deberta-v3-base-prompt-injection). It supports more up-to-date injections and less prone to false-positives. - **Dependencies:** No - **Tag maintainer:** - - **Twitter handle:** @alex_yaremchuk --------- Co-authored-by: Bagatur --- .../hugging_face_prompt_injection.ipynb | 39 ++++++++++++++----- .../hugging_face_identifier.py | 26 ++++++++++--- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/docs/docs/guides/safety/hugging_face_prompt_injection.ipynb b/docs/docs/guides/safety/hugging_face_prompt_injection.ipynb index 89c80b6fb84..21224ea0b0b 100644 --- a/docs/docs/guides/safety/hugging_face_prompt_injection.ipynb +++ b/docs/docs/guides/safety/hugging_face_prompt_injection.ipynb @@ -8,7 +8,7 @@ "# Hugging Face prompt injection identification\n", "\n", "This notebook shows how to prevent prompt injection attacks using the text classification model from `HuggingFace`.\n", - "It exploits the *deberta* model trained to identify prompt injections: https://huggingface.co/deepset/deberta-v3-base-injection" + "By default it uses a *deberta* model trained to identify prompt injections. In this walkthrough we'll use https://huggingface.co/laiyer/deberta-v3-base-prompt-injection." ] }, { @@ -21,19 +21,37 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "aea25588-3c3f-4506-9094-221b3a0d519b", "metadata": {}, "outputs": [ { "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "58ab3557623a495d8cc3c3e32a61938f", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - "'hugging_face_injection_identifier'" + "Downloading config.json: 0%| | 0.00/994 [00:00 Pipeline: +def _model_default_factory( + model_name: str = "deepset/deberta-v3-base-injection" +) -> Pipeline: try: from transformers import pipeline except ImportError as e: @@ -18,11 +20,11 @@ def _model_default_factory() -> Pipeline: "Cannot import transformers, please install with " "`pip install transformers`." ) from e - return pipeline("text-classification", model="deepset/deberta-v3-base-injection") + return pipeline("text-classification", model=model_name) class HuggingFaceInjectionIdentifier(BaseTool): - """Tool that uses deberta-v3-base-injection to detect prompt injection attacks.""" + """Tool that uses HF model to detect prompt injection attacks.""" name: str = "hugging_face_injection_identifier" description: str = ( @@ -30,7 +32,19 @@ class HuggingFaceInjectionIdentifier(BaseTool): "Useful for when you need to ensure that prompt is free of injection attacks. " "Input should be any message from the user." ) - model: Pipeline = Field(default_factory=_model_default_factory) + model: Any = Field(default_factory=_model_default_factory) + """Model to use for prompt injection detection. + + Can be specified as transformers Pipeline or string. String should correspond to the + model name of a text-classification transformers model. Defaults to + ``deepset/deberta-v3-base-injection`` model. + """ + + @root_validator(pre=True) + def validate_environment(cls, values: dict) -> dict: + if isinstance(values.get("model"), str): + values["model"] = _model_default_factory(model_name=values["model"]) + return values def _run(self, query: str) -> str: """Use the tool."""