mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-30 10:23:30 +00:00
Redis metadata filtering and specification, index customization (#8612)
### Description The previous Redis implementation did not allow for the user to specify the index configuration (i.e. changing the underlying algorithm) or add additional metadata to use for querying (i.e. hybrid or "filtered" search). This PR introduces the ability to specify custom index attributes and metadata attributes as well as use that metadata in filtered queries. Overall, more structure was introduced to the Redis implementation that should allow for easier maintainability moving forward. # New Features The following features are now available with the Redis integration into Langchain ## Index schema generation The schema for the index will now be automatically generated if not specified by the user. For example, the data above has the multiple metadata categories. The the following example ```python from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores.redis import Redis embeddings = OpenAIEmbeddings() rds, keys = Redis.from_texts_return_keys( texts, embeddings, metadatas=metadata, redis_url="redis://localhost:6379", index_name="users" ) ``` Loading the data in through this and the other ``from_documents`` and ``from_texts`` methods will now generate index schema in Redis like the following. view index schema with the ``redisvl`` tool. [link](redisvl.com) ```bash $ rvl index info -i users ``` Index Information: | Index Name | Storage Type | Prefixes | Index Options | Indexing | |--------------|----------------|---------------|-----------------|------------| | users | HASH | ['doc:users'] | [] | 0 | Index Fields: | Name | Attribute | Type | Field Option | Option Value | |----------------|----------------|---------|----------------|----------------| | user | user | TEXT | WEIGHT | 1 | | job | job | TEXT | WEIGHT | 1 | | credit_score | credit_score | TEXT | WEIGHT | 1 | | content | content | TEXT | WEIGHT | 1 | | age | age | NUMERIC | | | | content_vector | content_vector | VECTOR | | | ### Custom Metadata specification The metadata schema generation has the following rules 1. All text fields are indexed as text fields. 2. All numeric fields are index as numeric fields. If you would like to have a text field as a tag field, users can specify overrides like the following for the example data ```python # this can also be a path to a yaml file index_schema = { "text": [{"name": "user"}, {"name": "job"}], "tag": [{"name": "credit_score"}], "numeric": [{"name": "age"}], } rds, keys = Redis.from_texts_return_keys( texts, embeddings, metadatas=metadata, redis_url="redis://localhost:6379", index_name="users" ) ``` This will change the index specification to Index Information: | Index Name | Storage Type | Prefixes | Index Options | Indexing | |--------------|----------------|----------------|-----------------|------------| | users2 | HASH | ['doc:users2'] | [] | 0 | Index Fields: | Name | Attribute | Type | Field Option | Option Value | |----------------|----------------|---------|----------------|----------------| | user | user | TEXT | WEIGHT | 1 | | job | job | TEXT | WEIGHT | 1 | | content | content | TEXT | WEIGHT | 1 | | credit_score | credit_score | TAG | SEPARATOR | , | | age | age | NUMERIC | | | | content_vector | content_vector | VECTOR | | | and throw a warning to the user (log output) that the generated schema does not match the specified schema. ```text index_schema does not match generated schema from metadata. index_schema: {'text': [{'name': 'user'}, {'name': 'job'}], 'tag': [{'name': 'credit_score'}], 'numeric': [{'name': 'age'}]} generated_schema: {'text': [{'name': 'user'}, {'name': 'job'}, {'name': 'credit_score'}], 'numeric': [{'name': 'age'}]} ``` As long as this is on purpose, this is fine. The schema can be defined as a yaml file or a dictionary ```yaml text: - name: user - name: job tag: - name: credit_score numeric: - name: age ``` and you pass in a path like ```python rds, keys = Redis.from_texts_return_keys( texts, embeddings, metadatas=metadata, redis_url="redis://localhost:6379", index_name="users3", index_schema=Path("sample1.yml").resolve() ) ``` Which will create the same schema as defined in the dictionary example Index Information: | Index Name | Storage Type | Prefixes | Index Options | Indexing | |--------------|----------------|----------------|-----------------|------------| | users3 | HASH | ['doc:users3'] | [] | 0 | Index Fields: | Name | Attribute | Type | Field Option | Option Value | |----------------|----------------|---------|----------------|----------------| | user | user | TEXT | WEIGHT | 1 | | job | job | TEXT | WEIGHT | 1 | | content | content | TEXT | WEIGHT | 1 | | credit_score | credit_score | TAG | SEPARATOR | , | | age | age | NUMERIC | | | | content_vector | content_vector | VECTOR | | | ### Custom Vector Indexing Schema Users with large use cases may want to change how they formulate the vector index created by Langchain To utilize all the features of Redis for vector database use cases like this, you can now do the following to pass in index attribute modifiers like changing the indexing algorithm to HNSW. ```python vector_schema = { "algorithm": "HNSW" } rds, keys = Redis.from_texts_return_keys( texts, embeddings, metadatas=metadata, redis_url="redis://localhost:6379", index_name="users3", vector_schema=vector_schema ) ``` A more complex example may look like ```python vector_schema = { "algorithm": "HNSW", "ef_construction": 200, "ef_runtime": 20 } rds, keys = Redis.from_texts_return_keys( texts, embeddings, metadatas=metadata, redis_url="redis://localhost:6379", index_name="users3", vector_schema=vector_schema ) ``` All names correspond to the arguments you would set if using Redis-py or RedisVL. (put in doc link later) ### Better Querying Both vector queries and Range (limit) queries are now available and metadata is returned by default. The outputs are shown. ```python >>> query = "foo" >>> results = rds.similarity_search(query, k=1) >>> print(results) [Document(page_content='foo', metadata={'user': 'derrick', 'job': 'doctor', 'credit_score': 'low', 'age': '14', 'id': 'doc:users:657a47d7db8b447e88598b83da879b9d', 'score': '7.15255737305e-07'})] >>> results = rds.similarity_search_with_score(query, k=1, return_metadata=False) >>> print(results) # no metadata, but with scores [(Document(page_content='foo', metadata={}), 7.15255737305e-07)] >>> results = rds.similarity_search_limit_score(query, k=6, score_threshold=0.0001) >>> print(len(results)) # range query (only above threshold even if k is higher) 4 ``` ### Custom metadata filtering A big advantage of Redis in this space is being able to do filtering on data stored alongside the vector itself. With the example above, the following is now possible in langchain. The equivalence operators are overridden to describe a new expression language that mimic that of [redisvl](redisvl.com). This allows for arbitrarily long sequences of filters that resemble SQL commands that can be used directly with vector queries and range queries. There are two interfaces by which to do so and both are shown. ```python >>> from langchain.vectorstores.redis import RedisFilter, RedisNum, RedisText >>> age_filter = RedisFilter.num("age") > 18 >>> age_filter = RedisNum("age") > 18 # equivalent >>> results = rds.similarity_search(query, filter=age_filter) >>> print(len(results)) 3 >>> job_filter = RedisFilter.text("job") == "engineer" >>> job_filter = RedisText("job") == "engineer" # equivalent >>> results = rds.similarity_search(query, filter=job_filter) >>> print(len(results)) 2 # fuzzy match text search >>> job_filter = RedisFilter.text("job") % "eng*" >>> results = rds.similarity_search(query, filter=job_filter) >>> print(len(results)) 2 # combined filters (AND) >>> combined = age_filter & job_filter >>> results = rds.similarity_search(query, filter=combined) >>> print(len(results)) 1 # combined filters (OR) >>> combined = age_filter | job_filter >>> results = rds.similarity_search(query, filter=combined) >>> print(len(results)) 4 ``` All the above filter results can be checked against the data above. ### Other - Issue: #3967 - Dependencies: No added dependencies - Tag maintainer: @hwchase17 @baskaryan @rlancemartin - Twitter handle: @sampartee --------- Co-authored-by: Naresh Rangan <naresh.rangan0@walmart.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
fa0b8f3368
commit
a28eea5767
File diff suppressed because it is too large
Load Diff
@ -33,6 +33,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -302,6 +303,14 @@ class RedisSemanticCache(BaseCache):
|
|||||||
|
|
||||||
# TODO - implement a TTL policy in Redis
|
# TODO - implement a TTL policy in Redis
|
||||||
|
|
||||||
|
DEFAULT_SCHEMA = {
|
||||||
|
"content_key": "prompt",
|
||||||
|
"text": [
|
||||||
|
{"name": "prompt"},
|
||||||
|
],
|
||||||
|
"extra": [{"name": "return_val"}, {"name": "llm_string"}],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||||
):
|
):
|
||||||
@ -349,12 +358,14 @@ class RedisSemanticCache(BaseCache):
|
|||||||
embedding=self.embedding,
|
embedding=self.embedding,
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
redis_url=self.redis_url,
|
redis_url=self.redis_url,
|
||||||
|
schema=cast(Dict, self.DEFAULT_SCHEMA),
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
redis = RedisVectorstore(
|
redis = RedisVectorstore(
|
||||||
embedding_function=self.embedding.embed_query,
|
embedding=self.embedding,
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
redis_url=self.redis_url,
|
redis_url=self.redis_url,
|
||||||
|
index_schema=cast(Dict, self.DEFAULT_SCHEMA),
|
||||||
)
|
)
|
||||||
_embedding = self.embedding.embed_query(text="test")
|
_embedding = self.embedding.embed_query(text="test")
|
||||||
redis._create_index(dim=len(_embedding))
|
redis._create_index(dim=len(_embedding))
|
||||||
@ -374,17 +385,18 @@ class RedisSemanticCache(BaseCache):
|
|||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
llm_cache = self._get_llm_cache(llm_string)
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
generations = []
|
generations: List = []
|
||||||
# Read from a Hash
|
# Read from a Hash
|
||||||
results = llm_cache.similarity_search_limit_score(
|
results = llm_cache.similarity_search(
|
||||||
query=prompt,
|
query=prompt,
|
||||||
k=1,
|
k=1,
|
||||||
score_threshold=self.score_threshold,
|
distance_threshold=self.score_threshold,
|
||||||
)
|
)
|
||||||
if results:
|
if results:
|
||||||
for document in results:
|
for document in results:
|
||||||
for text in document.metadata["return_val"]:
|
generations.extend(
|
||||||
generations.append(Generation(text=text))
|
_load_generations_from_json(document.metadata["return_val"])
|
||||||
|
)
|
||||||
return generations if generations else None
|
return generations if generations else None
|
||||||
|
|
||||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||||
@ -402,11 +414,11 @@ class RedisSemanticCache(BaseCache):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
llm_cache = self._get_llm_cache(llm_string)
|
llm_cache = self._get_llm_cache(llm_string)
|
||||||
# Write to vectorstore
|
_dump_generations_to_json([g for g in return_val])
|
||||||
metadata = {
|
metadata = {
|
||||||
"llm_string": llm_string,
|
"llm_string": llm_string,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"return_val": [generation.text for generation in return_val],
|
"return_val": _dump_generations_to_json([g for g in return_val]),
|
||||||
}
|
}
|
||||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||||
|
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.schema.messages import get_buffer_string # noqa: 401
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -1,16 +1,64 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
import re
|
||||||
TYPE_CHECKING,
|
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
|
||||||
Any,
|
|
||||||
)
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from redis.client import Redis as RedisType
|
from redis.client import Redis as RedisType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
|
||||||
|
return np.array(array).astype(dtype).tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
class TokenEscaper:
|
||||||
|
"""
|
||||||
|
Escape punctuation within an input string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Characters that RediSearch requires us to escape during queries.
|
||||||
|
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||||
|
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
|
||||||
|
|
||||||
|
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||||
|
if escape_chars_re:
|
||||||
|
self.escaped_chars_re = escape_chars_re
|
||||||
|
else:
|
||||||
|
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
||||||
|
|
||||||
|
def escape(self, value: str) -> str:
|
||||||
|
def escape_symbol(match: re.Match) -> str:
|
||||||
|
value = match.group(0)
|
||||||
|
return f"\\{value}"
|
||||||
|
|
||||||
|
return self.escaped_chars_re.sub(escape_symbol, value)
|
||||||
|
|
||||||
|
|
||||||
|
def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||||
|
"""Check if the correct Redis modules are installed."""
|
||||||
|
installed_modules = client.module_list()
|
||||||
|
installed_modules = {
|
||||||
|
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||||
|
}
|
||||||
|
for module in required_modules:
|
||||||
|
if module["name"] in installed_modules and int(
|
||||||
|
installed_modules[module["name"]][b"ver"]
|
||||||
|
) >= int(module["ver"]):
|
||||||
|
return
|
||||||
|
# otherwise raise error
|
||||||
|
error_message = (
|
||||||
|
"Redis cannot be used as a vector database without RediSearch >=2.4"
|
||||||
|
"Please head to https://redis.io/docs/stack/search/quick_start/"
|
||||||
|
"to know more about installing the RediSearch module within Redis Stack."
|
||||||
|
)
|
||||||
|
logger.error(error_message)
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
||||||
def get_client(redis_url: str, **kwargs: Any) -> RedisType:
|
def get_client(redis_url: str, **kwargs: Any) -> RedisType:
|
||||||
|
@ -1,664 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
)
|
|
||||||
from langchain.docstore.document import Document
|
|
||||||
from langchain.embeddings.base import Embeddings
|
|
||||||
from langchain.pydantic_v1 import root_validator
|
|
||||||
from langchain.utilities.redis import get_client
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
|
||||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from redis.client import Redis as RedisType
|
|
||||||
from redis.commands.search.query import Query
|
|
||||||
|
|
||||||
|
|
||||||
# required modules
|
|
||||||
REDIS_REQUIRED_MODULES = [
|
|
||||||
{"name": "search", "ver": 20400},
|
|
||||||
{"name": "searchlight", "ver": 20400},
|
|
||||||
]
|
|
||||||
|
|
||||||
# distance mmetrics
|
|
||||||
REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"]
|
|
||||||
|
|
||||||
|
|
||||||
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
|
||||||
"""Check if the correct Redis modules are installed."""
|
|
||||||
installed_modules = client.module_list()
|
|
||||||
installed_modules = {
|
|
||||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
|
||||||
}
|
|
||||||
for module in required_modules:
|
|
||||||
if module["name"] in installed_modules and int(
|
|
||||||
installed_modules[module["name"]][b"ver"]
|
|
||||||
) >= int(module["ver"]):
|
|
||||||
return
|
|
||||||
# otherwise raise error
|
|
||||||
error_message = (
|
|
||||||
"Redis cannot be used as a vector database without RediSearch >=2.4"
|
|
||||||
"Please head to https://redis.io/docs/stack/search/quick_start/"
|
|
||||||
"to know more about installing the RediSearch module within Redis Stack."
|
|
||||||
)
|
|
||||||
logger.error(error_message)
|
|
||||||
raise ValueError(error_message)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_index_exists(client: RedisType, index_name: str) -> bool:
|
|
||||||
"""Check if Redis index exists."""
|
|
||||||
try:
|
|
||||||
client.ft(index_name).info()
|
|
||||||
except: # noqa: E722
|
|
||||||
logger.info("Index does not exist")
|
|
||||||
return False
|
|
||||||
logger.info("Index already exists")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _redis_key(prefix: str) -> str:
|
|
||||||
"""Redis key schema for a given prefix."""
|
|
||||||
return f"{prefix}:{uuid.uuid4().hex}"
|
|
||||||
|
|
||||||
|
|
||||||
def _redis_prefix(index_name: str) -> str:
|
|
||||||
"""Redis key prefix for a given index."""
|
|
||||||
return f"doc:{index_name}"
|
|
||||||
|
|
||||||
|
|
||||||
def _default_relevance_score(val: float) -> float:
|
|
||||||
return 1 - val
|
|
||||||
|
|
||||||
|
|
||||||
class Redis(VectorStore):
|
|
||||||
"""`Redis` vector store.
|
|
||||||
|
|
||||||
To use, you should have the ``redis`` python package installed.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.vectorstores import Redis
|
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
|
||||||
|
|
||||||
embeddings = OpenAIEmbeddings()
|
|
||||||
vectorstore = Redis(
|
|
||||||
redis_url="redis://username:password@localhost:6379"
|
|
||||||
index_name="my-index",
|
|
||||||
embedding_function=embeddings.embed_query,
|
|
||||||
)
|
|
||||||
|
|
||||||
To use a redis replication setup with multiple redis server and redis sentinels
|
|
||||||
set "redis_url" to "redis+sentinel://" scheme. With this url format a path is
|
|
||||||
needed holding the name of the redis service within the sentinels to get the
|
|
||||||
correct redis server connection. The default service name is "mymaster".
|
|
||||||
|
|
||||||
An optional username or password is used for booth connections to the rediserver
|
|
||||||
and the sentinel, different passwords for server and sentinel are not supported.
|
|
||||||
And as another constraint only one sentinel instance can be given:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
vectorstore = Redis(
|
|
||||||
redis_url="redis+sentinel://username:password@sentinelhost:26379/mymaster/0"
|
|
||||||
index_name="my-index",
|
|
||||||
embedding_function=embeddings.embed_query,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
redis_url: str,
|
|
||||||
index_name: str,
|
|
||||||
embedding_function: Callable,
|
|
||||||
content_key: str = "content",
|
|
||||||
metadata_key: str = "metadata",
|
|
||||||
vector_key: str = "content_vector",
|
|
||||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
|
||||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
|
||||||
**kwargs: Any,
|
|
||||||
):
|
|
||||||
"""Initialize with necessary components."""
|
|
||||||
self.embedding_function = embedding_function
|
|
||||||
self.index_name = index_name
|
|
||||||
try:
|
|
||||||
redis_client = get_client(redis_url=redis_url, **kwargs)
|
|
||||||
# check if redis has redisearch module installed
|
|
||||||
_check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Redis failed to connect: {e}")
|
|
||||||
|
|
||||||
self.client = redis_client
|
|
||||||
self.content_key = content_key
|
|
||||||
self.metadata_key = metadata_key
|
|
||||||
self.vector_key = vector_key
|
|
||||||
self.distance_metric = distance_metric
|
|
||||||
self.relevance_score_fn = relevance_score_fn
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embeddings(self) -> Optional[Embeddings]:
|
|
||||||
# TODO: Accept embedding object directly
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
|
||||||
if self.relevance_score_fn:
|
|
||||||
return self.relevance_score_fn
|
|
||||||
|
|
||||||
if self.distance_metric == "COSINE":
|
|
||||||
return self._cosine_relevance_score_fn
|
|
||||||
elif self.distance_metric == "IP":
|
|
||||||
return self._max_inner_product_relevance_score_fn
|
|
||||||
elif self.distance_metric == "L2":
|
|
||||||
return self._euclidean_relevance_score_fn
|
|
||||||
else:
|
|
||||||
return _default_relevance_score
|
|
||||||
|
|
||||||
def _create_index(self, dim: int = 1536) -> None:
|
|
||||||
try:
|
|
||||||
from redis.commands.search.field import TextField, VectorField
|
|
||||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if index exists
|
|
||||||
if not _check_index_exists(self.client, self.index_name):
|
|
||||||
# Define schema
|
|
||||||
schema = (
|
|
||||||
TextField(name=self.content_key),
|
|
||||||
TextField(name=self.metadata_key),
|
|
||||||
VectorField(
|
|
||||||
self.vector_key,
|
|
||||||
"FLAT",
|
|
||||||
{
|
|
||||||
"TYPE": "FLOAT32",
|
|
||||||
"DIM": dim,
|
|
||||||
"DISTANCE_METRIC": self.distance_metric,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
prefix = _redis_prefix(self.index_name)
|
|
||||||
|
|
||||||
# Create Redis Index
|
|
||||||
self.client.ft(self.index_name).create_index(
|
|
||||||
fields=schema,
|
|
||||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_texts(
|
|
||||||
self,
|
|
||||||
texts: Iterable[str],
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
embeddings: Optional[List[List[float]]] = None,
|
|
||||||
batch_size: int = 1000,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Add more texts to the vectorstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
|
|
||||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
|
||||||
Defaults to None.
|
|
||||||
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
|
|
||||||
embeddings. Defaults to None.
|
|
||||||
keys (List[str]) or ids (List[str]): Identifiers of entries.
|
|
||||||
Defaults to None.
|
|
||||||
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: List of ids added to the vectorstore
|
|
||||||
"""
|
|
||||||
ids = []
|
|
||||||
prefix = _redis_prefix(self.index_name)
|
|
||||||
|
|
||||||
# Get keys or ids from kwargs
|
|
||||||
# Other vectorstores use ids
|
|
||||||
keys_or_ids = kwargs.get("keys", kwargs.get("ids"))
|
|
||||||
|
|
||||||
# Write data to redis
|
|
||||||
pipeline = self.client.pipeline(transaction=False)
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
# Use provided values by default or fallback
|
|
||||||
key = keys_or_ids[i] if keys_or_ids else _redis_key(prefix)
|
|
||||||
metadata = metadatas[i] if metadatas else {}
|
|
||||||
embedding = embeddings[i] if embeddings else self.embedding_function(text)
|
|
||||||
pipeline.hset(
|
|
||||||
key,
|
|
||||||
mapping={
|
|
||||||
self.content_key: text,
|
|
||||||
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
|
|
||||||
self.metadata_key: json.dumps(metadata),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ids.append(key)
|
|
||||||
|
|
||||||
# Write batch
|
|
||||||
if i % batch_size == 0:
|
|
||||||
pipeline.execute()
|
|
||||||
|
|
||||||
# Cleanup final batch
|
|
||||||
pipeline.execute()
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def similarity_search(
|
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Returns the most similar indexed documents to the query text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The query text for which to find similar documents.
|
|
||||||
k (int): The number of documents to return. Default is 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Document]: A list of documents that are most similar to the query text.
|
|
||||||
"""
|
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
|
||||||
return [doc for doc, _ in docs_and_scores]
|
|
||||||
|
|
||||||
def similarity_search_limit_score(
|
|
||||||
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Returns the most similar indexed documents to the query text within the
|
|
||||||
score_threshold range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The query text for which to find similar documents.
|
|
||||||
k (int): The number of documents to return. Default is 4.
|
|
||||||
score_threshold (float): The minimum matching score required for a document
|
|
||||||
to be considered a match. Defaults to 0.2.
|
|
||||||
Because the similarity calculation algorithm is based on cosine
|
|
||||||
similarity, the smaller the angle, the higher the similarity.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Document]: A list of documents that are most similar to the query text,
|
|
||||||
including the match score for each document.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
If there are no documents that satisfy the score_threshold value,
|
|
||||||
an empty list is returned.
|
|
||||||
|
|
||||||
"""
|
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
|
||||||
return [doc for doc, score in docs_and_scores if score < score_threshold]
|
|
||||||
|
|
||||||
def _prepare_query(self, k: int) -> Query:
|
|
||||||
try:
|
|
||||||
from redis.commands.search.query import Query
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
# Prepare the Query
|
|
||||||
hybrid_fields = "*"
|
|
||||||
base_query = (
|
|
||||||
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
|
|
||||||
)
|
|
||||||
return_fields = [self.metadata_key, self.content_key, "vector_score", "id"]
|
|
||||||
return (
|
|
||||||
Query(base_query)
|
|
||||||
.return_fields(*return_fields)
|
|
||||||
.sort_by("vector_score")
|
|
||||||
.paging(0, k)
|
|
||||||
.dialect(2)
|
|
||||||
)
|
|
||||||
|
|
||||||
def similarity_search_with_score(
|
|
||||||
self, query: str, k: int = 4
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs most similar to query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents most similar to the query and score for each
|
|
||||||
"""
|
|
||||||
# Creates embedding vector from user query
|
|
||||||
embedding = self.embedding_function(query)
|
|
||||||
|
|
||||||
# Creates Redis query
|
|
||||||
redis_query = self._prepare_query(k)
|
|
||||||
|
|
||||||
params_dict: Mapping[str, str] = {
|
|
||||||
"vector": np.array(embedding) # type: ignore
|
|
||||||
.astype(dtype=np.float32)
|
|
||||||
.tobytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Perform vector search
|
|
||||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
|
||||||
|
|
||||||
# Prepare document results
|
|
||||||
docs_and_scores: List[Tuple[Document, float]] = []
|
|
||||||
for result in results.docs:
|
|
||||||
metadata = {**json.loads(result.metadata), "id": result.id}
|
|
||||||
doc = Document(page_content=result.content, metadata=metadata)
|
|
||||||
docs_and_scores.append((doc, float(result.vector_score)))
|
|
||||||
return docs_and_scores
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_texts_return_keys(
|
|
||||||
cls,
|
|
||||||
texts: List[str],
|
|
||||||
embedding: Embeddings,
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
index_name: Optional[str] = None,
|
|
||||||
content_key: str = "content",
|
|
||||||
metadata_key: str = "metadata",
|
|
||||||
vector_key: str = "content_vector",
|
|
||||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Tuple[Redis, List[str]]:
|
|
||||||
"""Create a Redis vectorstore from raw documents.
|
|
||||||
This is a user-friendly interface that:
|
|
||||||
1. Embeds documents.
|
|
||||||
2. Creates a new index for the embeddings in Redis.
|
|
||||||
3. Adds the documents to the newly created Redis index.
|
|
||||||
4. Returns the keys of the newly created documents.
|
|
||||||
|
|
||||||
This is intended to be a quick way to get started.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.vectorstores import Redis
|
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
|
||||||
embeddings = OpenAIEmbeddings()
|
|
||||||
redisearch, keys = RediSearch.from_texts_return_keys(
|
|
||||||
texts,
|
|
||||||
embeddings,
|
|
||||||
redis_url="redis://username:password@localhost:6379"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
|
||||||
|
|
||||||
if "redis_url" in kwargs:
|
|
||||||
kwargs.pop("redis_url")
|
|
||||||
|
|
||||||
# Name of the search index if not given
|
|
||||||
if not index_name:
|
|
||||||
index_name = uuid.uuid4().hex
|
|
||||||
|
|
||||||
# Create instance
|
|
||||||
instance = cls(
|
|
||||||
redis_url,
|
|
||||||
index_name,
|
|
||||||
embedding.embed_query,
|
|
||||||
content_key=content_key,
|
|
||||||
metadata_key=metadata_key,
|
|
||||||
vector_key=vector_key,
|
|
||||||
distance_metric=distance_metric,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create embeddings over documents
|
|
||||||
embeddings = embedding.embed_documents(texts)
|
|
||||||
|
|
||||||
# Create the search index
|
|
||||||
instance._create_index(dim=len(embeddings[0]))
|
|
||||||
|
|
||||||
# Add data to Redis
|
|
||||||
keys = instance.add_texts(texts, metadatas, embeddings)
|
|
||||||
return instance, keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_texts(
|
|
||||||
cls: Type[Redis],
|
|
||||||
texts: List[str],
|
|
||||||
embedding: Embeddings,
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
index_name: Optional[str] = None,
|
|
||||||
content_key: str = "content",
|
|
||||||
metadata_key: str = "metadata",
|
|
||||||
vector_key: str = "content_vector",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Redis:
|
|
||||||
"""Create a Redis vectorstore from raw documents.
|
|
||||||
This is a user-friendly interface that:
|
|
||||||
1. Embeds documents.
|
|
||||||
2. Creates a new index for the embeddings in Redis.
|
|
||||||
3. Adds the documents to the newly created Redis index.
|
|
||||||
|
|
||||||
This is intended to be a quick way to get started.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.vectorstores import Redis
|
|
||||||
from langchain.embeddings import OpenAIEmbeddings
|
|
||||||
embeddings = OpenAIEmbeddings()
|
|
||||||
redisearch = RediSearch.from_texts(
|
|
||||||
texts,
|
|
||||||
embeddings,
|
|
||||||
redis_url="redis://username:password@localhost:6379"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
instance, _ = cls.from_texts_return_keys(
|
|
||||||
texts,
|
|
||||||
embedding,
|
|
||||||
metadatas=metadatas,
|
|
||||||
index_name=index_name,
|
|
||||||
content_key=content_key,
|
|
||||||
metadata_key=metadata_key,
|
|
||||||
vector_key=vector_key,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return instance
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def delete(
|
|
||||||
ids: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a Redis entry.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids: List of ids (keys) to delete.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether or not the deletions were successful.
|
|
||||||
"""
|
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
|
||||||
|
|
||||||
if ids is None:
|
|
||||||
raise ValueError("'ids' (keys)() were not provided.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import redis # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# We need to first remove redis_url from kwargs,
|
|
||||||
# otherwise passing it to Redis will result in an error.
|
|
||||||
if "redis_url" in kwargs:
|
|
||||||
kwargs.pop("redis_url")
|
|
||||||
client = get_client(redis_url=redis_url, **kwargs)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Your redis connected error: {e}")
|
|
||||||
# Check if index exists
|
|
||||||
try:
|
|
||||||
client.delete(*ids)
|
|
||||||
logger.info("Entries deleted")
|
|
||||||
return True
|
|
||||||
except: # noqa: E722
|
|
||||||
# ids does not exist
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def drop_index(
|
|
||||||
index_name: str,
|
|
||||||
delete_documents: bool,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Drop a Redis search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name (str): Name of the index to drop.
|
|
||||||
delete_documents (bool): Whether to drop the associated documents.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether or not the drop was successful.
|
|
||||||
"""
|
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
|
||||||
try:
|
|
||||||
import redis # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# We need to first remove redis_url from kwargs,
|
|
||||||
# otherwise passing it to Redis will result in an error.
|
|
||||||
if "redis_url" in kwargs:
|
|
||||||
kwargs.pop("redis_url")
|
|
||||||
client = get_client(redis_url=redis_url, **kwargs)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Your redis connected error: {e}")
|
|
||||||
# Check if index exists
|
|
||||||
try:
|
|
||||||
client.ft(index_name).dropindex(delete_documents)
|
|
||||||
logger.info("Drop index")
|
|
||||||
return True
|
|
||||||
except: # noqa: E722
|
|
||||||
# Index not exist
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_existing_index(
|
|
||||||
cls,
|
|
||||||
embedding: Embeddings,
|
|
||||||
index_name: str,
|
|
||||||
content_key: str = "content",
|
|
||||||
metadata_key: str = "metadata",
|
|
||||||
vector_key: str = "content_vector",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Redis:
|
|
||||||
"""Connect to an existing Redis index."""
|
|
||||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
|
||||||
try:
|
|
||||||
import redis # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import redis python package. "
|
|
||||||
"Please install it with `pip install redis`."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# We need to first remove redis_url from kwargs,
|
|
||||||
# otherwise passing it to Redis will result in an error.
|
|
||||||
if "redis_url" in kwargs:
|
|
||||||
kwargs.pop("redis_url")
|
|
||||||
client = get_client(redis_url=redis_url, **kwargs)
|
|
||||||
# check if redis has redisearch module installed
|
|
||||||
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
|
||||||
# ensure that the index already exists
|
|
||||||
assert _check_index_exists(
|
|
||||||
client, index_name
|
|
||||||
), f"Index {index_name} does not exist"
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Redis failed to connect: {e}")
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
redis_url,
|
|
||||||
index_name,
|
|
||||||
embedding.embed_query,
|
|
||||||
content_key=content_key,
|
|
||||||
metadata_key=metadata_key,
|
|
||||||
vector_key=vector_key,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
|
|
||||||
tags = kwargs.pop("tags", None) or []
|
|
||||||
tags.extend(self._get_retriever_tags())
|
|
||||||
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisVectorStoreRetriever(VectorStoreRetriever):
|
|
||||||
"""Retriever for `Redis` vector store."""
|
|
||||||
|
|
||||||
vectorstore: Redis
|
|
||||||
"""Redis VectorStore."""
|
|
||||||
search_type: str = "similarity"
|
|
||||||
"""Type of search to perform. Can be either 'similarity' or 'similarity_limit'."""
|
|
||||||
k: int = 4
|
|
||||||
"""Number of documents to return."""
|
|
||||||
score_threshold: float = 0.4
|
|
||||||
"""Score threshold for similarity_limit search."""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_search_type(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate search type."""
|
|
||||||
if "search_type" in values:
|
|
||||||
search_type = values["search_type"]
|
|
||||||
if search_type not in ("similarity", "similarity_limit"):
|
|
||||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
|
||||||
return values
|
|
||||||
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
if self.search_type == "similarity":
|
|
||||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
|
||||||
elif self.search_type == "similarity_limit":
|
|
||||||
docs = self.vectorstore.similarity_search_limit_score(
|
|
||||||
query, k=self.k, score_threshold=self.score_threshold
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
|
||||||
|
|
||||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
|
||||||
"""Add documents to vectorstore."""
|
|
||||||
return self.vectorstore.add_documents(documents, **kwargs)
|
|
||||||
|
|
||||||
async def aadd_documents(
|
|
||||||
self, documents: List[Document], **kwargs: Any
|
|
||||||
) -> List[str]:
|
|
||||||
"""Add documents to vectorstore."""
|
|
||||||
return await self.vectorstore.aadd_documents(documents, **kwargs)
|
|
9
libs/langchain/langchain/vectorstores/redis/__init__.py
Normal file
9
libs/langchain/langchain/vectorstores/redis/__init__.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from .base import Redis
|
||||||
|
from .filters import (
|
||||||
|
RedisFilter,
|
||||||
|
RedisNum,
|
||||||
|
RedisTag,
|
||||||
|
RedisText,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["Redis", "RedisFilter", "RedisTag", "RedisText", "RedisNum"]
|
1361
libs/langchain/langchain/vectorstores/redis/base.py
Normal file
1361
libs/langchain/langchain/vectorstores/redis/base.py
Normal file
File diff suppressed because it is too large
Load Diff
20
libs/langchain/langchain/vectorstores/redis/constants.py
Normal file
20
libs/langchain/langchain/vectorstores/redis/constants.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# required modules
|
||||||
|
REDIS_REQUIRED_MODULES = [
|
||||||
|
{"name": "search", "ver": 20600},
|
||||||
|
{"name": "searchlight", "ver": 20600},
|
||||||
|
]
|
||||||
|
|
||||||
|
# distance metrics
|
||||||
|
REDIS_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"]
|
||||||
|
|
||||||
|
# supported vector datatypes
|
||||||
|
REDIS_VECTOR_DTYPE_MAP: Dict[str, Any] = {
|
||||||
|
"FLOAT32": np.float32,
|
||||||
|
"FLOAT64": np.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
REDIS_TAG_SEPARATOR = ","
|
420
libs/langchain/langchain/vectorstores/redis/filters.py
Normal file
420
libs/langchain/langchain/vectorstores/redis/filters.py
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.utilities.redis import TokenEscaper
|
||||||
|
|
||||||
|
# disable mypy error for dunder method overrides
|
||||||
|
# mypy: disable-error-code="override"
|
||||||
|
|
||||||
|
|
||||||
|
class RedisFilterOperator(Enum):
|
||||||
|
EQ = 1
|
||||||
|
NE = 2
|
||||||
|
LT = 3
|
||||||
|
GT = 4
|
||||||
|
LE = 5
|
||||||
|
GE = 6
|
||||||
|
OR = 7
|
||||||
|
AND = 8
|
||||||
|
LIKE = 9
|
||||||
|
IN = 10
|
||||||
|
|
||||||
|
|
||||||
|
class RedisFilter:
|
||||||
|
@staticmethod
|
||||||
|
def text(field: str) -> "RedisText":
|
||||||
|
return RedisText(field)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def num(field: str) -> "RedisNum":
|
||||||
|
return RedisNum(field)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tag(field: str) -> "RedisTag":
|
||||||
|
return RedisTag(field)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisFilterField:
|
||||||
|
escaper: "TokenEscaper" = TokenEscaper()
|
||||||
|
OPERATORS: Dict[RedisFilterOperator, str] = {}
|
||||||
|
|
||||||
|
def __init__(self, field: str):
|
||||||
|
self._field = field
|
||||||
|
self._value: Any = None
|
||||||
|
self._operator: RedisFilterOperator = RedisFilterOperator.EQ
|
||||||
|
|
||||||
|
def equals(self, other: "RedisFilterField") -> bool:
|
||||||
|
if not isinstance(other, type(self)):
|
||||||
|
return False
|
||||||
|
return self._field == other._field and self._value == other._value
|
||||||
|
|
||||||
|
def _set_value(
|
||||||
|
self, val: Any, val_type: type, operator: RedisFilterOperator
|
||||||
|
) -> None:
|
||||||
|
# check that the operator is supported by this class
|
||||||
|
if operator not in self.OPERATORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Operator {operator} not supported by {self.__class__.__name__}. "
|
||||||
|
+ f"Supported operators are {self.OPERATORS.values()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(val, val_type):
|
||||||
|
raise TypeError(
|
||||||
|
f"Right side argument passed to operator {self.OPERATORS[operator]} "
|
||||||
|
f"with left side "
|
||||||
|
f"argument {self.__class__.__name__} must be of type {val_type}"
|
||||||
|
)
|
||||||
|
self._value = val
|
||||||
|
self._operator = operator
|
||||||
|
|
||||||
|
|
||||||
|
def check_operator_misuse(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
|
||||||
|
# Extracting 'other' from positional arguments or keyword arguments
|
||||||
|
other = kwargs.get("other") if "other" in kwargs else None
|
||||||
|
if not other:
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, type(instance)):
|
||||||
|
other = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance(other, type(instance)):
|
||||||
|
raise ValueError(
|
||||||
|
"Equality operators are overridden for FilterExpression creation. Use "
|
||||||
|
".equals() for equality checks"
|
||||||
|
)
|
||||||
|
return func(instance, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class RedisTag(RedisFilterField):
|
||||||
|
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
|
||||||
|
|
||||||
|
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||||
|
RedisFilterOperator.EQ: "==",
|
||||||
|
RedisFilterOperator.NE: "!=",
|
||||||
|
RedisFilterOperator.IN: "==",
|
||||||
|
}
|
||||||
|
|
||||||
|
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||||
|
RedisFilterOperator.EQ: "@%s:{%s}",
|
||||||
|
RedisFilterOperator.NE: "(-@%s:{%s})",
|
||||||
|
RedisFilterOperator.IN: "@%s:{%s}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, field: str):
|
||||||
|
"""Create a RedisTag FilterField
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field (str): The name of the RedisTag field in the index to be queried
|
||||||
|
against.
|
||||||
|
"""
|
||||||
|
super().__init__(field)
|
||||||
|
|
||||||
|
def _set_tag_value(
|
||||||
|
self, other: Union[List[str], str], operator: RedisFilterOperator
|
||||||
|
) -> None:
|
||||||
|
if isinstance(other, list):
|
||||||
|
if not all(isinstance(tag, str) for tag in other):
|
||||||
|
raise ValueError("All tags must be strings")
|
||||||
|
else:
|
||||||
|
other = [other]
|
||||||
|
self._set_value(other, list, operator)
|
||||||
|
|
||||||
|
@check_operator_misuse
|
||||||
|
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisTag equality filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (Union[List[str], str]): The tag(s) to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisTag
|
||||||
|
>>> filter = RedisTag("brand") == "nike"
|
||||||
|
"""
|
||||||
|
self._set_tag_value(other, RedisFilterOperator.EQ)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
@check_operator_misuse
|
||||||
|
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisTag inequality filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (Union[List[str], str]): The tag(s) to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisTag
|
||||||
|
>>> filter = RedisTag("brand") != "nike"
|
||||||
|
"""
|
||||||
|
self._set_tag_value(other, RedisFilterOperator.NE)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _formatted_tag_value(self) -> str:
|
||||||
|
return "|".join([self.escaper.escape(tag) for tag in self._value])
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if not self._value:
|
||||||
|
raise ValueError(
|
||||||
|
f"Operator must be used before calling __str__. Operators are "
|
||||||
|
f"{self.OPERATORS.values()}"
|
||||||
|
)
|
||||||
|
"""Return the Redis Query syntax for a RedisTag filter expression"""
|
||||||
|
return self.OPERATOR_MAP[self._operator] % (
|
||||||
|
self._field,
|
||||||
|
self._formatted_tag_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisNum(RedisFilterField):
|
||||||
|
"""A RedisFilterField representing a numeric field in a Redis index."""
|
||||||
|
|
||||||
|
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||||
|
RedisFilterOperator.EQ: "==",
|
||||||
|
RedisFilterOperator.NE: "!=",
|
||||||
|
RedisFilterOperator.LT: "<",
|
||||||
|
RedisFilterOperator.GT: ">",
|
||||||
|
RedisFilterOperator.LE: "<=",
|
||||||
|
RedisFilterOperator.GE: ">=",
|
||||||
|
}
|
||||||
|
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||||
|
RedisFilterOperator.EQ: "@%s:[%i %i]",
|
||||||
|
RedisFilterOperator.NE: "(-@%s:[%i %i])",
|
||||||
|
RedisFilterOperator.GT: "@%s:[(%i +inf]",
|
||||||
|
RedisFilterOperator.LT: "@%s:[-inf (%i]",
|
||||||
|
RedisFilterOperator.GE: "@%s:[%i +inf]",
|
||||||
|
RedisFilterOperator.LE: "@%s:[-inf %i]",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return the Redis Query syntax for a Numeric filter expression"""
|
||||||
|
if not self._value:
|
||||||
|
raise ValueError(
|
||||||
|
f"Operator must be used before calling __str__. Operators are "
|
||||||
|
f"{self.OPERATORS.values()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._operator == RedisFilterOperator.EQ
|
||||||
|
or self._operator == RedisFilterOperator.NE
|
||||||
|
):
|
||||||
|
return self.OPERATOR_MAP[self._operator] % (
|
||||||
|
self._field,
|
||||||
|
self._value,
|
||||||
|
self._value,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||||
|
|
||||||
|
@check_operator_misuse
|
||||||
|
def __eq__(self, other: int) -> "RedisFilterExpression":
|
||||||
|
"""Create a Numeric equality filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (int): The value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
|
>>> filter = RedisNum("zipcode") == 90210
|
||||||
|
"""
|
||||||
|
self._set_value(other, int, RedisFilterOperator.EQ)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
@check_operator_misuse
|
||||||
|
def __ne__(self, other: int) -> "RedisFilterExpression":
|
||||||
|
"""Create a Numeric inequality filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (int): The value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
|
>>> filter = RedisNum("zipcode") != 90210
|
||||||
|
"""
|
||||||
|
self._set_value(other, int, RedisFilterOperator.NE)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
def __gt__(self, other: int) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisNumeric greater than filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (int): The value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
|
>>> filter = RedisNum("age") > 18
|
||||||
|
"""
|
||||||
|
self._set_value(other, int, RedisFilterOperator.GT)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
def __lt__(self, other: int) -> "RedisFilterExpression":
|
||||||
|
"""Create a Numeric less than filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (int): The value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
|
>>> filter = RedisNum("age") < 18
|
||||||
|
"""
|
||||||
|
self._set_value(other, int, RedisFilterOperator.LT)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
def __ge__(self, other: int) -> "RedisFilterExpression":
|
||||||
|
"""Create a Numeric greater than or equal to filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (int): The value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
|
>>> filter = RedisNum("age") >= 18
|
||||||
|
"""
|
||||||
|
self._set_value(other, int, RedisFilterOperator.GE)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
def __le__(self, other: int) -> "RedisFilterExpression":
|
||||||
|
"""Create a Numeric less than or equal to filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (int): The value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisNum
|
||||||
|
>>> filter = RedisNum("age") <= 18
|
||||||
|
"""
|
||||||
|
self._set_value(other, int, RedisFilterOperator.LE)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
|
||||||
|
class RedisText(RedisFilterField):
|
||||||
|
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
|
||||||
|
|
||||||
|
OPERATORS = {
|
||||||
|
RedisFilterOperator.EQ: "==",
|
||||||
|
RedisFilterOperator.NE: "!=",
|
||||||
|
RedisFilterOperator.LIKE: "%",
|
||||||
|
}
|
||||||
|
OPERATOR_MAP = {
|
||||||
|
RedisFilterOperator.EQ: '@%s:"%s"',
|
||||||
|
RedisFilterOperator.NE: '(-@%s:"%s")',
|
||||||
|
RedisFilterOperator.LIKE: "@%s:%s",
|
||||||
|
}
|
||||||
|
|
||||||
|
@check_operator_misuse
|
||||||
|
def __eq__(self, other: str) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisText equality filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (str): The text value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisText
|
||||||
|
>>> filter = RedisText("job") == "engineer"
|
||||||
|
"""
|
||||||
|
self._set_value(other, str, RedisFilterOperator.EQ)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
@check_operator_misuse
|
||||||
|
def __ne__(self, other: str) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisText inequality filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (str): The text value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisText
|
||||||
|
>>> filter = RedisText("job") != "engineer"
|
||||||
|
"""
|
||||||
|
self._set_value(other, str, RedisFilterOperator.NE)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
def __mod__(self, other: str) -> "RedisFilterExpression":
|
||||||
|
"""Create a RedisText like filter expression
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other (str): The text value to filter on.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from langchain.vectorstores.redis import RedisText
|
||||||
|
>>> filter = RedisText("job") % "engineer"
|
||||||
|
"""
|
||||||
|
self._set_value(other, str, RedisFilterOperator.LIKE)
|
||||||
|
return RedisFilterExpression(str(self))
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if not self._value:
|
||||||
|
raise ValueError(
|
||||||
|
f"Operator must be used before calling __str__. Operators are "
|
||||||
|
f"{self.OPERATORS.values()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||||
|
except KeyError:
|
||||||
|
raise Exception("Invalid operator")
|
||||||
|
|
||||||
|
|
||||||
|
class RedisFilterExpression:
|
||||||
|
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
|
||||||
|
|
||||||
|
RedisFilterExpressions can be combined using the & and | operators to create
|
||||||
|
complex logical expressions that evaluate to the Redis Query language.
|
||||||
|
|
||||||
|
This presents an interface by which users can create complex queries
|
||||||
|
without having to know the Redis Query language.
|
||||||
|
|
||||||
|
Filter expressions are not initialized directly. Instead they are built
|
||||||
|
by combining RedisFilterFields using the & and | operators.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> from langchain.vectorstores.redis import RedisTag, RedisNum
|
||||||
|
>>> brand_is_nike = RedisTag("brand") == "nike"
|
||||||
|
>>> price_is_under_100 = RedisNum("price") < 100
|
||||||
|
>>> filter = brand_is_nike & price_is_under_100
|
||||||
|
>>> print(str(filter))
|
||||||
|
(@brand:{nike} @price:[-inf (100)])
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
_filter: Optional[str] = None,
|
||||||
|
operator: Optional[RedisFilterOperator] = None,
|
||||||
|
left: Optional["RedisFilterExpression"] = None,
|
||||||
|
right: Optional["RedisFilterExpression"] = None,
|
||||||
|
):
|
||||||
|
self._filter = _filter
|
||||||
|
self._operator = operator
|
||||||
|
self._left = left
|
||||||
|
self._right = right
|
||||||
|
|
||||||
|
def __and__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
|
||||||
|
return RedisFilterExpression(
|
||||||
|
operator=RedisFilterOperator.AND, left=self, right=other
|
||||||
|
)
|
||||||
|
|
||||||
|
def __or__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
|
||||||
|
return RedisFilterExpression(
|
||||||
|
operator=RedisFilterOperator.OR, left=self, right=other
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# top level check that allows recursive calls to __str__
|
||||||
|
if not self._filter and not self._operator:
|
||||||
|
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||||
|
|
||||||
|
# allow for single filter expression without operators as last
|
||||||
|
# expression in the chain might not have an operator
|
||||||
|
if self._operator:
|
||||||
|
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
|
||||||
|
return f"({str(self._left)}{operator_str}{str(self._right)})"
|
||||||
|
|
||||||
|
# check that base case, the filter is set
|
||||||
|
if not self._filter:
|
||||||
|
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||||
|
return self._filter
|
276
libs/langchain/langchain/vectorstores/redis/schema.py
Normal file
276
libs/langchain/langchain/vectorstores/redis/schema.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# ignore type error here as it's a redis-py type problem
|
||||||
|
from redis.commands.search.field import ( # type: ignore
|
||||||
|
NumericField,
|
||||||
|
TagField,
|
||||||
|
TextField,
|
||||||
|
VectorField,
|
||||||
|
)
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Field, validator
|
||||||
|
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDistanceMetric(str, Enum):
|
||||||
|
l2 = "L2"
|
||||||
|
cosine = "COSINE"
|
||||||
|
ip = "IP"
|
||||||
|
|
||||||
|
|
||||||
|
class RedisField(BaseModel):
|
||||||
|
name: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TextFieldSchema(RedisField):
|
||||||
|
weight: float = 1
|
||||||
|
no_stem: bool = False
|
||||||
|
phonetic_matcher: Optional[str] = None
|
||||||
|
withsuffixtrie: bool = False
|
||||||
|
no_index: bool = False
|
||||||
|
sortable: Optional[bool] = False
|
||||||
|
|
||||||
|
def as_field(self) -> TextField:
|
||||||
|
return TextField(
|
||||||
|
self.name,
|
||||||
|
weight=self.weight,
|
||||||
|
no_stem=self.no_stem,
|
||||||
|
phonetic_matcher=self.phonetic_matcher,
|
||||||
|
sortable=self.sortable,
|
||||||
|
no_index=self.no_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TagFieldSchema(RedisField):
|
||||||
|
separator: str = ","
|
||||||
|
case_sensitive: bool = False
|
||||||
|
no_index: bool = False
|
||||||
|
sortable: Optional[bool] = False
|
||||||
|
|
||||||
|
def as_field(self) -> TagField:
|
||||||
|
return TagField(
|
||||||
|
self.name,
|
||||||
|
separator=self.separator,
|
||||||
|
case_sensitive=self.case_sensitive,
|
||||||
|
sortable=self.sortable,
|
||||||
|
no_index=self.no_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NumericFieldSchema(RedisField):
|
||||||
|
no_index: bool = False
|
||||||
|
sortable: Optional[bool] = False
|
||||||
|
|
||||||
|
def as_field(self) -> NumericField:
|
||||||
|
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisVectorField(RedisField):
|
||||||
|
dims: int = Field(...)
|
||||||
|
algorithm: object = Field(...)
|
||||||
|
datatype: str = Field(default="FLOAT32")
|
||||||
|
distance_metric: RedisDistanceMetric = Field(default="COSINE")
|
||||||
|
initial_cap: int = Field(default=20000)
|
||||||
|
|
||||||
|
@validator("distance_metric", pre=True)
|
||||||
|
def uppercase_strings(cls, v: str) -> str:
|
||||||
|
return v.upper()
|
||||||
|
|
||||||
|
@validator("datatype", pre=True)
|
||||||
|
def uppercase_and_check_dtype(cls, v: str) -> str:
|
||||||
|
if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
|
||||||
|
)
|
||||||
|
return v.upper()
|
||||||
|
|
||||||
|
|
||||||
|
class FlatVectorField(RedisVectorField):
|
||||||
|
algorithm: Literal["FLAT"] = "FLAT"
|
||||||
|
block_size: int = Field(default=1000)
|
||||||
|
|
||||||
|
def as_field(self) -> VectorField:
|
||||||
|
return VectorField(
|
||||||
|
self.name,
|
||||||
|
self.algorithm,
|
||||||
|
{
|
||||||
|
"TYPE": self.datatype,
|
||||||
|
"DIM": self.dims,
|
||||||
|
"DISTANCE_METRIC": self.distance_metric,
|
||||||
|
"INITIAL_CAP": self.initial_cap,
|
||||||
|
"BLOCK_SIZE": self.block_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HNSWVectorField(RedisVectorField):
|
||||||
|
algorithm: Literal["HNSW"] = "HNSW"
|
||||||
|
m: int = Field(default=16)
|
||||||
|
ef_construction: int = Field(default=200)
|
||||||
|
ef_runtime: int = Field(default=10)
|
||||||
|
epsilon: float = Field(default=0.8)
|
||||||
|
|
||||||
|
def as_field(self) -> VectorField:
|
||||||
|
return VectorField(
|
||||||
|
self.name,
|
||||||
|
self.algorithm,
|
||||||
|
{
|
||||||
|
"TYPE": self.datatype,
|
||||||
|
"DIM": self.dims,
|
||||||
|
"DISTANCE_METRIC": self.distance_metric,
|
||||||
|
"INITIAL_CAP": self.initial_cap,
|
||||||
|
"M": self.m,
|
||||||
|
"EF_CONSTRUCTION": self.ef_construction,
|
||||||
|
"EF_RUNTIME": self.ef_runtime,
|
||||||
|
"EPSILON": self.epsilon,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisModel(BaseModel):
|
||||||
|
# always have a content field for text
|
||||||
|
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
|
||||||
|
tag: Optional[List[TagFieldSchema]] = None
|
||||||
|
numeric: Optional[List[NumericFieldSchema]] = None
|
||||||
|
extra: Optional[List[RedisField]] = None
|
||||||
|
|
||||||
|
# filled by default_vector_schema
|
||||||
|
vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
|
||||||
|
content_key: str = "content"
|
||||||
|
content_vector_key: str = "content_vector"
|
||||||
|
|
||||||
|
def add_content_field(self) -> None:
|
||||||
|
if self.text is None:
|
||||||
|
self.text = []
|
||||||
|
for field in self.text:
|
||||||
|
if field.name == self.content_key:
|
||||||
|
return
|
||||||
|
self.text.append(TextFieldSchema(name=self.content_key))
|
||||||
|
|
||||||
|
def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
|
||||||
|
# catch case where user inputted no vector field spec
|
||||||
|
# in the index schema
|
||||||
|
if self.vector is None:
|
||||||
|
self.vector = []
|
||||||
|
|
||||||
|
# ignore types as pydantic is handling type validation and conversion
|
||||||
|
if vector_field["algorithm"] == "FLAT":
|
||||||
|
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
|
||||||
|
elif vector_field["algorithm"] == "HNSW":
|
||||||
|
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"algorithm must be either FLAT or HNSW. Got "
|
||||||
|
f"{vector_field['algorithm']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_dict(self) -> Dict[str, List[Any]]:
|
||||||
|
schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
|
||||||
|
# iter over all class attributes
|
||||||
|
for attr, attr_value in self.__dict__.items():
|
||||||
|
# only non-empty lists
|
||||||
|
if isinstance(attr_value, list) and len(attr_value) > 0:
|
||||||
|
field_values: List[Dict[str, Any]] = []
|
||||||
|
# iterate over all fields in each category (tag, text, etc)
|
||||||
|
for val in attr_value:
|
||||||
|
value: Dict[str, Any] = {}
|
||||||
|
# iterate over values within each field to extract
|
||||||
|
# settings for that field (i.e. name, weight, etc)
|
||||||
|
for field, field_value in val.__dict__.items():
|
||||||
|
# make enums into strings
|
||||||
|
if isinstance(field_value, Enum):
|
||||||
|
value[field] = field_value.value
|
||||||
|
# don't write null values
|
||||||
|
elif field_value is not None:
|
||||||
|
value[field] = field_value
|
||||||
|
field_values.append(value)
|
||||||
|
|
||||||
|
schemas[attr] = field_values
|
||||||
|
|
||||||
|
schema: Dict[str, List[Any]] = {}
|
||||||
|
# only write non-empty lists from defaults
|
||||||
|
for k, v in schemas.items():
|
||||||
|
if len(v) > 0:
|
||||||
|
schema[k] = v
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
|
||||||
|
if not self.vector:
|
||||||
|
raise ValueError("No vector fields found")
|
||||||
|
for field in self.vector:
|
||||||
|
if field.name == self.content_vector_key:
|
||||||
|
return field
|
||||||
|
raise ValueError("No content_vector field found")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vector_dtype(self) -> np.dtype:
|
||||||
|
# should only ever be called after pydantic has validated the schema
|
||||||
|
return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return all(
|
||||||
|
field is None for field in [self.tag, self.text, self.numeric, self.vector]
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_fields(self) -> List["RedisField"]:
|
||||||
|
redis_fields: List["RedisField"] = []
|
||||||
|
if self.is_empty:
|
||||||
|
return redis_fields
|
||||||
|
|
||||||
|
for field_name in self.__fields__.keys():
|
||||||
|
if field_name not in ["content_key", "content_vector_key", "extra"]:
|
||||||
|
field_group = getattr(self, field_name)
|
||||||
|
if field_group is not None:
|
||||||
|
for field in field_group:
|
||||||
|
redis_fields.append(field.as_field())
|
||||||
|
return redis_fields
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata_keys(self) -> List[str]:
|
||||||
|
keys: List[str] = []
|
||||||
|
if self.is_empty:
|
||||||
|
return keys
|
||||||
|
|
||||||
|
for field_name in self.__fields__.keys():
|
||||||
|
field_group = getattr(self, field_name)
|
||||||
|
if field_group is not None:
|
||||||
|
for field in field_group:
|
||||||
|
# check if it's a metadata field. exclude vector and content key
|
||||||
|
if not isinstance(field, str) and field.name not in [
|
||||||
|
self.content_key,
|
||||||
|
self.content_vector_key,
|
||||||
|
]:
|
||||||
|
keys.append(field.name)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
def read_schema(
|
||||||
|
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# check if its a dict and return RedisModel otherwise, check if it's a path and
|
||||||
|
# read in the file assuming it's a yaml file and return a RedisModel
|
||||||
|
if isinstance(index_schema, dict):
|
||||||
|
return index_schema
|
||||||
|
elif isinstance(index_schema, Path):
|
||||||
|
with open(index_schema, "rb") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
elif isinstance(index_schema, str):
|
||||||
|
if Path(index_schema).resolve().is_file():
|
||||||
|
with open(index_schema, "rb") as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"index_schema must be a dict, or path to a yaml file "
|
||||||
|
f"Got {type(index_schema)}"
|
||||||
|
)
|
@ -1,16 +1,27 @@
|
|||||||
"""Test Redis cache functionality."""
|
"""Test Redis cache functionality."""
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.cache import RedisCache, RedisSemanticCache
|
from langchain.cache import RedisCache, RedisSemanticCache
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Generation, LLMResult
|
from langchain.schema import Generation, LLMResult
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
ConsistentFakeEmbeddings,
|
||||||
|
FakeEmbeddings,
|
||||||
|
)
|
||||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
REDIS_TEST_URL = "redis://localhost:6379"
|
REDIS_TEST_URL = "redis://localhost:6379"
|
||||||
|
|
||||||
|
|
||||||
|
def random_string() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_ttl() -> None:
|
def test_redis_cache_ttl() -> None:
|
||||||
import redis
|
import redis
|
||||||
|
|
||||||
@ -30,12 +41,10 @@ def test_redis_cache() -> None:
|
|||||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||||
output = llm.generate(["foo"])
|
output = llm.generate(["foo"])
|
||||||
print(output)
|
|
||||||
expected_output = LLMResult(
|
expected_output = LLMResult(
|
||||||
generations=[[Generation(text="fizz")]],
|
generations=[[Generation(text="fizz")]],
|
||||||
llm_output={},
|
llm_output={},
|
||||||
)
|
)
|
||||||
print(expected_output)
|
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
langchain.llm_cache.redis.flushall()
|
langchain.llm_cache.redis.flushall()
|
||||||
|
|
||||||
@ -80,14 +89,90 @@ def test_redis_semantic_cache() -> None:
|
|||||||
langchain.llm_cache.clear(llm_string=llm_string)
|
langchain.llm_cache.clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
def test_redis_semantic_cache_chat() -> None:
|
def test_redis_semantic_cache_multi() -> None:
|
||||||
import redis
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
|
)
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
langchain.llm_cache.update(
|
||||||
|
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||||
|
)
|
||||||
|
output = llm.generate(
|
||||||
|
["bar"]
|
||||||
|
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||||
|
expected_output = LLMResult(
|
||||||
|
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
|
||||||
|
llm_output={},
|
||||||
|
)
|
||||||
|
assert output == expected_output
|
||||||
|
# clear the cache
|
||||||
|
langchain.llm_cache.clear(llm_string=llm_string)
|
||||||
|
|
||||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
|
||||||
|
def test_redis_semantic_cache_chat() -> None:
|
||||||
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
|
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||||
|
)
|
||||||
llm = FakeChatModel()
|
llm = FakeChatModel()
|
||||||
params = llm.dict()
|
params = llm.dict()
|
||||||
params["stop"] = None
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
with pytest.warns():
|
with pytest.warns():
|
||||||
llm.predict("foo")
|
llm.predict("foo")
|
||||||
llm.predict("foo")
|
llm.predict("foo")
|
||||||
langchain.llm_cache.redis.flushall()
|
langchain.llm_cache.clear(llm_string=llm_string)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"prompts, generations",
|
||||||
|
[
|
||||||
|
# Single prompt, single generation
|
||||||
|
([random_string()], [[random_string()]]),
|
||||||
|
# Single prompt, multiple generations
|
||||||
|
([random_string()], [[random_string(), random_string()]]),
|
||||||
|
# Single prompt, multiple generations
|
||||||
|
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||||
|
# Multiple prompts, multiple generations
|
||||||
|
(
|
||||||
|
[random_string(), random_string()],
|
||||||
|
[[random_string()], [random_string(), random_string()]],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"single_prompt_single_generation",
|
||||||
|
"single_prompt_multiple_generations",
|
||||||
|
"single_prompt_multiple_generations",
|
||||||
|
"multiple_prompts_multiple_generations",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_redis_semantic_cache_hit(
|
||||||
|
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
||||||
|
) -> None:
|
||||||
|
langchain.llm_cache = RedisSemanticCache(
|
||||||
|
embedding=embedding, redis_url=REDIS_TEST_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = FakeLLM()
|
||||||
|
params = llm.dict()
|
||||||
|
params["stop"] = None
|
||||||
|
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||||
|
|
||||||
|
llm_generations = [
|
||||||
|
[
|
||||||
|
Generation(text=generation, generation_info=params)
|
||||||
|
for generation in prompt_i_generations
|
||||||
|
]
|
||||||
|
for prompt_i_generations in generations
|
||||||
|
]
|
||||||
|
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||||
|
print(prompt_i)
|
||||||
|
print(llm_generations_i)
|
||||||
|
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||||
|
llm.generate(prompts)
|
||||||
|
assert llm.generate(prompts) == LLMResult(
|
||||||
|
generations=llm_generations, llm_output={}
|
||||||
|
)
|
||||||
|
@ -52,6 +52,7 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
|||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||||
one if the text is unknown."""
|
one if the text is unknown."""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
if text not in self.known_texts:
|
if text not in self.known_texts:
|
||||||
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
|
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
|
||||||
return [float(1.0)] * (self.dimensionality - 1) + [
|
return [float(1.0)] * (self.dimensionality - 1) + [
|
||||||
|
@ -1,17 +1,28 @@
|
|||||||
"""Test Redis functionality."""
|
"""Test Redis functionality."""
|
||||||
from typing import List
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores.redis import Redis
|
from langchain.vectorstores.redis import (
|
||||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
Redis,
|
||||||
|
RedisFilter,
|
||||||
|
RedisNum,
|
||||||
|
RedisText,
|
||||||
|
)
|
||||||
|
from langchain.vectorstores.redis.filters import RedisFilterExpression
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
ConsistentFakeEmbeddings,
|
||||||
|
FakeEmbeddings,
|
||||||
|
)
|
||||||
|
|
||||||
TEST_INDEX_NAME = "test"
|
TEST_INDEX_NAME = "test"
|
||||||
TEST_REDIS_URL = "redis://localhost:6379"
|
TEST_REDIS_URL = "redis://localhost:6379"
|
||||||
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||||
TEST_SINGLE_WITH_METADATA_RESULT = [Document(page_content="foo", metadata={"a": "b"})]
|
TEST_SINGLE_WITH_METADATA = {"a": "b"}
|
||||||
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||||
|
RANGE_SCORE = pytest.approx(0.0513, abs=0.002)
|
||||||
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
|
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
|
||||||
IP_SCORE = -8.0
|
IP_SCORE = -8.0
|
||||||
EUCLIDEAN_SCORE = 1.0
|
EUCLIDEAN_SCORE = 1.0
|
||||||
@ -23,6 +34,27 @@ def drop(index_name: str) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bytes(data: Any) -> Any:
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data.decode("ascii")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return dict(map(convert_bytes, data.items()))
|
||||||
|
if isinstance(data, list):
|
||||||
|
return list(map(convert_bytes, data))
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
return map(convert_bytes, data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def make_dict(values: List[Any]) -> dict:
|
||||||
|
i = 0
|
||||||
|
di = {}
|
||||||
|
while i < len(values) - 1:
|
||||||
|
di[values[i]] = values[i + 1]
|
||||||
|
i += 2
|
||||||
|
return di
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def texts() -> List[str]:
|
def texts() -> List[str]:
|
||||||
return ["foo", "bar", "baz"]
|
return ["foo", "bar", "baz"]
|
||||||
@ -31,7 +63,7 @@ def texts() -> List[str]:
|
|||||||
def test_redis(texts: List[str]) -> None:
|
def test_redis(texts: List[str]) -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
|
||||||
assert output == TEST_SINGLE_RESULT
|
assert output == TEST_SINGLE_RESULT
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
@ -40,30 +72,55 @@ def test_redis_new_vector(texts: List[str]) -> None:
|
|||||||
"""Test adding a new document"""
|
"""Test adding a new document"""
|
||||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
docsearch.add_texts(["foo"])
|
docsearch.add_texts(["foo"])
|
||||||
output = docsearch.similarity_search("foo", k=2)
|
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
|
||||||
assert output == TEST_RESULT
|
assert output == TEST_RESULT
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
def test_redis_from_existing(texts: List[str]) -> None:
|
def test_redis_from_existing(texts: List[str]) -> None:
|
||||||
"""Test adding a new document"""
|
"""Test adding a new document"""
|
||||||
Redis.from_texts(
|
docsearch = Redis.from_texts(
|
||||||
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||||
)
|
)
|
||||||
|
schema: Dict = docsearch.schema
|
||||||
|
|
||||||
|
# write schema for the next test
|
||||||
|
docsearch.write_schema("test_schema.yml")
|
||||||
|
|
||||||
# Test creating from an existing
|
# Test creating from an existing
|
||||||
docsearch2 = Redis.from_existing_index(
|
docsearch2 = Redis.from_existing_index(
|
||||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
FakeEmbeddings(),
|
||||||
|
index_name=TEST_INDEX_NAME,
|
||||||
|
redis_url=TEST_REDIS_URL,
|
||||||
|
schema=schema,
|
||||||
)
|
)
|
||||||
output = docsearch2.similarity_search("foo", k=1)
|
output = docsearch2.similarity_search("foo", k=1, return_metadata=False)
|
||||||
assert output == TEST_SINGLE_RESULT
|
assert output == TEST_SINGLE_RESULT
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_add_texts_to_existing() -> None:
|
||||||
|
"""Test adding a new document"""
|
||||||
|
# Test creating from an existing with yaml from file
|
||||||
|
docsearch = Redis.from_existing_index(
|
||||||
|
FakeEmbeddings(),
|
||||||
|
index_name=TEST_INDEX_NAME,
|
||||||
|
redis_url=TEST_REDIS_URL,
|
||||||
|
schema="test_schema.yml",
|
||||||
|
)
|
||||||
|
docsearch.add_texts(["foo"])
|
||||||
|
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
|
||||||
|
assert output == TEST_RESULT
|
||||||
|
assert drop(TEST_INDEX_NAME)
|
||||||
|
# remove the test_schema.yml file
|
||||||
|
os.remove("test_schema.yml")
|
||||||
|
|
||||||
|
|
||||||
def test_redis_from_texts_return_keys(texts: List[str]) -> None:
|
def test_redis_from_texts_return_keys(texts: List[str]) -> None:
|
||||||
"""Test from_texts_return_keys constructor."""
|
"""Test from_texts_return_keys constructor."""
|
||||||
docsearch, keys = Redis.from_texts_return_keys(
|
docsearch, keys = Redis.from_texts_return_keys(
|
||||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
|
||||||
assert output == TEST_SINGLE_RESULT
|
assert output == TEST_SINGLE_RESULT
|
||||||
assert len(keys) == len(texts)
|
assert len(keys) == len(texts)
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
@ -73,21 +130,124 @@ def test_redis_from_documents(texts: List[str]) -> None:
|
|||||||
"""Test from_documents constructor."""
|
"""Test from_documents constructor."""
|
||||||
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
|
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
|
||||||
docsearch = Redis.from_documents(docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
docsearch = Redis.from_documents(docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
output = docsearch.similarity_search("foo", k=1)
|
output = docsearch.similarity_search("foo", k=1, return_metadata=True)
|
||||||
assert output == TEST_SINGLE_WITH_METADATA_RESULT
|
assert "a" in output[0].metadata.keys()
|
||||||
|
assert "b" in output[0].metadata.values()
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
def test_redis_add_texts_to_existing() -> None:
|
# -- test filters -- #
|
||||||
"""Test adding a new document"""
|
|
||||||
# Test creating from an existing
|
|
||||||
docsearch = Redis.from_existing_index(
|
@pytest.mark.parametrize(
|
||||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
"filter_expr, expected_length, expected_nums",
|
||||||
|
[
|
||||||
|
(RedisText("text") == "foo", 1, None),
|
||||||
|
(RedisFilter.text("text") == "foo", 1, None),
|
||||||
|
(RedisText("text") % "ba*", 2, ["bar", "baz"]),
|
||||||
|
(RedisNum("num") > 2, 1, [3]),
|
||||||
|
(RedisNum("num") < 2, 1, [1]),
|
||||||
|
(RedisNum("num") >= 2, 2, [2, 3]),
|
||||||
|
(RedisNum("num") <= 2, 2, [1, 2]),
|
||||||
|
(RedisNum("num") != 2, 2, [1, 3]),
|
||||||
|
(RedisFilter.num("num") != 2, 2, [1, 3]),
|
||||||
|
(RedisFilter.tag("category") == "a", 3, None),
|
||||||
|
(RedisFilter.tag("category") == "b", 2, None),
|
||||||
|
(RedisFilter.tag("category") == "c", 2, None),
|
||||||
|
(RedisFilter.tag("category") == ["b", "c"], 3, None),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"text-filter-equals-foo",
|
||||||
|
"alternative-text-equals-foo",
|
||||||
|
"text-filter-fuzzy-match-ba",
|
||||||
|
"number-filter-greater-than-2",
|
||||||
|
"number-filter-less-than-2",
|
||||||
|
"number-filter-greater-equals-2",
|
||||||
|
"number-filter-less-equals-2",
|
||||||
|
"number-filter-not-equals-2",
|
||||||
|
"alternative-number-not-equals-2",
|
||||||
|
"tag-filter-equals-a",
|
||||||
|
"tag-filter-equals-b",
|
||||||
|
"tag-filter-equals-c",
|
||||||
|
"tag-filter-equals-b-or-c",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
docsearch.add_texts(["foo"])
|
def test_redis_filters_1(
|
||||||
output = docsearch.similarity_search("foo", k=2)
|
filter_expr: RedisFilterExpression,
|
||||||
assert output == TEST_RESULT
|
expected_length: int,
|
||||||
assert drop(TEST_INDEX_NAME)
|
expected_nums: Optional[list],
|
||||||
|
) -> None:
|
||||||
|
metadata = [
|
||||||
|
{"name": "joe", "num": 1, "text": "foo", "category": ["a", "b"]},
|
||||||
|
{"name": "john", "num": 2, "text": "bar", "category": ["a", "c"]},
|
||||||
|
{"name": "jane", "num": 3, "text": "baz", "category": ["b", "c", "a"]},
|
||||||
|
]
|
||||||
|
documents = [Document(page_content="foo", metadata=m) for m in metadata]
|
||||||
|
docsearch = Redis.from_documents(
|
||||||
|
documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
|
||||||
|
|
||||||
|
assert len(output) == expected_length
|
||||||
|
|
||||||
|
if expected_nums is not None:
|
||||||
|
for out in output:
|
||||||
|
assert (
|
||||||
|
out.metadata["text"] in expected_nums
|
||||||
|
or int(out.metadata["num"]) in expected_nums
|
||||||
|
)
|
||||||
|
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
# -- test index specification -- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_index_specification_generation() -> None:
|
||||||
|
index_schema = {
|
||||||
|
"text": [{"name": "job"}, {"name": "title"}],
|
||||||
|
"numeric": [{"name": "salary"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
text = ["foo"]
|
||||||
|
meta = {"job": "engineer", "title": "principal engineer", "salary": 100000}
|
||||||
|
docs = [Document(page_content=t, metadata=meta) for t in text]
|
||||||
|
r = Redis.from_documents(
|
||||||
|
docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, index_schema=index_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
output = r.similarity_search("foo", k=1, return_metadata=True)
|
||||||
|
assert output[0].metadata["job"] == "engineer"
|
||||||
|
assert output[0].metadata["title"] == "principal engineer"
|
||||||
|
assert int(output[0].metadata["salary"]) == 100000
|
||||||
|
|
||||||
|
info = convert_bytes(r.client.ft(r.index_name).info())
|
||||||
|
attributes = info["attributes"]
|
||||||
|
assert len(attributes) == 5
|
||||||
|
for attr in attributes:
|
||||||
|
d = make_dict(attr)
|
||||||
|
if d["identifier"] == "job":
|
||||||
|
assert d["type"] == "TEXT"
|
||||||
|
elif d["identifier"] == "title":
|
||||||
|
assert d["type"] == "TEXT"
|
||||||
|
elif d["identifier"] == "salary":
|
||||||
|
assert d["type"] == "NUMERIC"
|
||||||
|
elif d["identifier"] == "content":
|
||||||
|
assert d["type"] == "TEXT"
|
||||||
|
elif d["identifier"] == "content_vector":
|
||||||
|
assert d["type"] == "VECTOR"
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected attribute in index schema")
|
||||||
|
|
||||||
|
assert drop(r.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
# -- test distance metrics -- #
|
||||||
|
|
||||||
|
cosine_schema: Dict = {"distance_metric": "cosine"}
|
||||||
|
ip_schema: Dict = {"distance_metric": "IP"}
|
||||||
|
l2_schema: Dict = {"distance_metric": "L2"}
|
||||||
|
|
||||||
|
|
||||||
def test_cosine(texts: List[str]) -> None:
|
def test_cosine(texts: List[str]) -> None:
|
||||||
@ -96,7 +256,7 @@ def test_cosine(texts: List[str]) -> None:
|
|||||||
texts,
|
texts,
|
||||||
FakeEmbeddings(),
|
FakeEmbeddings(),
|
||||||
redis_url=TEST_REDIS_URL,
|
redis_url=TEST_REDIS_URL,
|
||||||
distance_metric="COSINE",
|
vector_schema=cosine_schema,
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search_with_score("far", k=2)
|
output = docsearch.similarity_search_with_score("far", k=2)
|
||||||
_, score = output[1]
|
_, score = output[1]
|
||||||
@ -107,7 +267,7 @@ def test_cosine(texts: List[str]) -> None:
|
|||||||
def test_l2(texts: List[str]) -> None:
|
def test_l2(texts: List[str]) -> None:
|
||||||
"""Test Flat L2 distance."""
|
"""Test Flat L2 distance."""
|
||||||
docsearch = Redis.from_texts(
|
docsearch = Redis.from_texts(
|
||||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
|
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=l2_schema
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search_with_score("far", k=2)
|
output = docsearch.similarity_search_with_score("far", k=2)
|
||||||
_, score = output[1]
|
_, score = output[1]
|
||||||
@ -118,7 +278,7 @@ def test_l2(texts: List[str]) -> None:
|
|||||||
def test_ip(texts: List[str]) -> None:
|
def test_ip(texts: List[str]) -> None:
|
||||||
"""Test inner product distance."""
|
"""Test inner product distance."""
|
||||||
docsearch = Redis.from_texts(
|
docsearch = Redis.from_texts(
|
||||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
|
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=ip_schema
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search_with_score("far", k=2)
|
output = docsearch.similarity_search_with_score("far", k=2)
|
||||||
_, score = output[1]
|
_, score = output[1]
|
||||||
@ -126,29 +286,34 @@ def test_ip(texts: List[str]) -> None:
|
|||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
def test_similarity_search_limit_score(texts: List[str]) -> None:
|
def test_similarity_search_limit_distance(texts: List[str]) -> None:
|
||||||
"""Test similarity search limit score."""
|
"""Test similarity search limit score."""
|
||||||
docsearch = Redis.from_texts(
|
docsearch = Redis.from_texts(
|
||||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
|
texts,
|
||||||
|
FakeEmbeddings(),
|
||||||
|
redis_url=TEST_REDIS_URL,
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search_limit_score("far", k=2, score_threshold=0.1)
|
output = docsearch.similarity_search(texts[0], k=3, distance_threshold=0.1)
|
||||||
assert len(output) == 1
|
|
||||||
_, score = output[0]
|
# can't check score but length of output should be 2
|
||||||
assert score == COSINE_SCORE
|
assert len(output) == 2
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
def test_similarity_search_with_score_with_limit_score(texts: List[str]) -> None:
|
def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None:
|
||||||
"""Test similarity search with score with limit score."""
|
"""Test similarity search with score with limit score."""
|
||||||
|
|
||||||
docsearch = Redis.from_texts(
|
docsearch = Redis.from_texts(
|
||||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
|
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||||
)
|
)
|
||||||
output = docsearch.similarity_search_with_relevance_scores(
|
output = docsearch.similarity_search_with_score(
|
||||||
"far", k=2, score_threshold=0.1
|
texts[0], k=3, distance_threshold=0.1, return_metadata=True
|
||||||
)
|
)
|
||||||
assert len(output) == 1
|
|
||||||
_, score = output[0]
|
assert len(output) == 2
|
||||||
assert score == COSINE_SCORE
|
for out, score in output:
|
||||||
|
if out.page_content == texts[1]:
|
||||||
|
score == COSINE_SCORE
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
@ -156,6 +321,48 @@ def test_delete(texts: List[str]) -> None:
|
|||||||
"""Test deleting a new document"""
|
"""Test deleting a new document"""
|
||||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
ids = docsearch.add_texts(["foo"])
|
ids = docsearch.add_texts(["foo"])
|
||||||
got = docsearch.delete(ids=ids)
|
got = docsearch.delete(ids=ids, redis_url=TEST_REDIS_URL)
|
||||||
assert got
|
assert got
|
||||||
assert drop(docsearch.index_name)
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_as_retriever() -> None:
|
||||||
|
texts = ["foo", "foo", "foo", "foo", "bar"]
|
||||||
|
docsearch = Redis.from_texts(
|
||||||
|
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
||||||
|
results = retriever.get_relevant_documents("foo")
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all([d.page_content == "foo" for d in results])
|
||||||
|
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_retriever_distance_threshold() -> None:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
|
|
||||||
|
retriever = docsearch.as_retriever(
|
||||||
|
search_type="similarity_distance_threshold",
|
||||||
|
search_kwargs={"k": 3, "distance_threshold": 0.1},
|
||||||
|
)
|
||||||
|
results = retriever.get_relevant_documents("foo")
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_redis_retriever_score_threshold() -> None:
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||||
|
|
||||||
|
retriever = docsearch.as_retriever(
|
||||||
|
search_type="similarity_score_threshold",
|
||||||
|
search_kwargs={"k": 3, "score_threshold": 0.91},
|
||||||
|
)
|
||||||
|
results = retriever.get_relevant_documents("foo")
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
assert drop(docsearch.index_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user