docs: docstrings langchain_community update (#14889)

Addded missed docstrings. Fixed inconsistency in docstrings.

**Note** CC @efriis 
There were PR errors on
`langchain_experimental/prompt_injection_identifier/hugging_face_identifier.py`
But, I didn't touch this file in this PR! Can it be some cache problems?
I fixed this error.
This commit is contained in:
Leonid Ganeline 2023-12-19 05:58:24 -08:00 committed by GitHub
parent 583696732c
commit b2fd41331e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 156 additions and 25 deletions

View File

@ -32,30 +32,44 @@ from langchain_community.utilities.github import GitHubAPIWrapper
class NoInput(BaseModel): class NoInput(BaseModel):
"""Schema for operations that do not require any input."""
no_input: str = Field("", description="No input required, e.g. `` (empty string).") no_input: str = Field("", description="No input required, e.g. `` (empty string).")
class GetIssue(BaseModel): class GetIssue(BaseModel):
"""Schema for operations that require an issue number as input."""
issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`") issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`")
class CommentOnIssue(BaseModel): class CommentOnIssue(BaseModel):
"""Schema for operations that require a comment as input."""
input: str = Field(..., description="Follow the required formatting.") input: str = Field(..., description="Follow the required formatting.")
class GetPR(BaseModel): class GetPR(BaseModel):
"""Schema for operations that require a PR number as input."""
pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`") pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`")
class CreatePR(BaseModel): class CreatePR(BaseModel):
"""Schema for operations that require a PR title and body as input."""
formatted_pr: str = Field(..., description="Follow the required formatting.") formatted_pr: str = Field(..., description="Follow the required formatting.")
class CreateFile(BaseModel): class CreateFile(BaseModel):
"""Schema for operations that require a file path and content as input."""
formatted_file: str = Field(..., description="Follow the required formatting.") formatted_file: str = Field(..., description="Follow the required formatting.")
class ReadFile(BaseModel): class ReadFile(BaseModel):
"""Schema for operations that require a file path as input."""
formatted_filepath: str = Field( formatted_filepath: str = Field(
..., ...,
description=( description=(
@ -66,12 +80,16 @@ class ReadFile(BaseModel):
class UpdateFile(BaseModel): class UpdateFile(BaseModel):
"""Schema for operations that require a file path and content as input."""
formatted_file_update: str = Field( formatted_file_update: str = Field(
..., description="Strictly follow the provided rules." ..., description="Strictly follow the provided rules."
) )
class DeleteFile(BaseModel): class DeleteFile(BaseModel):
"""Schema for operations that require a file path as input."""
formatted_filepath: str = Field( formatted_filepath: str = Field(
..., ...,
description=( description=(
@ -84,6 +102,8 @@ class DeleteFile(BaseModel):
class DirectoryPath(BaseModel): class DirectoryPath(BaseModel):
"""Schema for operations that require a directory path as input."""
input: str = Field( input: str = Field(
"", "",
description=( description=(
@ -94,12 +114,16 @@ class DirectoryPath(BaseModel):
class BranchName(BaseModel): class BranchName(BaseModel):
"""Schema for operations that require a branch name as input."""
branch_name: str = Field( branch_name: str = Field(
..., description="The name of the branch, e.g. `my_branch`." ..., description="The name of the branch, e.g. `my_branch`."
) )
class SearchCode(BaseModel): class SearchCode(BaseModel):
"""Schema for operations that require a search query as input."""
search_query: str = Field( search_query: str = Field(
..., ...,
description=( description=(
@ -110,6 +134,8 @@ class SearchCode(BaseModel):
class CreateReviewRequest(BaseModel): class CreateReviewRequest(BaseModel):
"""Schema for operations that require a username as input."""
username: str = Field( username: str = Field(
..., ...,
description="GitHub username of the user being requested, e.g. `my_username`.", description="GitHub username of the user being requested, e.g. `my_username`.",
@ -117,6 +143,8 @@ class CreateReviewRequest(BaseModel):
class SearchIssuesAndPRs(BaseModel): class SearchIssuesAndPRs(BaseModel):
"""Schema for operations that require a search query as input."""
search_query: str = Field( search_query: str = Field(
..., ...,
description="Natural language search query, e.g. `My issue title or topic`.", description="Natural language search query, e.g. `My issue title or topic`.",

View File

@ -50,10 +50,15 @@ def import_comet_llm_api() -> SimpleNamespace:
class CometTracer(BaseTracer): class CometTracer(BaseTracer):
"""Comet Tracer."""
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the Comet Tracer."""
super().__init__(**kwargs) super().__init__(**kwargs)
self._span_map: Dict["UUID", "Span"] = {} self._span_map: Dict["UUID", "Span"] = {}
"""Map from run id to span."""
self._chains_map: Dict["UUID", "Chain"] = {} self._chains_map: Dict["UUID", "Chain"] = {}
"""Map from run id to chain."""
self._initialize_comet_modules() self._initialize_comet_modules()
def _initialize_comet_modules(self) -> None: def _initialize_comet_modules(self) -> None:

View File

@ -259,6 +259,16 @@ class ChatFireworks(BaseChatModel):
def conditional_decorator( def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any] condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
"""Define conditional decorator.
Args:
condition: The condition.
decorator: The decorator.
Returns:
The decorated function.
"""
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition: if condition:
return decorator(func) return decorator(func)
@ -281,6 +291,7 @@ def completion_with_retry(
@conditional_decorator(use_retry, retry_decorator) @conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(**kwargs: Any) -> Any: def _completion_with_retry(**kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
return fireworks.client.ChatCompletion.create( return fireworks.client.ChatCompletion.create(
**kwargs, **kwargs,
) )

View File

@ -24,6 +24,8 @@ def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str: def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
"""Convert a list of messages to a prompt for llama."""
return "\n".join( return "\n".join(
[_convert_one_message_to_text_llama(message) for message in messages] [_convert_one_message_to_text_llama(message) for message in messages]
) )

View File

@ -53,6 +53,8 @@ logger = logging.getLogger(__name__)
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dict to a message."""
role = _dict["role"] role = _dict["role"]
if role == "user": if role == "user":
return HumanMessage(content=_dict["content"]) return HumanMessage(content=_dict["content"])
@ -72,6 +74,8 @@ def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
def convert_message_to_dict(message: BaseMessage) -> dict: def convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a message to a dict."""
message_dict: Dict[str, Any] message_dict: Dict[str, Any]
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content} message_dict = {"role": message.role, "content": message.content}

View File

@ -32,6 +32,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage: def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
"""Convert a dict to a message."""
content = _dict.get("choice", {}).get("message", {}).get("content", "") content = _dict.get("choice", {}).get("message", {}).get("content", "")
return AIMessage(content=content) return AIMessage(content=content)

View File

@ -5,11 +5,12 @@ from langchain_community.document_loaders.sitemap import SitemapLoader
class DocusaurusLoader(SitemapLoader): class DocusaurusLoader(SitemapLoader):
""" """Load from Docusaurus Documentation.
Loader that leverages the SitemapLoader to loop through the generated pages of a
It leverages the SitemapLoader to loop through the generated pages of a
Docusaurus Documentation website and extracts the content by looking for specific Docusaurus Documentation website and extracts the content by looking for specific
HTML tags. By default, the parser searches for the main content of the Docusaurus HTML tags. By default, the parser searches for the main content of the Docusaurus
page, which is normally the <article>. You also have the option to define your own page, which is normally the <article>. You can also define your own
custom HTML tags by providing them as a list, for example: ["div", ".main", "a"]. custom HTML tags by providing them as a list, for example: ["div", ".main", "a"].
""" """
@ -19,8 +20,8 @@ class DocusaurusLoader(SitemapLoader):
custom_html_tags: Optional[List[str]] = None, custom_html_tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
): ):
""" """Initialize DocusaurusLoader
Initialize DocusaurusLoader
Args: Args:
url: The base URL of the Docusaurus website. url: The base URL of the Docusaurus website.
custom_html_tags: Optional custom html tags to extract content from pages. custom_html_tags: Optional custom html tags to extract content from pages.
@ -39,7 +40,7 @@ class DocusaurusLoader(SitemapLoader):
) )
def _parsing_function(self, content: Any) -> str: def _parsing_function(self, content: Any) -> str:
"""Parses specific elements from a Docusarus page.""" """Parses specific elements from a Docusaurus page."""
relevant_elements = content.select(",".join(self.custom_html_tags)) relevant_elements = content.select(",".join(self.custom_html_tags))
for element in relevant_elements: for element in relevant_elements:

View File

@ -13,6 +13,8 @@ from langchain_community.document_loaders.unstructured import UnstructuredBaseLo
class LakeFSClient: class LakeFSClient:
"""Client for lakeFS."""
def __init__( def __init__(
self, self,
lakefs_access_key: str, lakefs_access_key: str,
@ -126,6 +128,8 @@ class LakeFSLoader(BaseLoader):
class UnstructuredLakeFSLoader(UnstructuredBaseLoader): class UnstructuredLakeFSLoader(UnstructuredBaseLoader):
"""Load from `lakeFS` as unstructured data."""
def __init__( def __init__(
self, self,
url: str, url: str,
@ -135,7 +139,7 @@ class UnstructuredLakeFSLoader(UnstructuredBaseLoader):
presign: bool = True, presign: bool = True,
**unstructured_kwargs: Any, **unstructured_kwargs: Any,
): ):
""" """Initialize UnstructuredLakeFSLoader.
Args: Args:

View File

@ -9,11 +9,9 @@ from langchain_community.document_loaders.base import BaseLoader
class RSpaceLoader(BaseLoader): class RSpaceLoader(BaseLoader):
""" """Load content from RSpace notebooks, folders, documents or PDF Gallery files.
Loads content from RSpace notebooks, folders, documents or PDF Gallery files into
Langchain documents.
Maps RSpace document <-> Langchain Document in 1-1. PDFs are imported using PyPDF. Map RSpace document <-> Langchain Document in 1-1. PDFs are imported using PyPDF.
Requirements are rspace_client (`pip install rspace_client`) and PyPDF if importing Requirements are rspace_client (`pip install rspace_client`) and PyPDF if importing
PDF docs (`pip install pypdf`). PDF docs (`pip install pypdf`).
@ -45,7 +43,7 @@ class RSpaceLoader(BaseLoader):
@classmethod @classmethod
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that API key and URL exists in environment.""" """Validate that API key and URL exist in environment."""
values["api_key"] = get_from_dict_or_env(values, "api_key", "RSPACE_API_KEY") values["api_key"] = get_from_dict_or_env(values, "api_key", "RSPACE_API_KEY")
values["url"] = get_from_dict_or_env(values, "url", "RSPACE_URL") values["url"] = get_from_dict_or_env(values, "url", "RSPACE_URL")
if "global_id" not in values or values["global_id"] is None: if "global_id" not in values or values["global_id"] is None:

View File

@ -137,6 +137,15 @@ class BeautifulSoupTransformer(BaseDocumentTransformer):
def get_navigable_strings(element: Any) -> Iterator[str]: def get_navigable_strings(element: Any) -> Iterator[str]:
"""Get all navigable strings from a BeautifulSoup element.
Args:
element: A BeautifulSoup element.
Returns:
A generator of strings.
"""
from bs4 import NavigableString, Tag from bs4 import NavigableString, Tag
for child in cast(Tag, element).children: for child in cast(Tag, element).children:

View File

@ -209,6 +209,16 @@ class Fireworks(BaseLLM):
def conditional_decorator( def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any] condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]: ) -> Callable[[Any], Any]:
"""Conditionally apply a decorator.
Args:
condition: A boolean indicating whether to apply the decorator.
decorator: A decorator function.
Returns:
A decorator function.
"""
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]: def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition: if condition:
return decorator(func) return decorator(func)

View File

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class TextGen(LLM): class TextGen(LLM):
"""text-generation-webui models. """Text generation models from WebUI.
To use, you should have the text-generation-webui installed, a model loaded, To use, you should have the text-generation-webui installed, a model loaded,
and --api added as a command-line option. and --api added as a command-line option.

View File

@ -10,6 +10,8 @@ from langchain_community.llms.utils import enforce_stop_tokens
class TitanTakeoffPro(LLM): class TitanTakeoffPro(LLM):
"""Titan Takeoff Pro is a language model that can be used to generate text."""
base_url: Optional[str] = "http://localhost:3000" base_url: Optional[str] = "http://localhost:3000"
"""Specifies the baseURL to use for the Titan Takeoff Pro API. """Specifies the baseURL to use for the Titan Takeoff Pro API.
Default = http://localhost:3000. Default = http://localhost:3000.

View File

@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class Together(LLM): class Together(LLM):
"""Wrapper around Together AI models. """LLM models from `Together`.
To use, you'll need an API key which you can find here: To use, you'll need an API key which you can find here:
https://api.together.xyz/settings/api-keys. This can be passed in as init param https://api.together.xyz/settings/api-keys. This can be passed in as init param

View File

@ -9,7 +9,8 @@ if TYPE_CHECKING:
class Xinference(LLM): class Xinference(LLM):
"""Wrapper for accessing Xinference's large-scale model inference service. """`Xinference` large-scale model inference service.
To use, you should have the xinference library installed: To use, you should have the xinference library installed:
.. code-block:: bash .. code-block:: bash

View File

@ -7,15 +7,19 @@ from langchain_core.retrievers import BaseRetriever
class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
"""Configuration for vector search."""
numberOfResults: int = 4 numberOfResults: int = 4
class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg] class RetrievalConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
"""Configuration for retrieval."""
vectorSearchConfiguration: VectorSearchConfig vectorSearchConfiguration: VectorSearchConfig
class AmazonKnowledgeBasesRetriever(BaseRetriever): class AmazonKnowledgeBasesRetriever(BaseRetriever):
"""A retriever class for `Amazon Bedrock Knowledge Bases`. """`Amazon Bedrock Knowledge Bases` retrieval.
See https://aws.amazon.com/bedrock/knowledge-bases for more info. See https://aws.amazon.com/bedrock/knowledge-bases for more info.

View File

@ -10,6 +10,8 @@ from langchain_community.utilities.arxiv import ArxivAPIWrapper
class ArxivInput(BaseModel): class ArxivInput(BaseModel):
"""Input for the Arxiv tool."""
query: str = Field(description="search query to look up") query: str = Field(description="search query to look up")

View File

@ -11,6 +11,8 @@ from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIW
class DDGInput(BaseModel): class DDGInput(BaseModel):
"""Input for the DuckDuckGo search tool."""
query: str = Field(description="search query to look up") query: str = Field(description="search query to look up")

View File

@ -728,6 +728,14 @@ class Unparser:
def roundtrip(filename, output=sys.stdout): def roundtrip(filename, output=sys.stdout):
"""Parse a file and pretty-print it to output.
The output is formatted as valid Python source code.
Args:
filename: The name of the file to parse.
output: The output stream to write to.
"""
with open(filename, "rb") as pyfile: with open(filename, "rb") as pyfile:
encoding = tokenize.detect_encoding(pyfile.readline)[0] encoding = tokenize.detect_encoding(pyfile.readline)[0]
with open(filename, "r", encoding=encoding) as pyfile: with open(filename, "r", encoding=encoding) as pyfile:

View File

@ -13,6 +13,8 @@ from langchain_community.llms.gradient_ai import TrainResult
@runtime_checkable @runtime_checkable
class TrainableLLM(Protocol): class TrainableLLM(Protocol):
"""Protocol for trainable language models."""
@abstractmethod @abstractmethod
def train_unsupervised( def train_unsupervised(
self, self,
@ -31,6 +33,8 @@ class TrainableLLM(Protocol):
class Memorize(BaseTool): class Memorize(BaseTool):
"""Tool that trains a language model."""
name: str = "Memorize" name: str = "Memorize"
description: str = ( description: str = (
"Useful whenever you observed novel information " "Useful whenever you observed novel information "

View File

@ -8,6 +8,8 @@ from langchain_community.tools.slack.base import SlackBaseTool
class SlackGetChannel(SlackBaseTool): class SlackGetChannel(SlackBaseTool):
"""Tool that gets Slack channel information."""
name: str = "get_channelid_name_dict" name: str = "get_channelid_name_dict"
description: str = "Use this tool to get channelid-name dict." description: str = "Use this tool to get channelid-name dict."

View File

@ -18,6 +18,8 @@ class SlackGetMessageSchema(BaseModel):
class SlackGetMessage(SlackBaseTool): class SlackGetMessage(SlackBaseTool):
"""Tool that gets Slack messages."""
name: str = "get_messages" name: str = "get_messages"
description: str = "Use this tool to get messages from a channel." description: str = "Use this tool to get messages from a channel."

View File

@ -13,6 +13,8 @@ from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
class TavilyInput(BaseModel): class TavilyInput(BaseModel):
"""Input for the Tavily tool."""
query: str = Field(description="search query to look up") query: str = Field(description="search query to look up")

View File

@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
class GoogleFinanceAPIWrapper(BaseModel): class GoogleFinanceAPIWrapper(BaseModel):
"""Wrapper for SerpApi's Google Finance API """Wrapper for SerpApi's Google Finance API
You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up. You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
The wrapper uses the SerpApi.com python package: The wrapper uses the SerpApi.com python package:
https://serpapi.com/integrations/python https://serpapi.com/integrations/python

View File

@ -7,6 +7,7 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
class GoogleJobsAPIWrapper(BaseModel): class GoogleJobsAPIWrapper(BaseModel):
"""Wrapper for SerpApi's Google Scholar API """Wrapper for SerpApi's Google Scholar API
You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up. You can create SerpApi.com key by signing up at: https://serpapi.com/users/sign_up.
The wrapper uses the SerpApi.com python package: The wrapper uses the SerpApi.com python package:
https://serpapi.com/integrations/python https://serpapi.com/integrations/python

View File

@ -8,6 +8,8 @@ IMAGE_AND_VIDEO_LIBRARY_URL = "https://images-api.nasa.gov"
class NasaAPIWrapper(BaseModel): class NasaAPIWrapper(BaseModel):
"""Wrapper for NASA API."""
def get_media(self, query: str) -> str: def get_media(self, query: str) -> str:
params = json.loads(query) params = json.loads(query)
if params.get("q"): if params.get("q"):

View File

@ -6,5 +6,6 @@ from packaging.version import parse
def is_openai_v1() -> bool: def is_openai_v1() -> bool:
"""Return whether OpenAI API is v1 or more."""
_version = parse(version("openai")) _version = parse(version("openai"))
return _version.major >= 1 return _version.major >= 1

View File

@ -23,6 +23,8 @@ class _ORMBase(DeclarativeBase):
class PGVecto_rs(VectorStore): class PGVecto_rs(VectorStore):
"""VectorStore backed by pgvecto_rs."""
_engine: sqlalchemy.engine.Engine _engine: sqlalchemy.engine.Engine
_table: Type[_ORMBase] _table: Type[_ORMBase]
_embedding: Embeddings _embedding: Embeddings
@ -35,6 +37,16 @@ class PGVecto_rs(VectorStore):
collection_name: str, collection_name: str,
new_table: bool = False, new_table: bool = False,
) -> None: ) -> None:
"""Initialize a PGVecto_rs vectorstore.
Args:
embedding: Embeddings to use.
dimension: Dimension of the embeddings.
db_url: Database URL.
collection_name: Name of the collection.
new_table: Whether to create a new table or connect to an existing one.
Defaults to False.
"""
try: try:
from pgvecto_rs.sqlalchemy import Vector from pgvecto_rs.sqlalchemy import Vector
except ImportError as e: except ImportError as e:

View File

@ -127,7 +127,7 @@ class SKLearnVectorStoreException(RuntimeError):
class SKLearnVectorStore(VectorStore): class SKLearnVectorStore(VectorStore):
"""Simple in-memory vector store based on the `scikit-learn` library """Simple in-memory vector store based on the `scikit-learn` library
`NearestNeighbors` implementation.""" `NearestNeighbors`."""
def __init__( def __init__(
self, self,

View File

@ -201,7 +201,6 @@ class SurrealDBStore(VectorStore):
where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold} where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold}
order by similarity desc LIMIT {k} order by similarity desc LIMIT {k}
""".format(**args) """.format(**args)
results = await self.sdb.query(query) results = await self.sdb.query(query)
if len(results) == 0: if len(results) == 0:

View File

@ -64,7 +64,7 @@ class IndexParams:
class TencentVectorDB(VectorStore): class TencentVectorDB(VectorStore):
"""Initialize wrapper around the tencent vector database. """Tencent VectorDB as a vector store.
In order to use this you need to have a database instance. In order to use this you need to have a database instance.
See the following documentation for details: See the following documentation for details:

View File

@ -37,23 +37,34 @@ def dependable_tiledb_import() -> Any:
def get_vector_index_uri_from_group(group: Any) -> str: def get_vector_index_uri_from_group(group: Any) -> str:
"""Get the URI of the vector index."""
return group[VECTOR_INDEX_NAME].uri return group[VECTOR_INDEX_NAME].uri
def get_documents_array_uri_from_group(group: Any) -> str: def get_documents_array_uri_from_group(group: Any) -> str:
"""Get the URI of the documents array from group.
Args:
group: TileDB group object.
Returns:
URI of the documents array.
"""
return group[DOCUMENTS_ARRAY_NAME].uri return group[DOCUMENTS_ARRAY_NAME].uri
def get_vector_index_uri(uri: str) -> str: def get_vector_index_uri(uri: str) -> str:
"""Get the URI of the vector index."""
return f"{uri}/{VECTOR_INDEX_NAME}" return f"{uri}/{VECTOR_INDEX_NAME}"
def get_documents_array_uri(uri: str) -> str: def get_documents_array_uri(uri: str) -> str:
"""Get the URI of the documents array."""
return f"{uri}/{DOCUMENTS_ARRAY_NAME}" return f"{uri}/{DOCUMENTS_ARRAY_NAME}"
class TileDB(VectorStore): class TileDB(VectorStore):
"""Wrapper around TileDB vector database. """TileDB vector store.
To use, you should have the ``tiledb-vector-search`` python package installed. To use, you should have the ``tiledb-vector-search`` python package installed.

View File

@ -37,8 +37,7 @@ _LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_store"
class TimescaleVector(VectorStore): class TimescaleVector(VectorStore):
"""VectorStore implementation using the timescale vector client to store vectors """Timescale Postgres vector store
in Postgres.
To use, you should have the ``timescale_vector`` python package installed. To use, you should have the ``timescale_vector`` python package installed.

View File

@ -11,7 +11,9 @@ if TYPE_CHECKING:
class PromptInjectionException(ValueError): class PromptInjectionException(ValueError):
def __init__(self, message="Prompt injection attack detected", score: float = 1.0): def __init__(
self, message: str = "Prompt injection attack detected", score: float = 1.0
):
self.message = message self.message = message
self.score = score self.score = score
@ -83,7 +85,7 @@ class HuggingFaceInjectionIdentifier(BaseTool):
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
"""Use the tool.""" """Use the tool."""
result = self.model(query) result = self.model(query) # type: ignore
score = ( score = (
result[0]["score"] result[0]["score"]
if result[0]["label"] == self.injection_label if result[0]["label"] == self.injection_label