mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
Add redis self-query support (#10199)
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.query_constructor.base import load_query_constructor_chain
|
||||
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
@@ -16,6 +16,7 @@ from langchain.retrievers.self_query.milvus import MilvusTranslator
|
||||
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
||||
from langchain.retrievers.self_query.pinecone import PineconeTranslator
|
||||
from langchain.retrievers.self_query.qdrant import QdrantTranslator
|
||||
from langchain.retrievers.self_query.redis import RedisTranslator
|
||||
from langchain.retrievers.self_query.supabase import SupabaseVectorTranslator
|
||||
from langchain.retrievers.self_query.vectara import VectaraTranslator
|
||||
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
||||
@@ -30,6 +31,7 @@ from langchain.vectorstores import (
|
||||
MyScale,
|
||||
Pinecone,
|
||||
Qdrant,
|
||||
Redis,
|
||||
SupabaseVectorStore,
|
||||
Vectara,
|
||||
VectorStore,
|
||||
@@ -39,7 +41,6 @@ from langchain.vectorstores import (
|
||||
|
||||
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||
"""Get the translator class corresponding to the vector store class."""
|
||||
vectorstore_cls = vectorstore.__class__
|
||||
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
|
||||
Pinecone: PineconeTranslator,
|
||||
Chroma: ChromaTranslator,
|
||||
@@ -53,16 +54,19 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||
Milvus: MilvusTranslator,
|
||||
SupabaseVectorStore: SupabaseVectorTranslator,
|
||||
}
|
||||
if vectorstore_cls not in BUILTIN_TRANSLATORS:
|
||||
raise ValueError(
|
||||
f"Self query retriever with Vector Store type {vectorstore_cls}"
|
||||
f" not supported."
|
||||
)
|
||||
if isinstance(vectorstore, Qdrant):
|
||||
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
|
||||
elif isinstance(vectorstore, MyScale):
|
||||
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
|
||||
return BUILTIN_TRANSLATORS[vectorstore_cls]()
|
||||
elif isinstance(vectorstore, Redis):
|
||||
return RedisTranslator.from_vectorstore(vectorstore)
|
||||
elif vectorstore.__class__ in BUILTIN_TRANSLATORS:
|
||||
return BUILTIN_TRANSLATORS[vectorstore.__class__]()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Self query retriever with Vector Store type {vectorstore.__class__}"
|
||||
f" not supported."
|
||||
)
|
||||
|
||||
|
||||
class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
@@ -80,8 +84,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
structured_query_translator: Visitor
|
||||
"""Translator for turning internal query language into vectorstore search params."""
|
||||
verbose: bool = False
|
||||
"""Use original query instead of the revised new query from LLM"""
|
||||
|
||||
use_original_query: bool = False
|
||||
"""Use original query instead of the revised new query from LLM"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
102
libs/langchain/langchain/retrievers/self_query/redis.py
Normal file
102
libs/langchain/langchain/retrievers/self_query/redis.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
Visitor,
|
||||
)
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from langchain.vectorstores.redis.filters import (
|
||||
RedisFilterExpression,
|
||||
RedisFilterField,
|
||||
RedisFilterOperator,
|
||||
RedisNum,
|
||||
RedisTag,
|
||||
RedisText,
|
||||
)
|
||||
from langchain.vectorstores.redis.schema import RedisModel
|
||||
|
||||
_COMPARATOR_TO_BUILTIN_METHOD = {
|
||||
Comparator.EQ: "__eq__",
|
||||
Comparator.NE: "__ne__",
|
||||
Comparator.LT: "__lt__",
|
||||
Comparator.GT: "__gt__",
|
||||
Comparator.LTE: "__le__",
|
||||
Comparator.GTE: "__ge__",
|
||||
Comparator.CONTAIN: "__eq__",
|
||||
Comparator.LIKE: "__mod__",
|
||||
}
|
||||
|
||||
|
||||
class RedisTranslator(Visitor):
|
||||
"""Translate"""
|
||||
|
||||
allowed_comparators = (
|
||||
Comparator.EQ,
|
||||
Comparator.NE,
|
||||
Comparator.LT,
|
||||
Comparator.LTE,
|
||||
Comparator.GT,
|
||||
Comparator.GTE,
|
||||
Comparator.CONTAIN,
|
||||
Comparator.LIKE,
|
||||
)
|
||||
"""Subset of allowed logical comparators."""
|
||||
allowed_operators = (Operator.AND, Operator.OR)
|
||||
"""Subset of allowed logical operators."""
|
||||
|
||||
def __init__(self, schema: RedisModel) -> None:
|
||||
self._schema = schema
|
||||
|
||||
def _attribute_to_filter_field(self, attribute: str) -> RedisFilterField:
|
||||
if attribute in [tf.name for tf in self._schema.text]:
|
||||
return RedisText(attribute)
|
||||
elif attribute in [tf.name for tf in self._schema.tag or []]:
|
||||
return RedisTag(attribute)
|
||||
elif attribute in [tf.name for tf in self._schema.numeric or []]:
|
||||
return RedisNum(attribute)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attribute {attribute} not in vector store schema. Schema is:"
|
||||
f"\n{self._schema.as_dict()}"
|
||||
)
|
||||
|
||||
def visit_comparison(self, comparison: Comparison) -> RedisFilterExpression:
|
||||
filter_field = self._attribute_to_filter_field(comparison.attribute)
|
||||
comparison_method = _COMPARATOR_TO_BUILTIN_METHOD[comparison.comparator]
|
||||
return getattr(filter_field, comparison_method)(comparison.value)
|
||||
|
||||
def visit_operation(self, operation: Operation) -> Any:
|
||||
left = operation.arguments[0].accept(self)
|
||||
if len(operation.arguments) > 2:
|
||||
right = self.visit_operation(
|
||||
Operation(
|
||||
operator=operation.operator, arguments=operation.arguments[1:]
|
||||
)
|
||||
)
|
||||
else:
|
||||
right = operation.arguments[1].accept(self)
|
||||
redis_operator = (
|
||||
RedisFilterOperator.OR
|
||||
if operation.operator == Operator.OR
|
||||
else RedisFilterOperator.AND
|
||||
)
|
||||
return RedisFilterExpression(operator=redis_operator, left=left, right=right)
|
||||
|
||||
def visit_structured_query(
|
||||
self, structured_query: StructuredQuery
|
||||
) -> Tuple[str, dict]:
|
||||
if structured_query.filter is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"filter": structured_query.filter.accept(self)}
|
||||
return structured_query.query, kwargs
|
||||
|
||||
@classmethod
|
||||
def from_vectorstore(cls, vectorstore: Redis) -> RedisTranslator:
|
||||
return cls(vectorstore._schema)
|
@@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.utilities.redis import TokenEscaper
|
||||
@@ -56,14 +57,15 @@ class RedisFilterField:
|
||||
if operator not in self.OPERATORS:
|
||||
raise ValueError(
|
||||
f"Operator {operator} not supported by {self.__class__.__name__}. "
|
||||
+ f"Supported operators are {self.OPERATORS.values()}"
|
||||
+ f"Supported operators are {self.OPERATORS.values()}."
|
||||
)
|
||||
|
||||
if not isinstance(val, val_type):
|
||||
raise TypeError(
|
||||
f"Right side argument passed to operator {self.OPERATORS[operator]} "
|
||||
f"with left side "
|
||||
f"argument {self.__class__.__name__} must be of type {val_type}"
|
||||
f"argument {self.__class__.__name__} must be of type {val_type}, "
|
||||
f"received value {val}"
|
||||
)
|
||||
self._value = val
|
||||
self._operator = operator
|
||||
@@ -181,12 +183,12 @@ class RedisNum(RedisFilterField):
|
||||
RedisFilterOperator.GE: ">=",
|
||||
}
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:[%i %i]",
|
||||
RedisFilterOperator.NE: "(-@%s:[%i %i])",
|
||||
RedisFilterOperator.GT: "@%s:[(%i +inf]",
|
||||
RedisFilterOperator.LT: "@%s:[-inf (%i]",
|
||||
RedisFilterOperator.GE: "@%s:[%i +inf]",
|
||||
RedisFilterOperator.LE: "@%s:[-inf %i]",
|
||||
RedisFilterOperator.EQ: "@%s:[%f %f]",
|
||||
RedisFilterOperator.NE: "(-@%s:[%f %f])",
|
||||
RedisFilterOperator.GT: "@%s:[(%f +inf]",
|
||||
RedisFilterOperator.LT: "@%s:[-inf (%f]",
|
||||
RedisFilterOperator.GE: "@%s:[%f +inf]",
|
||||
RedisFilterOperator.LE: "@%s:[-inf %f]",
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -210,83 +212,83 @@ class RedisNum(RedisFilterField):
|
||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: int) -> "RedisFilterExpression":
|
||||
def __eq__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric equality filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
other (Number): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") == 90210
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.EQ)
|
||||
self._set_value(other, Number, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: int) -> "RedisFilterExpression":
|
||||
def __ne__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric inequality filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
other (Number): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") != 90210
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.NE)
|
||||
self._set_value(other, Number, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __gt__(self, other: int) -> "RedisFilterExpression":
|
||||
def __gt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a RedisNumeric greater than filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
other (Number): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") > 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.GT)
|
||||
self._set_value(other, Number, RedisFilterOperator.GT)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __lt__(self, other: int) -> "RedisFilterExpression":
|
||||
def __lt__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
other (Number): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") < 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.LT)
|
||||
self._set_value(other, Number, RedisFilterOperator.LT)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __ge__(self, other: int) -> "RedisFilterExpression":
|
||||
def __ge__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric greater than or equal to filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
other (Number): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") >= 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.GE)
|
||||
self._set_value(other, Number, RedisFilterOperator.GE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __le__(self, other: int) -> "RedisFilterExpression":
|
||||
def __le__(self, other: Union[int, float]) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than or equal to filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
other (Number): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") <= 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.LE)
|
||||
self._set_value(other, Number, RedisFilterOperator.LE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
@@ -5,19 +7,19 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
# ignore type error here as it's a redis-py type problem
|
||||
from redis.commands.search.field import ( # type: ignore
|
||||
NumericField,
|
||||
TagField,
|
||||
TextField,
|
||||
VectorField,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
from typing_extensions import TYPE_CHECKING, Literal
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field, validator
|
||||
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.commands.search.field import ( # type: ignore
|
||||
NumericField,
|
||||
TagField,
|
||||
TextField,
|
||||
VectorField,
|
||||
)
|
||||
|
||||
|
||||
class RedisDistanceMetric(str, Enum):
|
||||
l2 = "L2"
|
||||
@@ -38,6 +40,8 @@ class TextFieldSchema(RedisField):
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TextField:
|
||||
from redis.commands.search.field import TextField # type: ignore
|
||||
|
||||
return TextField(
|
||||
self.name,
|
||||
weight=self.weight,
|
||||
@@ -55,6 +59,8 @@ class TagFieldSchema(RedisField):
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TagField:
|
||||
from redis.commands.search.field import TagField # type: ignore
|
||||
|
||||
return TagField(
|
||||
self.name,
|
||||
separator=self.separator,
|
||||
@@ -69,6 +75,8 @@ class NumericFieldSchema(RedisField):
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> NumericField:
|
||||
from redis.commands.search.field import NumericField # type: ignore
|
||||
|
||||
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
|
||||
|
||||
|
||||
@@ -97,6 +105,8 @@ class FlatVectorField(RedisVectorField):
|
||||
block_size: int = Field(default=1000)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
from redis.commands.search.field import VectorField # type: ignore
|
||||
|
||||
return VectorField(
|
||||
self.name,
|
||||
self.algorithm,
|
||||
@@ -118,6 +128,8 @@ class HNSWVectorField(RedisVectorField):
|
||||
epsilon: float = Field(default=0.8)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
from redis.commands.search.field import VectorField # type: ignore
|
||||
|
||||
return VectorField(
|
||||
self.name,
|
||||
self.algorithm,
|
||||
|
@@ -0,0 +1,122 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.redis import RedisTranslator
|
||||
from langchain.vectorstores.redis.filters import (
|
||||
RedisFilterExpression,
|
||||
RedisNum,
|
||||
RedisTag,
|
||||
RedisText,
|
||||
)
|
||||
from langchain.vectorstores.redis.schema import (
|
||||
NumericFieldSchema,
|
||||
RedisModel,
|
||||
TagFieldSchema,
|
||||
TextFieldSchema,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def translator() -> RedisTranslator:
|
||||
schema = RedisModel(
|
||||
text=[TextFieldSchema(name="bar")],
|
||||
numeric=[NumericFieldSchema(name="foo")],
|
||||
tag=[TagFieldSchema(name="tag")],
|
||||
)
|
||||
return RedisTranslator(schema)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("comp", "expected"),
|
||||
[
|
||||
(
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=1),
|
||||
RedisNum("foo") < 1,
|
||||
),
|
||||
(
|
||||
Comparison(comparator=Comparator.LIKE, attribute="bar", value="baz*"),
|
||||
RedisText("bar") % "baz*",
|
||||
),
|
||||
(
|
||||
Comparison(
|
||||
comparator=Comparator.CONTAIN, attribute="tag", value=["blue", "green"]
|
||||
),
|
||||
RedisTag("tag") == ["blue", "green"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_visit_comparison(
|
||||
translator: RedisTranslator, comp: Comparison, expected: RedisFilterExpression
|
||||
) -> None:
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
|
||||
expected = RedisNum("foo") < 1
|
||||
actual = translator.visit_comparison(comp)
|
||||
assert str(expected) == str(actual)
|
||||
|
||||
|
||||
def test_visit_operation(translator: RedisTranslator) -> None:
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
Comparison(comparator=Comparator.EQ, attribute="tag", value="high"),
|
||||
],
|
||||
)
|
||||
expected = (RedisNum("foo") < 2) & (
|
||||
(RedisText("bar") == "baz") & (RedisTag("tag") == "high")
|
||||
)
|
||||
actual = translator.visit_operation(op)
|
||||
assert str(expected) == str(actual)
|
||||
|
||||
|
||||
def test_visit_structured_query_no_filter(translator: RedisTranslator) -> None:
|
||||
query = "What is the capital of France?"
|
||||
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = translator.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query_comparison(translator: RedisTranslator) -> None:
|
||||
query = "What is the capital of France?"
|
||||
comp = Comparison(comparator=Comparator.GTE, attribute="foo", value=2)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
expected_filter = RedisNum("foo") >= 2
|
||||
actual_query, actual_filter = translator.visit_structured_query(structured_query)
|
||||
assert actual_query == query
|
||||
assert str(actual_filter["filter"]) == str(expected_filter)
|
||||
|
||||
|
||||
def test_visit_structured_query_operation(translator: RedisTranslator) -> None:
|
||||
query = "What is the capital of France?"
|
||||
op = Operation(
|
||||
operator=Operator.OR,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.CONTAIN, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected_filter = (RedisNum("foo") == 2) | (RedisText("bar") == "baz")
|
||||
actual_query, actual_filter = translator.visit_structured_query(structured_query)
|
||||
assert actual_query == query
|
||||
assert str(actual_filter["filter"]) == str(expected_filter)
|
Reference in New Issue
Block a user