mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
community[patch]: Add missing annotations (#24890)
This PR adds annotations in comunity package. Annotations are only strictly needed in subclasses of BaseModel for pydantic 2 compatibility. This PR adds some unnecessary annotations, but they're not bad to have regardless for documentation pages.
This commit is contained in:
parent
7720483432
commit
d24b82357f
@ -58,7 +58,7 @@ class UpstashRatelimitHandler(BaseCallbackHandler):
|
||||
every time you invoke.
|
||||
"""
|
||||
|
||||
raise_error = True
|
||||
raise_error: bool = True
|
||||
_checked: bool = False
|
||||
|
||||
def __init__(
|
||||
|
@ -76,7 +76,7 @@ class PebbloRetrievalQA(Chain):
|
||||
"""Classifier endpoint."""
|
||||
classifier_location: str = "local" #: :meta private:
|
||||
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
|
||||
_discover_sent = False #: :meta private:
|
||||
_discover_sent: bool = False #: :meta private:
|
||||
"""Flag to check if discover payload has been sent."""
|
||||
_prompt_sent: bool = False #: :meta private:
|
||||
"""Flag to check if prompt payload has been sent."""
|
||||
|
@ -9,7 +9,7 @@ import os
|
||||
import re
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
|
||||
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
@ -164,7 +164,7 @@ class _KineticaLlmFileContextParser:
|
||||
"""Parser for Kinetica LLM context datafiles."""
|
||||
|
||||
# parse line into a dict containing role and content
|
||||
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
||||
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
||||
|
||||
@classmethod
|
||||
def _removesuffix(cls, text: str, suffix: str) -> str:
|
||||
|
@ -135,7 +135,7 @@ class Provider(ABC):
|
||||
|
||||
|
||||
class CohereProvider(Provider):
|
||||
stop_sequence_key = "stop_sequences"
|
||||
stop_sequence_key: str = "stop_sequences"
|
||||
|
||||
def __init__(self) -> None:
|
||||
from oci.generative_ai_inference import models
|
||||
@ -364,7 +364,7 @@ class CohereProvider(Provider):
|
||||
|
||||
|
||||
class MetaProvider(Provider):
|
||||
stop_sequence_key = "stop"
|
||||
stop_sequence_key: str = "stop"
|
||||
|
||||
def __init__(self) -> None:
|
||||
from oci.generative_ai_inference import models
|
||||
|
@ -1,7 +1,7 @@
|
||||
# LLM Lingua Document Compressor
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Pattern, Sequence, Tuple
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
@ -24,8 +24,8 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
|
||||
|
||||
# Pattern to match ref tags at the beginning or end of the string,
|
||||
# allowing for malformed tags
|
||||
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
|
||||
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
|
||||
_pattern_beginning: Pattern = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
|
||||
_pattern_ending: Pattern = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
|
||||
|
||||
model_name: str = "NousResearch/Llama-2-7b-hf"
|
||||
"""The hugging face model to use"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Union
|
||||
from typing import Iterator, Pattern, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@ -10,7 +10,9 @@ from langchain_community.document_loaders.base import BaseLoader
|
||||
class AcreomLoader(BaseLoader):
|
||||
"""Load `acreom` vault from a directory."""
|
||||
|
||||
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL)
|
||||
FRONT_MATTER_REGEX: Pattern = re.compile(
|
||||
r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
"""Regex to match front matter metadata in markdown files."""
|
||||
|
||||
def __init__(
|
||||
|
@ -44,13 +44,13 @@ class DocugamiLoader(BaseLoader, BaseModel):
|
||||
access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
|
||||
"""The Docugami API access token to use."""
|
||||
|
||||
max_text_length = 4096
|
||||
max_text_length: int = 4096
|
||||
"""Max length of chunk text returned."""
|
||||
|
||||
min_text_length: int = 32
|
||||
"""Threshold under which chunks are appended to next to avoid over-chunking."""
|
||||
|
||||
max_metadata_length = 512
|
||||
max_metadata_length: int = 512
|
||||
"""Max length of metadata text returned."""
|
||||
|
||||
include_xml_tags: bool = False
|
||||
|
@ -36,8 +36,8 @@ class HuggingFaceModelLoader(BaseLoader):
|
||||
print(doc.metadata) # Metadata of the model
|
||||
"""
|
||||
|
||||
BASE_URL = "https://huggingface.co/api/models"
|
||||
README_BASE_URL = "https://huggingface.co/{model_id}/raw/main/README.md"
|
||||
BASE_URL: str = "https://huggingface.co/api/models"
|
||||
README_BASE_URL: str = "https://huggingface.co/{model_id}/raw/main/README.md"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -2,7 +2,7 @@ import functools
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, Union
|
||||
from typing import Any, Dict, Iterator, Pattern, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core.documents import Document
|
||||
@ -15,12 +15,16 @@ logger = logging.getLogger(__name__)
|
||||
class ObsidianLoader(BaseLoader):
|
||||
"""Load `Obsidian` files from directory."""
|
||||
|
||||
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
|
||||
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
|
||||
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
|
||||
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
|
||||
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
|
||||
DATAVIEW_INLINE_PAREN_REGEX = re.compile(r"\((\w+)::\s*(.*)\)", re.MULTILINE)
|
||||
FRONT_MATTER_REGEX: Pattern = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
|
||||
TEMPLATE_VARIABLE_REGEX: Pattern = re.compile(r"{{(.*?)}}", re.DOTALL)
|
||||
TAG_REGEX: Pattern = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
|
||||
DATAVIEW_LINE_REGEX: Pattern = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
|
||||
DATAVIEW_INLINE_BRACKET_REGEX: Pattern = re.compile(
|
||||
r"\[(\w+)::\s*(.*)\]", re.MULTILINE
|
||||
)
|
||||
DATAVIEW_INLINE_PAREN_REGEX: Pattern = re.compile(
|
||||
r"\((\w+)::\s*(.*)\)", re.MULTILINE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -39,7 +39,7 @@ class OneNoteLoader(BaseLoader, BaseModel):
|
||||
"""Personal access token"""
|
||||
onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
|
||||
"""URL of Microsoft Graph API for OneNote"""
|
||||
authority_url = "https://login.microsoftonline.com/consumers/"
|
||||
authority_url: str = "https://login.microsoftonline.com/consumers/"
|
||||
"""A URL that identifies a token authority"""
|
||||
token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
|
||||
"""Path to the file where the access token is stored"""
|
||||
|
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Pattern
|
||||
|
||||
from langchain_community.document_loaders.parsers.language.code_segmenter import (
|
||||
CodeSegmenter,
|
||||
@ -9,11 +9,11 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
|
||||
class CobolSegmenter(CodeSegmenter):
|
||||
"""Code segmenter for `COBOL`."""
|
||||
|
||||
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
|
||||
DIVISION_PATTERN = re.compile(
|
||||
PARAGRAPH_PATTERN: Pattern = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
|
||||
DIVISION_PATTERN: Pattern = re.compile(
|
||||
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
|
||||
)
|
||||
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
|
||||
SECTION_PATTERN: Pattern = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(code)
|
||||
|
@ -27,7 +27,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
chunk_size: int = 16
|
||||
|
||||
model_name = "ErnieBot-Embedding-V1"
|
||||
model_name: str = "ErnieBot-Embedding-V1"
|
||||
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from hashlib import md5
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
|
||||
|
||||
from langchain_community.graphs.graph_document import GraphDocument
|
||||
from langchain_community.graphs.graph_store import GraphStore
|
||||
@ -63,7 +63,7 @@ class AGEGraph(GraphStore):
|
||||
}
|
||||
|
||||
# precompiled regex for checking chars in graph labels
|
||||
label_regex = re.compile("[^0-9a-zA-Z]+")
|
||||
label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
|
||||
|
||||
def __init__(
|
||||
self, graph_name: str, conf: Dict[str, Any], create: bool = True
|
||||
|
@ -39,9 +39,9 @@ class MoonshotCommon(BaseModel):
|
||||
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""
|
||||
model_name: str = Field(default="moonshot-v1-8k", alias="model")
|
||||
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing"""
|
||||
max_tokens = 1024
|
||||
max_tokens: int = 1024
|
||||
"""Maximum number of tokens to generate."""
|
||||
temperature = 0.3
|
||||
temperature: float = 0.3
|
||||
"""Temperature parameter (higher values make the model more creative)."""
|
||||
|
||||
class Config:
|
||||
|
@ -244,7 +244,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
|
||||
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
|
||||
Defaults to True."""
|
||||
|
||||
return_full_text = False
|
||||
return_full_text: bool = False
|
||||
"""Whether to prepend the prompt to the generated text. Defaults to False."""
|
||||
|
||||
@property
|
||||
|
@ -26,7 +26,7 @@ class Provider(ABC):
|
||||
|
||||
|
||||
class CohereProvider(Provider):
|
||||
stop_sequence_key = "stop_sequences"
|
||||
stop_sequence_key: str = "stop_sequences"
|
||||
|
||||
def __init__(self) -> None:
|
||||
from oci.generative_ai_inference import models
|
||||
@ -38,7 +38,7 @@ class CohereProvider(Provider):
|
||||
|
||||
|
||||
class MetaProvider(Provider):
|
||||
stop_sequence_key = "stop"
|
||||
stop_sequence_key: str = "stop"
|
||||
|
||||
def __init__(self) -> None:
|
||||
from oci.generative_ai_inference import models
|
||||
|
@ -16,7 +16,7 @@ class SVEndpointHandler:
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
|
||||
API_BASE_PATH = "/api/predict"
|
||||
API_BASE_PATH: str = "/api/predict"
|
||||
|
||||
def __init__(self, host_url: str):
|
||||
"""
|
||||
|
@ -41,7 +41,7 @@ class SolarCommon(BaseModel):
|
||||
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
||||
"""Model name. Available models listed here: https://console.upstage.ai/services/solar"""
|
||||
max_tokens: int = Field(default=1024)
|
||||
temperature = 0.3
|
||||
temperature: float = 0.3
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
@ -27,7 +27,7 @@ class SupabaseVectorTranslator(Visitor):
|
||||
]
|
||||
"""Subset of allowed logical comparators."""
|
||||
|
||||
metadata_column = "metadata"
|
||||
metadata_column: str = "metadata"
|
||||
|
||||
def _map_comparator(self, comparator: Comparator) -> str:
|
||||
"""
|
||||
|
@ -74,8 +74,8 @@ class BearlyInterpreterTool:
|
||||
"""Tool for evaluating python code in a sandbox environment."""
|
||||
|
||||
api_key: str
|
||||
endpoint = "https://exec.bearly.ai/v1/interpreter"
|
||||
name = "bearly_interpreter"
|
||||
endpoint: str = "https://exec.bearly.ai/v1/interpreter"
|
||||
name: str = "bearly_interpreter"
|
||||
args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
|
||||
files: Dict[str, FileInfo] = {}
|
||||
|
||||
|
@ -51,12 +51,12 @@ class ZenGuardTool(BaseTool):
|
||||
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
|
||||
)
|
||||
args_schema = ZenGuardInput
|
||||
return_direct = True
|
||||
return_direct: bool = True
|
||||
|
||||
zenguard_api_key: Optional[str] = Field(default=None)
|
||||
|
||||
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
|
||||
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"
|
||||
_ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/"
|
||||
_ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
|
||||
|
||||
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
|
||||
def set_api_key(cls, v: str) -> str:
|
||||
|
@ -11,7 +11,7 @@ class Portkey:
|
||||
Default: "https://api.portkey.ai/v1/proxy"
|
||||
"""
|
||||
|
||||
base = "https://api.portkey.ai/v1/proxy"
|
||||
base: str = "https://api.portkey.ai/v1/proxy"
|
||||
|
||||
@staticmethod
|
||||
def Config(
|
||||
|
@ -28,7 +28,7 @@ class TokenEscaper:
|
||||
|
||||
# Characters that RediSearch requires us to escape during queries.
|
||||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
|
||||
DEFAULT_ESCAPED_CHARS: str = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
|
||||
|
||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||
if escape_chars_re:
|
||||
|
@ -29,7 +29,7 @@ class AtlasDB(VectorStore):
|
||||
vectorstore = AtlasDB("my_project", embeddings.embed_query)
|
||||
"""
|
||||
|
||||
_ATLAS_DEFAULT_ID_FIELD = "atlas_id"
|
||||
_ATLAS_DEFAULT_ID_FIELD: str = "atlas_id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -21,7 +21,7 @@ DEFAULT_TOPN = 4
|
||||
class AwaDB(VectorStore):
|
||||
"""`AwaDB` vector store."""
|
||||
|
||||
_DEFAULT_TABLE_NAME = "langchain_awadb"
|
||||
_DEFAULT_TABLE_NAME: str = "langchain_awadb"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -53,7 +53,7 @@ class Bagel(VectorStore):
|
||||
vectorstore = Bagel(cluster_name="langchain_store")
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain"
|
||||
_LANGCHAIN_DEFAULT_CLUSTER_NAME: str = "langchain"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -66,7 +66,7 @@ class Chroma(VectorStore):
|
||||
vectorstore = Chroma("langchain_store", embeddings)
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME: str = "langchain"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -60,10 +60,10 @@ class CouchbaseVectorStore(VectorStore):
|
||||
"""
|
||||
|
||||
# Default batch size
|
||||
DEFAULT_BATCH_SIZE = 100
|
||||
_metadata_key = "metadata"
|
||||
_default_text_key = "text"
|
||||
_default_embedding_key = "embedding"
|
||||
DEFAULT_BATCH_SIZE: int = 100
|
||||
_metadata_key: str = "metadata"
|
||||
_default_text_key: str = "text"
|
||||
_default_embedding_key: str = "embedding"
|
||||
|
||||
def _check_bucket_exists(self) -> bool:
|
||||
"""Check if the bucket exists in the linked Couchbase cluster"""
|
||||
|
@ -51,7 +51,7 @@ class DeepLake(VectorStore):
|
||||
vectorstore = DeepLake("langchain_store", embeddings.embed_query)
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH = "./deeplake/"
|
||||
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH: str = "./deeplake/"
|
||||
_valid_search_kwargs = ["lambda_mult"]
|
||||
|
||||
def __init__(
|
||||
|
@ -45,9 +45,9 @@ class Epsilla(VectorStore):
|
||||
epsilla = Epsilla(client, embeddings, db_path, db_name)
|
||||
"""
|
||||
|
||||
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
|
||||
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
|
||||
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
|
||||
_LANGCHAIN_DEFAULT_DB_NAME: str = "langchain_store"
|
||||
_LANGCHAIN_DEFAULT_DB_PATH: str = "/tmp/langchain-epsilla"
|
||||
_LANGCHAIN_DEFAULT_TABLE_NAME: str = "langchain_collection"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -13,6 +13,7 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
@ -223,7 +224,7 @@ class HanaDB(VectorStore):
|
||||
return embedding
|
||||
|
||||
# Compile pattern only once, for better performance
|
||||
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||
_compiled_pattern: Pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
||||
|
@ -48,7 +48,7 @@ class ManticoreSearchSettings(BaseSettings):
|
||||
hnsw_m: int = 16 # The default is 16.
|
||||
|
||||
# An optional setting that defines a construction time/accuracy trade-off.
|
||||
hnsw_ef_construction = 100
|
||||
hnsw_ef_construction: int = 100
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
return self.proto + "://" + self.host + ":" + str(self.port)
|
||||
|
@ -85,8 +85,8 @@ class Qdrant(VectorStore):
|
||||
qdrant = Qdrant(client, collection_name, embedding_function)
|
||||
"""
|
||||
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
CONTENT_KEY: str = "page_content"
|
||||
METADATA_KEY: str = "metadata"
|
||||
VECTOR_NAME = None
|
||||
|
||||
def __init__(
|
||||
|
@ -25,7 +25,7 @@ class SemaDB(VectorStore):
|
||||
|
||||
"""
|
||||
|
||||
HOST = "semadb.p.rapidapi.com"
|
||||
HOST: str = "semadb.p.rapidapi.com"
|
||||
BASE_URL = "https://" + HOST
|
||||
|
||||
def __init__(
|
||||
|
@ -23,9 +23,9 @@ def mock_quip(): # type: ignore
|
||||
|
||||
@pytest.mark.requires("quip_api")
|
||||
class TestQuipLoader:
|
||||
API_URL = "https://example-api.quip.com"
|
||||
API_URL: str = "https://example-api.quip.com"
|
||||
DOC_URL_PREFIX = ("https://example.quip.com",)
|
||||
ACCESS_TOKEN = "api_token"
|
||||
ACCESS_TOKEN: str = "api_token"
|
||||
|
||||
MOCK_FOLDER_IDS = ["ABC"]
|
||||
MOCK_THREAD_IDS = ["ABC", "DEF"]
|
||||
|
@ -59,8 +59,8 @@ def test_custom_formatter() -> None:
|
||||
"""Test ability to create a custom content formatter."""
|
||||
|
||||
class CustomFormatter(ContentFormatterBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
content_type: str = "application/json"
|
||||
accepts: str = "application/json"
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
||||
input_str = json.dumps(
|
||||
@ -101,8 +101,8 @@ def test_invalid_request_format() -> None:
|
||||
"""Test invalid request format."""
|
||||
|
||||
class CustomContentFormatter(ContentFormatterBase):
|
||||
content_type = "application/json"
|
||||
accepts = "application/json"
|
||||
content_type: str = "application/json"
|
||||
accepts: str = "application/json"
|
||||
|
||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
||||
input_str = json.dumps(
|
||||
|
@ -13,19 +13,19 @@ README_PATH = Path(__file__).parents[4] / "README.md"
|
||||
|
||||
|
||||
class FakeUploadResponse:
|
||||
status_code = 200
|
||||
text = "fake_uuid"
|
||||
status_code: int = 200
|
||||
text: str = "fake_uuid"
|
||||
|
||||
|
||||
class FakePushResponse:
|
||||
status_code = 200
|
||||
status_code: int = 200
|
||||
|
||||
def json(self) -> Any:
|
||||
return {"uuid": "fake_uuid"}
|
||||
|
||||
|
||||
class FakePullResponse:
|
||||
status_code = 200
|
||||
status_code: int = 200
|
||||
|
||||
def json(self) -> Any:
|
||||
return {
|
||||
|
@ -11,7 +11,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestChatKinetica:
|
||||
test_ctx_json = """
|
||||
test_ctx_json: str = """
|
||||
{
|
||||
"payload":{
|
||||
"context":[
|
||||
|
@ -20,10 +20,10 @@ def mock_confluence(): # type: ignore
|
||||
|
||||
@pytest.mark.requires("atlassian", "bs4", "lxml")
|
||||
class TestConfluenceLoader:
|
||||
CONFLUENCE_URL = "https://example.atlassian.com/wiki"
|
||||
MOCK_USERNAME = "user@gmail.com"
|
||||
MOCK_API_TOKEN = "api_token"
|
||||
MOCK_SPACE_KEY = "spaceId123"
|
||||
CONFLUENCE_URL: str = "https://example.atlassian.com/wiki"
|
||||
MOCK_USERNAME: str = "user@gmail.com"
|
||||
MOCK_API_TOKEN: str = "api_token"
|
||||
MOCK_SPACE_KEY: str = "spaceId123"
|
||||
|
||||
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
|
||||
ConfluenceLoader(
|
||||
|
@ -57,12 +57,12 @@ def mock_lakefs_client_no_presign_local() -> Any:
|
||||
|
||||
|
||||
class TestLakeFSLoader(unittest.TestCase):
|
||||
lakefs_access_key = "lakefs_access_key"
|
||||
lakefs_secret_key = "lakefs_secret_key"
|
||||
endpoint = "endpoint"
|
||||
repo = "repo"
|
||||
ref = "ref"
|
||||
path = "path"
|
||||
lakefs_access_key: str = "lakefs_access_key"
|
||||
lakefs_secret_key: str = "lakefs_secret_key"
|
||||
endpoint: str = "endpoint"
|
||||
repo: str = "repo"
|
||||
ref: str = "ref"
|
||||
path: str = "path"
|
||||
|
||||
@requests_mock.Mocker()
|
||||
@pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local")
|
||||
|
@ -21,9 +21,9 @@ def mock_connector_id(): # type: ignore
|
||||
|
||||
@pytest.mark.requires("psychicapi")
|
||||
class TestPsychicLoader:
|
||||
MOCK_API_KEY = "api_key"
|
||||
MOCK_CONNECTOR_ID = "notion"
|
||||
MOCK_ACCOUNT_ID = "account_id"
|
||||
MOCK_API_KEY: str = "api_key"
|
||||
MOCK_CONNECTOR_ID: str = "notion"
|
||||
MOCK_ACCOUNT_ID: str = "account_id"
|
||||
|
||||
def test_psychic_loader_initialization(
|
||||
self, mock_psychic: MagicMock, mock_connector_id: MagicMock
|
||||
|
@ -4,9 +4,9 @@ from langchain_community.document_loaders.rspace import RSpaceLoader
|
||||
|
||||
|
||||
class TestRSpaceLoader(unittest.TestCase):
|
||||
url = "https://community.researchspace.com"
|
||||
api_key = "myapikey"
|
||||
global_id = "SD12345"
|
||||
url: str = "https://community.researchspace.com"
|
||||
api_key: str = "myapikey"
|
||||
global_id: str = "SD12345"
|
||||
|
||||
def test_valid_arguments(self) -> None:
|
||||
loader = RSpaceLoader(
|
||||
|
@ -70,7 +70,7 @@ class MockGradientaiPackage(MagicMock):
|
||||
"""Mock Gradientai package."""
|
||||
|
||||
Gradient = MockGradient
|
||||
__version__ = "1.4.0"
|
||||
__version__: str = "1.4.0"
|
||||
|
||||
|
||||
def test_gradient_llm_sync() -> None:
|
||||
|
@ -5,7 +5,7 @@ from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||
|
||||
|
||||
class MockResponse:
|
||||
status_code = 200
|
||||
status_code: int = 200
|
||||
|
||||
def json(self) -> dict:
|
||||
return {"embedding": [1, 2, 3]}
|
||||
|
Loading…
Reference in New Issue
Block a user