Add redis self-query support (#10199)

This commit is contained in:
Bagatur
2023-09-08 16:43:16 -07:00
committed by GitHub
parent 4258c23867
commit 7203c97e8f
7 changed files with 785 additions and 72 deletions

View File

@@ -2,8 +2,8 @@
from typing import Any, Dict, List, Optional, Type, cast
from langchain import LLMChain
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains import LLMChain
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
@@ -16,6 +16,7 @@ from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.redis import RedisTranslator
from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator
from langchain.retrievers.self_query.vectara import VectaraTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
@@ -30,6 +31,7 @@ from langchain.vectorstores import (
MyScale,
Pinecone,
Qdrant,
Redis,
SupabaseVectorStore,
Vectara,
VectorStore,
@@ -39,7 +41,6 @@ from langchain.vectorstores import (
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
vectorstore_cls = vectorstore.__class__
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
@@ -53,16 +54,19 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Milvus: MilvusTranslator,
SupabaseVectorStore: SupabaseVectorTranslator,
}
if vectorstore_cls not in BUILTIN_TRANSLATORS:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore_cls}"
f" not supported."
)
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
return BUILTIN_TRANSLATORS[vectorstore_cls]()
elif isinstance(vectorstore, Redis):
return RedisTranslator.from_vectorstore(vectorstore)
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
else:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore.__class__}"
f" not supported."
)
class SelfQueryRetriever(BaseRetriever, BaseModel):
@@ -80,8 +84,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False
"""Use original query instead of the revised new query from LLM"""
use_original_query: bool = False
"""Use original query instead of the revised new query from LLM"""
class Config:
"""Configuration for this pydantic object."""

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
from typing import Any, Tuple
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
from langchain.vectorstores.redis import Redis
from langchain.vectorstores.redis.filters import (
RedisFilterExpression,
RedisFilterField,
RedisFilterOperator,
RedisNum,
RedisTag,
RedisText,
)
from langchain.vectorstores.redis.schema import RedisModel
_COMPARATOR_TO_BUILTIN_METHOD = {
Comparator.EQ: "__eq__",
Comparator.NE: "__ne__",
Comparator.LT: "__lt__",
Comparator.GT: "__gt__",
Comparator.LTE: "__le__",
Comparator.GTE: "__ge__",
Comparator.CONTAIN: "__eq__",
Comparator.LIKE: "__mod__",
}
class RedisTranslator(Visitor):
"""Translate"""
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
Comparator.CONTAIN,
Comparator.LIKE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""
def __init__(self, schema: RedisModel) -> None:
self._schema = schema
def _attribute_to_filter_field(self, attribute: str) -> RedisFilterField:
if attribute in [tf.name for tf in self._schema.text]:
return RedisText(attribute)
elif attribute in [tf.name for tf in self._schema.tag or []]:
return RedisTag(attribute)
elif attribute in [tf.name for tf in self._schema.numeric or []]:
return RedisNum(attribute)
else:
raise ValueError(
f"Invalid attribute {attribute} not in vector store schema. Schema is:"
f"\n{self._schema.as_dict()}"
)
def visit_comparison(self, comparison: Comparison) -> RedisFilterExpression:
filter_field = self._attribute_to_filter_field(comparison.attribute)
comparison_method = _COMPARATOR_TO_BUILTIN_METHOD[comparison.comparator]
return getattr(filter_field, comparison_method)(comparison.value)
def visit_operation(self, operation: Operation) -> Any:
left = operation.arguments[0].accept(self)
if len(operation.arguments) > 2:
right = self.visit_operation(
Operation(
operator=operation.operator, arguments=operation.arguments[1:]
)
)
else:
right = operation.arguments[1].accept(self)
redis_operator = (
RedisFilterOperator.OR
if operation.operator == Operator.OR
else RedisFilterOperator.AND
)
return RedisFilterExpression(operator=redis_operator, left=left, right=right)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs
@classmethod
def from_vectorstore(cls, vectorstore: Redis) -> RedisTranslator:
return cls(vectorstore._schema)

View File

@@ -1,5 +1,6 @@
from enum import Enum
from functools import wraps
from numbers import Number
from typing import Any, Callable, Dict, List, Optional, Union
from langchain.utilities.redis import TokenEscaper
@@ -56,14 +57,15 @@ class RedisFilterField:
if operator not in self.OPERATORS:
raise ValueError(
f"Operator {operator} not supported by {self.__class__.__name__}. "
+ f"Supported operators are {self.OPERATORS.values()}"
+ f"Supported operators are {self.OPERATORS.values()}."
)
if not isinstance(val, val_type):
raise TypeError(
f"Right side argument passed to operator {self.OPERATORS[operator]} "
f"with left side "
f"argument {self.__class__.__name__} must be of type {val_type}"
f"argument {self.__class__.__name__} must be of type {val_type}, "
f"received value {val}"
)
self._value = val
self._operator = operator
@@ -181,12 +183,12 @@ class RedisNum(RedisFilterField):
RedisFilterOperator.GE: ">=",
}
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "@%s:[%i %i]",
RedisFilterOperator.NE: "(-@%s:[%i %i])",
RedisFilterOperator.GT: "@%s:[(%i +inf]",
RedisFilterOperator.LT: "@%s:[-inf (%i]",
RedisFilterOperator.GE: "@%s:[%i +inf]",
RedisFilterOperator.LE: "@%s:[-inf %i]",
RedisFilterOperator.EQ: "@%s:[%f %f]",
RedisFilterOperator.NE: "(-@%s:[%f %f])",
RedisFilterOperator.GT: "@%s:[(%f +inf]",
RedisFilterOperator.LT: "@%s:[-inf (%f]",
RedisFilterOperator.GE: "@%s:[%f +inf]",
RedisFilterOperator.LE: "@%s:[-inf %f]",
}
def __str__(self) -> str:
@@ -210,83 +212,83 @@ class RedisNum(RedisFilterField):
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
@check_operator_misuse
def __eq__(self, other: int) -> "RedisFilterExpression":
def __eq__(self, other: Union[int, float]) -> "RedisFilterExpression":
"""Create a Numeric equality filter expression
Args:
other (int): The value to filter on.
other (Number): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") == 90210
"""
self._set_value(other, int, RedisFilterOperator.EQ)
self._set_value(other, Number, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: int) -> "RedisFilterExpression":
def __ne__(self, other: Union[int, float]) -> "RedisFilterExpression":
"""Create a Numeric inequality filter expression
Args:
other (int): The value to filter on.
other (Number): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") != 90210
"""
self._set_value(other, int, RedisFilterOperator.NE)
self._set_value(other, Number, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
def __gt__(self, other: int) -> "RedisFilterExpression":
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
"""Create a RedisNumeric greater than filter expression
Args:
other (int): The value to filter on.
other (Number): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") > 18
"""
self._set_value(other, int, RedisFilterOperator.GT)
self._set_value(other, Number, RedisFilterOperator.GT)
return RedisFilterExpression(str(self))
def __lt__(self, other: int) -> "RedisFilterExpression":
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
"""Create a Numeric less than filter expression
Args:
other (int): The value to filter on.
other (Number): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") < 18
"""
self._set_value(other, int, RedisFilterOperator.LT)
self._set_value(other, Number, RedisFilterOperator.LT)
return RedisFilterExpression(str(self))
def __ge__(self, other: int) -> "RedisFilterExpression":
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
"""Create a Numeric greater than or equal to filter expression
Args:
other (int): The value to filter on.
other (Number): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") >= 18
"""
self._set_value(other, int, RedisFilterOperator.GE)
self._set_value(other, Number, RedisFilterOperator.GE)
return RedisFilterExpression(str(self))
def __le__(self, other: int) -> "RedisFilterExpression":
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
"""Create a Numeric less than or equal to filter expression
Args:
other (int): The value to filter on.
other (Number): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") <= 18
"""
self._set_value(other, int, RedisFilterOperator.LE)
self._set_value(other, Number, RedisFilterOperator.LE)
return RedisFilterExpression(str(self))

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
from enum import Enum
from pathlib import Path
@@ -5,19 +7,19 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
import yaml
# ignore type error here as it's a redis-py type problem
from redis.commands.search.field import ( # type: ignore
NumericField,
TagField,
TextField,
VectorField,
)
from typing_extensions import Literal
from typing_extensions import TYPE_CHECKING, Literal
from langchain.pydantic_v1 import BaseModel, Field, validator
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
if TYPE_CHECKING:
from redis.commands.search.field import ( # type: ignore
NumericField,
TagField,
TextField,
VectorField,
)
class RedisDistanceMetric(str, Enum):
l2 = "L2"
@@ -38,6 +40,8 @@ class TextFieldSchema(RedisField):
sortable: Optional[bool] = False
def as_field(self) -> TextField:
from redis.commands.search.field import TextField # type: ignore
return TextField(
self.name,
weight=self.weight,
@@ -55,6 +59,8 @@ class TagFieldSchema(RedisField):
sortable: Optional[bool] = False
def as_field(self) -> TagField:
from redis.commands.search.field import TagField # type: ignore
return TagField(
self.name,
separator=self.separator,
@@ -69,6 +75,8 @@ class NumericFieldSchema(RedisField):
sortable: Optional[bool] = False
def as_field(self) -> NumericField:
from redis.commands.search.field import NumericField # type: ignore
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
@@ -97,6 +105,8 @@ class FlatVectorField(RedisVectorField):
block_size: int = Field(default=1000)
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
return VectorField(
self.name,
self.algorithm,
@@ -118,6 +128,8 @@ class HNSWVectorField(RedisVectorField):
epsilon: float = Field(default=0.8)
def as_field(self) -> VectorField:
from redis.commands.search.field import VectorField # type: ignore
return VectorField(
self.name,
self.algorithm,

View File

@@ -0,0 +1,122 @@
from typing import Dict, Tuple
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
)
from langchain.retrievers.self_query.redis import RedisTranslator
from langchain.vectorstores.redis.filters import (
RedisFilterExpression,
RedisNum,
RedisTag,
RedisText,
)
from langchain.vectorstores.redis.schema import (
NumericFieldSchema,
RedisModel,
TagFieldSchema,
TextFieldSchema,
)
@pytest.fixture
def translator() -> RedisTranslator:
schema = RedisModel(
text=[TextFieldSchema(name="bar")],
numeric=[NumericFieldSchema(name="foo")],
tag=[TagFieldSchema(name="tag")],
)
return RedisTranslator(schema)
@pytest.mark.parametrize(
("comp", "expected"),
[
(
Comparison(comparator=Comparator.LT, attribute="foo", value=1),
RedisNum("foo") < 1,
),
(
Comparison(comparator=Comparator.LIKE, attribute="bar", value="baz*"),
RedisText("bar") % "baz*",
),
(
Comparison(
comparator=Comparator.CONTAIN, attribute="tag", value=["blue", "green"]
),
RedisTag("tag") == ["blue", "green"],
),
],
)
def test_visit_comparison(
translator: RedisTranslator, comp: Comparison, expected: RedisFilterExpression
) -> None:
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
expected = RedisNum("foo") < 1
actual = translator.visit_comparison(comp)
assert str(expected) == str(actual)
def test_visit_operation(translator: RedisTranslator) -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
Comparison(comparator=Comparator.EQ, attribute="tag", value="high"),
],
)
expected = (RedisNum("foo") < 2) & (
(RedisText("bar") == "baz") & (RedisTag("tag") == "high")
)
actual = translator.visit_operation(op)
assert str(expected) == str(actual)
def test_visit_structured_query_no_filter(translator: RedisTranslator) -> None:
query = "What is the capital of France?"
structured_query = StructuredQuery(
query=query,
filter=None,
)
expected: Tuple[str, Dict] = (query, {})
actual = translator.visit_structured_query(structured_query)
assert expected == actual
def test_visit_structured_query_comparison(translator: RedisTranslator) -> None:
query = "What is the capital of France?"
comp = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
structured_query = StructuredQuery(
query=query,
filter=comp,
)
expected_filter = RedisNum("foo") >= 2
actual_query, actual_filter = translator.visit_structured_query(structured_query)
assert actual_query == query
assert str(actual_filter["filter"]) == str(expected_filter)
def test_visit_structured_query_operation(translator: RedisTranslator) -> None:
query = "What is the capital of France?"
op = Operation(
operator=Operator.OR,
arguments=[
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
Comparison(comparator=Comparator.CONTAIN, attribute="bar", value="baz"),
],
)
structured_query = StructuredQuery(
query=query,
filter=op,
)
expected_filter = (RedisNum("foo") == 2) | (RedisText("bar") == "baz")
actual_query, actual_filter = translator.visit_structured_query(structured_query)
assert actual_query == query
assert str(actual_filter["filter"]) == str(expected_filter)