mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
IMPROVEMENT: Minor redis improvements (#13381)
- **Description:** - Fixes a `key_prefix` bug where passing it in on `Redis.from_existing(...)` did not work properly. Updates doc strings accordingly. - Updates Redis filter classes logic with best practices on typing, string formatting, and handling "empty" filters. - Fixes a bug that would prevent multiple tag filters from being applied together in some scenarios. - Added a whole new filter unit testing module. Also updated code formatting for a number of modules that were failing the `make` commands. - **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** @baskaryan - **Twitter handle:** @tchutch94
This commit is contained in:
parent
674bd90a47
commit
190952fe76
@ -28,7 +28,7 @@ class TokenEscaper:
|
|||||||
|
|
||||||
# Characters that RediSearch requires us to escape during queries.
|
# Characters that RediSearch requires us to escape during queries.
|
||||||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
|
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
|
||||||
|
|
||||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||||
if escape_chars_re:
|
if escape_chars_re:
|
||||||
@ -37,6 +37,12 @@ class TokenEscaper:
|
|||||||
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
||||||
|
|
||||||
def escape(self, value: str) -> str:
|
def escape(self, value: str) -> str:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise TypeError(
|
||||||
|
"Value must be a string object for token escaping."
|
||||||
|
f"Got type {type(value)}"
|
||||||
|
)
|
||||||
|
|
||||||
def escape_symbol(match: re.Match) -> str:
|
def escape_symbol(match: re.Match) -> str:
|
||||||
value = match.group(0)
|
value = match.group(0)
|
||||||
return f"\\{value}"
|
return f"\\{value}"
|
||||||
|
@ -60,9 +60,9 @@ def check_index_exists(client: RedisType, index_name: str) -> bool:
|
|||||||
try:
|
try:
|
||||||
client.ft(index_name).info()
|
client.ft(index_name).info()
|
||||||
except: # noqa: E722
|
except: # noqa: E722
|
||||||
logger.info("Index does not exist")
|
logger.debug("Index does not exist")
|
||||||
return False
|
return False
|
||||||
logger.info("Index already exists")
|
logger.debug("Index already exists")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -155,9 +155,12 @@ class Redis(VectorStore):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
rds = Redis.from_existing_index(
|
# must pass in schema and key_prefix from another index
|
||||||
|
existing_rds = Redis.from_existing_index(
|
||||||
embeddings, # an Embeddings object
|
embeddings, # an Embeddings object
|
||||||
index_name="my-index",
|
index_name="my-index",
|
||||||
|
schema=rds.schema, # schema dumped from another index
|
||||||
|
key_prefix=rds.key_prefix, # key prefix from another index
|
||||||
redis_url="redis://localhost:6379",
|
redis_url="redis://localhost:6379",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -249,7 +252,7 @@ class Redis(VectorStore):
|
|||||||
key_prefix: Optional[str] = None,
|
key_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize Redis vector store with necessary components."""
|
||||||
self._check_deprecated_kwargs(kwargs)
|
self._check_deprecated_kwargs(kwargs)
|
||||||
try:
|
try:
|
||||||
# TODO use importlib to check if redis is installed
|
# TODO use importlib to check if redis is installed
|
||||||
@ -401,6 +404,7 @@ class Redis(VectorStore):
|
|||||||
index_schema = generated_schema
|
index_schema = generated_schema
|
||||||
|
|
||||||
# Create instance
|
# Create instance
|
||||||
|
# init the class -- if Redis is unavailable, will throw exception
|
||||||
instance = cls(
|
instance = cls(
|
||||||
redis_url,
|
redis_url,
|
||||||
index_name,
|
index_name,
|
||||||
@ -495,6 +499,7 @@ class Redis(VectorStore):
|
|||||||
embedding: Embeddings,
|
embedding: Embeddings,
|
||||||
index_name: str,
|
index_name: str,
|
||||||
schema: Union[Dict[str, str], str, os.PathLike],
|
schema: Union[Dict[str, str], str, os.PathLike],
|
||||||
|
key_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Redis:
|
) -> Redis:
|
||||||
"""Connect to an existing Redis index.
|
"""Connect to an existing Redis index.
|
||||||
@ -504,11 +509,16 @@ class Redis(VectorStore):
|
|||||||
|
|
||||||
from langchain.vectorstores import Redis
|
from langchain.vectorstores import Redis
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings()
|
embeddings = OpenAIEmbeddings()
|
||||||
redisearch = Redis.from_existing_index(
|
|
||||||
|
# must pass in schema and key_prefix from another index
|
||||||
|
existing_rds = Redis.from_existing_index(
|
||||||
embeddings,
|
embeddings,
|
||||||
index_name="my-index",
|
index_name="my-index",
|
||||||
redis_url="redis://username:password@localhost:6379"
|
schema=rds.schema, # schema dumped from another index
|
||||||
|
key_prefix=rds.key_prefix, # key prefix from another index
|
||||||
|
redis_url="redis://username:password@localhost:6379",
|
||||||
)
|
)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -516,8 +526,9 @@ class Redis(VectorStore):
|
|||||||
for embedding queries.
|
for embedding queries.
|
||||||
index_name (str): Name of the index to connect to.
|
index_name (str): Name of the index to connect to.
|
||||||
schema (Union[Dict[str, str], str, os.PathLike]): Schema of the index
|
schema (Union[Dict[str, str], str, os.PathLike]): Schema of the index
|
||||||
and the vector schema. Can be a dict, or path to yaml file
|
and the vector schema. Can be a dict, or path to yaml file.
|
||||||
|
key_prefix (Optional[str]): Prefix to use for all keys in Redis associated
|
||||||
|
with this index.
|
||||||
**kwargs (Any): Additional keyword arguments to pass to the Redis client.
|
**kwargs (Any): Additional keyword arguments to pass to the Redis client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -528,29 +539,32 @@ class Redis(VectorStore):
|
|||||||
ImportError: If the redis python package is not installed.
|
ImportError: If the redis python package is not installed.
|
||||||
"""
|
"""
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||||
try:
|
# We need to first remove redis_url from kwargs,
|
||||||
# We need to first remove redis_url from kwargs,
|
# otherwise passing it to Redis will result in an error.
|
||||||
# otherwise passing it to Redis will result in an error.
|
if "redis_url" in kwargs:
|
||||||
if "redis_url" in kwargs:
|
kwargs.pop("redis_url")
|
||||||
kwargs.pop("redis_url")
|
|
||||||
client = get_client(redis_url=redis_url, **kwargs)
|
|
||||||
# check if redis has redisearch module installed
|
|
||||||
check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
|
||||||
# ensure that the index already exists
|
|
||||||
assert check_index_exists(
|
|
||||||
client, index_name
|
|
||||||
), f"Index {index_name} does not exist"
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Redis failed to connect: {e}")
|
|
||||||
|
|
||||||
return cls(
|
# Create instance
|
||||||
|
# init the class -- if Redis is unavailable, will throw exception
|
||||||
|
instance = cls(
|
||||||
redis_url,
|
redis_url,
|
||||||
index_name,
|
index_name,
|
||||||
embedding,
|
embedding,
|
||||||
index_schema=schema,
|
index_schema=schema,
|
||||||
|
key_prefix=key_prefix,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check for existence of the declared index
|
||||||
|
if not check_index_exists(instance.client, index_name):
|
||||||
|
# Will only raise if the running Redis server does not
|
||||||
|
# have a record of this particular index
|
||||||
|
raise ValueError(
|
||||||
|
f"Redis failed to connect: Index {index_name} does not exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self) -> Dict[str, List[Any]]:
|
def schema(self) -> Dict[str, List[Any]]:
|
||||||
"""Return the schema of the index."""
|
"""Return the schema of the index."""
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from numbers import Number
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from langchain.utilities.redis import TokenEscaper
|
from langchain.utilities.redis import TokenEscaper
|
||||||
|
|
||||||
@ -57,7 +56,7 @@ class RedisFilterField:
|
|||||||
return self._field == other._field and self._value == other._value
|
return self._field == other._field and self._value == other._value
|
||||||
|
|
||||||
def _set_value(
|
def _set_value(
|
||||||
self, val: Any, val_type: type, operator: RedisFilterOperator
|
self, val: Any, val_type: Tuple[Any], operator: RedisFilterOperator
|
||||||
) -> None:
|
) -> None:
|
||||||
# check that the operator is supported by this class
|
# check that the operator is supported by this class
|
||||||
if operator not in self.OPERATORS:
|
if operator not in self.OPERATORS:
|
||||||
@ -108,15 +107,15 @@ class RedisTag(RedisFilterField):
|
|||||||
RedisFilterOperator.NE: "!=",
|
RedisFilterOperator.NE: "!=",
|
||||||
RedisFilterOperator.IN: "==",
|
RedisFilterOperator.IN: "==",
|
||||||
}
|
}
|
||||||
|
|
||||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||||
RedisFilterOperator.EQ: "@%s:{%s}",
|
RedisFilterOperator.EQ: "@%s:{%s}",
|
||||||
RedisFilterOperator.NE: "(-@%s:{%s})",
|
RedisFilterOperator.NE: "(-@%s:{%s})",
|
||||||
RedisFilterOperator.IN: "@%s:{%s}",
|
RedisFilterOperator.IN: "@%s:{%s}",
|
||||||
}
|
}
|
||||||
|
SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None))
|
||||||
|
|
||||||
def __init__(self, field: str):
|
def __init__(self, field: str):
|
||||||
"""Create a RedisTag FilterField
|
"""Create a RedisTag FilterField.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field (str): The name of the RedisTag field in the index to be queried
|
field (str): The name of the RedisTag field in the index to be queried
|
||||||
@ -125,21 +124,33 @@ class RedisTag(RedisFilterField):
|
|||||||
super().__init__(field)
|
super().__init__(field)
|
||||||
|
|
||||||
def _set_tag_value(
|
def _set_tag_value(
|
||||||
self, other: Union[List[str], str], operator: RedisFilterOperator
|
self,
|
||||||
|
other: Union[List[str], Set[str], Tuple[str], str],
|
||||||
|
operator: RedisFilterOperator,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(other, list):
|
if isinstance(other, (list, set, tuple)):
|
||||||
if not all(isinstance(tag, str) for tag in other):
|
try:
|
||||||
raise ValueError("All tags must be strings")
|
# "if val" clause removes non-truthy values from list
|
||||||
else:
|
other = [str(val) for val in other if val]
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("All tags within collection must be strings")
|
||||||
|
# above to catch the "" case
|
||||||
|
elif not other:
|
||||||
|
other = []
|
||||||
|
elif isinstance(other, str):
|
||||||
other = [other]
|
other = [other]
|
||||||
self._set_value(other, list, operator)
|
|
||||||
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore
|
||||||
|
|
||||||
@check_operator_misuse
|
@check_operator_misuse
|
||||||
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
def __eq__(
|
||||||
"""Create a RedisTag equality filter expression
|
self, other: Union[List[str], Set[str], Tuple[str], str]
|
||||||
|
) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisTag equality filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Union[List[str], str]): The tag(s) to filter on.
|
other (Union[List[str], Set[str], Tuple[str], str]):
|
||||||
|
The tag(s) to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisTag
|
>>> from langchain.vectorstores.redis import RedisTag
|
||||||
@ -149,11 +160,14 @@ class RedisTag(RedisFilterField):
|
|||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
@check_operator_misuse
|
@check_operator_misuse
|
||||||
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
def __ne__(
|
||||||
"""Create a RedisTag inequality filter expression
|
self, other: Union[List[str], Set[str], Tuple[str], str]
|
||||||
|
) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisTag inequality filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Union[List[str], str]): The tag(s) to filter on.
|
other (Union[List[str], Set[str], Tuple[str], str]):
|
||||||
|
The tag(s) to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisTag
|
>>> from langchain.vectorstores.redis import RedisTag
|
||||||
@ -167,12 +181,10 @@ class RedisTag(RedisFilterField):
|
|||||||
return "|".join([self.escaper.escape(tag) for tag in self._value])
|
return "|".join([self.escaper.escape(tag) for tag in self._value])
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
"""Return the query syntax for a RedisTag filter expression."""
|
||||||
if not self._value:
|
if not self._value:
|
||||||
raise ValueError(
|
return "*"
|
||||||
f"Operator must be used before calling __str__. Operators are "
|
|
||||||
f"{self.OPERATORS.values()}"
|
|
||||||
)
|
|
||||||
"""Return the Redis Query syntax for a RedisTag filter expression"""
|
|
||||||
return self.OPERATOR_MAP[self._operator] % (
|
return self.OPERATOR_MAP[self._operator] % (
|
||||||
self._field,
|
self._field,
|
||||||
self._formatted_tag_value,
|
self._formatted_tag_value,
|
||||||
@ -191,21 +203,19 @@ class RedisNum(RedisFilterField):
|
|||||||
RedisFilterOperator.GE: ">=",
|
RedisFilterOperator.GE: ">=",
|
||||||
}
|
}
|
||||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||||
RedisFilterOperator.EQ: "@%s:[%f %f]",
|
RedisFilterOperator.EQ: "@%s:[%s %s]",
|
||||||
RedisFilterOperator.NE: "(-@%s:[%f %f])",
|
RedisFilterOperator.NE: "(-@%s:[%s %s])",
|
||||||
RedisFilterOperator.GT: "@%s:[(%f +inf]",
|
RedisFilterOperator.GT: "@%s:[(%s +inf]",
|
||||||
RedisFilterOperator.LT: "@%s:[-inf (%f]",
|
RedisFilterOperator.LT: "@%s:[-inf (%s]",
|
||||||
RedisFilterOperator.GE: "@%s:[%f +inf]",
|
RedisFilterOperator.GE: "@%s:[%s +inf]",
|
||||||
RedisFilterOperator.LE: "@%s:[-inf %f]",
|
RedisFilterOperator.LE: "@%s:[-inf %s]",
|
||||||
}
|
}
|
||||||
|
SUPPORTED_VAL_TYPES = (int, float, type(None))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""Return the Redis Query syntax for a Numeric filter expression"""
|
"""Return the query syntax for a RedisNum filter expression."""
|
||||||
if not self._value:
|
if not self._value:
|
||||||
raise ValueError(
|
return "*"
|
||||||
f"Operator must be used before calling __str__. Operators are "
|
|
||||||
f"{self.OPERATORS.values()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self._operator == RedisFilterOperator.EQ
|
self._operator == RedisFilterOperator.EQ
|
||||||
@ -221,102 +231,103 @@ class RedisNum(RedisFilterField):
|
|||||||
|
|
||||||
@check_operator_misuse
|
@check_operator_misuse
|
||||||
def __eq__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
def __eq__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||||
"""Create a Numeric equality filter expression
|
"""Create a Numeric equality filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Number): The value to filter on.
|
other (Union[int, float]): The value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisNum
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
>>> filter = RedisNum("zipcode") == 90210
|
>>> filter = RedisNum("zipcode") == 90210
|
||||||
"""
|
"""
|
||||||
self._set_value(other, Number, RedisFilterOperator.EQ)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
@check_operator_misuse
|
@check_operator_misuse
|
||||||
def __ne__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
def __ne__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||||
"""Create a Numeric inequality filter expression
|
"""Create a Numeric inequality filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Number): The value to filter on.
|
other (Union[int, float]): The value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisNum
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
>>> filter = RedisNum("zipcode") != 90210
|
>>> filter = RedisNum("zipcode") != 90210
|
||||||
"""
|
"""
|
||||||
self._set_value(other, Number, RedisFilterOperator.NE)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||||
"""Create a RedisNumeric greater than filter expression
|
"""Create a Numeric greater than filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Number): The value to filter on.
|
other (Union[int, float]): The value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisNum
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
>>> filter = RedisNum("age") > 18
|
>>> filter = RedisNum("age") > 18
|
||||||
"""
|
"""
|
||||||
self._set_value(other, Number, RedisFilterOperator.GT)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||||
"""Create a Numeric less than filter expression
|
"""Create a Numeric less than filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Number): The value to filter on.
|
other (Union[int, float]): The value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisNum
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
>>> filter = RedisNum("age") < 18
|
>>> filter = RedisNum("age") < 18
|
||||||
"""
|
"""
|
||||||
self._set_value(other, Number, RedisFilterOperator.LT)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||||
"""Create a Numeric greater than or equal to filter expression
|
"""Create a Numeric greater than or equal to filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Number): The value to filter on.
|
other (Union[int, float]): The value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisNum
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
>>> filter = RedisNum("age") >= 18
|
>>> filter = RedisNum("age") >= 18
|
||||||
"""
|
"""
|
||||||
self._set_value(other, Number, RedisFilterOperator.GE)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||||
"""Create a Numeric less than or equal to filter expression
|
"""Create a Numeric less than or equal to filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (Number): The value to filter on.
|
other (Union[int, float]): The value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisNum
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
>>> filter = RedisNum("age") <= 18
|
>>> filter = RedisNum("age") <= 18
|
||||||
"""
|
"""
|
||||||
self._set_value(other, Number, RedisFilterOperator.LE)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
|
||||||
class RedisText(RedisFilterField):
|
class RedisText(RedisFilterField):
|
||||||
"""A RedisFilterField representing a text field in a Redis index."""
|
"""A RedisFilterField representing a text field in a Redis index."""
|
||||||
|
|
||||||
OPERATORS = {
|
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||||
RedisFilterOperator.EQ: "==",
|
RedisFilterOperator.EQ: "==",
|
||||||
RedisFilterOperator.NE: "!=",
|
RedisFilterOperator.NE: "!=",
|
||||||
RedisFilterOperator.LIKE: "%",
|
RedisFilterOperator.LIKE: "%",
|
||||||
}
|
}
|
||||||
OPERATOR_MAP = {
|
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||||
RedisFilterOperator.EQ: '@%s:"%s"',
|
RedisFilterOperator.EQ: '@%s:("%s")',
|
||||||
RedisFilterOperator.NE: '(-@%s:"%s")',
|
RedisFilterOperator.NE: '(-@%s:"%s")',
|
||||||
RedisFilterOperator.LIKE: "@%s:%s",
|
RedisFilterOperator.LIKE: "@%s:(%s)",
|
||||||
}
|
}
|
||||||
|
SUPPORTED_VAL_TYPES = (str, type(None))
|
||||||
|
|
||||||
@check_operator_misuse
|
@check_operator_misuse
|
||||||
def __eq__(self, other: str) -> "RedisFilterExpression":
|
def __eq__(self, other: str) -> "RedisFilterExpression":
|
||||||
"""Create a RedisText equality filter expression
|
"""Create a RedisText equality (exact match) filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (str): The text value to filter on.
|
other (str): The text value to filter on.
|
||||||
@ -325,12 +336,12 @@ class RedisText(RedisFilterField):
|
|||||||
>>> from langchain.vectorstores.redis import RedisText
|
>>> from langchain.vectorstores.redis import RedisText
|
||||||
>>> filter = RedisText("job") == "engineer"
|
>>> filter = RedisText("job") == "engineer"
|
||||||
"""
|
"""
|
||||||
self._set_value(other, str, RedisFilterOperator.EQ)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
@check_operator_misuse
|
@check_operator_misuse
|
||||||
def __ne__(self, other: str) -> "RedisFilterExpression":
|
def __ne__(self, other: str) -> "RedisFilterExpression":
|
||||||
"""Create a RedisText inequality filter expression
|
"""Create a RedisText inequality filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (str): The text value to filter on.
|
other (str): The text value to filter on.
|
||||||
@ -339,33 +350,34 @@ class RedisText(RedisFilterField):
|
|||||||
>>> from langchain.vectorstores.redis import RedisText
|
>>> from langchain.vectorstores.redis import RedisText
|
||||||
>>> filter = RedisText("job") != "engineer"
|
>>> filter = RedisText("job") != "engineer"
|
||||||
"""
|
"""
|
||||||
self._set_value(other, str, RedisFilterOperator.NE)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
def __mod__(self, other: str) -> "RedisFilterExpression":
|
def __mod__(self, other: str) -> "RedisFilterExpression":
|
||||||
"""Create a RedisText like filter expression
|
"""Create a RedisText "LIKE" filter expression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
other (str): The text value to filter on.
|
other (str): The text value to filter on.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from langchain.vectorstores.redis import RedisText
|
>>> from langchain.vectorstores.redis import RedisText
|
||||||
>>> filter = RedisText("job") % "engineer"
|
>>> filter = RedisText("job") % "engine*" # suffix wild card match
|
||||||
|
>>> filter = RedisText("job") % "%%engine%%" # fuzzy match w/ LD
|
||||||
|
>>> filter = RedisText("job") % "engineer|doctor" # contains either term
|
||||||
|
>>> filter = RedisText("job") % "engineer doctor" # contains both terms
|
||||||
"""
|
"""
|
||||||
self._set_value(other, str, RedisFilterOperator.LIKE)
|
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore
|
||||||
return RedisFilterExpression(str(self))
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
"""Return the query syntax for a RedisText filter expression."""
|
||||||
if not self._value:
|
if not self._value:
|
||||||
raise ValueError(
|
return "*"
|
||||||
f"Operator must be used before calling __str__. Operators are "
|
|
||||||
f"{self.OPERATORS.values()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
return self.OPERATOR_MAP[self._operator] % (
|
||||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
self._field,
|
||||||
except KeyError:
|
self._value,
|
||||||
raise Exception("Invalid operator")
|
)
|
||||||
|
|
||||||
|
|
||||||
class RedisFilterExpression:
|
class RedisFilterExpression:
|
||||||
@ -413,16 +425,36 @@ class RedisFilterExpression:
|
|||||||
operator=RedisFilterOperator.OR, left=self, right=other
|
operator=RedisFilterOperator.OR, left=self, right=other
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def format_expression(
|
||||||
|
left: "RedisFilterExpression", right: "RedisFilterExpression", operator_str: str
|
||||||
|
) -> str:
|
||||||
|
_left, _right = str(left), str(right)
|
||||||
|
if _left == _right == "*":
|
||||||
|
return _left
|
||||||
|
if _left == "*" != _right:
|
||||||
|
return _right
|
||||||
|
if _right == "*" != _left:
|
||||||
|
return _left
|
||||||
|
return f"({_left}{operator_str}{_right})"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
# top level check that allows recursive calls to __str__
|
# top level check that allows recursive calls to __str__
|
||||||
if not self._filter and not self._operator:
|
if not self._filter and not self._operator:
|
||||||
raise ValueError("Improperly initialized RedisFilterExpression")
|
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||||
|
|
||||||
# allow for single filter expression without operators as last
|
# if there's an operator, combine expressions accordingly
|
||||||
# expression in the chain might not have an operator
|
|
||||||
if self._operator:
|
if self._operator:
|
||||||
|
if not isinstance(self._left, RedisFilterExpression) or not isinstance(
|
||||||
|
self._right, RedisFilterExpression
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"Improper combination of filters."
|
||||||
|
"Both left and right should be type FilterExpression"
|
||||||
|
)
|
||||||
|
|
||||||
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
|
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
|
||||||
return f"({str(self._left)}{operator_str}{str(self._right)})"
|
return self.format_expression(self._left, self._right, operator_str)
|
||||||
|
|
||||||
# check that base case, the filter is set
|
# check that base case, the filter is set
|
||||||
if not self._filter:
|
if not self._filter:
|
||||||
|
@ -0,0 +1,193 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.vectorstores.redis import (
|
||||||
|
RedisNum as Num,
|
||||||
|
)
|
||||||
|
from langchain.vectorstores.redis import (
|
||||||
|
RedisTag as Tag,
|
||||||
|
)
|
||||||
|
from langchain.vectorstores.redis import (
|
||||||
|
RedisText as Text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases for various tag scenarios
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"operation,tags,expected",
|
||||||
|
[
|
||||||
|
# Testing single tags
|
||||||
|
("==", "simpletag", "@tag_field:{simpletag}"),
|
||||||
|
(
|
||||||
|
"==",
|
||||||
|
"tag with space",
|
||||||
|
"@tag_field:{tag\\ with\\ space}",
|
||||||
|
), # Escaping spaces within quotes
|
||||||
|
(
|
||||||
|
"==",
|
||||||
|
"special$char",
|
||||||
|
"@tag_field:{special\\$char}",
|
||||||
|
), # Escaping a special character
|
||||||
|
("!=", "negated", "(-@tag_field:{negated})"),
|
||||||
|
# Testing multiple tags
|
||||||
|
("==", ["tag1", "tag2"], "@tag_field:{tag1|tag2}"),
|
||||||
|
(
|
||||||
|
"==",
|
||||||
|
["alpha", "beta with space", "gamma$special"],
|
||||||
|
"@tag_field:{alpha|beta\\ with\\ space|gamma\\$special}",
|
||||||
|
), # Multiple tags with spaces and special chars
|
||||||
|
("!=", ["tagA", "tagB"], "(-@tag_field:{tagA|tagB})"),
|
||||||
|
# Complex tag scenarios with special characters
|
||||||
|
("==", "weird:tag", "@tag_field:{weird\\:tag}"), # Tags with colon
|
||||||
|
("==", "tag&another", "@tag_field:{tag\\&another}"), # Tags with ampersand
|
||||||
|
# Escaping various special characters within tags
|
||||||
|
("==", "tag/with/slashes", "@tag_field:{tag\\/with\\/slashes}"),
|
||||||
|
(
|
||||||
|
"==",
|
||||||
|
["hyphen-tag", "under_score", "dot.tag"],
|
||||||
|
"@tag_field:{hyphen\\-tag|under_score|dot\\.tag}",
|
||||||
|
),
|
||||||
|
# ...additional unique cases as desired...
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_tag_filter_varied(operation: str, tags: str, expected: str) -> None:
|
||||||
|
if operation == "==":
|
||||||
|
tf = Tag("tag_field") == tags
|
||||||
|
elif operation == "!=":
|
||||||
|
tf = Tag("tag_field") != tags
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported operation: {operation}")
|
||||||
|
|
||||||
|
# Verify the string representation matches the expected RediSearch query part
|
||||||
|
assert str(tf) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value, expected",
|
||||||
|
[
|
||||||
|
(None, "*"),
|
||||||
|
([], "*"),
|
||||||
|
("", "*"),
|
||||||
|
([None], "*"),
|
||||||
|
([None, "tag"], "@tag_field:{tag}"),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"none",
|
||||||
|
"empty_list",
|
||||||
|
"empty_string",
|
||||||
|
"list_with_none",
|
||||||
|
"list_with_none_and_tag",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_nullable_tags(value: Any, expected: str) -> None:
|
||||||
|
tag = Tag("tag_field")
|
||||||
|
assert str(tag == value) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"operation, value, expected",
|
||||||
|
[
|
||||||
|
("__eq__", 5, "@numeric_field:[5 5]"),
|
||||||
|
("__ne__", 5, "(-@numeric_field:[5 5])"),
|
||||||
|
("__gt__", 5, "@numeric_field:[(5 +inf]"),
|
||||||
|
("__ge__", 5, "@numeric_field:[5 +inf]"),
|
||||||
|
("__lt__", 5.55, "@numeric_field:[-inf (5.55]"),
|
||||||
|
("__le__", 5, "@numeric_field:[-inf 5]"),
|
||||||
|
("__le__", None, "*"),
|
||||||
|
("__eq__", None, "*"),
|
||||||
|
("__ne__", None, "*"),
|
||||||
|
],
|
||||||
|
ids=["eq", "ne", "gt", "ge", "lt", "le", "le_none", "eq_none", "ne_none"],
|
||||||
|
)
|
||||||
|
def test_numeric_filter(operation: str, value: Any, expected: str) -> None:
|
||||||
|
nf = Num("numeric_field")
|
||||||
|
assert str(getattr(nf, operation)(value)) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"operation, value, expected",
|
||||||
|
[
|
||||||
|
("__eq__", "text", '@text_field:("text")'),
|
||||||
|
("__ne__", "text", '(-@text_field:"text")'),
|
||||||
|
("__eq__", "", "*"),
|
||||||
|
("__ne__", "", "*"),
|
||||||
|
("__eq__", None, "*"),
|
||||||
|
("__ne__", None, "*"),
|
||||||
|
("__mod__", "text", "@text_field:(text)"),
|
||||||
|
("__mod__", "tex*", "@text_field:(tex*)"),
|
||||||
|
("__mod__", "%text%", "@text_field:(%text%)"),
|
||||||
|
("__mod__", "", "*"),
|
||||||
|
("__mod__", None, "*"),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"eq",
|
||||||
|
"ne",
|
||||||
|
"eq-empty",
|
||||||
|
"ne-empty",
|
||||||
|
"eq-none",
|
||||||
|
"ne-none",
|
||||||
|
"like",
|
||||||
|
"like_wildcard",
|
||||||
|
"like_full",
|
||||||
|
"like_empty",
|
||||||
|
"like_none",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_text_filter(operation: str, value: Any, expected: str) -> None:
|
||||||
|
txt_f = getattr(Text("text_field"), operation)(value)
|
||||||
|
assert str(txt_f) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_filters_combination() -> None:
|
||||||
|
tf1 = Tag("tag_field") == ["tag1", "tag2"]
|
||||||
|
tf2 = Tag("tag_field") == "tag3"
|
||||||
|
combined = tf1 & tf2
|
||||||
|
assert str(combined) == "(@tag_field:{tag1|tag2} @tag_field:{tag3})"
|
||||||
|
|
||||||
|
combined = tf1 | tf2
|
||||||
|
assert str(combined) == "(@tag_field:{tag1|tag2} | @tag_field:{tag3})"
|
||||||
|
|
||||||
|
tf1 = Tag("tag_field") == []
|
||||||
|
assert str(tf1) == "*"
|
||||||
|
assert str(tf1 & tf2) == str(tf2)
|
||||||
|
assert str(tf1 | tf2) == str(tf2)
|
||||||
|
|
||||||
|
# test combining filters with None values and empty strings
|
||||||
|
tf1 = Tag("tag_field") == None # noqa: E711
|
||||||
|
tf2 = Tag("tag_field") == ""
|
||||||
|
assert str(tf1 & tf2) == "*"
|
||||||
|
|
||||||
|
tf1 = Tag("tag_field") == None # noqa: E711
|
||||||
|
tf2 = Tag("tag_field") == "tag"
|
||||||
|
assert str(tf1 & tf2) == str(tf2)
|
||||||
|
|
||||||
|
tf1 = Tag("tag_field") == None # noqa: E711
|
||||||
|
tf2 = Tag("tag_field") == ["tag1", "tag2"]
|
||||||
|
assert str(tf1 & tf2) == str(tf2)
|
||||||
|
|
||||||
|
tf1 = Tag("tag_field") == None # noqa: E711
|
||||||
|
tf2 = Tag("tag_field") != None # noqa: E711
|
||||||
|
assert str(tf1 & tf2) == "*"
|
||||||
|
|
||||||
|
tf1 = Tag("tag_field") == ""
|
||||||
|
tf2 = Tag("tag_field") == "tag"
|
||||||
|
tf3 = Tag("tag_field") == ["tag1", "tag2"]
|
||||||
|
assert str(tf1 & tf2 & tf3) == str(tf2 & tf3)
|
||||||
|
|
||||||
|
# test none filters for Tag Num Text
|
||||||
|
tf1 = Tag("tag_field") == None # noqa: E711
|
||||||
|
tf2 = Num("num_field") == None # noqa: E711
|
||||||
|
tf3 = Text("text_field") == None # noqa: E711
|
||||||
|
assert str(tf1 & tf2 & tf3) == "*"
|
||||||
|
|
||||||
|
tf1 = Tag("tag_field") != None # noqa: E711
|
||||||
|
tf2 = Num("num_field") != None # noqa: E711
|
||||||
|
tf3 = Text("text_field") != None # noqa: E711
|
||||||
|
assert str(tf1 & tf2 & tf3) == "*"
|
||||||
|
|
||||||
|
# test combinations of real and None filters
|
||||||
|
tf1 = Tag("tag_field") == "tag"
|
||||||
|
tf2 = Num("num_field") == None # noqa: E711
|
||||||
|
tf3 = Text("text_field") == None # noqa: E711
|
||||||
|
assert str(tf1 & tf2 & tf3) == str(tf1)
|
Loading…
Reference in New Issue
Block a user