From db67ccb0bb7084f65a1afdd65f3b50044bc5005b Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Tue, 10 Oct 2023 19:56:47 -0700 Subject: [PATCH] docstrings cleanup (#11640) Added missed docstrings. Some reformatting. --- .../chains/graph_qa/neptune_cypher.py | 1 + .../chat_models/baidu_qianfan_endpoint.py | 6 ++-- .../document_loaders/parsers/msword.py | 10 ++++++ .../langchain/embeddings/xinference.py | 3 +- libs/langchain/langchain/hub.py | 2 +- .../langchain/output_parsers/json.py | 9 +++++ .../langchain/retrievers/self_query/milvus.py | 12 ++++++- .../retrievers/self_query/vectara.py | 1 + .../langchain/schema/runnable/utils.py | 33 +++++++++++++++++++ libs/langchain/langchain/text_splitter.py | 6 ++++ .../langchain/tools/ainetwork/app.py | 6 ++++ .../langchain/tools/ainetwork/base.py | 3 +- .../langchain/tools/ainetwork/owner.py | 4 +++ .../langchain/tools/ainetwork/rule.py | 4 +++ .../langchain/tools/ainetwork/transfer.py | 4 +++ .../langchain/tools/ainetwork/value.py | 4 +++ libs/langchain/langchain/utils/html.py | 9 +++++ .../langchain/utils/openai_functions.py | 1 + .../langchain/vectorstores/llm_rails.py | 7 +++- .../langchain/vectorstores/neo4j_vector.py | 1 + .../langchain/vectorstores/redis/base.py | 2 +- .../langchain/vectorstores/redis/filters.py | 14 ++++++-- .../langchain/vectorstores/redis/schema.py | 25 ++++++++++++-- 23 files changed, 154 insertions(+), 13 deletions(-) diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py index 1e3cb646440..67f0303b0a4 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py @@ -21,6 +21,7 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps" def trim_query(query: str) -> str: + """Trim the query to only include Cypher keywords.""" keywords = ( "CALL", "CREATE", diff --git a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py index 2bcd70c1641..df035464e53 100644 --- a/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py +++ b/libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py @@ -42,6 +42,7 @@ def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk: def convert_message_to_dict(message: BaseMessage) -> dict: + """Convert a message to a dictionary that can be passed to the API.""" message_dict: Dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} @@ -105,11 +106,12 @@ class QianfanChatEndpoint(BaseChatModel): """ model: str = "ERNIE-Bot-turbo" - """Model name. + """Model name. you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu preset models are mapping to an endpoint. - `model` will be ignored if `endpoint` is set + `model` will be ignored if `endpoint` is set. + Default is ERNIE-Bot-turbo. """ endpoint: Optional[str] = None diff --git a/libs/langchain/langchain/document_loaders/parsers/msword.py b/libs/langchain/langchain/document_loaders/parsers/msword.py index 6a83a4d954d..3823a191974 100644 --- a/libs/langchain/langchain/document_loaders/parsers/msword.py +++ b/libs/langchain/langchain/document_loaders/parsers/msword.py @@ -6,7 +6,17 @@ from langchain.schema import Document class MsWordParser(BaseBlobParser): + """Parse the Microsoft Word documents from a blob.""" + def lazy_parse(self, blob: Blob) -> Iterator[Document]: + """Parse a Microsoft Word document into the Document iterator. + + Args: + blob: The blob to parse. + + Returns: An iterator of Documents. + + """ try: from unstructured.partition.doc import partition_doc from unstructured.partition.docx import partition_docx diff --git a/libs/langchain/langchain/embeddings/xinference.py b/libs/langchain/langchain/embeddings/xinference.py index 573dbd868bc..62ab74d2aad 100644 --- a/libs/langchain/langchain/embeddings/xinference.py +++ b/libs/langchain/langchain/embeddings/xinference.py @@ -6,7 +6,8 @@ from langchain.schema.embeddings import Embeddings class XinferenceEmbeddings(Embeddings): - """Wrapper around xinference embedding models. + """Xinference embedding models. + To use, you should have the xinference library installed: .. code-block:: bash diff --git a/libs/langchain/langchain/hub.py b/libs/langchain/langchain/hub.py index ac5012e635a..ebacbad6c8e 100644 --- a/libs/langchain/langchain/hub.py +++ b/libs/langchain/langchain/hub.py @@ -1,4 +1,4 @@ -"""Push and pull to the LangChain Hub.""" +"""Interface with the LangChain Hub.""" from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index 557151f61ab..53c2b1c3c98 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -46,6 +46,15 @@ def _custom_parser(multiline_string: str) -> str: # Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py # MIT License def parse_partial_json(s: str, *, strict: bool = False) -> Any: + """Parse a JSON string that may be missing closing braces. + + Args: + s: The JSON string to parse. + strict: Whether to use strict parsing. Defaults to False. + + Returns: + The parsed JSON object as a Python dictionary. + """ # Attempt to parse the string as-is. try: return json.loads(s, strict=strict) diff --git a/libs/langchain/langchain/retrievers/self_query/milvus.py b/libs/langchain/langchain/retrievers/self_query/milvus.py index 2b4af4500f4..f855bf99093 100644 --- a/libs/langchain/langchain/retrievers/self_query/milvus.py +++ b/libs/langchain/langchain/retrievers/self_query/milvus.py @@ -22,7 +22,17 @@ UNARY_OPERATORS = [Operator.NOT] def process_value(value: Union[int, float, str]) -> str: - # required for comparators involving strings + """Convert a value to a string and add double quotes if it is a string. + + It required for comparators involving strings. + + Args: + value: The value to convert. + + Returns: + The converted value as a string. + """ + # if isinstance(value, str): # If the value is already a string, add double quotes return f'"{value}"' diff --git a/libs/langchain/langchain/retrievers/self_query/vectara.py b/libs/langchain/langchain/retrievers/self_query/vectara.py index 73dc46ff592..02d64f04708 100644 --- a/libs/langchain/langchain/retrievers/self_query/vectara.py +++ b/libs/langchain/langchain/retrievers/self_query/vectara.py @@ -11,6 +11,7 @@ from langchain.chains.query_constructor.ir import ( def process_value(value: Union[int, float, str]) -> str: + """Convert a value to a string and add single quotes if it is a string.""" if isinstance(value, str): return f"'{value}'" else: diff --git a/libs/langchain/langchain/schema/runnable/utils.py b/libs/langchain/langchain/schema/runnable/utils.py index 1d62b50dc89..284cf8cc237 100644 --- a/libs/langchain/langchain/schema/runnable/utils.py +++ b/libs/langchain/langchain/schema/runnable/utils.py @@ -30,11 +30,20 @@ Output = TypeVar("Output") async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: + """Run a coroutine with a semaphore. + Args: + semaphore: The semaphore to use. + coro: The coroutine to run. + + Returns: + The result of the coroutine. + """ async with semaphore: return await coro async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: + """Gather coroutines with a limit on the number of concurrent coroutines.""" if n is None: return await asyncio.gather(*coros) @@ -44,6 +53,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis def accepts_run_manager(callable: Callable[..., Any]) -> bool: + """Check if a callable accepts a run_manager argument.""" try: return signature(callable).parameters.get("run_manager") is not None except ValueError: @@ -51,6 +61,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool: def accepts_config(callable: Callable[..., Any]) -> bool: + """Check if a callable accepts a config argument.""" try: return signature(callable).parameters.get("config") is not None except ValueError: @@ -58,6 +69,8 @@ def accepts_config(callable: Callable[..., Any]) -> bool: class IsLocalDict(ast.NodeVisitor): + """Check if a name is a local dict.""" + def __init__(self, name: str, keys: Set[str]) -> None: self.name = name self.keys = keys @@ -88,6 +101,8 @@ class IsLocalDict(ast.NodeVisitor): class IsFunctionArgDict(ast.NodeVisitor): + """Check if the first argument of a function is a dict.""" + def __init__(self) -> None: self.keys: Set[str] = set() @@ -105,17 +120,22 @@ class IsFunctionArgDict(ast.NodeVisitor): class GetLambdaSource(ast.NodeVisitor): + """Get the source code of a lambda function.""" + def __init__(self) -> None: + """Initialize the visitor.""" self.source: Optional[str] = None self.count = 0 def visit_Lambda(self, node: ast.Lambda) -> Any: + """Visit a lambda function.""" self.count += 1 if hasattr(ast, "unparse"): self.source = ast.unparse(node) def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: + """Get the keys of the first argument of a function if it is a dict.""" try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) @@ -190,6 +210,8 @@ _T_contra = TypeVar("_T_contra", contravariant=True) class SupportsAdd(Protocol[_T_contra, _T_co]): + """Protocol for objects that support addition.""" + def __add__(self, __x: _T_contra) -> _T_co: ... @@ -198,6 +220,7 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any]) def add(addables: Iterable[Addable]) -> Optional[Addable]: + """Add a sequence of addable objects together.""" final = None for chunk in addables: if final is None: @@ -208,6 +231,7 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]: async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: + """Asynchronously add a sequence of addable objects together.""" final = None async for chunk in addables: if final is None: @@ -218,6 +242,8 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: class ConfigurableField(NamedTuple): + """A field that can be configured by the user.""" + id: str name: Optional[str] = None @@ -226,6 +252,8 @@ class ConfigurableField(NamedTuple): class ConfigurableFieldSingleOption(NamedTuple): + """A field that can be configured by the user with a default value.""" + id: str options: Mapping[str, Any] default: str @@ -235,6 +263,8 @@ class ConfigurableFieldSingleOption(NamedTuple): class ConfigurableFieldMultiOption(NamedTuple): + """A field that can be configured by the user with multiple default values.""" + id: str options: Mapping[str, Any] default: Sequence[str] @@ -249,6 +279,8 @@ AnyConfigurableField = Union[ class ConfigurableFieldSpec(NamedTuple): + """A field that can be configured by the user. It is a specification of a field.""" + id: str name: Optional[str] description: Optional[str] @@ -260,6 +292,7 @@ class ConfigurableFieldSpec(NamedTuple): def get_unique_config_specs( specs: Iterable[ConfigurableFieldSpec], ) -> Sequence[ConfigurableFieldSpec]: + """Get the unique config specs from a sequence of config specs.""" grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id) unique: List[ConfigurableFieldSpec] = [] for id, dupes in grouped: diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index c7ad4a2c377..b894937a57c 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -642,10 +642,16 @@ class HTMLHeaderTextSplitter: # @dataclass(frozen=True, kw_only=True, slots=True) @dataclass(frozen=True) class Tokenizer: + """Tokenizer data class.""" + chunk_overlap: int + """Overlap in tokens between chunks""" tokens_per_chunk: int + """Maximum number of tokens per chunk""" decode: Callable[[list[int]], str] + """ Function to decode a list of token ids to a string""" encode: Callable[[str], List[int]] + """ Function to encode a string to a list of token ids""" def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: diff --git a/libs/langchain/langchain/tools/ainetwork/app.py b/libs/langchain/langchain/tools/ainetwork/app.py index 26599d75d7e..64c32046e86 100644 --- a/libs/langchain/langchain/tools/ainetwork/app.py +++ b/libs/langchain/langchain/tools/ainetwork/app.py @@ -9,11 +9,15 @@ from langchain.tools.ainetwork.base import AINBaseTool class AppOperationType(str, Enum): + """Type of app operation as enumerator.""" + SET_ADMIN = "SET_ADMIN" GET_CONFIG = "GET_CONFIG" class AppSchema(BaseModel): + """Schema for app operations.""" + type: AppOperationType = Field(...) appName: str = Field(..., description="Name of the application on the blockchain") address: Optional[Union[str, List[str]]] = Field( @@ -26,6 +30,8 @@ class AppSchema(BaseModel): class AINAppOps(AINBaseTool): + """Tool for app operations.""" + name: str = "AINappOps" description: str = """ Create an app in the AINetwork Blockchain database by creating the /apps/ path. diff --git a/libs/langchain/langchain/tools/ainetwork/base.py b/libs/langchain/langchain/tools/ainetwork/base.py index 7899761302b..0d40f403896 100644 --- a/libs/langchain/langchain/tools/ainetwork/base.py +++ b/libs/langchain/langchain/tools/ainetwork/base.py @@ -1,4 +1,3 @@ -"""Base class for AINetwork tools.""" from __future__ import annotations import asyncio @@ -16,6 +15,8 @@ if TYPE_CHECKING: class OperationType(str, Enum): + """Type of operation as enumerator.""" + SET = "SET" GET = "GET" diff --git a/libs/langchain/langchain/tools/ainetwork/owner.py b/libs/langchain/langchain/tools/ainetwork/owner.py index 42a2b18d861..33c182ac710 100644 --- a/libs/langchain/langchain/tools/ainetwork/owner.py +++ b/libs/langchain/langchain/tools/ainetwork/owner.py @@ -8,6 +8,8 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType class RuleSchema(BaseModel): + """Schema for owner operations.""" + type: OperationType = Field(...) path: str = Field(..., description="Blockchain reference path") address: Optional[Union[str, List[str]]] = Field( @@ -28,6 +30,8 @@ class RuleSchema(BaseModel): class AINOwnerOps(AINBaseTool): + """Tool for owner operations.""" + name: str = "AINownerOps" description: str = """ Rules for `owner` in AINetwork Blockchain database. diff --git a/libs/langchain/langchain/tools/ainetwork/rule.py b/libs/langchain/langchain/tools/ainetwork/rule.py index d66fd20205b..030edb2eb37 100644 --- a/libs/langchain/langchain/tools/ainetwork/rule.py +++ b/libs/langchain/langchain/tools/ainetwork/rule.py @@ -8,12 +8,16 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType class RuleSchema(BaseModel): + """Schema for owner operations.""" + type: OperationType = Field(...) path: str = Field(..., description="Path on the blockchain where the rule applies") eval: Optional[str] = Field(None, description="eval string to determine permission") class AINRuleOps(AINBaseTool): + """Tool for owner operations.""" + name: str = "AINruleOps" description: str = """ Covers the write `rule` for the AINetwork Blockchain database. The SET type specifies write permissions using the `eval` variable as a JavaScript eval string. diff --git a/libs/langchain/langchain/tools/ainetwork/transfer.py b/libs/langchain/langchain/tools/ainetwork/transfer.py index b267f724a32..04f15c6748b 100644 --- a/libs/langchain/langchain/tools/ainetwork/transfer.py +++ b/libs/langchain/langchain/tools/ainetwork/transfer.py @@ -7,11 +7,15 @@ from langchain.tools.ainetwork.base import AINBaseTool class TransferSchema(BaseModel): + """Schema for transfer operations.""" + address: str = Field(..., description="Address to transfer AIN to") amount: int = Field(..., description="Amount of AIN to transfer") class AINTransfer(AINBaseTool): + """Tool for transfer operations.""" + name: str = "AINtransfer" description: str = "Transfers AIN to a specified address" args_schema: Type[TransferSchema] = TransferSchema diff --git a/libs/langchain/langchain/tools/ainetwork/value.py b/libs/langchain/langchain/tools/ainetwork/value.py index 2e6c92c685d..844b98e9968 100644 --- a/libs/langchain/langchain/tools/ainetwork/value.py +++ b/libs/langchain/langchain/tools/ainetwork/value.py @@ -8,6 +8,8 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType class ValueSchema(BaseModel): + """Schema for value operations.""" + type: OperationType = Field(...) path: str = Field(..., description="Blockchain reference path") value: Optional[Union[int, str, float, dict]] = Field( @@ -16,6 +18,8 @@ class ValueSchema(BaseModel): class AINValueOps(AINBaseTool): + """Tool for value operations.""" + name: str = "AINvalueOps" description: str = """ Covers the read and write value for the AINetwork Blockchain database. diff --git a/libs/langchain/langchain/utils/html.py b/libs/langchain/langchain/utils/html.py index 09a76876d1c..ee830e246ee 100644 --- a/libs/langchain/langchain/utils/html.py +++ b/libs/langchain/langchain/utils/html.py @@ -31,6 +31,15 @@ DEFAULT_LINK_REGEX = ( def find_all_links( raw_html: str, *, pattern: Union[str, re.Pattern, None] = None ) -> List[str]: + """Extract all links from a raw html string. + + Args: + raw_html: original html. + pattern: Regex to use for extracting links from raw html. + + Returns: + List[str]: all links + """ pattern = pattern or DEFAULT_LINK_REGEX return list(set(re.findall(pattern, raw_html))) diff --git a/libs/langchain/langchain/utils/openai_functions.py b/libs/langchain/langchain/utils/openai_functions.py index cfb1e76d595..0106894a839 100644 --- a/libs/langchain/langchain/utils/openai_functions.py +++ b/libs/langchain/langchain/utils/openai_functions.py @@ -21,6 +21,7 @@ def convert_pydantic_to_openai_function( name: Optional[str] = None, description: Optional[str] = None ) -> FunctionDescription: + """Converts a Pydantic model to a function description for the OpenAI API.""" schema = dereference_refs(model.schema()) schema.pop("definitions", None) return { diff --git a/libs/langchain/langchain/vectorstores/llm_rails.py b/libs/langchain/langchain/vectorstores/llm_rails.py index 5bf772f5fe7..23ed41ad1ec 100644 --- a/libs/langchain/langchain/vectorstores/llm_rails.py +++ b/libs/langchain/langchain/vectorstores/llm_rails.py @@ -16,7 +16,10 @@ from langchain.vectorstores.base import VectorStore, VectorStoreRetriever class LLMRails(VectorStore): - """Implementation of Vector Store using LLMRails (https://llmrails.com/). + """Implementation of Vector Store using LLMRails. + + See https://llmrails.com/ + Example: .. code-block:: python @@ -224,6 +227,8 @@ class LLMRails(VectorStore): class LLMRailsRetriever(VectorStoreRetriever): + """Retriever for LLMRails.""" + vectorstore: LLMRails search_kwargs: dict = Field(default_factory=lambda: {"k": 5}) """Search params. diff --git a/libs/langchain/langchain/vectorstores/neo4j_vector.py b/libs/langchain/langchain/vectorstores/neo4j_vector.py index 7e524ad4ffc..e826a502212 100644 --- a/libs/langchain/langchain/vectorstores/neo4j_vector.py +++ b/libs/langchain/langchain/vectorstores/neo4j_vector.py @@ -61,6 +61,7 @@ def _get_search_index_query(search_type: SearchType) -> str: def check_if_not_null(props: List[str], values: List[Any]) -> None: + """Check if the values are not None or empty string""" for prop, value in zip(props, values): if not value: raise ValueError(f"Parameter `{prop}` must not be None or empty string") diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index 8409f4fa55f..1306ba0177a 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -80,7 +80,7 @@ def check_index_exists(client: RedisType, index_name: str) -> bool: class Redis(VectorStore): - """Wrapper around Redis vector database. + """Redis vector database. To use, you should have the ``redis`` python package installed and have a running Redis Enterprise or Redis-Stack server diff --git a/libs/langchain/langchain/vectorstores/redis/filters.py b/libs/langchain/langchain/vectorstores/redis/filters.py index 633c8f4073c..2ea59e645ad 100644 --- a/libs/langchain/langchain/vectorstores/redis/filters.py +++ b/libs/langchain/langchain/vectorstores/redis/filters.py @@ -10,6 +10,8 @@ from langchain.utilities.redis import TokenEscaper class RedisFilterOperator(Enum): + """RedisFilterOperator enumerator is used to create RedisFilterExpressions.""" + EQ = 1 NE = 2 LT = 3 @@ -23,6 +25,8 @@ class RedisFilterOperator(Enum): class RedisFilter: + """Collection of RedisFilterFields.""" + @staticmethod def text(field: str) -> "RedisText": return RedisText(field) @@ -37,6 +41,8 @@ class RedisFilter: class RedisFilterField: + """Base class for RedisFilterFields.""" + escaper: "TokenEscaper" = TokenEscaper() OPERATORS: Dict[RedisFilterOperator, str] = {} @@ -72,6 +78,8 @@ class RedisFilterField: def check_operator_misuse(func: Callable) -> Callable: + """Decorator to check for misuse of equality operators.""" + @wraps(func) def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: # Extracting 'other' from positional arguments or keyword arguments @@ -93,7 +101,7 @@ def check_operator_misuse(func: Callable) -> Callable: class RedisTag(RedisFilterField): - """A RedisTag is a RedisFilterField representing a tag in a Redis index.""" + """A RedisFilterField representing a tag in a Redis index.""" OPERATORS: Dict[RedisFilterOperator, str] = { RedisFilterOperator.EQ: "==", @@ -293,7 +301,7 @@ class RedisNum(RedisFilterField): class RedisText(RedisFilterField): - """A RedisText is a RedisFilterField representing a text field in a Redis index.""" + """A RedisFilterField representing a text field in a Redis index.""" OPERATORS = { RedisFilterOperator.EQ: "==", @@ -361,7 +369,7 @@ class RedisText(RedisFilterField): class RedisFilterExpression: - """A RedisFilterExpression is a logical expression of RedisFilterFields. + """A logical expression of RedisFilterFields. RedisFilterExpressions can be combined using the & and | operators to create complex logical expressions that evaluate to the Redis Query language. diff --git a/libs/langchain/langchain/vectorstores/redis/schema.py b/libs/langchain/langchain/vectorstores/redis/schema.py index 79833a94bc6..55e3639a540 100644 --- a/libs/langchain/langchain/vectorstores/redis/schema.py +++ b/libs/langchain/langchain/vectorstores/redis/schema.py @@ -22,16 +22,22 @@ if TYPE_CHECKING: class RedisDistanceMetric(str, Enum): + """Distance metrics for Redis vector fields.""" + l2 = "L2" cosine = "COSINE" ip = "IP" class RedisField(BaseModel): + """Base class for Redis fields.""" + name: str = Field(...) class TextFieldSchema(RedisField): + """Schema for text fields in Redis.""" + weight: float = 1 no_stem: bool = False phonetic_matcher: Optional[str] = None @@ -53,6 +59,8 @@ class TextFieldSchema(RedisField): class TagFieldSchema(RedisField): + """Schema for tag fields in Redis.""" + separator: str = "," case_sensitive: bool = False no_index: bool = False @@ -71,6 +79,8 @@ class TagFieldSchema(RedisField): class NumericFieldSchema(RedisField): + """Schema for numeric fields in Redis.""" + no_index: bool = False sortable: Optional[bool] = False @@ -81,6 +91,8 @@ class NumericFieldSchema(RedisField): class RedisVectorField(RedisField): + """Base class for Redis vector fields.""" + dims: int = Field(...) algorithm: object = Field(...) datatype: str = Field(default="FLOAT32") @@ -101,6 +113,8 @@ class RedisVectorField(RedisField): class FlatVectorField(RedisVectorField): + """Schema for flat vector fields in Redis.""" + algorithm: Literal["FLAT"] = "FLAT" block_size: int = Field(default=1000) @@ -121,6 +135,8 @@ class FlatVectorField(RedisVectorField): class HNSWVectorField(RedisVectorField): + """Schema for HNSW vector fields in Redis.""" + algorithm: Literal["HNSW"] = "HNSW" m: int = Field(default=16) ef_construction: int = Field(default=200) @@ -147,6 +163,8 @@ class HNSWVectorField(RedisVectorField): class RedisModel(BaseModel): + """Schema for Redis index.""" + # always have a content field for text text: List[TextFieldSchema] = [TextFieldSchema(name="content")] tag: Optional[List[TagFieldSchema]] = None @@ -268,8 +286,11 @@ class RedisModel(BaseModel): def read_schema( index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] ) -> Dict[str, Any]: - # check if its a dict and return RedisModel otherwise, check if it's a path and - # read in the file assuming it's a yaml file and return a RedisModel + """Reads in the index schema from a dict or yaml file. + + Check if it is a dict and return RedisModel otherwise, check if it's a path and + read in the file assuming it's a yaml file and return a RedisModel + """ if isinstance(index_schema, dict): return index_schema elif isinstance(index_schema, Path):