mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-29 06:23:20 +00:00
docs: docstrings langchain_community
update (#14889)
Addded missed docstrings. Fixed inconsistency in docstrings. **Note** CC @efriis There were PR errors on `langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py` But, I didn't touch this file in this PR! Can it be some cache problems? I fixed this error.
This commit is contained in:
parent
583696732c
commit
b2fd41331e
@ -32,30 +32,44 @@ from langchain_community.utilities.github import GitHubAPIWrapper
|
||||
|
||||
|
||||
class NoInput(BaseModel):
|
||||
"""Schema for operations that do not require any input."""
|
||||
|
||||
no_input: str = Field("", description="No input required, e.g. `` (empty string).")
|
||||
|
||||
|
||||
class GetIssue(BaseModel):
|
||||
"""Schema for operations that require an issue number as input."""
|
||||
|
||||
issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`")
|
||||
|
||||
|
||||
class CommentOnIssue(BaseModel):
|
||||
"""Schema for operations that require a comment as input."""
|
||||
|
||||
input: str = Field(..., description="Follow the required formatting.")
|
||||
|
||||
|
||||
class GetPR(BaseModel):
|
||||
"""Schema for operations that require a PR number as input."""
|
||||
|
||||
pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`")
|
||||
|
||||
|
||||
class CreatePR(BaseModel):
|
||||
"""Schema for operations that require a PR title and body as input."""
|
||||
|
||||
formatted_pr: str = Field(..., description="Follow the required formatting.")
|
||||
|
||||
|
||||
class CreateFile(BaseModel):
|
||||
"""Schema for operations that require a file path and content as input."""
|
||||
|
||||
formatted_file: str = Field(..., description="Follow the required formatting.")
|
||||
|
||||
|
||||
class ReadFile(BaseModel):
|
||||
"""Schema for operations that require a file path as input."""
|
||||
|
||||
formatted_filepath: str = Field(
|
||||
...,
|
||||
description=(
|
||||
@ -66,12 +80,16 @@ class ReadFile(BaseModel):
|
||||
|
||||
|
||||
class UpdateFile(BaseModel):
|
||||
"""Schema for operations that require a file path and content as input."""
|
||||
|
||||
formatted_file_update: str = Field(
|
||||
..., description="Strictly follow the provided rules."
|
||||
)
|
||||
|
||||
|
||||
class DeleteFile(BaseModel):
|
||||
"""Schema for operations that require a file path as input."""
|
||||
|
||||
formatted_filepath: str = Field(
|
||||
...,
|
||||
description=(
|
||||
@ -84,6 +102,8 @@ class DeleteFile(BaseModel):
|
||||
|
||||
|
||||
class DirectoryPath(BaseModel):
|
||||
"""Schema for operations that require a directory path as input."""
|
||||
|
||||
input: str = Field(
|
||||
"",
|
||||
description=(
|
||||
@ -94,12 +114,16 @@ class DirectoryPath(BaseModel):
|
||||
|
||||
|
||||
class BranchName(BaseModel):
|
||||
"""Schema for operations that require a branch name as input."""
|
||||
|
||||
branch_name: str = Field(
|
||||
..., description="The name of the branch, e.g. `my_branch`."
|
||||
)
|
||||
|
||||
|
||||
class SearchCode(BaseModel):
|
||||
"""Schema for operations that require a search query as input."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description=(
|
||||
@ -110,6 +134,8 @@ class SearchCode(BaseModel):
|
||||
|
||||
|
||||
class CreateReviewRequest(BaseModel):
|
||||
"""Schema for operations that require a username as input."""
|
||||
|
||||
username: str = Field(
|
||||
...,
|
||||
description="GitHub username of the user being requested, e.g. `my_username`.",
|
||||
@ -117,6 +143,8 @@ class CreateReviewRequest(BaseModel):
|
||||
|
||||
|
||||
class SearchIssuesAndPRs(BaseModel):
|
||||
"""Schema for operations that require a search query as input."""
|
||||
|
||||
search_query: str = Field(
|
||||
...,
|
||||
description="Natural language search query, e.g. `My issue title or topic`.",
|
||||
|
@ -50,10 +50,15 @@ def import_comet_llm_api() -> SimpleNamespace:
|
||||
|
||||
|
||||
class CometTracer(BaseTracer):
|
||||
"""Comet Tracer."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the Comet Tracer."""
|
||||
super().__init__(**kwargs)
|
||||
self._span_map: Dict["UUID", "Span"] = {}
|
||||
"""Map from run id to span."""
|
||||
self._chains_map: Dict["UUID", "Chain"] = {}
|
||||
"""Map from run id to chain."""
|
||||
self._initialize_comet_modules()
|
||||
|
||||
def _initialize_comet_modules(self) -> None:
|
||||
|
@ -259,6 +259,16 @@ class ChatFireworks(BaseChatModel):
|
||||
def conditional_decorator(
|
||||
condition: bool, decorator: Callable[[Any], Any]
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Define conditional decorator.
|
||||
|
||||
Args:
|
||||
condition: The condition.
|
||||
decorator: The decorator.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||
if condition:
|
||||
return decorator(func)
|
||||
@ -281,6 +291,7 @@ def completion_with_retry(
|
||||
|
||||
@conditional_decorator(use_retry, retry_decorator)
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
return fireworks.client.ChatCompletion.create(
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -24,6 +24,8 @@ def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
|
||||
|
||||
|
||||
def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
|
||||
"""Convert a list of messages to a prompt for llama."""
|
||||
|
||||
return "\n".join(
|
||||
[_convert_one_message_to_text_llama(message) for message in messages]
|
||||
)
|
||||
|
@ -53,6 +53,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dict to a message."""
|
||||
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
@ -72,6 +74,8 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a message to a dict."""
|
||||
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
|
@ -32,6 +32,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
|
||||
"""Convert a dict to a message."""
|
||||
|
||||
content = _dict.get("choice", {}).get("message", {}).get("content", "")
|
||||
return AIMessage(content=content)
|
||||
|
||||
|
@ -5,11 +5,12 @@ from langchain_community.document_loaders.sitemap import SitemapLoader
|
||||
|
||||
|
||||
class DocusaurusLoader(SitemapLoader):
|
||||
"""
|
||||
Loader that leverages the SitemapLoader to loop through the generated pages of a
|
||||
"""Load from Docusaurus Documentation.
|
||||
|
||||
It leverages the SitemapLoader to loop through the generated pages of a
|
||||
Docusaurus Documentation website and extracts the content by looking for specific
|
||||
HTML tags. By default, the parser searches for the main content of the Docusaurus
|
||||
page, which is normally the <article>. You also have the option to define your own
|
||||
page, which is normally the <article>. You can also define your own
|
||||
custom HTML tags by providing them as a list, for example: ["div", ".main", "a"].
|
||||
"""
|
||||
|
||||
@ -19,8 +20,8 @@ class DocusaurusLoader(SitemapLoader):
|
||||
custom_html_tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize DocusaurusLoader
|
||||
"""Initialize DocusaurusLoader
|
||||
|
||||
Args:
|
||||
url: The base URL of the Docusaurus website.
|
||||
custom_html_tags: Optional custom html tags to extract content from pages.
|
||||
@ -39,7 +40,7 @@ class DocusaurusLoader(SitemapLoader):
|
||||
)
|
||||
|
||||
def _parsing_function(self, content: Any) -> str:
|
||||
"""Parses specific elements from a Docusarus page."""
|
||||
"""Parses specific elements from a Docusaurus page."""
|
||||
relevant_elements = content.select(",".join(self.custom_html_tags))
|
||||
|
||||
for element in relevant_elements:
|
||||
|
@ -13,6 +13,8 @@ from langchain_community.document_loaders.unstructured import UnstructuredBaseLo
|
||||
|
||||
|
||||
class LakeFSClient:
|
||||
"""Client for lakeFS."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lakefs_access_key: str,
|
||||
@ -126,6 +128,8 @@ class LakeFSLoader(BaseLoader):
|
||||
|
||||
|
||||
class UnstructuredLakeFSLoader(UnstructuredBaseLoader):
|
||||
"""Load from `lakeFS` as unstructured data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
@ -135,7 +139,7 @@ class UnstructuredLakeFSLoader(UnstructuredBaseLoader):
|
||||
presign: bool = True,
|
||||
**unstructured_kwargs: Any,
|
||||
):
|
||||
"""
|
||||
"""Initialize UnstructuredLakeFSLoader.
|
||||
|
||||
Args:
|
||||
|
||||
|
@ -9,11 +9,9 @@ from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class RSpaceLoader(BaseLoader):
|
||||
"""
|
||||
Loads content from RSpace notebooks, folders, documents or PDF Gallery files into
|
||||
Langchain documents.
|
||||
"""Load content from RSpace notebooks, folders, documents or PDF Gallery files.
|
||||
|
||||
Maps RSpace document <-> Langchain Document in 1-1. PDFs are imported using PyPDF.
|
||||
Map RSpace document <-> Langchain Document in 1-1. PDFs are imported using PyPDF.
|
||||
|
||||
Requirements are rspace_client (`pip install rspace_client`) and PyPDF if importing
|
||||
PDF docs (`pip install pypdf`).
|
||||
@ -45,7 +43,7 @@ class RSpaceLoader(BaseLoader):
|
||||
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that API key and URL exists in environment."""
|
||||
"""Validate that API key and URL exist in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(values, "api_key", "RSPACE_API_KEY")
|
||||
values["url"] = get_from_dict_or_env(values, "url", "RSPACE_URL")
|
||||
if "global_id" not in values or values["global_id"] is None:
|
||||
|
@ -137,6 +137,15 @@ class BeautifulSoupTransformer(BaseDocumentTransformer):
|
||||
|
||||
|
||||
def get_navigable_strings(element: Any) -> Iterator[str]:
|
||||
"""Get all navigable strings from a BeautifulSoup element.
|
||||
|
||||
Args:
|
||||
element: A BeautifulSoup element.
|
||||
|
||||
Returns:
|
||||
A generator of strings.
|
||||
"""
|
||||
|
||||
from bs4 import NavigableString, Tag
|
||||
|
||||
for child in cast(Tag, element).children:
|
||||
|
@ -209,6 +209,16 @@ class Fireworks(BaseLLM):
|
||||
def conditional_decorator(
|
||||
condition: bool, decorator: Callable[[Any], Any]
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Conditionally apply a decorator.
|
||||
|
||||
Args:
|
||||
condition: A boolean indicating whether to apply the decorator.
|
||||
decorator: A decorator function.
|
||||
|
||||
Returns:
|
||||
A decorator function.
|
||||
"""
|
||||
|
||||
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
|
||||
if condition:
|
||||
return decorator(func)
|
||||
|
@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextGen(LLM):
|
||||
"""text-generation-webui models.
|
||||
"""Text generation models from WebUI.
|
||||
|
||||
To use, you should have the text-generation-webui installed, a model loaded,
|
||||
and --api added as a command-line option.
|
||||
|
@ -10,6 +10,8 @@ from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class TitanTakeoffPro(LLM):
|
||||
"""Titan Takeoff Pro is a language model that can be used to generate text."""
|
||||
|
||||
base_url: Optional[str] = "http://localhost:3000"
|
||||
"""Specifies the baseURL to use for the Titan Takeoff Pro API.
|
||||
Default = http://localhost:3000.
|
||||
|
@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Together(LLM):
|
||||
"""Wrapper around Together AI models.
|
||||
"""LLM models from `Together`.
|
||||
|
||||
To use, you'll need an API key which you can find here:
|
||||
https://api.together.xyz/settings/api-keys. This can be passed in as init param
|
||||
|
@ -9,7 +9,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Xinference(LLM):
|
||||
"""Wrapper for accessing Xinference's large-scale model inference service.
|
||||
"""`Xinference` large-scale model inference service.
|
||||
|
||||
To use, you should have the xinference library installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
@ -7,15 +7,19 @@ from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
|
||||
"""Configuration for vector search."""
|
||||
|
||||
numberOfResults: int = 4
|
||||
|
||||
|
||||
class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
|
||||
"""Configuration for retrieval."""
|
||||
|
||||
vectorSearchConfiguration: VectorSearchConfig
|
||||
|
||||
|
||||
class AmazonKnowledgeBasesRetriever(BaseRetriever):
|
||||
"""A retriever class for `Amazon Bedrock Knowledge Bases`.
|
||||
"""`Amazon Bedrock Knowledge Bases` retrieval.
|
||||
|
||||
See https://aws.amazon.com/bedrock/knowledge-bases for more info.
|
||||
|
||||
|
@ -10,6 +10,8 @@ from langchain_community.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
|
||||
class ArxivInput(BaseModel):
|
||||
"""Input for the Arxiv tool."""
|
||||
|
||||
query: str = Field(description="search query to look up")
|
||||
|
||||
|
||||
|
@ -11,6 +11,8 @@ from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIW
|
||||
|
||||
|
||||
class DDGInput(BaseModel):
|
||||
"""Input for the DuckDuckGo search tool."""
|
||||
|
||||
query: str = Field(description="search query to look up")
|
||||
|
||||
|
||||
|
@ -728,6 +728,14 @@ class Unparser:
|
||||
|
||||
|
||||
def roundtrip(filename, output=sys.stdout):
|
||||
"""Parse a file and pretty-print it to output.
|
||||
|
||||
The output is formatted as valid Python source code.
|
||||
|
||||
Args:
|
||||
filename: The name of the file to parse.
|
||||
output: The output stream to write to.
|
||||
"""
|
||||
with open(filename, "rb") as pyfile:
|
||||
encoding = tokenize.detect_encoding(pyfile.readline)[0]
|
||||
with open(filename, "r", encoding=encoding) as pyfile:
|
||||
|
@ -13,6 +13,8 @@ from langchain_community.llms.gradient_ai import TrainResult
|
||||
|
||||
@runtime_checkable
|
||||
class TrainableLLM(Protocol):
|
||||
"""Protocol for trainable language models."""
|
||||
|
||||
@abstractmethod
|
||||
def train_unsupervised(
|
||||
self,
|
||||
@ -31,6 +33,8 @@ class TrainableLLM(Protocol):
|
||||
|
||||
|
||||
class Memorize(BaseTool):
|
||||
"""Tool that trains a language model."""
|
||||
|
||||
name: str = "Memorize"
|
||||
description: str = (
|
||||
"Useful whenever you observed novel information "
|
||||
|
@ -8,6 +8,8 @@ from langchain_community.tools.slack.base import SlackBaseTool
|
||||
|
||||
|
||||
class SlackGetChannel(SlackBaseTool):
|
||||
"""Tool that gets Slack channel information."""
|
||||
|
||||
name: str = "get_channelid_name_dict"
|
||||
description: str = "Use this tool to get channelid-name dict."
|
||||
|
||||
|
@ -18,6 +18,8 @@ class SlackGetMessageSchema(BaseModel):
|
||||
|
||||
|
||||
class SlackGetMessage(SlackBaseTool):
|
||||
"""Tool that gets Slack messages."""
|
||||
|
||||
name: str = "get_messages"
|
||||
description: str = "Use this tool to get messages from a channel."
|
||||
|
||||
|
@ -13,6 +13,8 @@ from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
|
||||
|
||||
class TavilyInput(BaseModel):
|
||||
"""Input for the Tavily tool."""
|
||||
|
||||
query: str = Field(description="search query to look up")
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
class GoogleFinanceAPIWrapper(BaseModel):
|
||||
"""Wrapper for SerpApi's Google Finance API
|
||||
|
||||
You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
|
||||
The wrapper uses the SerpApi.com python package:
|
||||
https://serpapi.com/integrations/python
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
class GoogleJobsAPIWrapper(BaseModel):
|
||||
"""Wrapper for SerpApi's Google Scholar API
|
||||
|
||||
You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
|
||||
The wrapper uses the SerpApi.com python package:
|
||||
https://serpapi.com/integrations/python
|
||||
|
@ -8,6 +8,8 @@ IMAGE_AND_VIDEO_LIBRARY_URL = "https://images-api.nasa.gov"
|
||||
|
||||
|
||||
class NasaAPIWrapper(BaseModel):
|
||||
"""Wrapper for NASA API."""
|
||||
|
||||
def get_media(self, query: str) -> str:
|
||||
params = json.loads(query)
|
||||
if params.get("q"):
|
||||
|
@ -6,5 +6,6 @@ from packaging.version import parse
|
||||
|
||||
|
||||
def is_openai_v1() -> bool:
|
||||
"""Return whether OpenAI API is v1 or more."""
|
||||
_version = parse(version("openai"))
|
||||
return _version.major >= 1
|
||||
|
@ -23,6 +23,8 @@ class _ORMBase(DeclarativeBase):
|
||||
|
||||
|
||||
class PGVecto_rs(VectorStore):
|
||||
"""VectorStore backed by pgvecto_rs."""
|
||||
|
||||
_engine: sqlalchemy.engine.Engine
|
||||
_table: Type[_ORMBase]
|
||||
_embedding: Embeddings
|
||||
@ -35,6 +37,16 @@ class PGVecto_rs(VectorStore):
|
||||
collection_name: str,
|
||||
new_table: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a PGVecto_rs vectorstore.
|
||||
|
||||
Args:
|
||||
embedding: Embeddings to use.
|
||||
dimension: Dimension of the embeddings.
|
||||
db_url: Database URL.
|
||||
collection_name: Name of the collection.
|
||||
new_table: Whether to create a new table or connect to an existing one.
|
||||
Defaults to False.
|
||||
"""
|
||||
try:
|
||||
from pgvecto_rs.sqlalchemy import Vector
|
||||
except ImportError as e:
|
||||
|
@ -127,7 +127,7 @@ class SKLearnVectorStoreException(RuntimeError):
|
||||
|
||||
class SKLearnVectorStore(VectorStore):
|
||||
"""Simple in-memory vector store based on the `scikit-learn` library
|
||||
`NearestNeighbors` implementation."""
|
||||
`NearestNeighbors`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -201,7 +201,6 @@ class SurrealDBStore(VectorStore):
|
||||
where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold}
|
||||
order by similarity desc LIMIT {k}
|
||||
""".format(**args)
|
||||
|
||||
results = await self.sdb.query(query)
|
||||
|
||||
if len(results) == 0:
|
||||
|
@ -64,7 +64,7 @@ class IndexParams:
|
||||
|
||||
|
||||
class TencentVectorDB(VectorStore):
|
||||
"""Initialize wrapper around the tencent vector database.
|
||||
"""Tencent VectorDB as a vector store.
|
||||
|
||||
In order to use this you need to have a database instance.
|
||||
See the following documentation for details:
|
||||
|
@ -37,23 +37,34 @@ def dependable_tiledb_import() -> Any:
|
||||
|
||||
|
||||
def get_vector_index_uri_from_group(group: Any) -> str:
|
||||
"""Get the URI of the vector index."""
|
||||
return group[VECTOR_INDEX_NAME].uri
|
||||
|
||||
|
||||
def get_documents_array_uri_from_group(group: Any) -> str:
|
||||
"""Get the URI of the documents array from group.
|
||||
|
||||
Args:
|
||||
group: TileDB group object.
|
||||
|
||||
Returns:
|
||||
URI of the documents array.
|
||||
"""
|
||||
return group[DOCUMENTS_ARRAY_NAME].uri
|
||||
|
||||
|
||||
def get_vector_index_uri(uri: str) -> str:
|
||||
"""Get the URI of the vector index."""
|
||||
return f"{uri}/{VECTOR_INDEX_NAME}"
|
||||
|
||||
|
||||
def get_documents_array_uri(uri: str) -> str:
|
||||
"""Get the URI of the documents array."""
|
||||
return f"{uri}/{DOCUMENTS_ARRAY_NAME}"
|
||||
|
||||
|
||||
class TileDB(VectorStore):
|
||||
"""Wrapper around TileDB vector database.
|
||||
"""TileDB vector store.
|
||||
|
||||
To use, you should have the ``tiledb-vector-search`` python package installed.
|
||||
|
||||
|
@ -37,8 +37,7 @@ _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_store"
|
||||
|
||||
|
||||
class TimescaleVector(VectorStore):
|
||||
"""VectorStore implementation using the timescale vector client to store vectors
|
||||
in Postgres.
|
||||
"""Timescale Postgres vector store
|
||||
|
||||
To use, you should have the ``timescale_vector`` python package installed.
|
||||
|
||||
|
@ -11,7 +11,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class PromptInjectionException(ValueError):
|
||||
def __init__(self, message="Prompt injection attack detected", score: float = 1.0):
|
||||
def __init__(
|
||||
self, message: str = "Prompt injection attack detected", score: float = 1.0
|
||||
):
|
||||
self.message = message
|
||||
self.score = score
|
||||
|
||||
@ -83,7 +85,7 @@ class HuggingFaceInjectionIdentifier(BaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
result = self.model(query)
|
||||
result = self.model(query) # type: ignore
|
||||
score = (
|
||||
result[0]["score"]
|
||||
if result[0]["label"] == self.injection_label
|
||||
|
Loading…
Reference in New Issue
Block a user