community[patch]: Add missing annotations (#24890)

This PR adds annotations in comunity package.

Annotations are only strictly needed in subclasses of BaseModel for
pydantic 2 compatibility.

This PR adds some unnecessary annotations, but they're not bad to have
regardless for documentation pages.
This commit is contained in:
Eugene Yurtsev 2024-07-31 14:13:44 -04:00 committed by GitHub
parent 7720483432
commit d24b82357f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 98 additions and 91 deletions

View File

@ -58,7 +58,7 @@ class UpstashRatelimitHandler(BaseCallbackHandler):
every time you invoke.
"""
raise_error = True
raise_error: bool = True
_checked: bool = False
def __init__(

View File

@ -76,7 +76,7 @@ class PebbloRetrievalQA(Chain):
"""Classifier endpoint."""
classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private:
_discover_sent: bool = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""

View File

@ -9,7 +9,7 @@ import os
import re
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
from langchain_core.utils import pre_init
@ -164,7 +164,7 @@ class _KineticaLlmFileContextParser:
"""Parser for Kinetica LLM context datafiles."""
# parse line into a dict containing role and content
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
@classmethod
def _removesuffix(cls, text: str, suffix: str) -> str:

View File

@ -135,7 +135,7 @@ class Provider(ABC):
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
stop_sequence_key: str = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
@ -364,7 +364,7 @@ class CohereProvider(Provider):
class MetaProvider(Provider):
stop_sequence_key = "stop"
stop_sequence_key: str = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models

View File

@ -1,7 +1,7 @@
# LLM Lingua Document Compressor
import re
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Pattern, Sequence, Tuple
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
@ -24,8 +24,8 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
# Pattern to match ref tags at the beginning or end of the string,
# allowing for malformed tags
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
_pattern_beginning: Pattern = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
_pattern_ending: Pattern = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
model_name: str = "NousResearch/Llama-2-7b-hf"
"""The hugging face model to use"""

View File

@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import Iterator, Union
from typing import Iterator, Pattern, Union
from langchain_core.documents import Document
@ -10,7 +10,9 @@ from langchain_community.document_loaders.base import BaseLoader
class AcreomLoader(BaseLoader):
"""Load `acreom` vault from a directory."""
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL)
FRONT_MATTER_REGEX: Pattern = re.compile(
r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL
)
"""Regex to match front matter metadata in markdown files."""
def __init__(

View File

@ -44,13 +44,13 @@ class DocugamiLoader(BaseLoader, BaseModel):
access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
"""The Docugami API access token to use."""
max_text_length = 4096
max_text_length: int = 4096
"""Max length of chunk text returned."""
min_text_length: int = 32
"""Threshold under which chunks are appended to next to avoid over-chunking."""
max_metadata_length = 512
max_metadata_length: int = 512
"""Max length of metadata text returned."""
include_xml_tags: bool = False

View File

@ -36,8 +36,8 @@ class HuggingFaceModelLoader(BaseLoader):
print(doc.metadata) # Metadata of the model
"""
BASE_URL = "https://huggingface.co/api/models"
README_BASE_URL = "https://huggingface.co/{model_id}/raw/main/README.md"
BASE_URL: str = "https://huggingface.co/api/models"
README_BASE_URL: str = "https://huggingface.co/{model_id}/raw/main/README.md"
def __init__(
self,

View File

@ -2,7 +2,7 @@ import functools
import logging
import re
from pathlib import Path
from typing import Any, Dict, Iterator, Union
from typing import Any, Dict, Iterator, Pattern, Union
import yaml
from langchain_core.documents import Document
@ -15,12 +15,16 @@ logger = logging.getLogger(__name__)
class ObsidianLoader(BaseLoader):
"""Load `Obsidian` files from directory."""
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
DATAVIEW_INLINE_PAREN_REGEX = re.compile(r"\((\w+)::\s*(.*)\)", re.MULTILINE)
FRONT_MATTER_REGEX: Pattern = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX: Pattern = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX: Pattern = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX: Pattern = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX: Pattern = re.compile(
r"\[(\w+)::\s*(.*)\]", re.MULTILINE
)
DATAVIEW_INLINE_PAREN_REGEX: Pattern = re.compile(
r"\((\w+)::\s*(.*)\)", re.MULTILINE
)
def __init__(
self,

View File

@ -39,7 +39,7 @@ class OneNoteLoader(BaseLoader, BaseModel):
"""Personal access token"""
onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
"""URL of Microsoft Graph API for OneNote"""
authority_url = "https://login.microsoftonline.com/consumers/"
authority_url: str = "https://login.microsoftonline.com/consumers/"
"""A URL that identifies a token authority"""
token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
"""Path to the file where the access token is stored"""

View File

@ -1,5 +1,5 @@
import re
from typing import Callable, List
from typing import Callable, List, Pattern
from langchain_community.document_loaders.parsers.language.code_segmenter import (
CodeSegmenter,
@ -9,11 +9,11 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
class CobolSegmenter(CodeSegmenter):
"""Code segmenter for `COBOL`."""
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN = re.compile(
PARAGRAPH_PATTERN: Pattern = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN: Pattern = re.compile(
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
)
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
SECTION_PATTERN: Pattern = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
def __init__(self, code: str):
super().__init__(code)

View File

@ -27,7 +27,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
chunk_size: int = 16
model_name = "ErnieBot-Embedding-V1"
model_name: str = "ErnieBot-Embedding-V1"
_lock = threading.Lock()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json
import re
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore
@ -63,7 +63,7 @@ class AGEGraph(GraphStore):
}
# precompiled regex for checking chars in graph labels
label_regex = re.compile("[^0-9a-zA-Z]+")
label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
def __init__(
self, graph_name: str, conf: Dict[str, Any], create: bool = True

View File

@ -39,9 +39,9 @@ class MoonshotCommon(BaseModel):
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""
model_name: str = Field(default="moonshot-v1-8k", alias="model")
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing"""
max_tokens = 1024
max_tokens: int = 1024
"""Maximum number of tokens to generate."""
temperature = 0.3
temperature: float = 0.3
"""Temperature parameter (higher values make the model more creative)."""
class Config:

View File

@ -244,7 +244,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
Defaults to True."""
return_full_text = False
return_full_text: bool = False
"""Whether to prepend the prompt to the generated text. Defaults to False."""
@property

View File

@ -26,7 +26,7 @@ class Provider(ABC):
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
stop_sequence_key: str = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
@ -38,7 +38,7 @@ class CohereProvider(Provider):
class MetaProvider(Provider):
stop_sequence_key = "stop"
stop_sequence_key: str = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models

View File

@ -16,7 +16,7 @@ class SVEndpointHandler:
:param str host_url: Base URL of the DaaS API service
"""
API_BASE_PATH = "/api/predict"
API_BASE_PATH: str = "/api/predict"
def __init__(self, host_url: str):
"""

View File

@ -41,7 +41,7 @@ class SolarCommon(BaseModel):
model_name: str = Field(default="solar-1-mini-chat", alias="model")
"""Model name. Available models listed here: https://console.upstage.ai/services/solar"""
max_tokens: int = Field(default=1024)
temperature = 0.3
temperature: float = 0.3
class Config:
allow_population_by_field_name = True

View File

@ -27,7 +27,7 @@ class SupabaseVectorTranslator(Visitor):
]
"""Subset of allowed logical comparators."""
metadata_column = "metadata"
metadata_column: str = "metadata"
def _map_comparator(self, comparator: Comparator) -> str:
"""

View File

@ -74,8 +74,8 @@ class BearlyInterpreterTool:
"""Tool for evaluating python code in a sandbox environment."""
api_key: str
endpoint = "https://exec.bearly.ai/v1/interpreter"
name = "bearly_interpreter"
endpoint: str = "https://exec.bearly.ai/v1/interpreter"
name: str = "bearly_interpreter"
args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
files: Dict[str, FileInfo] = {}

View File

@ -51,12 +51,12 @@ class ZenGuardTool(BaseTool):
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
)
args_schema = ZenGuardInput
return_direct = True
return_direct: bool = True
zenguard_api_key: Optional[str] = Field(default=None)
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"
_ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/"
_ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
def set_api_key(cls, v: str) -> str:

View File

@ -11,7 +11,7 @@ class Portkey:
Default: "https://api.portkey.ai/v1/proxy"
"""
base = "https://api.portkey.ai/v1/proxy"
base: str = "https://api.portkey.ai/v1/proxy"
@staticmethod
def Config(

View File

@ -28,7 +28,7 @@ class TokenEscaper:
# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
DEFAULT_ESCAPED_CHARS: str = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:

View File

@ -29,7 +29,7 @@ class AtlasDB(VectorStore):
vectorstore = AtlasDB("my_project", embeddings.embed_query)
"""
_ATLAS_DEFAULT_ID_FIELD = "atlas_id"
_ATLAS_DEFAULT_ID_FIELD: str = "atlas_id"
def __init__(
self,

View File

@ -21,7 +21,7 @@ DEFAULT_TOPN = 4
class AwaDB(VectorStore):
"""`AwaDB` vector store."""
_DEFAULT_TABLE_NAME = "langchain_awadb"
_DEFAULT_TABLE_NAME: str = "langchain_awadb"
def __init__(
self,

View File

@ -53,7 +53,7 @@ class Bagel(VectorStore):
vectorstore = Bagel(cluster_name="langchain_store")
"""
_LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain"
_LANGCHAIN_DEFAULT_CLUSTER_NAME: str = "langchain"
def __init__(
self,

View File

@ -66,7 +66,7 @@ class Chroma(VectorStore):
vectorstore = Chroma("langchain_store", embeddings)
"""
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
_LANGCHAIN_DEFAULT_COLLECTION_NAME: str = "langchain"
def __init__(
self,

View File

@ -60,10 +60,10 @@ class CouchbaseVectorStore(VectorStore):
"""
# Default batch size
DEFAULT_BATCH_SIZE = 100
_metadata_key = "metadata"
_default_text_key = "text"
_default_embedding_key = "embedding"
DEFAULT_BATCH_SIZE: int = 100
_metadata_key: str = "metadata"
_default_text_key: str = "text"
_default_embedding_key: str = "embedding"
def _check_bucket_exists(self) -> bool:
"""Check if the bucket exists in the linked Couchbase cluster"""

View File

@ -51,7 +51,7 @@ class DeepLake(VectorStore):
vectorstore = DeepLake("langchain_store", embeddings.embed_query)
"""
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH = "./deeplake/"
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH: str = "./deeplake/"
_valid_search_kwargs = ["lambda_mult"]
def __init__(

View File

@ -45,9 +45,9 @@ class Epsilla(VectorStore):
epsilla = Epsilla(client, embeddings, db_path, db_name)
"""
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
_LANGCHAIN_DEFAULT_DB_NAME: str = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH: str = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME: str = "langchain_collection"
def __init__(
self,

View File

@ -13,6 +13,7 @@ from typing import (
Iterable,
List,
Optional,
Pattern,
Tuple,
Type,
)
@ -223,7 +224,7 @@ class HanaDB(VectorStore):
return embedding
# Compile pattern only once, for better performance
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
_compiled_pattern: Pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
@staticmethod
def _sanitize_metadata_keys(metadata: dict) -> dict:

View File

@ -48,7 +48,7 @@ class ManticoreSearchSettings(BaseSettings):
hnsw_m: int = 16 # The default is 16.
# An optional setting that defines a construction time/accuracy trade-off.
hnsw_ef_construction = 100
hnsw_ef_construction: int = 100
def get_connection_string(self) -> str:
return self.proto + "://" + self.host + ":" + str(self.port)

View File

@ -85,8 +85,8 @@ class Qdrant(VectorStore):
qdrant = Qdrant(client, collection_name, embedding_function)
"""
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
CONTENT_KEY: str = "page_content"
METADATA_KEY: str = "metadata"
VECTOR_NAME = None
def __init__(

View File

@ -25,7 +25,7 @@ class SemaDB(VectorStore):
"""
HOST = "semadb.p.rapidapi.com"
HOST: str = "semadb.p.rapidapi.com"
BASE_URL = "https://" + HOST
def __init__(

View File

@ -23,9 +23,9 @@ def mock_quip(): # type: ignore
@pytest.mark.requires("quip_api")
class TestQuipLoader:
API_URL = "https://example-api.quip.com"
API_URL: str = "https://example-api.quip.com"
DOC_URL_PREFIX = ("https://example.quip.com",)
ACCESS_TOKEN = "api_token"
ACCESS_TOKEN: str = "api_token"
MOCK_FOLDER_IDS = ["ABC"]
MOCK_THREAD_IDS = ["ABC", "DEF"]

View File

@ -59,8 +59,8 @@ def test_custom_formatter() -> None:
"""Test ability to create a custom content formatter."""
class CustomFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
content_type: str = "application/json"
accepts: str = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(
@ -101,8 +101,8 @@ def test_invalid_request_format() -> None:
"""Test invalid request format."""
class CustomContentFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
content_type: str = "application/json"
accepts: str = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(

View File

@ -13,19 +13,19 @@ README_PATH = Path(__file__).parents[4] / "README.md"
class FakeUploadResponse:
status_code = 200
text = "fake_uuid"
status_code: int = 200
text: str = "fake_uuid"
class FakePushResponse:
status_code = 200
status_code: int = 200
def json(self) -> Any:
return {"uuid": "fake_uuid"}
class FakePullResponse:
status_code = 200
status_code: int = 200
def json(self) -> Any:
return {

View File

@ -11,7 +11,7 @@ LOG = logging.getLogger(__name__)
class TestChatKinetica:
test_ctx_json = """
test_ctx_json: str = """
{
"payload":{
"context":[

View File

@ -20,10 +20,10 @@ def mock_confluence(): # type: ignore
@pytest.mark.requires("atlassian", "bs4", "lxml")
class TestConfluenceLoader:
CONFLUENCE_URL = "https://example.atlassian.com/wiki"
MOCK_USERNAME = "user@gmail.com"
MOCK_API_TOKEN = "api_token"
MOCK_SPACE_KEY = "spaceId123"
CONFLUENCE_URL: str = "https://example.atlassian.com/wiki"
MOCK_USERNAME: str = "user@gmail.com"
MOCK_API_TOKEN: str = "api_token"
MOCK_SPACE_KEY: str = "spaceId123"
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
ConfluenceLoader(

View File

@ -57,12 +57,12 @@ def mock_lakefs_client_no_presign_local() -> Any:
class TestLakeFSLoader(unittest.TestCase):
lakefs_access_key = "lakefs_access_key"
lakefs_secret_key = "lakefs_secret_key"
endpoint = "endpoint"
repo = "repo"
ref = "ref"
path = "path"
lakefs_access_key: str = "lakefs_access_key"
lakefs_secret_key: str = "lakefs_secret_key"
endpoint: str = "endpoint"
repo: str = "repo"
ref: str = "ref"
path: str = "path"
@requests_mock.Mocker()
@pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local")

View File

@ -21,9 +21,9 @@ def mock_connector_id(): # type: ignore
@pytest.mark.requires("psychicapi")
class TestPsychicLoader:
MOCK_API_KEY = "api_key"
MOCK_CONNECTOR_ID = "notion"
MOCK_ACCOUNT_ID = "account_id"
MOCK_API_KEY: str = "api_key"
MOCK_CONNECTOR_ID: str = "notion"
MOCK_ACCOUNT_ID: str = "account_id"
def test_psychic_loader_initialization(
self, mock_psychic: MagicMock, mock_connector_id: MagicMock

View File

@ -4,9 +4,9 @@ from langchain_community.document_loaders.rspace import RSpaceLoader
class TestRSpaceLoader(unittest.TestCase):
url = "https://community.researchspace.com"
api_key = "myapikey"
global_id = "SD12345"
url: str = "https://community.researchspace.com"
api_key: str = "myapikey"
global_id: str = "SD12345"
def test_valid_arguments(self) -> None:
loader = RSpaceLoader(

View File

@ -70,7 +70,7 @@ class MockGradientaiPackage(MagicMock):
"""Mock Gradientai package."""
Gradient = MockGradient
__version__ = "1.4.0"
__version__: str = "1.4.0"
def test_gradient_llm_sync() -> None:

View File

@ -5,7 +5,7 @@ from langchain_community.embeddings.ollama import OllamaEmbeddings
class MockResponse:
status_code = 200
status_code: int = 200
def json(self) -> dict:
return {"embedding": [1, 2, 3]}