mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 09:28:48 +00:00
community[patch]: Add missing annotations (#24890)
This PR adds annotations in comunity package. Annotations are only strictly needed in subclasses of BaseModel for pydantic 2 compatibility. This PR adds some unnecessary annotations, but they're not bad to have regardless for documentation pages.
This commit is contained in:
parent
7720483432
commit
d24b82357f
@ -58,7 +58,7 @@ class UpstashRatelimitHandler(BaseCallbackHandler):
|
|||||||
every time you invoke.
|
every time you invoke.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise_error = True
|
raise_error: bool = True
|
||||||
_checked: bool = False
|
_checked: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -76,7 +76,7 @@ class PebbloRetrievalQA(Chain):
|
|||||||
"""Classifier endpoint."""
|
"""Classifier endpoint."""
|
||||||
classifier_location: str = "local" #: :meta private:
|
classifier_location: str = "local" #: :meta private:
|
||||||
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
|
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
|
||||||
_discover_sent = False #: :meta private:
|
_discover_sent: bool = False #: :meta private:
|
||||||
"""Flag to check if discover payload has been sent."""
|
"""Flag to check if discover payload has been sent."""
|
||||||
_prompt_sent: bool = False #: :meta private:
|
_prompt_sent: bool = False #: :meta private:
|
||||||
"""Flag to check if prompt payload has been sent."""
|
"""Flag to check if prompt payload has been sent."""
|
||||||
|
@ -9,7 +9,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, cast
|
||||||
|
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ class _KineticaLlmFileContextParser:
|
|||||||
"""Parser for Kinetica LLM context datafiles."""
|
"""Parser for Kinetica LLM context datafiles."""
|
||||||
|
|
||||||
# parse line into a dict containing role and content
|
# parse line into a dict containing role and content
|
||||||
PARSER = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
PARSER: Pattern = re.compile(r"^<\|(?P<role>\w+)\|>\W*(?P<content>.*)$", re.DOTALL)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _removesuffix(cls, text: str, suffix: str) -> str:
|
def _removesuffix(cls, text: str, suffix: str) -> str:
|
||||||
|
@ -135,7 +135,7 @@ class Provider(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class CohereProvider(Provider):
|
class CohereProvider(Provider):
|
||||||
stop_sequence_key = "stop_sequences"
|
stop_sequence_key: str = "stop_sequences"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
from oci.generative_ai_inference import models
|
from oci.generative_ai_inference import models
|
||||||
@ -364,7 +364,7 @@ class CohereProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class MetaProvider(Provider):
|
class MetaProvider(Provider):
|
||||||
stop_sequence_key = "stop"
|
stop_sequence_key: str = "stop"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
from oci.generative_ai_inference import models
|
from oci.generative_ai_inference import models
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# LLM Lingua Document Compressor
|
# LLM Lingua Document Compressor
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, List, Optional, Pattern, Sequence, Tuple
|
||||||
|
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -24,8 +24,8 @@ class LLMLinguaCompressor(BaseDocumentCompressor):
|
|||||||
|
|
||||||
# Pattern to match ref tags at the beginning or end of the string,
|
# Pattern to match ref tags at the beginning or end of the string,
|
||||||
# allowing for malformed tags
|
# allowing for malformed tags
|
||||||
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
|
_pattern_beginning: Pattern = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
|
||||||
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
|
_pattern_ending: Pattern = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
|
||||||
|
|
||||||
model_name: str = "NousResearch/Llama-2-7b-hf"
|
model_name: str = "NousResearch/Llama-2-7b-hf"
|
||||||
"""The hugging face model to use"""
|
"""The hugging face model to use"""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Union
|
from typing import Iterator, Pattern, Union
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@ -10,7 +10,9 @@ from langchain_community.document_loaders.base import BaseLoader
|
|||||||
class AcreomLoader(BaseLoader):
|
class AcreomLoader(BaseLoader):
|
||||||
"""Load `acreom` vault from a directory."""
|
"""Load `acreom` vault from a directory."""
|
||||||
|
|
||||||
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL)
|
FRONT_MATTER_REGEX: Pattern = re.compile(
|
||||||
|
r"^---\n(.*?)\n---\n", re.MULTILINE | re.DOTALL
|
||||||
|
)
|
||||||
"""Regex to match front matter metadata in markdown files."""
|
"""Regex to match front matter metadata in markdown files."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -44,13 +44,13 @@ class DocugamiLoader(BaseLoader, BaseModel):
|
|||||||
access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
|
access_token: Optional[str] = os.environ.get("DOCUGAMI_API_KEY")
|
||||||
"""The Docugami API access token to use."""
|
"""The Docugami API access token to use."""
|
||||||
|
|
||||||
max_text_length = 4096
|
max_text_length: int = 4096
|
||||||
"""Max length of chunk text returned."""
|
"""Max length of chunk text returned."""
|
||||||
|
|
||||||
min_text_length: int = 32
|
min_text_length: int = 32
|
||||||
"""Threshold under which chunks are appended to next to avoid over-chunking."""
|
"""Threshold under which chunks are appended to next to avoid over-chunking."""
|
||||||
|
|
||||||
max_metadata_length = 512
|
max_metadata_length: int = 512
|
||||||
"""Max length of metadata text returned."""
|
"""Max length of metadata text returned."""
|
||||||
|
|
||||||
include_xml_tags: bool = False
|
include_xml_tags: bool = False
|
||||||
|
@ -36,8 +36,8 @@ class HuggingFaceModelLoader(BaseLoader):
|
|||||||
print(doc.metadata) # Metadata of the model
|
print(doc.metadata) # Metadata of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BASE_URL = "https://huggingface.co/api/models"
|
BASE_URL: str = "https://huggingface.co/api/models"
|
||||||
README_BASE_URL = "https://huggingface.co/{model_id}/raw/main/README.md"
|
README_BASE_URL: str = "https://huggingface.co/{model_id}/raw/main/README.md"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -2,7 +2,7 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Iterator, Union
|
from typing import Any, Dict, Iterator, Pattern, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -15,12 +15,16 @@ logger = logging.getLogger(__name__)
|
|||||||
class ObsidianLoader(BaseLoader):
|
class ObsidianLoader(BaseLoader):
|
||||||
"""Load `Obsidian` files from directory."""
|
"""Load `Obsidian` files from directory."""
|
||||||
|
|
||||||
FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
|
FRONT_MATTER_REGEX: Pattern = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL)
|
||||||
TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL)
|
TEMPLATE_VARIABLE_REGEX: Pattern = re.compile(r"{{(.*?)}}", re.DOTALL)
|
||||||
TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
|
TAG_REGEX: Pattern = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)")
|
||||||
DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
|
DATAVIEW_LINE_REGEX: Pattern = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE)
|
||||||
DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE)
|
DATAVIEW_INLINE_BRACKET_REGEX: Pattern = re.compile(
|
||||||
DATAVIEW_INLINE_PAREN_REGEX = re.compile(r"\((\w+)::\s*(.*)\)", re.MULTILINE)
|
r"\[(\w+)::\s*(.*)\]", re.MULTILINE
|
||||||
|
)
|
||||||
|
DATAVIEW_INLINE_PAREN_REGEX: Pattern = re.compile(
|
||||||
|
r"\((\w+)::\s*(.*)\)", re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -39,7 +39,7 @@ class OneNoteLoader(BaseLoader, BaseModel):
|
|||||||
"""Personal access token"""
|
"""Personal access token"""
|
||||||
onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
|
onenote_api_base_url: str = "https://graph.microsoft.com/v1.0/me/onenote"
|
||||||
"""URL of Microsoft Graph API for OneNote"""
|
"""URL of Microsoft Graph API for OneNote"""
|
||||||
authority_url = "https://login.microsoftonline.com/consumers/"
|
authority_url: str = "https://login.microsoftonline.com/consumers/"
|
||||||
"""A URL that identifies a token authority"""
|
"""A URL that identifies a token authority"""
|
||||||
token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
|
token_path: FilePath = Path.home() / ".credentials" / "onenote_graph_token.txt"
|
||||||
"""Path to the file where the access token is stored"""
|
"""Path to the file where the access token is stored"""
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Pattern
|
||||||
|
|
||||||
from langchain_community.document_loaders.parsers.language.code_segmenter import (
|
from langchain_community.document_loaders.parsers.language.code_segmenter import (
|
||||||
CodeSegmenter,
|
CodeSegmenter,
|
||||||
@ -9,11 +9,11 @@ from langchain_community.document_loaders.parsers.language.code_segmenter import
|
|||||||
class CobolSegmenter(CodeSegmenter):
|
class CobolSegmenter(CodeSegmenter):
|
||||||
"""Code segmenter for `COBOL`."""
|
"""Code segmenter for `COBOL`."""
|
||||||
|
|
||||||
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
|
PARAGRAPH_PATTERN: Pattern = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
|
||||||
DIVISION_PATTERN = re.compile(
|
DIVISION_PATTERN: Pattern = re.compile(
|
||||||
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
|
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
|
||||||
)
|
)
|
||||||
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
|
SECTION_PATTERN: Pattern = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
|
||||||
|
|
||||||
def __init__(self, code: str):
|
def __init__(self, code: str):
|
||||||
super().__init__(code)
|
super().__init__(code)
|
||||||
|
@ -27,7 +27,7 @@ class ErnieEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
chunk_size: int = 16
|
chunk_size: int = 16
|
||||||
|
|
||||||
model_name = "ErnieBot-Embedding-V1"
|
model_name: str = "ErnieBot-Embedding-V1"
|
||||||
|
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Pattern, Tuple, Union
|
||||||
|
|
||||||
from langchain_community.graphs.graph_document import GraphDocument
|
from langchain_community.graphs.graph_document import GraphDocument
|
||||||
from langchain_community.graphs.graph_store import GraphStore
|
from langchain_community.graphs.graph_store import GraphStore
|
||||||
@ -63,7 +63,7 @@ class AGEGraph(GraphStore):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# precompiled regex for checking chars in graph labels
|
# precompiled regex for checking chars in graph labels
|
||||||
label_regex = re.compile("[^0-9a-zA-Z]+")
|
label_regex: Pattern = re.compile("[^0-9a-zA-Z]+")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, graph_name: str, conf: Dict[str, Any], create: bool = True
|
self, graph_name: str, conf: Dict[str, Any], create: bool = True
|
||||||
|
@ -39,9 +39,9 @@ class MoonshotCommon(BaseModel):
|
|||||||
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""
|
"""Moonshot API key. Get it here: https://platform.moonshot.cn/console/api-keys"""
|
||||||
model_name: str = Field(default="moonshot-v1-8k", alias="model")
|
model_name: str = Field(default="moonshot-v1-8k", alias="model")
|
||||||
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing"""
|
"""Model name. Available models listed here: https://platform.moonshot.cn/pricing"""
|
||||||
max_tokens = 1024
|
max_tokens: int = 1024
|
||||||
"""Maximum number of tokens to generate."""
|
"""Maximum number of tokens to generate."""
|
||||||
temperature = 0.3
|
temperature: float = 0.3
|
||||||
"""Temperature parameter (higher values make the model more creative)."""
|
"""Temperature parameter (higher values make the model more creative)."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -244,7 +244,7 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
|
|||||||
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
|
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
|
||||||
Defaults to True."""
|
Defaults to True."""
|
||||||
|
|
||||||
return_full_text = False
|
return_full_text: bool = False
|
||||||
"""Whether to prepend the prompt to the generated text. Defaults to False."""
|
"""Whether to prepend the prompt to the generated text. Defaults to False."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -26,7 +26,7 @@ class Provider(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class CohereProvider(Provider):
|
class CohereProvider(Provider):
|
||||||
stop_sequence_key = "stop_sequences"
|
stop_sequence_key: str = "stop_sequences"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
from oci.generative_ai_inference import models
|
from oci.generative_ai_inference import models
|
||||||
@ -38,7 +38,7 @@ class CohereProvider(Provider):
|
|||||||
|
|
||||||
|
|
||||||
class MetaProvider(Provider):
|
class MetaProvider(Provider):
|
||||||
stop_sequence_key = "stop"
|
stop_sequence_key: str = "stop"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
from oci.generative_ai_inference import models
|
from oci.generative_ai_inference import models
|
||||||
|
@ -16,7 +16,7 @@ class SVEndpointHandler:
|
|||||||
:param str host_url: Base URL of the DaaS API service
|
:param str host_url: Base URL of the DaaS API service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
API_BASE_PATH = "/api/predict"
|
API_BASE_PATH: str = "/api/predict"
|
||||||
|
|
||||||
def __init__(self, host_url: str):
|
def __init__(self, host_url: str):
|
||||||
"""
|
"""
|
||||||
|
@ -41,7 +41,7 @@ class SolarCommon(BaseModel):
|
|||||||
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
||||||
"""Model name. Available models listed here: https://console.upstage.ai/services/solar"""
|
"""Model name. Available models listed here: https://console.upstage.ai/services/solar"""
|
||||||
max_tokens: int = Field(default=1024)
|
max_tokens: int = Field(default=1024)
|
||||||
temperature = 0.3
|
temperature: float = 0.3
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
allow_population_by_field_name = True
|
allow_population_by_field_name = True
|
||||||
|
@ -27,7 +27,7 @@ class SupabaseVectorTranslator(Visitor):
|
|||||||
]
|
]
|
||||||
"""Subset of allowed logical comparators."""
|
"""Subset of allowed logical comparators."""
|
||||||
|
|
||||||
metadata_column = "metadata"
|
metadata_column: str = "metadata"
|
||||||
|
|
||||||
def _map_comparator(self, comparator: Comparator) -> str:
|
def _map_comparator(self, comparator: Comparator) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -74,8 +74,8 @@ class BearlyInterpreterTool:
|
|||||||
"""Tool for evaluating python code in a sandbox environment."""
|
"""Tool for evaluating python code in a sandbox environment."""
|
||||||
|
|
||||||
api_key: str
|
api_key: str
|
||||||
endpoint = "https://exec.bearly.ai/v1/interpreter"
|
endpoint: str = "https://exec.bearly.ai/v1/interpreter"
|
||||||
name = "bearly_interpreter"
|
name: str = "bearly_interpreter"
|
||||||
args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
|
args_schema: Type[BaseModel] = BearlyInterpreterToolArguments
|
||||||
files: Dict[str, FileInfo] = {}
|
files: Dict[str, FileInfo] = {}
|
||||||
|
|
||||||
|
@ -51,12 +51,12 @@ class ZenGuardTool(BaseTool):
|
|||||||
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
|
"ZenGuard AI integration package. ZenGuard AI - the fastest GenAI guardrails."
|
||||||
)
|
)
|
||||||
args_schema = ZenGuardInput
|
args_schema = ZenGuardInput
|
||||||
return_direct = True
|
return_direct: bool = True
|
||||||
|
|
||||||
zenguard_api_key: Optional[str] = Field(default=None)
|
zenguard_api_key: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
|
_ZENGUARD_API_URL_ROOT: str = "https://api.zenguard.ai/"
|
||||||
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"
|
_ZENGUARD_API_KEY_ENV_NAME: str = "ZENGUARD_API_KEY"
|
||||||
|
|
||||||
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
|
@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
|
||||||
def set_api_key(cls, v: str) -> str:
|
def set_api_key(cls, v: str) -> str:
|
||||||
|
@ -11,7 +11,7 @@ class Portkey:
|
|||||||
Default: "https://api.portkey.ai/v1/proxy"
|
Default: "https://api.portkey.ai/v1/proxy"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base = "https://api.portkey.ai/v1/proxy"
|
base: str = "https://api.portkey.ai/v1/proxy"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def Config(
|
def Config(
|
||||||
|
@ -28,7 +28,7 @@ class TokenEscaper:
|
|||||||
|
|
||||||
# Characters that RediSearch requires us to escape during queries.
|
# Characters that RediSearch requires us to escape during queries.
|
||||||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
|
DEFAULT_ESCAPED_CHARS: str = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
|
||||||
|
|
||||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||||
if escape_chars_re:
|
if escape_chars_re:
|
||||||
|
@ -29,7 +29,7 @@ class AtlasDB(VectorStore):
|
|||||||
vectorstore = AtlasDB("my_project", embeddings.embed_query)
|
vectorstore = AtlasDB("my_project", embeddings.embed_query)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_ATLAS_DEFAULT_ID_FIELD = "atlas_id"
|
_ATLAS_DEFAULT_ID_FIELD: str = "atlas_id"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -21,7 +21,7 @@ DEFAULT_TOPN = 4
|
|||||||
class AwaDB(VectorStore):
|
class AwaDB(VectorStore):
|
||||||
"""`AwaDB` vector store."""
|
"""`AwaDB` vector store."""
|
||||||
|
|
||||||
_DEFAULT_TABLE_NAME = "langchain_awadb"
|
_DEFAULT_TABLE_NAME: str = "langchain_awadb"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -53,7 +53,7 @@ class Bagel(VectorStore):
|
|||||||
vectorstore = Bagel(cluster_name="langchain_store")
|
vectorstore = Bagel(cluster_name="langchain_store")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain"
|
_LANGCHAIN_DEFAULT_CLUSTER_NAME: str = "langchain"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -66,7 +66,7 @@ class Chroma(VectorStore):
|
|||||||
vectorstore = Chroma("langchain_store", embeddings)
|
vectorstore = Chroma("langchain_store", embeddings)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
_LANGCHAIN_DEFAULT_COLLECTION_NAME: str = "langchain"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -60,10 +60,10 @@ class CouchbaseVectorStore(VectorStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Default batch size
|
# Default batch size
|
||||||
DEFAULT_BATCH_SIZE = 100
|
DEFAULT_BATCH_SIZE: int = 100
|
||||||
_metadata_key = "metadata"
|
_metadata_key: str = "metadata"
|
||||||
_default_text_key = "text"
|
_default_text_key: str = "text"
|
||||||
_default_embedding_key = "embedding"
|
_default_embedding_key: str = "embedding"
|
||||||
|
|
||||||
def _check_bucket_exists(self) -> bool:
|
def _check_bucket_exists(self) -> bool:
|
||||||
"""Check if the bucket exists in the linked Couchbase cluster"""
|
"""Check if the bucket exists in the linked Couchbase cluster"""
|
||||||
|
@ -51,7 +51,7 @@ class DeepLake(VectorStore):
|
|||||||
vectorstore = DeepLake("langchain_store", embeddings.embed_query)
|
vectorstore = DeepLake("langchain_store", embeddings.embed_query)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH = "./deeplake/"
|
_LANGCHAIN_DEFAULT_DEEPLAKE_PATH: str = "./deeplake/"
|
||||||
_valid_search_kwargs = ["lambda_mult"]
|
_valid_search_kwargs = ["lambda_mult"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -45,9 +45,9 @@ class Epsilla(VectorStore):
|
|||||||
epsilla = Epsilla(client, embeddings, db_path, db_name)
|
epsilla = Epsilla(client, embeddings, db_path, db_name)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
|
_LANGCHAIN_DEFAULT_DB_NAME: str = "langchain_store"
|
||||||
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
|
_LANGCHAIN_DEFAULT_DB_PATH: str = "/tmp/langchain-epsilla"
|
||||||
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
|
_LANGCHAIN_DEFAULT_TABLE_NAME: str = "langchain_collection"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -13,6 +13,7 @@ from typing import (
|
|||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Pattern,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
@ -223,7 +224,7 @@ class HanaDB(VectorStore):
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
# Compile pattern only once, for better performance
|
# Compile pattern only once, for better performance
|
||||||
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
_compiled_pattern: Pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
def _sanitize_metadata_keys(metadata: dict) -> dict:
|
||||||
|
@ -48,7 +48,7 @@ class ManticoreSearchSettings(BaseSettings):
|
|||||||
hnsw_m: int = 16 # The default is 16.
|
hnsw_m: int = 16 # The default is 16.
|
||||||
|
|
||||||
# An optional setting that defines a construction time/accuracy trade-off.
|
# An optional setting that defines a construction time/accuracy trade-off.
|
||||||
hnsw_ef_construction = 100
|
hnsw_ef_construction: int = 100
|
||||||
|
|
||||||
def get_connection_string(self) -> str:
|
def get_connection_string(self) -> str:
|
||||||
return self.proto + "://" + self.host + ":" + str(self.port)
|
return self.proto + "://" + self.host + ":" + str(self.port)
|
||||||
|
@ -85,8 +85,8 @@ class Qdrant(VectorStore):
|
|||||||
qdrant = Qdrant(client, collection_name, embedding_function)
|
qdrant = Qdrant(client, collection_name, embedding_function)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CONTENT_KEY = "page_content"
|
CONTENT_KEY: str = "page_content"
|
||||||
METADATA_KEY = "metadata"
|
METADATA_KEY: str = "metadata"
|
||||||
VECTOR_NAME = None
|
VECTOR_NAME = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -25,7 +25,7 @@ class SemaDB(VectorStore):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
HOST = "semadb.p.rapidapi.com"
|
HOST: str = "semadb.p.rapidapi.com"
|
||||||
BASE_URL = "https://" + HOST
|
BASE_URL = "https://" + HOST
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -23,9 +23,9 @@ def mock_quip(): # type: ignore
|
|||||||
|
|
||||||
@pytest.mark.requires("quip_api")
|
@pytest.mark.requires("quip_api")
|
||||||
class TestQuipLoader:
|
class TestQuipLoader:
|
||||||
API_URL = "https://example-api.quip.com"
|
API_URL: str = "https://example-api.quip.com"
|
||||||
DOC_URL_PREFIX = ("https://example.quip.com",)
|
DOC_URL_PREFIX = ("https://example.quip.com",)
|
||||||
ACCESS_TOKEN = "api_token"
|
ACCESS_TOKEN: str = "api_token"
|
||||||
|
|
||||||
MOCK_FOLDER_IDS = ["ABC"]
|
MOCK_FOLDER_IDS = ["ABC"]
|
||||||
MOCK_THREAD_IDS = ["ABC", "DEF"]
|
MOCK_THREAD_IDS = ["ABC", "DEF"]
|
||||||
|
@ -59,8 +59,8 @@ def test_custom_formatter() -> None:
|
|||||||
"""Test ability to create a custom content formatter."""
|
"""Test ability to create a custom content formatter."""
|
||||||
|
|
||||||
class CustomFormatter(ContentFormatterBase):
|
class CustomFormatter(ContentFormatterBase):
|
||||||
content_type = "application/json"
|
content_type: str = "application/json"
|
||||||
accepts = "application/json"
|
accepts: str = "application/json"
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
||||||
input_str = json.dumps(
|
input_str = json.dumps(
|
||||||
@ -101,8 +101,8 @@ def test_invalid_request_format() -> None:
|
|||||||
"""Test invalid request format."""
|
"""Test invalid request format."""
|
||||||
|
|
||||||
class CustomContentFormatter(ContentFormatterBase):
|
class CustomContentFormatter(ContentFormatterBase):
|
||||||
content_type = "application/json"
|
content_type: str = "application/json"
|
||||||
accepts = "application/json"
|
accepts: str = "application/json"
|
||||||
|
|
||||||
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
|
||||||
input_str = json.dumps(
|
input_str = json.dumps(
|
||||||
|
@ -13,19 +13,19 @@ README_PATH = Path(__file__).parents[4] / "README.md"
|
|||||||
|
|
||||||
|
|
||||||
class FakeUploadResponse:
|
class FakeUploadResponse:
|
||||||
status_code = 200
|
status_code: int = 200
|
||||||
text = "fake_uuid"
|
text: str = "fake_uuid"
|
||||||
|
|
||||||
|
|
||||||
class FakePushResponse:
|
class FakePushResponse:
|
||||||
status_code = 200
|
status_code: int = 200
|
||||||
|
|
||||||
def json(self) -> Any:
|
def json(self) -> Any:
|
||||||
return {"uuid": "fake_uuid"}
|
return {"uuid": "fake_uuid"}
|
||||||
|
|
||||||
|
|
||||||
class FakePullResponse:
|
class FakePullResponse:
|
||||||
status_code = 200
|
status_code: int = 200
|
||||||
|
|
||||||
def json(self) -> Any:
|
def json(self) -> Any:
|
||||||
return {
|
return {
|
||||||
|
@ -11,7 +11,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class TestChatKinetica:
|
class TestChatKinetica:
|
||||||
test_ctx_json = """
|
test_ctx_json: str = """
|
||||||
{
|
{
|
||||||
"payload":{
|
"payload":{
|
||||||
"context":[
|
"context":[
|
||||||
|
@ -20,10 +20,10 @@ def mock_confluence(): # type: ignore
|
|||||||
|
|
||||||
@pytest.mark.requires("atlassian", "bs4", "lxml")
|
@pytest.mark.requires("atlassian", "bs4", "lxml")
|
||||||
class TestConfluenceLoader:
|
class TestConfluenceLoader:
|
||||||
CONFLUENCE_URL = "https://example.atlassian.com/wiki"
|
CONFLUENCE_URL: str = "https://example.atlassian.com/wiki"
|
||||||
MOCK_USERNAME = "user@gmail.com"
|
MOCK_USERNAME: str = "user@gmail.com"
|
||||||
MOCK_API_TOKEN = "api_token"
|
MOCK_API_TOKEN: str = "api_token"
|
||||||
MOCK_SPACE_KEY = "spaceId123"
|
MOCK_SPACE_KEY: str = "spaceId123"
|
||||||
|
|
||||||
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
|
def test_confluence_loader_initialization(self, mock_confluence: MagicMock) -> None:
|
||||||
ConfluenceLoader(
|
ConfluenceLoader(
|
||||||
|
@ -57,12 +57,12 @@ def mock_lakefs_client_no_presign_local() -> Any:
|
|||||||
|
|
||||||
|
|
||||||
class TestLakeFSLoader(unittest.TestCase):
|
class TestLakeFSLoader(unittest.TestCase):
|
||||||
lakefs_access_key = "lakefs_access_key"
|
lakefs_access_key: str = "lakefs_access_key"
|
||||||
lakefs_secret_key = "lakefs_secret_key"
|
lakefs_secret_key: str = "lakefs_secret_key"
|
||||||
endpoint = "endpoint"
|
endpoint: str = "endpoint"
|
||||||
repo = "repo"
|
repo: str = "repo"
|
||||||
ref = "ref"
|
ref: str = "ref"
|
||||||
path = "path"
|
path: str = "path"
|
||||||
|
|
||||||
@requests_mock.Mocker()
|
@requests_mock.Mocker()
|
||||||
@pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local")
|
@pytest.mark.usefixtures("mock_lakefs_client_no_presign_not_local")
|
||||||
|
@ -21,9 +21,9 @@ def mock_connector_id(): # type: ignore
|
|||||||
|
|
||||||
@pytest.mark.requires("psychicapi")
|
@pytest.mark.requires("psychicapi")
|
||||||
class TestPsychicLoader:
|
class TestPsychicLoader:
|
||||||
MOCK_API_KEY = "api_key"
|
MOCK_API_KEY: str = "api_key"
|
||||||
MOCK_CONNECTOR_ID = "notion"
|
MOCK_CONNECTOR_ID: str = "notion"
|
||||||
MOCK_ACCOUNT_ID = "account_id"
|
MOCK_ACCOUNT_ID: str = "account_id"
|
||||||
|
|
||||||
def test_psychic_loader_initialization(
|
def test_psychic_loader_initialization(
|
||||||
self, mock_psychic: MagicMock, mock_connector_id: MagicMock
|
self, mock_psychic: MagicMock, mock_connector_id: MagicMock
|
||||||
|
@ -4,9 +4,9 @@ from langchain_community.document_loaders.rspace import RSpaceLoader
|
|||||||
|
|
||||||
|
|
||||||
class TestRSpaceLoader(unittest.TestCase):
|
class TestRSpaceLoader(unittest.TestCase):
|
||||||
url = "https://community.researchspace.com"
|
url: str = "https://community.researchspace.com"
|
||||||
api_key = "myapikey"
|
api_key: str = "myapikey"
|
||||||
global_id = "SD12345"
|
global_id: str = "SD12345"
|
||||||
|
|
||||||
def test_valid_arguments(self) -> None:
|
def test_valid_arguments(self) -> None:
|
||||||
loader = RSpaceLoader(
|
loader = RSpaceLoader(
|
||||||
|
@ -70,7 +70,7 @@ class MockGradientaiPackage(MagicMock):
|
|||||||
"""Mock Gradientai package."""
|
"""Mock Gradientai package."""
|
||||||
|
|
||||||
Gradient = MockGradient
|
Gradient = MockGradient
|
||||||
__version__ = "1.4.0"
|
__version__: str = "1.4.0"
|
||||||
|
|
||||||
|
|
||||||
def test_gradient_llm_sync() -> None:
|
def test_gradient_llm_sync() -> None:
|
||||||
|
@ -5,7 +5,7 @@ from langchain_community.embeddings.ollama import OllamaEmbeddings
|
|||||||
|
|
||||||
|
|
||||||
class MockResponse:
|
class MockResponse:
|
||||||
status_code = 200
|
status_code: int = 200
|
||||||
|
|
||||||
def json(self) -> dict:
|
def json(self) -> dict:
|
||||||
return {"embedding": [1, 2, 3]}
|
return {"embedding": [1, 2, 3]}
|
||||||
|
Loading…
Reference in New Issue
Block a user