mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
IMPROVEMENT: Minor redis improvements (#13381)
- **Description:** - Fixes a `key_prefix` bug where passing it in on `Redis.from_existing(...)` did not work properly. Updates doc strings accordingly. - Updates Redis filter classes logic with best practices on typing, string formatting, and handling "empty" filters. - Fixes a bug that would prevent multiple tag filters from being applied together in some scenarios. - Added a whole new filter unit testing module. Also updated code formatting for a number of modules that were failing the `make` commands. - **Issue:** N/A - **Dependencies:** N/A - **Tag maintainer:** @baskaryan - **Twitter handle:** @tchutch94
This commit is contained in:
parent
674bd90a47
commit
190952fe76
@ -28,7 +28,7 @@ class TokenEscaper:
|
||||
|
||||
# Characters that RediSearch requires us to escape during queries.
|
||||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
|
||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]"
|
||||
|
||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||
if escape_chars_re:
|
||||
@ -37,6 +37,12 @@ class TokenEscaper:
|
||||
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
||||
|
||||
def escape(self, value: str) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
"Value must be a string object for token escaping."
|
||||
f"Got type {type(value)}"
|
||||
)
|
||||
|
||||
def escape_symbol(match: re.Match) -> str:
|
||||
value = match.group(0)
|
||||
return f"\\{value}"
|
||||
|
@ -60,9 +60,9 @@ def check_index_exists(client: RedisType, index_name: str) -> bool:
|
||||
try:
|
||||
client.ft(index_name).info()
|
||||
except: # noqa: E722
|
||||
logger.info("Index does not exist")
|
||||
logger.debug("Index does not exist")
|
||||
return False
|
||||
logger.info("Index already exists")
|
||||
logger.debug("Index already exists")
|
||||
return True
|
||||
|
||||
|
||||
@ -155,9 +155,12 @@ class Redis(VectorStore):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
rds = Redis.from_existing_index(
|
||||
# must pass in schema and key_prefix from another index
|
||||
existing_rds = Redis.from_existing_index(
|
||||
embeddings, # an Embeddings object
|
||||
index_name="my-index",
|
||||
schema=rds.schema, # schema dumped from another index
|
||||
key_prefix=rds.key_prefix, # key prefix from another index
|
||||
redis_url="redis://localhost:6379",
|
||||
)
|
||||
|
||||
@ -249,7 +252,7 @@ class Redis(VectorStore):
|
||||
key_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
"""Initialize Redis vector store with necessary components."""
|
||||
self._check_deprecated_kwargs(kwargs)
|
||||
try:
|
||||
# TODO use importlib to check if redis is installed
|
||||
@ -401,6 +404,7 @@ class Redis(VectorStore):
|
||||
index_schema = generated_schema
|
||||
|
||||
# Create instance
|
||||
# init the class -- if Redis is unavailable, will throw exception
|
||||
instance = cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
@ -495,6 +499,7 @@ class Redis(VectorStore):
|
||||
embedding: Embeddings,
|
||||
index_name: str,
|
||||
schema: Union[Dict[str, str], str, os.PathLike],
|
||||
key_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Redis:
|
||||
"""Connect to an existing Redis index.
|
||||
@ -504,11 +509,16 @@ class Redis(VectorStore):
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redisearch = Redis.from_existing_index(
|
||||
|
||||
# must pass in schema and key_prefix from another index
|
||||
existing_rds = Redis.from_existing_index(
|
||||
embeddings,
|
||||
index_name="my-index",
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
schema=rds.schema, # schema dumped from another index
|
||||
key_prefix=rds.key_prefix, # key prefix from another index
|
||||
redis_url="redis://username:password@localhost:6379",
|
||||
)
|
||||
|
||||
Args:
|
||||
@ -516,8 +526,9 @@ class Redis(VectorStore):
|
||||
for embedding queries.
|
||||
index_name (str): Name of the index to connect to.
|
||||
schema (Union[Dict[str, str], str, os.PathLike]): Schema of the index
|
||||
and the vector schema. Can be a dict, or path to yaml file
|
||||
|
||||
and the vector schema. Can be a dict, or path to yaml file.
|
||||
key_prefix (Optional[str]): Prefix to use for all keys in Redis associated
|
||||
with this index.
|
||||
**kwargs (Any): Additional keyword arguments to pass to the Redis client.
|
||||
|
||||
Returns:
|
||||
@ -528,29 +539,32 @@ class Redis(VectorStore):
|
||||
ImportError: If the redis python package is not installed.
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
# check if redis has redisearch module installed
|
||||
check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
||||
# ensure that the index already exists
|
||||
assert check_index_exists(
|
||||
client, index_name
|
||||
), f"Index {index_name} does not exist"
|
||||
except Exception as e:
|
||||
raise ValueError(f"Redis failed to connect: {e}")
|
||||
|
||||
return cls(
|
||||
# Create instance
|
||||
# init the class -- if Redis is unavailable, will throw exception
|
||||
instance = cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
embedding,
|
||||
index_schema=schema,
|
||||
key_prefix=key_prefix,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check for existence of the declared index
|
||||
if not check_index_exists(instance.client, index_name):
|
||||
# Will only raise if the running Redis server does not
|
||||
# have a record of this particular index
|
||||
raise ValueError(
|
||||
f"Redis failed to connect: Index {index_name} does not exist."
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
@property
|
||||
def schema(self) -> Dict[str, List[Any]]:
|
||||
"""Return the schema of the index."""
|
||||
|
@ -1,7 +1,6 @@
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from langchain.utilities.redis import TokenEscaper
|
||||
|
||||
@ -57,7 +56,7 @@ class RedisFilterField:
|
||||
return self._field == other._field and self._value == other._value
|
||||
|
||||
def _set_value(
|
||||
self, val: Any, val_type: type, operator: RedisFilterOperator
|
||||
self, val: Any, val_type: Tuple[Any], operator: RedisFilterOperator
|
||||
) -> None:
|
||||
# check that the operator is supported by this class
|
||||
if operator not in self.OPERATORS:
|
||||
@ -108,15 +107,15 @@ class RedisTag(RedisFilterField):
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.IN: "==",
|
||||
}
|
||||
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:{%s}",
|
||||
RedisFilterOperator.NE: "(-@%s:{%s})",
|
||||
RedisFilterOperator.IN: "@%s:{%s}",
|
||||
}
|
||||
SUPPORTED_VAL_TYPES = (list, set, tuple, str, type(None))
|
||||
|
||||
def __init__(self, field: str):
|
||||
"""Create a RedisTag FilterField
|
||||
"""Create a RedisTag FilterField.
|
||||
|
||||
Args:
|
||||
field (str): The name of the RedisTag field in the index to be queried
|
||||
@ -125,21 +124,33 @@ class RedisTag(RedisFilterField):
|
||||
super().__init__(field)
|
||||
|
||||
def _set_tag_value(
|
||||
self, other: Union[List[str], str], operator: RedisFilterOperator
|
||||
self,
|
||||
other: Union[List[str], Set[str], Tuple[str], str],
|
||||
operator: RedisFilterOperator,
|
||||
) -> None:
|
||||
if isinstance(other, list):
|
||||
if not all(isinstance(tag, str) for tag in other):
|
||||
raise ValueError("All tags must be strings")
|
||||
else:
|
||||
if isinstance(other, (list, set, tuple)):
|
||||
try:
|
||||
# "if val" clause removes non-truthy values from list
|
||||
other = [str(val) for val in other if val]
|
||||
except ValueError:
|
||||
raise ValueError("All tags within collection must be strings")
|
||||
# above to catch the "" case
|
||||
elif not other:
|
||||
other = []
|
||||
elif isinstance(other, str):
|
||||
other = [other]
|
||||
self._set_value(other, list, operator)
|
||||
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, operator) # type: ignore
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag equality filter expression
|
||||
def __eq__(
|
||||
self, other: Union[List[str], Set[str], Tuple[str], str]
|
||||
) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag equality filter expression.
|
||||
|
||||
Args:
|
||||
other (Union[List[str], str]): The tag(s) to filter on.
|
||||
other (Union[List[str], Set[str], Tuple[str], str]):
|
||||
The tag(s) to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisTag
|
||||
@ -149,11 +160,14 @@ class RedisTag(RedisFilterField):
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag inequality filter expression
|
||||
def __ne__(
|
||||
self, other: Union[List[str], Set[str], Tuple[str], str]
|
||||
) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag inequality filter expression.
|
||||
|
||||
Args:
|
||||
other (Union[List[str], str]): The tag(s) to filter on.
|
||||
other (Union[List[str], Set[str], Tuple[str], str]):
|
||||
The tag(s) to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisTag
|
||||
@ -167,12 +181,10 @@ class RedisTag(RedisFilterField):
|
||||
return "|".join([self.escaper.escape(tag) for tag in self._value])
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the query syntax for a RedisTag filter expression."""
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
"""Return the Redis Query syntax for a RedisTag filter expression"""
|
||||
return "*"
|
||||
|
||||
return self.OPERATOR_MAP[self._operator] % (
|
||||
self._field,
|
||||
self._formatted_tag_value,
|
||||
@ -191,21 +203,19 @@ class RedisNum(RedisFilterField):
|
||||
RedisFilterOperator.GE: ">=",
|
||||
}
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:[%f %f]",
|
||||
RedisFilterOperator.NE: "(-@%s:[%f %f])",
|
||||
RedisFilterOperator.GT: "@%s:[(%f +inf]",
|
||||
RedisFilterOperator.LT: "@%s:[-inf (%f]",
|
||||
RedisFilterOperator.GE: "@%s:[%f +inf]",
|
||||
RedisFilterOperator.LE: "@%s:[-inf %f]",
|
||||
RedisFilterOperator.EQ: "@%s:[%s %s]",
|
||||
RedisFilterOperator.NE: "(-@%s:[%s %s])",
|
||||
RedisFilterOperator.GT: "@%s:[(%s +inf]",
|
||||
RedisFilterOperator.LT: "@%s:[-inf (%s]",
|
||||
RedisFilterOperator.GE: "@%s:[%s +inf]",
|
||||
RedisFilterOperator.LE: "@%s:[-inf %s]",
|
||||
}
|
||||
SUPPORTED_VAL_TYPES = (int, float, type(None))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the Redis Query syntax for a Numeric filter expression"""
|
||||
"""Return the query syntax for a RedisNum filter expression."""
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
return "*"
|
||||
|
||||
if (
|
||||
self._operator == RedisFilterOperator.EQ
|
||||
@ -221,102 +231,103 @@ class RedisNum(RedisFilterField):
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric equality filter expression
|
||||
"""Create a Numeric equality filter expression.
|
||||
|
||||
Args:
|
||||
other (Number): The value to filter on.
|
||||
other (Union[int, float]): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") == 90210
|
||||
"""
|
||||
self._set_value(other, Number, RedisFilterOperator.EQ)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric inequality filter expression
|
||||
"""Create a Numeric inequality filter expression.
|
||||
|
||||
Args:
|
||||
other (Number): The value to filter on.
|
||||
other (Union[int, float]): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") != 90210
|
||||
"""
|
||||
self._set_value(other, Number, RedisFilterOperator.NE)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a RedisNumeric greater than filter expression
|
||||
"""Create a Numeric greater than filter expression.
|
||||
|
||||
Args:
|
||||
other (Number): The value to filter on.
|
||||
other (Union[int, float]): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") > 18
|
||||
"""
|
||||
self._set_value(other, Number, RedisFilterOperator.GT)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GT) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than filter expression
|
||||
"""Create a Numeric less than filter expression.
|
||||
|
||||
Args:
|
||||
other (Number): The value to filter on.
|
||||
other (Union[int, float]): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") < 18
|
||||
"""
|
||||
self._set_value(other, Number, RedisFilterOperator.LT)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LT) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric greater than or equal to filter expression
|
||||
"""Create a Numeric greater than or equal to filter expression.
|
||||
|
||||
Args:
|
||||
other (Number): The value to filter on.
|
||||
other (Union[int, float]): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") >= 18
|
||||
"""
|
||||
self._set_value(other, Number, RedisFilterOperator.GE)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.GE) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than or equal to filter expression
|
||||
"""Create a Numeric less than or equal to filter expression.
|
||||
|
||||
Args:
|
||||
other (Number): The value to filter on.
|
||||
other (Union[int, float]): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") <= 18
|
||||
"""
|
||||
self._set_value(other, Number, RedisFilterOperator.LE)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LE) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
|
||||
class RedisText(RedisFilterField):
|
||||
"""A RedisFilterField representing a text field in a Redis index."""
|
||||
|
||||
OPERATORS = {
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.LIKE: "%",
|
||||
}
|
||||
OPERATOR_MAP = {
|
||||
RedisFilterOperator.EQ: '@%s:"%s"',
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: '@%s:("%s")',
|
||||
RedisFilterOperator.NE: '(-@%s:"%s")',
|
||||
RedisFilterOperator.LIKE: "@%s:%s",
|
||||
RedisFilterOperator.LIKE: "@%s:(%s)",
|
||||
}
|
||||
SUPPORTED_VAL_TYPES = (str, type(None))
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText equality filter expression
|
||||
"""Create a RedisText equality (exact match) filter expression.
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
@ -325,12 +336,12 @@ class RedisText(RedisFilterField):
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") == "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.EQ)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.EQ) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText inequality filter expression
|
||||
"""Create a RedisText inequality filter expression.
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
@ -339,33 +350,34 @@ class RedisText(RedisFilterField):
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") != "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.NE)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.NE) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __mod__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText like filter expression
|
||||
"""Create a RedisText "LIKE" filter expression.
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") % "engineer"
|
||||
>>> filter = RedisText("job") % "engine*" # suffix wild card match
|
||||
>>> filter = RedisText("job") % "%%engine%%" # fuzzy match w/ LD
|
||||
>>> filter = RedisText("job") % "engineer|doctor" # contains either term
|
||||
>>> filter = RedisText("job") % "engineer doctor" # contains both terms
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.LIKE)
|
||||
self._set_value(other, self.SUPPORTED_VAL_TYPES, RedisFilterOperator.LIKE) # type: ignore
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the query syntax for a RedisText filter expression."""
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
return "*"
|
||||
|
||||
try:
|
||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||
except KeyError:
|
||||
raise Exception("Invalid operator")
|
||||
return self.OPERATOR_MAP[self._operator] % (
|
||||
self._field,
|
||||
self._value,
|
||||
)
|
||||
|
||||
|
||||
class RedisFilterExpression:
|
||||
@ -413,16 +425,36 @@ class RedisFilterExpression:
|
||||
operator=RedisFilterOperator.OR, left=self, right=other
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_expression(
|
||||
left: "RedisFilterExpression", right: "RedisFilterExpression", operator_str: str
|
||||
) -> str:
|
||||
_left, _right = str(left), str(right)
|
||||
if _left == _right == "*":
|
||||
return _left
|
||||
if _left == "*" != _right:
|
||||
return _right
|
||||
if _right == "*" != _left:
|
||||
return _left
|
||||
return f"({_left}{operator_str}{_right})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
# top level check that allows recursive calls to __str__
|
||||
if not self._filter and not self._operator:
|
||||
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||
|
||||
# allow for single filter expression without operators as last
|
||||
# expression in the chain might not have an operator
|
||||
# if there's an operator, combine expressions accordingly
|
||||
if self._operator:
|
||||
if not isinstance(self._left, RedisFilterExpression) or not isinstance(
|
||||
self._right, RedisFilterExpression
|
||||
):
|
||||
raise TypeError(
|
||||
"Improper combination of filters."
|
||||
"Both left and right should be type FilterExpression"
|
||||
)
|
||||
|
||||
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
|
||||
return f"({str(self._left)}{operator_str}{str(self._right)})"
|
||||
return self.format_expression(self._left, self._right, operator_str)
|
||||
|
||||
# check that base case, the filter is set
|
||||
if not self._filter:
|
||||
|
@ -0,0 +1,193 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.vectorstores.redis import (
|
||||
RedisNum as Num,
|
||||
)
|
||||
from langchain.vectorstores.redis import (
|
||||
RedisTag as Tag,
|
||||
)
|
||||
from langchain.vectorstores.redis import (
|
||||
RedisText as Text,
|
||||
)
|
||||
|
||||
|
||||
# Test cases for various tag scenarios
|
||||
@pytest.mark.parametrize(
|
||||
"operation,tags,expected",
|
||||
[
|
||||
# Testing single tags
|
||||
("==", "simpletag", "@tag_field:{simpletag}"),
|
||||
(
|
||||
"==",
|
||||
"tag with space",
|
||||
"@tag_field:{tag\\ with\\ space}",
|
||||
), # Escaping spaces within quotes
|
||||
(
|
||||
"==",
|
||||
"special$char",
|
||||
"@tag_field:{special\\$char}",
|
||||
), # Escaping a special character
|
||||
("!=", "negated", "(-@tag_field:{negated})"),
|
||||
# Testing multiple tags
|
||||
("==", ["tag1", "tag2"], "@tag_field:{tag1|tag2}"),
|
||||
(
|
||||
"==",
|
||||
["alpha", "beta with space", "gamma$special"],
|
||||
"@tag_field:{alpha|beta\\ with\\ space|gamma\\$special}",
|
||||
), # Multiple tags with spaces and special chars
|
||||
("!=", ["tagA", "tagB"], "(-@tag_field:{tagA|tagB})"),
|
||||
# Complex tag scenarios with special characters
|
||||
("==", "weird:tag", "@tag_field:{weird\\:tag}"), # Tags with colon
|
||||
("==", "tag&another", "@tag_field:{tag\\&another}"), # Tags with ampersand
|
||||
# Escaping various special characters within tags
|
||||
("==", "tag/with/slashes", "@tag_field:{tag\\/with\\/slashes}"),
|
||||
(
|
||||
"==",
|
||||
["hyphen-tag", "under_score", "dot.tag"],
|
||||
"@tag_field:{hyphen\\-tag|under_score|dot\\.tag}",
|
||||
),
|
||||
# ...additional unique cases as desired...
|
||||
],
|
||||
)
|
||||
def test_tag_filter_varied(operation: str, tags: str, expected: str) -> None:
|
||||
if operation == "==":
|
||||
tf = Tag("tag_field") == tags
|
||||
elif operation == "!=":
|
||||
tf = Tag("tag_field") != tags
|
||||
else:
|
||||
raise ValueError(f"Unsupported operation: {operation}")
|
||||
|
||||
# Verify the string representation matches the expected RediSearch query part
|
||||
assert str(tf) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value, expected",
|
||||
[
|
||||
(None, "*"),
|
||||
([], "*"),
|
||||
("", "*"),
|
||||
([None], "*"),
|
||||
([None, "tag"], "@tag_field:{tag}"),
|
||||
],
|
||||
ids=[
|
||||
"none",
|
||||
"empty_list",
|
||||
"empty_string",
|
||||
"list_with_none",
|
||||
"list_with_none_and_tag",
|
||||
],
|
||||
)
|
||||
def test_nullable_tags(value: Any, expected: str) -> None:
|
||||
tag = Tag("tag_field")
|
||||
assert str(tag == value) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation, value, expected",
|
||||
[
|
||||
("__eq__", 5, "@numeric_field:[5 5]"),
|
||||
("__ne__", 5, "(-@numeric_field:[5 5])"),
|
||||
("__gt__", 5, "@numeric_field:[(5 +inf]"),
|
||||
("__ge__", 5, "@numeric_field:[5 +inf]"),
|
||||
("__lt__", 5.55, "@numeric_field:[-inf (5.55]"),
|
||||
("__le__", 5, "@numeric_field:[-inf 5]"),
|
||||
("__le__", None, "*"),
|
||||
("__eq__", None, "*"),
|
||||
("__ne__", None, "*"),
|
||||
],
|
||||
ids=["eq", "ne", "gt", "ge", "lt", "le", "le_none", "eq_none", "ne_none"],
|
||||
)
|
||||
def test_numeric_filter(operation: str, value: Any, expected: str) -> None:
|
||||
nf = Num("numeric_field")
|
||||
assert str(getattr(nf, operation)(value)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operation, value, expected",
|
||||
[
|
||||
("__eq__", "text", '@text_field:("text")'),
|
||||
("__ne__", "text", '(-@text_field:"text")'),
|
||||
("__eq__", "", "*"),
|
||||
("__ne__", "", "*"),
|
||||
("__eq__", None, "*"),
|
||||
("__ne__", None, "*"),
|
||||
("__mod__", "text", "@text_field:(text)"),
|
||||
("__mod__", "tex*", "@text_field:(tex*)"),
|
||||
("__mod__", "%text%", "@text_field:(%text%)"),
|
||||
("__mod__", "", "*"),
|
||||
("__mod__", None, "*"),
|
||||
],
|
||||
ids=[
|
||||
"eq",
|
||||
"ne",
|
||||
"eq-empty",
|
||||
"ne-empty",
|
||||
"eq-none",
|
||||
"ne-none",
|
||||
"like",
|
||||
"like_wildcard",
|
||||
"like_full",
|
||||
"like_empty",
|
||||
"like_none",
|
||||
],
|
||||
)
|
||||
def test_text_filter(operation: str, value: Any, expected: str) -> None:
|
||||
txt_f = getattr(Text("text_field"), operation)(value)
|
||||
assert str(txt_f) == expected
|
||||
|
||||
|
||||
def test_filters_combination() -> None:
|
||||
tf1 = Tag("tag_field") == ["tag1", "tag2"]
|
||||
tf2 = Tag("tag_field") == "tag3"
|
||||
combined = tf1 & tf2
|
||||
assert str(combined) == "(@tag_field:{tag1|tag2} @tag_field:{tag3})"
|
||||
|
||||
combined = tf1 | tf2
|
||||
assert str(combined) == "(@tag_field:{tag1|tag2} | @tag_field:{tag3})"
|
||||
|
||||
tf1 = Tag("tag_field") == []
|
||||
assert str(tf1) == "*"
|
||||
assert str(tf1 & tf2) == str(tf2)
|
||||
assert str(tf1 | tf2) == str(tf2)
|
||||
|
||||
# test combining filters with None values and empty strings
|
||||
tf1 = Tag("tag_field") == None # noqa: E711
|
||||
tf2 = Tag("tag_field") == ""
|
||||
assert str(tf1 & tf2) == "*"
|
||||
|
||||
tf1 = Tag("tag_field") == None # noqa: E711
|
||||
tf2 = Tag("tag_field") == "tag"
|
||||
assert str(tf1 & tf2) == str(tf2)
|
||||
|
||||
tf1 = Tag("tag_field") == None # noqa: E711
|
||||
tf2 = Tag("tag_field") == ["tag1", "tag2"]
|
||||
assert str(tf1 & tf2) == str(tf2)
|
||||
|
||||
tf1 = Tag("tag_field") == None # noqa: E711
|
||||
tf2 = Tag("tag_field") != None # noqa: E711
|
||||
assert str(tf1 & tf2) == "*"
|
||||
|
||||
tf1 = Tag("tag_field") == ""
|
||||
tf2 = Tag("tag_field") == "tag"
|
||||
tf3 = Tag("tag_field") == ["tag1", "tag2"]
|
||||
assert str(tf1 & tf2 & tf3) == str(tf2 & tf3)
|
||||
|
||||
# test none filters for Tag Num Text
|
||||
tf1 = Tag("tag_field") == None # noqa: E711
|
||||
tf2 = Num("num_field") == None # noqa: E711
|
||||
tf3 = Text("text_field") == None # noqa: E711
|
||||
assert str(tf1 & tf2 & tf3) == "*"
|
||||
|
||||
tf1 = Tag("tag_field") != None # noqa: E711
|
||||
tf2 = Num("num_field") != None # noqa: E711
|
||||
tf3 = Text("text_field") != None # noqa: E711
|
||||
assert str(tf1 & tf2 & tf3) == "*"
|
||||
|
||||
# test combinations of real and None filters
|
||||
tf1 = Tag("tag_field") == "tag"
|
||||
tf2 = Num("num_field") == None # noqa: E711
|
||||
tf3 = Text("text_field") == None # noqa: E711
|
||||
assert str(tf1 & tf2 & tf3) == str(tf1)
|
Loading…
Reference in New Issue
Block a user