mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
docstrings cleanup (#11640)
Added missed docstrings. Some reformatting.
This commit is contained in:
parent
78b4c7d5a0
commit
db67ccb0bb
@ -21,6 +21,7 @@ INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||
|
||||
|
||||
def trim_query(query: str) -> str:
|
||||
"""Trim the query to only include Cypher keywords."""
|
||||
keywords = (
|
||||
"CALL",
|
||||
"CREATE",
|
||||
|
@ -42,6 +42,7 @@ def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a message to a dictionary that can be passed to the API."""
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
@ -105,11 +106,12 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
"""
|
||||
|
||||
model: str = "ERNIE-Bot-turbo"
|
||||
"""Model name.
|
||||
"""Model name.
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
`model` will be ignored if `endpoint` is set.
|
||||
Default is ERNIE-Bot-turbo.
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
|
@ -6,7 +6,17 @@ from langchain.schema import Document
|
||||
|
||||
|
||||
class MsWordParser(BaseBlobParser):
|
||||
"""Parse the Microsoft Word documents from a blob."""
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Parse a Microsoft Word document into the Document iterator.
|
||||
|
||||
Args:
|
||||
blob: The blob to parse.
|
||||
|
||||
Returns: An iterator of Documents.
|
||||
|
||||
"""
|
||||
try:
|
||||
from unstructured.partition.doc import partition_doc
|
||||
from unstructured.partition.docx import partition_docx
|
||||
|
@ -6,7 +6,8 @@ from langchain.schema.embeddings import Embeddings
|
||||
|
||||
class XinferenceEmbeddings(Embeddings):
|
||||
|
||||
"""Wrapper around xinference embedding models.
|
||||
"""Xinference embedding models.
|
||||
|
||||
To use, you should have the xinference library installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Push and pull to the LangChain Hub."""
|
||||
"""Interface with the LangChain Hub."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
@ -46,6 +46,15 @@ def _custom_parser(multiline_string: str) -> str:
|
||||
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
|
||||
# MIT License
|
||||
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
|
||||
"""Parse a JSON string that may be missing closing braces.
|
||||
|
||||
Args:
|
||||
s: The JSON string to parse.
|
||||
strict: Whether to use strict parsing. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object as a Python dictionary.
|
||||
"""
|
||||
# Attempt to parse the string as-is.
|
||||
try:
|
||||
return json.loads(s, strict=strict)
|
||||
|
@ -22,7 +22,17 @@ UNARY_OPERATORS = [Operator.NOT]
|
||||
|
||||
|
||||
def process_value(value: Union[int, float, str]) -> str:
|
||||
# required for comparators involving strings
|
||||
"""Convert a value to a string and add double quotes if it is a string.
|
||||
|
||||
It required for comparators involving strings.
|
||||
|
||||
Args:
|
||||
value: The value to convert.
|
||||
|
||||
Returns:
|
||||
The converted value as a string.
|
||||
"""
|
||||
#
|
||||
if isinstance(value, str):
|
||||
# If the value is already a string, add double quotes
|
||||
return f'"{value}"'
|
||||
|
@ -11,6 +11,7 @@ from langchain.chains.query_constructor.ir import (
|
||||
|
||||
|
||||
def process_value(value: Union[int, float, str]) -> str:
|
||||
"""Convert a value to a string and add single quotes if it is a string."""
|
||||
if isinstance(value, str):
|
||||
return f"'{value}'"
|
||||
else:
|
||||
|
@ -30,11 +30,20 @@ Output = TypeVar("Output")
|
||||
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
"""Run a coroutine with a semaphore.
|
||||
Args:
|
||||
semaphore: The semaphore to use.
|
||||
coro: The coroutine to run.
|
||||
|
||||
Returns:
|
||||
The result of the coroutine.
|
||||
"""
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||
"""Gather coroutines with a limit on the number of concurrent coroutines."""
|
||||
if n is None:
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
@ -44,6 +53,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
||||
|
||||
|
||||
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a run_manager argument."""
|
||||
try:
|
||||
return signature(callable).parameters.get("run_manager") is not None
|
||||
except ValueError:
|
||||
@ -51,6 +61,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
||||
|
||||
|
||||
def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable accepts a config argument."""
|
||||
try:
|
||||
return signature(callable).parameters.get("config") is not None
|
||||
except ValueError:
|
||||
@ -58,6 +69,8 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
||||
|
||||
|
||||
class IsLocalDict(ast.NodeVisitor):
|
||||
"""Check if a name is a local dict."""
|
||||
|
||||
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||
self.name = name
|
||||
self.keys = keys
|
||||
@ -88,6 +101,8 @@ class IsLocalDict(ast.NodeVisitor):
|
||||
|
||||
|
||||
class IsFunctionArgDict(ast.NodeVisitor):
|
||||
"""Check if the first argument of a function is a dict."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.keys: Set[str] = set()
|
||||
|
||||
@ -105,17 +120,22 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
||||
|
||||
|
||||
class GetLambdaSource(ast.NodeVisitor):
|
||||
"""Get the source code of a lambda function."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the visitor."""
|
||||
self.source: Optional[str] = None
|
||||
self.count = 0
|
||||
|
||||
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||
"""Visit a lambda function."""
|
||||
self.count += 1
|
||||
if hasattr(ast, "unparse"):
|
||||
self.source = ast.unparse(node)
|
||||
|
||||
|
||||
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||
"""Get the keys of the first argument of a function if it is a dict."""
|
||||
try:
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(textwrap.dedent(code))
|
||||
@ -190,6 +210,8 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
|
||||
|
||||
|
||||
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
||||
"""Protocol for objects that support addition."""
|
||||
|
||||
def __add__(self, __x: _T_contra) -> _T_co:
|
||||
...
|
||||
|
||||
@ -198,6 +220,7 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
||||
|
||||
|
||||
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
"""Add a sequence of addable objects together."""
|
||||
final = None
|
||||
for chunk in addables:
|
||||
if final is None:
|
||||
@ -208,6 +231,7 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
|
||||
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
"""Asynchronously add a sequence of addable objects together."""
|
||||
final = None
|
||||
async for chunk in addables:
|
||||
if final is None:
|
||||
@ -218,6 +242,8 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
||||
|
||||
|
||||
class ConfigurableField(NamedTuple):
|
||||
"""A field that can be configured by the user."""
|
||||
|
||||
id: str
|
||||
|
||||
name: Optional[str] = None
|
||||
@ -226,6 +252,8 @@ class ConfigurableField(NamedTuple):
|
||||
|
||||
|
||||
class ConfigurableFieldSingleOption(NamedTuple):
|
||||
"""A field that can be configured by the user with a default value."""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
default: str
|
||||
@ -235,6 +263,8 @@ class ConfigurableFieldSingleOption(NamedTuple):
|
||||
|
||||
|
||||
class ConfigurableFieldMultiOption(NamedTuple):
|
||||
"""A field that can be configured by the user with multiple default values."""
|
||||
|
||||
id: str
|
||||
options: Mapping[str, Any]
|
||||
default: Sequence[str]
|
||||
@ -249,6 +279,8 @@ AnyConfigurableField = Union[
|
||||
|
||||
|
||||
class ConfigurableFieldSpec(NamedTuple):
|
||||
"""A field that can be configured by the user. It is a specification of a field."""
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
description: Optional[str]
|
||||
@ -260,6 +292,7 @@ class ConfigurableFieldSpec(NamedTuple):
|
||||
def get_unique_config_specs(
|
||||
specs: Iterable[ConfigurableFieldSpec],
|
||||
) -> Sequence[ConfigurableFieldSpec]:
|
||||
"""Get the unique config specs from a sequence of config specs."""
|
||||
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
||||
unique: List[ConfigurableFieldSpec] = []
|
||||
for id, dupes in grouped:
|
||||
|
@ -642,10 +642,16 @@ class HTMLHeaderTextSplitter:
|
||||
# @dataclass(frozen=True, kw_only=True, slots=True)
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
"""Tokenizer data class."""
|
||||
|
||||
chunk_overlap: int
|
||||
"""Overlap in tokens between chunks"""
|
||||
tokens_per_chunk: int
|
||||
"""Maximum number of tokens per chunk"""
|
||||
decode: Callable[[list[int]], str]
|
||||
""" Function to decode a list of token ids to a string"""
|
||||
encode: Callable[[str], List[int]]
|
||||
""" Function to encode a string to a list of token ids"""
|
||||
|
||||
|
||||
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
|
||||
|
@ -9,11 +9,15 @@ from langchain.tools.ainetwork.base import AINBaseTool
|
||||
|
||||
|
||||
class AppOperationType(str, Enum):
|
||||
"""Type of app operation as enumerator."""
|
||||
|
||||
SET_ADMIN = "SET_ADMIN"
|
||||
GET_CONFIG = "GET_CONFIG"
|
||||
|
||||
|
||||
class AppSchema(BaseModel):
|
||||
"""Schema for app operations."""
|
||||
|
||||
type: AppOperationType = Field(...)
|
||||
appName: str = Field(..., description="Name of the application on the blockchain")
|
||||
address: Optional[Union[str, List[str]]] = Field(
|
||||
@ -26,6 +30,8 @@ class AppSchema(BaseModel):
|
||||
|
||||
|
||||
class AINAppOps(AINBaseTool):
|
||||
"""Tool for app operations."""
|
||||
|
||||
name: str = "AINappOps"
|
||||
description: str = """
|
||||
Create an app in the AINetwork Blockchain database by creating the /apps/<appName> path.
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Base class for AINetwork tools."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@ -16,6 +15,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class OperationType(str, Enum):
|
||||
"""Type of operation as enumerator."""
|
||||
|
||||
SET = "SET"
|
||||
GET = "GET"
|
||||
|
||||
|
@ -8,6 +8,8 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType
|
||||
|
||||
|
||||
class RuleSchema(BaseModel):
|
||||
"""Schema for owner operations."""
|
||||
|
||||
type: OperationType = Field(...)
|
||||
path: str = Field(..., description="Blockchain reference path")
|
||||
address: Optional[Union[str, List[str]]] = Field(
|
||||
@ -28,6 +30,8 @@ class RuleSchema(BaseModel):
|
||||
|
||||
|
||||
class AINOwnerOps(AINBaseTool):
|
||||
"""Tool for owner operations."""
|
||||
|
||||
name: str = "AINownerOps"
|
||||
description: str = """
|
||||
Rules for `owner` in AINetwork Blockchain database.
|
||||
|
@ -8,12 +8,16 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType
|
||||
|
||||
|
||||
class RuleSchema(BaseModel):
|
||||
"""Schema for owner operations."""
|
||||
|
||||
type: OperationType = Field(...)
|
||||
path: str = Field(..., description="Path on the blockchain where the rule applies")
|
||||
eval: Optional[str] = Field(None, description="eval string to determine permission")
|
||||
|
||||
|
||||
class AINRuleOps(AINBaseTool):
|
||||
"""Tool for owner operations."""
|
||||
|
||||
name: str = "AINruleOps"
|
||||
description: str = """
|
||||
Covers the write `rule` for the AINetwork Blockchain database. The SET type specifies write permissions using the `eval` variable as a JavaScript eval string.
|
||||
|
@ -7,11 +7,15 @@ from langchain.tools.ainetwork.base import AINBaseTool
|
||||
|
||||
|
||||
class TransferSchema(BaseModel):
|
||||
"""Schema for transfer operations."""
|
||||
|
||||
address: str = Field(..., description="Address to transfer AIN to")
|
||||
amount: int = Field(..., description="Amount of AIN to transfer")
|
||||
|
||||
|
||||
class AINTransfer(AINBaseTool):
|
||||
"""Tool for transfer operations."""
|
||||
|
||||
name: str = "AINtransfer"
|
||||
description: str = "Transfers AIN to a specified address"
|
||||
args_schema: Type[TransferSchema] = TransferSchema
|
||||
|
@ -8,6 +8,8 @@ from langchain.tools.ainetwork.base import AINBaseTool, OperationType
|
||||
|
||||
|
||||
class ValueSchema(BaseModel):
|
||||
"""Schema for value operations."""
|
||||
|
||||
type: OperationType = Field(...)
|
||||
path: str = Field(..., description="Blockchain reference path")
|
||||
value: Optional[Union[int, str, float, dict]] = Field(
|
||||
@ -16,6 +18,8 @@ class ValueSchema(BaseModel):
|
||||
|
||||
|
||||
class AINValueOps(AINBaseTool):
|
||||
"""Tool for value operations."""
|
||||
|
||||
name: str = "AINvalueOps"
|
||||
description: str = """
|
||||
Covers the read and write value for the AINetwork Blockchain database.
|
||||
|
@ -31,6 +31,15 @@ DEFAULT_LINK_REGEX = (
|
||||
def find_all_links(
|
||||
raw_html: str, *, pattern: Union[str, re.Pattern, None] = None
|
||||
) -> List[str]:
|
||||
"""Extract all links from a raw html string.
|
||||
|
||||
Args:
|
||||
raw_html: original html.
|
||||
pattern: Regex to use for extracting links from raw html.
|
||||
|
||||
Returns:
|
||||
List[str]: all links
|
||||
"""
|
||||
pattern = pattern or DEFAULT_LINK_REGEX
|
||||
return list(set(re.findall(pattern, raw_html)))
|
||||
|
||||
|
@ -21,6 +21,7 @@ def convert_pydantic_to_openai_function(
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None
|
||||
) -> FunctionDescription:
|
||||
"""Converts a Pydantic model to a function description for the OpenAI API."""
|
||||
schema = dereference_refs(model.schema())
|
||||
schema.pop("definitions", None)
|
||||
return {
|
||||
|
@ -16,7 +16,10 @@ from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
class LLMRails(VectorStore):
|
||||
"""Implementation of Vector Store using LLMRails (https://llmrails.com/).
|
||||
"""Implementation of Vector Store using LLMRails.
|
||||
|
||||
See https://llmrails.com/
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@ -224,6 +227,8 @@ class LLMRails(VectorStore):
|
||||
|
||||
|
||||
class LLMRailsRetriever(VectorStoreRetriever):
|
||||
"""Retriever for LLMRails."""
|
||||
|
||||
vectorstore: LLMRails
|
||||
search_kwargs: dict = Field(default_factory=lambda: {"k": 5})
|
||||
"""Search params.
|
||||
|
@ -61,6 +61,7 @@ def _get_search_index_query(search_type: SearchType) -> str:
|
||||
|
||||
|
||||
def check_if_not_null(props: List[str], values: List[Any]) -> None:
|
||||
"""Check if the values are not None or empty string"""
|
||||
for prop, value in zip(props, values):
|
||||
if not value:
|
||||
raise ValueError(f"Parameter `{prop}` must not be None or empty string")
|
||||
|
@ -80,7 +80,7 @@ def check_index_exists(client: RedisType, index_name: str) -> bool:
|
||||
|
||||
|
||||
class Redis(VectorStore):
|
||||
"""Wrapper around Redis vector database.
|
||||
"""Redis vector database.
|
||||
|
||||
To use, you should have the ``redis`` python package installed
|
||||
and have a running Redis Enterprise or Redis-Stack server
|
||||
|
@ -10,6 +10,8 @@ from langchain.utilities.redis import TokenEscaper
|
||||
|
||||
|
||||
class RedisFilterOperator(Enum):
|
||||
"""RedisFilterOperator enumerator is used to create RedisFilterExpressions."""
|
||||
|
||||
EQ = 1
|
||||
NE = 2
|
||||
LT = 3
|
||||
@ -23,6 +25,8 @@ class RedisFilterOperator(Enum):
|
||||
|
||||
|
||||
class RedisFilter:
|
||||
"""Collection of RedisFilterFields."""
|
||||
|
||||
@staticmethod
|
||||
def text(field: str) -> "RedisText":
|
||||
return RedisText(field)
|
||||
@ -37,6 +41,8 @@ class RedisFilter:
|
||||
|
||||
|
||||
class RedisFilterField:
|
||||
"""Base class for RedisFilterFields."""
|
||||
|
||||
escaper: "TokenEscaper" = TokenEscaper()
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {}
|
||||
|
||||
@ -72,6 +78,8 @@ class RedisFilterField:
|
||||
|
||||
|
||||
def check_operator_misuse(func: Callable) -> Callable:
|
||||
"""Decorator to check for misuse of equality operators."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
|
||||
# Extracting 'other' from positional arguments or keyword arguments
|
||||
@ -93,7 +101,7 @@ def check_operator_misuse(func: Callable) -> Callable:
|
||||
|
||||
|
||||
class RedisTag(RedisFilterField):
|
||||
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
|
||||
"""A RedisFilterField representing a tag in a Redis index."""
|
||||
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
@ -293,7 +301,7 @@ class RedisNum(RedisFilterField):
|
||||
|
||||
|
||||
class RedisText(RedisFilterField):
|
||||
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
|
||||
"""A RedisFilterField representing a text field in a Redis index."""
|
||||
|
||||
OPERATORS = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
@ -361,7 +369,7 @@ class RedisText(RedisFilterField):
|
||||
|
||||
|
||||
class RedisFilterExpression:
|
||||
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
|
||||
"""A logical expression of RedisFilterFields.
|
||||
|
||||
RedisFilterExpressions can be combined using the & and | operators to create
|
||||
complex logical expressions that evaluate to the Redis Query language.
|
||||
|
@ -22,16 +22,22 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class RedisDistanceMetric(str, Enum):
|
||||
"""Distance metrics for Redis vector fields."""
|
||||
|
||||
l2 = "L2"
|
||||
cosine = "COSINE"
|
||||
ip = "IP"
|
||||
|
||||
|
||||
class RedisField(BaseModel):
|
||||
"""Base class for Redis fields."""
|
||||
|
||||
name: str = Field(...)
|
||||
|
||||
|
||||
class TextFieldSchema(RedisField):
|
||||
"""Schema for text fields in Redis."""
|
||||
|
||||
weight: float = 1
|
||||
no_stem: bool = False
|
||||
phonetic_matcher: Optional[str] = None
|
||||
@ -53,6 +59,8 @@ class TextFieldSchema(RedisField):
|
||||
|
||||
|
||||
class TagFieldSchema(RedisField):
|
||||
"""Schema for tag fields in Redis."""
|
||||
|
||||
separator: str = ","
|
||||
case_sensitive: bool = False
|
||||
no_index: bool = False
|
||||
@ -71,6 +79,8 @@ class TagFieldSchema(RedisField):
|
||||
|
||||
|
||||
class NumericFieldSchema(RedisField):
|
||||
"""Schema for numeric fields in Redis."""
|
||||
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
@ -81,6 +91,8 @@ class NumericFieldSchema(RedisField):
|
||||
|
||||
|
||||
class RedisVectorField(RedisField):
|
||||
"""Base class for Redis vector fields."""
|
||||
|
||||
dims: int = Field(...)
|
||||
algorithm: object = Field(...)
|
||||
datatype: str = Field(default="FLOAT32")
|
||||
@ -101,6 +113,8 @@ class RedisVectorField(RedisField):
|
||||
|
||||
|
||||
class FlatVectorField(RedisVectorField):
|
||||
"""Schema for flat vector fields in Redis."""
|
||||
|
||||
algorithm: Literal["FLAT"] = "FLAT"
|
||||
block_size: int = Field(default=1000)
|
||||
|
||||
@ -121,6 +135,8 @@ class FlatVectorField(RedisVectorField):
|
||||
|
||||
|
||||
class HNSWVectorField(RedisVectorField):
|
||||
"""Schema for HNSW vector fields in Redis."""
|
||||
|
||||
algorithm: Literal["HNSW"] = "HNSW"
|
||||
m: int = Field(default=16)
|
||||
ef_construction: int = Field(default=200)
|
||||
@ -147,6 +163,8 @@ class HNSWVectorField(RedisVectorField):
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
"""Schema for Redis index."""
|
||||
|
||||
# always have a content field for text
|
||||
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
|
||||
tag: Optional[List[TagFieldSchema]] = None
|
||||
@ -268,8 +286,11 @@ class RedisModel(BaseModel):
|
||||
def read_schema(
|
||||
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
|
||||
) -> Dict[str, Any]:
|
||||
# check if its a dict and return RedisModel otherwise, check if it's a path and
|
||||
# read in the file assuming it's a yaml file and return a RedisModel
|
||||
"""Reads in the index schema from a dict or yaml file.
|
||||
|
||||
Check if it is a dict and return RedisModel otherwise, check if it's a path and
|
||||
read in the file assuming it's a yaml file and return a RedisModel
|
||||
"""
|
||||
if isinstance(index_schema, dict):
|
||||
return index_schema
|
||||
elif isinstance(index_schema, Path):
|
||||
|
Loading…
Reference in New Issue
Block a user