community[patch]: Add missing annotations (#24890)

This PR adds annotations in comunity package.

Annotations are only strictly needed in subclasses of BaseModel for
pydantic 2 compatibility.

This PR adds some unnecessary annotations, but they're not bad to have
regardless for documentation pages.
This commit is contained in:
Eugene Yurtsev 2024-07-31 14:13:44 -04:00 committed by GitHub
parent 7720483432
commit d24b82357f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 98 additions and 91 deletions

View File

@ -58,7 +58,7 @@ class UpstashRatelimitHandler(BaseCallbackHandler):
every time you invoke. every time you invoke.
""" """
raise_error = True raise_error: bool = True
_checked: bool = False _checked: bool = False
def __init__( def __init__(

View File

@ -76,7 +76,7 @@ class PebbloRetrievalQA(Chain):
"""Classifier endpoint.""" """Classifier endpoint."""
classifier_location: str = "local" #: :meta private: classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'.""" """Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent = False #: :meta private: _discover_sent: bool = False #: :meta private:
"""Flag to check if discover payload has been sent.""" """Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private: _prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent.""" """Flag to check if prompt payload has been sent."""

View File

@ -9,7 +9,7 @@ import os
import re import re
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
@ -164,7 +164,7 @@ class _KineticaLlmFileContextParser:
"""Parser for Kinetica LLM context datafiles.""" """Parser for Kinetica LLM context datafiles."""
# parse line into a dict containing role and content # parse line into a dict containing role and content
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL) PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
@classmethod @classmethod
def _removesuffix(cls, text: str, suffix: str) -> str: def _removesuffix(cls, text: str, suffix: str) -> str:

View File

@ -135,7 +135,7 @@ class Provider(ABC):
class CohereProvider(Provider): class CohereProvider(Provider):
stop_sequence_key = "stop_sequences" stop_sequence_key: str = "stop_sequences"
def __init__(self) -> None: def __init__(self) -> None:
from oci.generative_ai_inference import models from oci.generative_ai_inference import models
@ -364,7 +364,7 @@ class CohereProvider(Provider):
class MetaProvider(Provider): class MetaProvider(Provider):
stop_sequence_key = "stop" stop_sequence_key: str = "stop"
def __init__(self) -> None: def __init__(self) -> None:
from oci.generative_ai_inference import models from oci.generative_ai_inference import models

View File

@ -1,7 +1,7 @@
# LLM Lingua Document Compressor # LLM Lingua Document Compressor
import re import re
from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import Any, Dict, List, Optional, Pattern, Sequence, Tuple
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document from langchain_core.documents import Document
@ -24,8 +24,8 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
# Pattern to match ref tags at the beginning or end of the string, # Pattern to match ref tags at the beginning or end of the string,
# allowing for malformed tags # allowing for malformed tags
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?") _pattern_beginning: Pattern = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z") _pattern_ending: Pattern = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
model_name: str = "NousResearch/Llama-2-7b-hf" model_name: str = "NousResearch/Llama-2-7b-hf"
"""The hugging face model to use""" """The hugging face model to use"""

View File

@ -1,6 +1,6 @@
import re import re
from pathlib import Path from pathlib import Path
from typing import Iterator, Union from typing import Iterator, Pattern, Union
from langchain_core.documents import Document from langchain_core.documents import Document
@ -10,7 +10,9 @@ from langchain_community.document_loaders.base import BaseLoader
class AcreomLoader(BaseLoader): class AcreomLoader(BaseLoader):
"""Load `acreom` vault from a directory.""" """Load `acreom` vault from a directory."""
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL) FRONT_MATTER_REGEX: Pattern = re.compile(
r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL
)
"""Regex to match front matter metadata in markdown files.""" """Regex to match front matter metadata in markdown files."""
def __init__( def __init__(

View File

@ -44,13 +44,13 @@ class DocugamiLoader(BaseLoader, BaseModel):
access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY") access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
"""The Docugami API access token to use.""" """The Docugami API access token to use."""
max_text_length = 4096 max_text_length: int = 4096
"""Max length of chunk text returned.""" """Max length of chunk text returned."""
min_text_length: int = 32 min_text_length: int = 32
"""Threshold under which chunks are appended to next to avoid over-chunking.""" """Threshold under which chunks are appended to next to avoid over-chunking."""
max_metadata_length = 512 max_metadata_length: int = 512
"""Max length of metadata text returned.""" """Max length of metadata text returned."""
include_xml_tags: bool = False include_xml_tags: bool = False

View File

@ -36,8 +36,8 @@ class HuggingFaceModelLoader(BaseLoader):
print(doc.metadata) # Metadata of the model print(doc.metadata) # Metadata of the model
""" """
BASE_URL = "https://huggingface.co/api/models" BASE_URL: str = "https://huggingface.co/api/models"
README_BASE_URL = "https://huggingface.co/{model_id}/raw/main/README.md" README_BASE_URL: str = "https://huggingface.co/{model_id}/raw/main/README.md"
def __init__( def __init__(
self, self,

View File

@ -2,7 +2,7 @@ import functools
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterator, Union from typing import Any, Dict, Iterator, Pattern, Union
import yaml import yaml
from langchain_core.documents import Document from langchain_core.documents import Document
@ -15,12 +15,16 @@ logger = logging.getLogger(__name__)
class ObsidianLoader(BaseLoader): class ObsidianLoader(BaseLoader):
"""Load `Obsidian` files from directory.""" """Load `Obsidian` files from directory."""
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL) FRONT_MATTER_REGEX: Pattern = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL) TEMPLATE_VARIABLE_REGEX: Pattern = re.compile(r"{{(.*?)}}", re.DOTALL)
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)") TAG_REGEX: Pattern = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE) DATAVIEW_LINE_REGEX: Pattern = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE) DATAVIEW_INLINE_BRACKET_REGEX: Pattern = re.compile(
DATAVIEW_INLINE_PAREN_REGEX = re.compile(r"\((\w+)::\s*(.*)\)", re.MULTILINE) r"\[(\w+)::\s*(.*)\]", re.MULTILINE
)
DATAVIEW_INLINE_PAREN_REGEX: Pattern = re.compile(
r"\((\w+)::\s*(.*)\)", re.MULTILINE
)
def __init__( def __init__(
self, self,

View File

@ -39,7 +39,7 @@ class OneNoteLoader(BaseLoader, BaseModel):
"""Personal access token""" """Personal access token"""
onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote" onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
"""URL of Microsoft Graph API for OneNote""" """URL of Microsoft Graph API for OneNote"""
authority_url = "https://login.microsoftonline.com/consumers/" authority_url: str = "https://login.microsoftonline.com/consumers/"
"""A URL that identifies a token authority""" """A URL that identifies a token authority"""
token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt" token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
"""Path to the file where the access token is stored""" """Path to the file where the access token is stored"""

View File

@ -1,5 +1,5 @@
import re import re
from typing import Callable, List from typing import Callable, List, Pattern
from langchain_community.document_loaders.parsers.language.code_segmenter import ( from langchain_community.document_loaders.parsers.language.code_segmenter import (
CodeSegmenter, CodeSegmenter,
@ -9,11 +9,11 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
class CobolSegmenter(CodeSegmenter): class CobolSegmenter(CodeSegmenter):
"""Code segmenter for `COBOL`.""" """Code segmenter for `COBOL`."""
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE) PARAGRAPH_PATTERN: Pattern = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN = re.compile( DIVISION_PATTERN: Pattern = re.compile(
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
) )
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE) SECTION_PATTERN: Pattern = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
def __init__(self, code: str): def __init__(self, code: str):
super().__init__(code) super().__init__(code)

View File

@ -27,7 +27,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
chunk_size: int = 16 chunk_size: int = 16
model_name = "ErnieBot-Embedding-V1" model_name: str = "ErnieBot-Embedding-V1"
_lock = threading.Lock() _lock = threading.Lock()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
import re import re
from hashlib import md5 from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
from langchain_community.graphs.graph_document import GraphDocument from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore from langchain_community.graphs.graph_store import GraphStore
@ -63,7 +63,7 @@ class AGEGraph(GraphStore):
} }
# precompiled regex for checking chars in graph labels # precompiled regex for checking chars in graph labels
label_regex = re.compile("[^0-9a-zA-Z]+") label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
def __init__( def __init__(
self, graph_name: str, conf: Dict[str, Any], create: bool = True self, graph_name: str, conf: Dict[str, Any], create: bool = True

View File

@ -39,9 +39,9 @@ class MoonshotCommon(BaseModel):
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys""" """Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""
model_name: str = Field(default="moonshot-v1-8k", alias="model") model_name: str = Field(default="moonshot-v1-8k", alias="model")
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing""" """Model name. Available models listed here: https://platform.moonshot.cn/pricing"""
max_tokens = 1024 max_tokens: int = 1024
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
temperature = 0.3 temperature: float = 0.3
"""Temperature parameter (higher values make the model more creative).""" """Temperature parameter (higher values make the model more creative)."""
class Config: class Config:

View File

@ -244,7 +244,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_. """Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
Defaults to True.""" Defaults to True."""
return_full_text = False return_full_text: bool = False
"""Whether to prepend the prompt to the generated text. Defaults to False.""" """Whether to prepend the prompt to the generated text. Defaults to False."""
@property @property

View File

@ -26,7 +26,7 @@ class Provider(ABC):
class CohereProvider(Provider): class CohereProvider(Provider):
stop_sequence_key = "stop_sequences" stop_sequence_key: str = "stop_sequences"
def __init__(self) -> None: def __init__(self) -> None:
from oci.generative_ai_inference import models from oci.generative_ai_inference import models
@ -38,7 +38,7 @@ class CohereProvider(Provider):
class MetaProvider(Provider): class MetaProvider(Provider):
stop_sequence_key = "stop" stop_sequence_key: str = "stop"
def __init__(self) -> None: def __init__(self) -> None:
from oci.generative_ai_inference import models from oci.generative_ai_inference import models

View File

@ -16,7 +16,7 @@ class SVEndpointHandler:
:param str host_url: Base URL of the DaaS API service :param str host_url: Base URL of the DaaS API service
""" """
API_BASE_PATH = "/api/predict" API_BASE_PATH: str = "/api/predict"
def __init__(self, host_url: str): def __init__(self, host_url: str):
""" """

View File

@ -41,7 +41,7 @@ class SolarCommon(BaseModel):
model_name: str = Field(default="solar-1-mini-chat", alias="model") model_name: str = Field(default="solar-1-mini-chat", alias="model")
"""Model name. Available models listed here: https://console.upstage.ai/services/solar""" """Model name. Available models listed here: https://console.upstage.ai/services/solar"""
max_tokens: int = Field(default=1024) max_tokens: int = Field(default=1024)
temperature = 0.3 temperature: float = 0.3
class Config: class Config:
allow_population_by_field_name = True allow_population_by_field_name = True

View File

@ -27,7 +27,7 @@ class SupabaseVectorTranslator(Visitor):
] ]
"""Subset of allowed logical comparators.""" """Subset of allowed logical comparators."""
metadata_column = "metadata" metadata_column: str = "metadata"
def _map_comparator(self, comparator: Comparator) -> str: def _map_comparator(self, comparator: Comparator) -> str:
""" """

View File

@ -74,8 +74,8 @@ class BearlyInterpreterTool:
"""Tool for evaluating python code in a sandbox environment.""" """Tool for evaluating python code in a sandbox environment."""
api_key: str api_key: str
endpoint = "https://exec.bearly.ai/v1/interpreter" endpoint: str = "https://exec.bearly.ai/v1/interpreter"
name = "bearly_interpreter" name: str = "bearly_interpreter"
args_schema: Type[BaseModel] = BearlyInterpreterToolArguments args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
files: Dict[str, FileInfo] = {} files: Dict[str, FileInfo] = {}

View File

@ -51,12 +51,12 @@ class ZenGuardTool(BaseTool):
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails." "ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
) )
args_schema = ZenGuardInput args_schema = ZenGuardInput
return_direct = True return_direct: bool = True
zenguard_api_key: Optional[str] = Field(default=None) zenguard_api_key: Optional[str] = Field(default=None)
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/" _ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/"
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY" _ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
@validator("zenguard_api_key", pre=True, always=True, check_fields=False) @validator("zenguard_api_key", pre=True, always=True, check_fields=False)
def set_api_key(cls, v: str) -> str: def set_api_key(cls, v: str) -> str:

View File

@ -11,7 +11,7 @@ class Portkey:
Default: "https://api.portkey.ai/v1/proxy" Default: "https://api.portkey.ai/v1/proxy"
""" """
base = "https://api.portkey.ai/v1/proxy" base: str = "https://api.portkey.ai/v1/proxy"
@staticmethod @staticmethod
def Config( def Config(

View File

@ -28,7 +28,7 @@ class TokenEscaper:
# Characters that RediSearch requires us to escape during queries. # Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization # Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]" DEFAULT_ESCAPED_CHARS: str = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
def __init__(self, escape_chars_re: Optional[Pattern] = None): def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re: if escape_chars_re:

View File

@ -29,7 +29,7 @@ class AtlasDB(VectorStore):
vectorstore = AtlasDB("my_project", embeddings.embed_query) vectorstore = AtlasDB("my_project", embeddings.embed_query)
""" """
_ATLAS_DEFAULT_ID_FIELD = "atlas_id" _ATLAS_DEFAULT_ID_FIELD: str = "atlas_id"
def __init__( def __init__(
self, self,

View File

@ -21,7 +21,7 @@ DEFAULT_TOPN = 4
class AwaDB(VectorStore): class AwaDB(VectorStore):
"""`AwaDB` vector store.""" """`AwaDB` vector store."""
_DEFAULT_TABLE_NAME = "langchain_awadb" _DEFAULT_TABLE_NAME: str = "langchain_awadb"
def __init__( def __init__(
self, self,

View File

@ -53,7 +53,7 @@ class Bagel(VectorStore):
vectorstore = Bagel(cluster_name="langchain_store") vectorstore = Bagel(cluster_name="langchain_store")
""" """
_LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain" _LANGCHAIN_DEFAULT_CLUSTER_NAME: str = "langchain"
def __init__( def __init__(
self, self,

View File

@ -66,7 +66,7 @@ class Chroma(VectorStore):
vectorstore = Chroma("langchain_store", embeddings) vectorstore = Chroma("langchain_store", embeddings)
""" """
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" _LANGCHAIN_DEFAULT_COLLECTION_NAME: str = "langchain"
def __init__( def __init__(
self, self,

View File

@ -60,10 +60,10 @@ class CouchbaseVectorStore(VectorStore):
""" """
# Default batch size # Default batch size
DEFAULT_BATCH_SIZE = 100 DEFAULT_BATCH_SIZE: int = 100
_metadata_key = "metadata" _metadata_key: str = "metadata"
_default_text_key = "text" _default_text_key: str = "text"
_default_embedding_key = "embedding" _default_embedding_key: str = "embedding"
def _check_bucket_exists(self) -> bool: def _check_bucket_exists(self) -> bool:
"""Check if the bucket exists in the linked Couchbase cluster""" """Check if the bucket exists in the linked Couchbase cluster"""

View File

@ -51,7 +51,7 @@ class DeepLake(VectorStore):
vectorstore = DeepLake("langchain_store", embeddings.embed_query) vectorstore = DeepLake("langchain_store", embeddings.embed_query)
""" """
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH = "./deeplake/" _LANGCHAIN_DEFAULT_DEEPLAKE_PATH: str = "./deeplake/"
_valid_search_kwargs = ["lambda_mult"] _valid_search_kwargs = ["lambda_mult"]
def __init__( def __init__(

View File

@ -45,9 +45,9 @@ class Epsilla(VectorStore):
epsilla = Epsilla(client, embeddings, db_path, db_name) epsilla = Epsilla(client, embeddings, db_path, db_name)
""" """
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store" _LANGCHAIN_DEFAULT_DB_NAME: str = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla" _LANGCHAIN_DEFAULT_DB_PATH: str = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection" _LANGCHAIN_DEFAULT_TABLE_NAME: str = "langchain_collection"
def __init__( def __init__(
self, self,

View File

@ -13,6 +13,7 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Pattern,
Tuple, Tuple,
Type, Type,
) )
@ -223,7 +224,7 @@ class HanaDB(VectorStore):
return embedding return embedding
# Compile pattern only once, for better performance # Compile pattern only once, for better performance
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$") _compiled_pattern: Pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
@staticmethod @staticmethod
def _sanitize_metadata_keys(metadata: dict) -> dict: def _sanitize_metadata_keys(metadata: dict) -> dict:

View File

@ -48,7 +48,7 @@ class ManticoreSearchSettings(BaseSettings):
hnsw_m: int = 16 # The default is 16. hnsw_m: int = 16 # The default is 16.
# An optional setting that defines a construction time/accuracy trade-off. # An optional setting that defines a construction time/accuracy trade-off.
hnsw_ef_construction = 100 hnsw_ef_construction: int = 100
def get_connection_string(self) -> str: def get_connection_string(self) -> str:
return self.proto + "://" + self.host + ":" + str(self.port) return self.proto + "://" + self.host + ":" + str(self.port)

View File

@ -85,8 +85,8 @@ class Qdrant(VectorStore):
qdrant = Qdrant(client, collection_name, embedding_function) qdrant = Qdrant(client, collection_name, embedding_function)
""" """
CONTENT_KEY = "page_content" CONTENT_KEY: str = "page_content"
METADATA_KEY = "metadata" METADATA_KEY: str = "metadata"
VECTOR_NAME = None VECTOR_NAME = None
def __init__( def __init__(

View File

@ -25,7 +25,7 @@ class SemaDB(VectorStore):
""" """
HOST = "semadb.p.rapidapi.com" HOST: str = "semadb.p.rapidapi.com"
BASE_URL = "https://" + HOST BASE_URL = "https://" + HOST
def __init__( def __init__(

View File

@ -23,9 +23,9 @@ def mock_quip(): # type: ignore
@pytest.mark.requires("quip_api") @pytest.mark.requires("quip_api")
class TestQuipLoader: class TestQuipLoader:
API_URL = "https://example-api.quip.com" API_URL: str = "https://example-api.quip.com"
DOC_URL_PREFIX = ("https://example.quip.com",) DOC_URL_PREFIX = ("https://example.quip.com",)
ACCESS_TOKEN = "api_token" ACCESS_TOKEN: str = "api_token"
MOCK_FOLDER_IDS = ["ABC"] MOCK_FOLDER_IDS = ["ABC"]
MOCK_THREAD_IDS = ["ABC", "DEF"] MOCK_THREAD_IDS = ["ABC", "DEF"]

View File

@ -59,8 +59,8 @@ def test_custom_formatter() -> None:
"""Test ability to create a custom content formatter.""" """Test ability to create a custom content formatter."""
class CustomFormatter(ContentFormatterBase): class CustomFormatter(ContentFormatterBase):
content_type = "application/json" content_type: str = "application/json"
accepts = "application/json" accepts: str = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override] def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps( input_str = json.dumps(
@ -101,8 +101,8 @@ def test_invalid_request_format() -> None:
"""Test invalid request format.""" """Test invalid request format."""
class CustomContentFormatter(ContentFormatterBase): class CustomContentFormatter(ContentFormatterBase):
content_type = "application/json" content_type: str = "application/json"
accepts = "application/json" accepts: str = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override] def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps( input_str = json.dumps(

View File

@ -13,19 +13,19 @@ README_PATH = Path(__file__).parents[4] / "README.md"
class FakeUploadResponse: class FakeUploadResponse:
status_code = 200 status_code: int = 200
text = "fake_uuid" text: str = "fake_uuid"
class FakePushResponse: class FakePushResponse:
status_code = 200 status_code: int = 200
def json(self) -> Any: def json(self) -> Any:
return {"uuid": "fake_uuid"} return {"uuid": "fake_uuid"}
class FakePullResponse: class FakePullResponse:
status_code = 200 status_code: int = 200
def json(self) -> Any: def json(self) -> Any:
return { return {

View File

@ -11,7 +11,7 @@ LOG = logging.getLogger(__name__)
class TestChatKinetica: class TestChatKinetica:
test_ctx_json = """ test_ctx_json: str = """
{ {
"payload":{ "payload":{
"context":[ "context":[

View File

@ -20,10 +20,10 @@ def mock_confluence(): # type: ignore
@pytest.mark.requires("atlassian", "bs4", "lxml") @pytest.mark.requires("atlassian", "bs4", "lxml")
class TestConfluenceLoader: class TestConfluenceLoader:
CONFLUENCE_URL = "https://example.atlassian.com/wiki" CONFLUENCE_URL: str = "https://example.atlassian.com/wiki"
MOCK_USERNAME = "user@gmail.com" MOCK_USERNAME: str = "user@gmail.com"
MOCK_API_TOKEN = "api_token" MOCK_API_TOKEN: str = "api_token"
MOCK_SPACE_KEY = "spaceId123" MOCK_SPACE_KEY: str = "spaceId123"
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None: def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
ConfluenceLoader( ConfluenceLoader(

View File

@ -57,12 +57,12 @@ def mock_lakefs_client_no_presign_local() -> Any:
class TestLakeFSLoader(unittest.TestCase): class TestLakeFSLoader(unittest.TestCase):
lakefs_access_key = "lakefs_access_key" lakefs_access_key: str = "lakefs_access_key"
lakefs_secret_key = "lakefs_secret_key" lakefs_secret_key: str = "lakefs_secret_key"
endpoint = "endpoint" endpoint: str = "endpoint"
repo = "repo" repo: str = "repo"
ref = "ref" ref: str = "ref"
path = "path" path: str = "path"
@requests_mock.Mocker() @requests_mock.Mocker()
@pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local") @pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local")

View File

@ -21,9 +21,9 @@ def mock_connector_id(): # type: ignore
@pytest.mark.requires("psychicapi") @pytest.mark.requires("psychicapi")
class TestPsychicLoader: class TestPsychicLoader:
MOCK_API_KEY = "api_key" MOCK_API_KEY: str = "api_key"
MOCK_CONNECTOR_ID = "notion" MOCK_CONNECTOR_ID: str = "notion"
MOCK_ACCOUNT_ID = "account_id" MOCK_ACCOUNT_ID: str = "account_id"
def test_psychic_loader_initialization( def test_psychic_loader_initialization(
self, mock_psychic: MagicMock, mock_connector_id: MagicMock self, mock_psychic: MagicMock, mock_connector_id: MagicMock

View File

@ -4,9 +4,9 @@ from langchain_community.document_loaders.rspace import RSpaceLoader
class TestRSpaceLoader(unittest.TestCase): class TestRSpaceLoader(unittest.TestCase):
url = "https://community.researchspace.com" url: str = "https://community.researchspace.com"
api_key = "myapikey" api_key: str = "myapikey"
global_id = "SD12345" global_id: str = "SD12345"
def test_valid_arguments(self) -> None: def test_valid_arguments(self) -> None:
loader = RSpaceLoader( loader = RSpaceLoader(

View File

@ -70,7 +70,7 @@ class MockGradientaiPackage(MagicMock):
"""Mock Gradientai package.""" """Mock Gradientai package."""
Gradient = MockGradient Gradient = MockGradient
__version__ = "1.4.0" __version__: str = "1.4.0"
def test_gradient_llm_sync() -> None: def test_gradient_llm_sync() -> None:

View File

@ -5,7 +5,7 @@ from langchain_community.embeddings.ollama import OllamaEmbeddings
class MockResponse: class MockResponse:
status_code = 200 status_code: int = 200
def json(self) -> dict: def json(self) -> dict:
return {"embedding": [1, 2, 3]} return {"embedding": [1, 2, 3]}