Redis metadata filtering and specification, index customization (#8612)

### Description

The previous Redis implementation did not allow for the user to specify
the index configuration (i.e. changing the underlying algorithm) or add
additional metadata to use for querying (i.e. hybrid or "filtered"
search).

This PR introduces the ability to specify custom index attributes and
metadata attributes as well as use that metadata in filtered queries.
Overall, more structure was introduced to the Redis implementation that
should allow for easier maintainability moving forward.

# New Features

The following features are now available with the Redis integration into
Langchain

## Index schema generation

The schema for the index will now be automatically generated if not
specified by the user. For example, the data above has the multiple
metadata categories. The the following example

```python

from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.redis import Redis

embeddings = OpenAIEmbeddings()


rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users"
)
```

Loading the data in through this and the other ``from_documents`` and
``from_texts`` methods will now generate index schema in Redis like the
following.

view index schema with the ``redisvl`` tool. [link](redisvl.com)

```bash
$ rvl index info -i users
```


Index Information:
| Index Name | Storage Type | Prefixes | Index Options | Indexing |

|--------------|----------------|---------------|-----------------|------------|
| users | HASH | ['doc:users'] | [] | 0 |
Index Fields:
| Name | Attribute | Type | Field Option | Option Value |

|----------------|----------------|---------|----------------|----------------|
| user | user | TEXT | WEIGHT | 1 |
| job | job | TEXT | WEIGHT | 1 |
| credit_score | credit_score | TEXT | WEIGHT | 1 |
| content | content | TEXT | WEIGHT | 1 |
| age | age | NUMERIC | | |
| content_vector | content_vector | VECTOR | | |


### Custom Metadata specification

The metadata schema generation has the following rules
1. All text fields are indexed as text fields.
2. All numeric fields are index as numeric fields.

If you would like to have a text field as a tag field, users can specify
overrides like the following for the example data

```python

# this can also be a path to a yaml file
index_schema = {
    "text": [{"name": "user"}, {"name": "job"}],
    "tag": [{"name": "credit_score"}],
    "numeric": [{"name": "age"}],
}

rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users"
)
```
This will change the index specification to 

Index Information:
| Index Name | Storage Type | Prefixes | Index Options | Indexing |

|--------------|----------------|----------------|-----------------|------------|
| users2 | HASH | ['doc:users2'] | [] | 0 |
Index Fields:
| Name | Attribute | Type | Field Option | Option Value |

|----------------|----------------|---------|----------------|----------------|
| user | user | TEXT | WEIGHT | 1 |
| job | job | TEXT | WEIGHT | 1 |
| content | content | TEXT | WEIGHT | 1 |
| credit_score | credit_score | TAG | SEPARATOR | , |
| age | age | NUMERIC | | |
| content_vector | content_vector | VECTOR | | |


and throw a warning to the user (log output) that the generated schema
does not match the specified schema.

```text
index_schema does not match generated schema from metadata.
index_schema: {'text': [{'name': 'user'}, {'name': 'job'}], 'tag': [{'name': 'credit_score'}], 'numeric': [{'name': 'age'}]}
generated_schema: {'text': [{'name': 'user'}, {'name': 'job'}, {'name': 'credit_score'}], 'numeric': [{'name': 'age'}]}
```

As long as this is on purpose,  this is fine.

The schema can be defined as a yaml file or a dictionary

```yaml

text:
  - name: user
  - name: job
tag:
  - name: credit_score
numeric:
  - name: age

```

and you pass in a path like

```python
rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users3",
    index_schema=Path("sample1.yml").resolve()
)
```

Which will create the same schema as defined in the dictionary example


Index Information:
| Index Name | Storage Type | Prefixes | Index Options | Indexing |

|--------------|----------------|----------------|-----------------|------------|
| users3 | HASH | ['doc:users3'] | [] | 0 |
Index Fields:
| Name | Attribute | Type | Field Option | Option Value |

|----------------|----------------|---------|----------------|----------------|
| user | user | TEXT | WEIGHT | 1 |
| job | job | TEXT | WEIGHT | 1 |
| content | content | TEXT | WEIGHT | 1 |
| credit_score | credit_score | TAG | SEPARATOR | , |
| age | age | NUMERIC | | |
| content_vector | content_vector | VECTOR | | |



### Custom Vector Indexing Schema

Users with large use cases may want to change how they formulate the
vector index created by Langchain

To utilize all the features of Redis for vector database use cases like
this, you can now do the following to pass in index attribute modifiers
like changing the indexing algorithm to HNSW.

```python
vector_schema = {
    "algorithm": "HNSW"
}

rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users3",
    vector_schema=vector_schema
)

```

A more complex example may look like

```python
vector_schema = {
    "algorithm": "HNSW",
    "ef_construction": 200,
    "ef_runtime": 20
}

rds, keys = Redis.from_texts_return_keys(
    texts,
    embeddings,
    metadatas=metadata,
    redis_url="redis://localhost:6379",
    index_name="users3",
    vector_schema=vector_schema
)
```

All names correspond to the arguments you would set if using Redis-py or
RedisVL. (put in doc link later)


### Better Querying

Both vector queries and Range (limit) queries are now available and
metadata is returned by default. The outputs are shown.

```python
>>> query = "foo"
>>> results = rds.similarity_search(query, k=1)
>>> print(results)
[Document(page_content='foo', metadata={'user': 'derrick', 'job': 'doctor', 'credit_score': 'low', 'age': '14', 'id': 'doc:users:657a47d7db8b447e88598b83da879b9d', 'score': '7.15255737305e-07'})]

>>> results = rds.similarity_search_with_score(query, k=1, return_metadata=False)
>>> print(results) # no metadata, but with scores
[(Document(page_content='foo', metadata={}), 7.15255737305e-07)]

>>> results = rds.similarity_search_limit_score(query, k=6, score_threshold=0.0001)
>>> print(len(results)) # range query (only above threshold even if k is higher)
4
```

### Custom metadata filtering

A big advantage of Redis in this space is being able to do filtering on
data stored alongside the vector itself. With the example above, the
following is now possible in langchain. The equivalence operators are
overridden to describe a new expression language that mimic that of
[redisvl](redisvl.com). This allows for arbitrarily long sequences of
filters that resemble SQL commands that can be used directly with vector
queries and range queries.

There are two interfaces by which to do so and both are shown. 

```python

>>> from langchain.vectorstores.redis import RedisFilter, RedisNum, RedisText

>>> age_filter = RedisFilter.num("age") > 18
>>> age_filter = RedisNum("age") > 18 # equivalent
>>> results = rds.similarity_search(query, filter=age_filter)
>>> print(len(results))
3

>>> job_filter = RedisFilter.text("job") == "engineer" 
>>> job_filter = RedisText("job") == "engineer" # equivalent
>>> results = rds.similarity_search(query, filter=job_filter)
>>> print(len(results))
2

# fuzzy match text search
>>> job_filter = RedisFilter.text("job") % "eng*"
>>> results = rds.similarity_search(query, filter=job_filter)
>>> print(len(results))
2


# combined filters (AND)
>>> combined = age_filter & job_filter
>>> results = rds.similarity_search(query, filter=combined)
>>> print(len(results))
1

# combined filters (OR)
>>> combined = age_filter | job_filter
>>> results = rds.similarity_search(query, filter=combined)
>>> print(len(results))
4
```

All the above filter results can be checked against the data above.


### Other

  - Issue: #3967 
  - Dependencies: No added dependencies
  - Tag maintainer: @hwchase17 @baskaryan @rlancemartin 
  - Twitter handle: @sampartee

---------

Co-authored-by: Naresh Rangan <naresh.rangan0@walmart.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Sam Partee 2023-08-25 17:22:50 -07:00 committed by GitHub
parent fa0b8f3368
commit a28eea5767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 3480 additions and 821 deletions

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
@ -302,6 +303,14 @@ class RedisSemanticCache(BaseCache):
# TODO - implement a TTL policy in Redis
DEFAULT_SCHEMA = {
"content_key": "prompt",
"text": [
{"name": "prompt"},
],
"extra": [{"name": "return_val"}, {"name": "llm_string"}],
}
def __init__(
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
):
@ -349,12 +358,14 @@ class RedisSemanticCache(BaseCache):
embedding=self.embedding,
index_name=index_name,
redis_url=self.redis_url,
schema=cast(Dict, self.DEFAULT_SCHEMA),
)
except ValueError:
redis = RedisVectorstore(
embedding_function=self.embedding.embed_query,
embedding=self.embedding,
index_name=index_name,
redis_url=self.redis_url,
index_schema=cast(Dict, self.DEFAULT_SCHEMA),
)
_embedding = self.embedding.embed_query(text="test")
redis._create_index(dim=len(_embedding))
@ -374,17 +385,18 @@ class RedisSemanticCache(BaseCache):
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
llm_cache = self._get_llm_cache(llm_string)
generations = []
generations: List = []
# Read from a Hash
results = llm_cache.similarity_search_limit_score(
results = llm_cache.similarity_search(
query=prompt,
k=1,
score_threshold=self.score_threshold,
distance_threshold=self.score_threshold,
)
if results:
for document in results:
for text in document.metadata["return_val"]:
generations.append(Generation(text=text))
generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
@ -402,11 +414,11 @@ class RedisSemanticCache(BaseCache):
)
return
llm_cache = self._get_llm_cache(llm_string)
# Write to vectorstore
_dump_generations_to_json([g for g in return_val])
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": [generation.text for generation in return_val],
"return_val": _dump_generations_to_json([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])

View File

@ -1,7 +1,5 @@
from typing import Any, Dict, List
from langchain.schema.messages import get_buffer_string # noqa: 401
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
"""

View File

@ -1,16 +1,64 @@
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
)
import re
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
from urllib.parse import urlparse
import numpy as np
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
logger = logging.getLogger(__name__)
def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
return np.array(array).astype(dtype).tobytes()
class TokenEscaper:
"""
Escape punctuation within an input string.
"""
# Characters that RediSearch requires us to escape during queries.
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
def __init__(self, escape_chars_re: Optional[Pattern] = None):
if escape_chars_re:
self.escaped_chars_re = escape_chars_re
else:
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
def escape(self, value: str) -> str:
def escape_symbol(match: re.Match) -> str:
value = match.group(0)
return f"\\{value}"
return self.escaped_chars_re.sub(escape_symbol, value)
def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"Redis cannot be used as a vector database without RediSearch >=2.4"
"Please head to https://redis.io/docs/stack/search/quick_start/"
"to know more about installing the RediSearch module within Redis Stack."
)
logger.error(error_message)
raise ValueError(error_message)
def get_client(redis_url: str, **kwargs: Any) -> RedisType:

View File

@ -1,664 +0,0 @@
from __future__ import annotations
import json
import logging
import uuid
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Type,
)
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import root_validator
from langchain.utilities.redis import get_client
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
from redis.commands.search.query import Query
# required modules
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20400},
{"name": "searchlight", "ver": 20400},
]
# distance mmetrics
REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"]
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
"""Check if the correct Redis modules are installed."""
installed_modules = client.module_list()
installed_modules = {
module[b"name"].decode("utf-8"): module for module in installed_modules
}
for module in required_modules:
if module["name"] in installed_modules and int(
installed_modules[module["name"]][b"ver"]
) >= int(module["ver"]):
return
# otherwise raise error
error_message = (
"Redis cannot be used as a vector database without RediSearch >=2.4"
"Please head to https://redis.io/docs/stack/search/quick_start/"
"to know more about installing the RediSearch module within Redis Stack."
)
logger.error(error_message)
raise ValueError(error_message)
def _check_index_exists(client: RedisType, index_name: str) -> bool:
"""Check if Redis index exists."""
try:
client.ft(index_name).info()
except: # noqa: E722
logger.info("Index does not exist")
return False
logger.info("Index already exists")
return True
def _redis_key(prefix: str) -> str:
"""Redis key schema for a given prefix."""
return f"{prefix}:{uuid.uuid4().hex}"
def _redis_prefix(index_name: str) -> str:
"""Redis key prefix for a given index."""
return f"doc:{index_name}"
def _default_relevance_score(val: float) -> float:
return 1 - val
class Redis(VectorStore):
"""`Redis` vector store.
To use, you should have the ``redis`` python package installed.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = Redis(
redis_url="redis://username:password@localhost:6379"
index_name="my-index",
embedding_function=embeddings.embed_query,
)
To use a redis replication setup with multiple redis server and redis sentinels
set "redis_url" to "redis+sentinel://" scheme. With this url format a path is
needed holding the name of the redis service within the sentinels to get the
correct redis server connection. The default service name is "mymaster".
An optional username or password is used for booth connections to the rediserver
and the sentinel, different passwords for server and sentinel are not supported.
And as another constraint only one sentinel instance can be given:
Example:
.. code-block:: python
vectorstore = Redis(
redis_url="redis+sentinel://username:password@sentinelhost:26379/mymaster/0"
index_name="my-index",
embedding_function=embeddings.embed_query,
)
"""
def __init__(
self,
redis_url: str,
index_name: str,
embedding_function: Callable,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
relevance_score_fn: Optional[Callable[[float], float]] = None,
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
**kwargs: Any,
):
"""Initialize with necessary components."""
self.embedding_function = embedding_function
self.index_name = index_name
try:
redis_client = get_client(redis_url=redis_url, **kwargs)
# check if redis has redisearch module installed
_check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
except ValueError as e:
raise ValueError(f"Redis failed to connect: {e}")
self.client = redis_client
self.content_key = content_key
self.metadata_key = metadata_key
self.vector_key = vector_key
self.distance_metric = distance_metric
self.relevance_score_fn = relevance_score_fn
@property
def embeddings(self) -> Optional[Embeddings]:
# TODO: Accept embedding object directly
return None
def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self.relevance_score_fn:
return self.relevance_score_fn
if self.distance_metric == "COSINE":
return self._cosine_relevance_score_fn
elif self.distance_metric == "IP":
return self._max_inner_product_relevance_score_fn
elif self.distance_metric == "L2":
return self._euclidean_relevance_score_fn
else:
return _default_relevance_score
def _create_index(self, dim: int = 1536) -> None:
try:
from redis.commands.search.field import TextField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
except ImportError:
raise ImportError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Check if index exists
if not _check_index_exists(self.client, self.index_name):
# Define schema
schema = (
TextField(name=self.content_key),
TextField(name=self.metadata_key),
VectorField(
self.vector_key,
"FLAT",
{
"TYPE": "FLOAT32",
"DIM": dim,
"DISTANCE_METRIC": self.distance_metric,
},
),
)
prefix = _redis_prefix(self.index_name)
# Create Redis Index
self.client.ft(self.index_name).create_index(
fields=schema,
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
embeddings: Optional[List[List[float]]] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Add more texts to the vectorstore.
Args:
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
Defaults to None.
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
embeddings. Defaults to None.
keys (List[str]) or ids (List[str]): Identifiers of entries.
Defaults to None.
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
Returns:
List[str]: List of ids added to the vectorstore
"""
ids = []
prefix = _redis_prefix(self.index_name)
# Get keys or ids from kwargs
# Other vectorstores use ids
keys_or_ids = kwargs.get("keys", kwargs.get("ids"))
# Write data to redis
pipeline = self.client.pipeline(transaction=False)
for i, text in enumerate(texts):
# Use provided values by default or fallback
key = keys_or_ids[i] if keys_or_ids else _redis_key(prefix)
metadata = metadatas[i] if metadatas else {}
embedding = embeddings[i] if embeddings else self.embedding_function(text)
pipeline.hset(
key,
mapping={
self.content_key: text,
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
self.metadata_key: json.dumps(metadata),
},
)
ids.append(key)
# Write batch
if i % batch_size == 0:
pipeline.execute()
# Cleanup final batch
pipeline.execute()
return ids
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""
Returns the most similar indexed documents to the query text.
Args:
query (str): The query text for which to find similar documents.
k (int): The number of documents to return. Default is 4.
Returns:
List[Document]: A list of documents that are most similar to the query text.
"""
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, _ in docs_and_scores]
def similarity_search_limit_score(
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
) -> List[Document]:
"""
Returns the most similar indexed documents to the query text within the
score_threshold range.
Args:
query (str): The query text for which to find similar documents.
k (int): The number of documents to return. Default is 4.
score_threshold (float): The minimum matching score required for a document
to be considered a match. Defaults to 0.2.
Because the similarity calculation algorithm is based on cosine
similarity, the smaller the angle, the higher the similarity.
Returns:
List[Document]: A list of documents that are most similar to the query text,
including the match score for each document.
Note:
If there are no documents that satisfy the score_threshold value,
an empty list is returned.
"""
docs_and_scores = self.similarity_search_with_score(query, k=k)
return [doc for doc, score in docs_and_scores if score < score_threshold]
def _prepare_query(self, k: int) -> Query:
try:
from redis.commands.search.query import Query
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
# Prepare the Query
hybrid_fields = "*"
base_query = (
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
)
return_fields = [self.metadata_key, self.content_key, "vector_score", "id"]
return (
Query(base_query)
.return_fields(*return_fields)
.sort_by("vector_score")
.paging(0, k)
.dialect(2)
)
def similarity_search_with_score(
self, query: str, k: int = 4
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query and score for each
"""
# Creates embedding vector from user query
embedding = self.embedding_function(query)
# Creates Redis query
redis_query = self._prepare_query(k)
params_dict: Mapping[str, str] = {
"vector": np.array(embedding) # type: ignore
.astype(dtype=np.float32)
.tobytes()
}
# Perform vector search
results = self.client.ft(self.index_name).search(redis_query, params_dict)
# Prepare document results
docs_and_scores: List[Tuple[Document, float]] = []
for result in results.docs:
metadata = {**json.loads(result.metadata), "id": result.id}
doc = Document(page_content=result.content, metadata=metadata)
docs_and_scores.append((doc, float(result.vector_score)))
return docs_and_scores
@classmethod
def from_texts_return_keys(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
**kwargs: Any,
) -> Tuple[Redis, List[str]]:
"""Create a Redis vectorstore from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in Redis.
3. Adds the documents to the newly created Redis index.
4. Returns the keys of the newly created documents.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
redisearch, keys = RediSearch.from_texts_return_keys(
texts,
embeddings,
redis_url="redis://username:password@localhost:6379"
)
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
if "redis_url" in kwargs:
kwargs.pop("redis_url")
# Name of the search index if not given
if not index_name:
index_name = uuid.uuid4().hex
# Create instance
instance = cls(
redis_url,
index_name,
embedding.embed_query,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
distance_metric=distance_metric,
**kwargs,
)
# Create embeddings over documents
embeddings = embedding.embed_documents(texts)
# Create the search index
instance._create_index(dim=len(embeddings[0]))
# Add data to Redis
keys = instance.add_texts(texts, metadatas, embeddings)
return instance, keys
@classmethod
def from_texts(
cls: Type[Redis],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
index_name: Optional[str] = None,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any,
) -> Redis:
"""Create a Redis vectorstore from raw documents.
This is a user-friendly interface that:
1. Embeds documents.
2. Creates a new index for the embeddings in Redis.
3. Adds the documents to the newly created Redis index.
This is intended to be a quick way to get started.
Example:
.. code-block:: python
from langchain.vectorstores import Redis
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
redisearch = RediSearch.from_texts(
texts,
embeddings,
redis_url="redis://username:password@localhost:6379"
)
"""
instance, _ = cls.from_texts_return_keys(
texts,
embedding,
metadatas=metadatas,
index_name=index_name,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
**kwargs,
)
return instance
@staticmethod
def delete(
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> bool:
"""
Delete a Redis entry.
Args:
ids: List of ids (keys) to delete.
Returns:
bool: Whether or not the deletions were successful.
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
if ids is None:
raise ValueError("'ids' (keys)() were not provided.")
try:
import redis # noqa: F401
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = get_client(redis_url=redis_url, **kwargs)
except ValueError as e:
raise ValueError(f"Your redis connected error: {e}")
# Check if index exists
try:
client.delete(*ids)
logger.info("Entries deleted")
return True
except: # noqa: E722
# ids does not exist
return False
@staticmethod
def drop_index(
index_name: str,
delete_documents: bool,
**kwargs: Any,
) -> bool:
"""
Drop a Redis search index.
Args:
index_name (str): Name of the index to drop.
delete_documents (bool): Whether to drop the associated documents.
Returns:
bool: Whether or not the drop was successful.
"""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
try:
import redis # noqa: F401
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = get_client(redis_url=redis_url, **kwargs)
except ValueError as e:
raise ValueError(f"Your redis connected error: {e}")
# Check if index exists
try:
client.ft(index_name).dropindex(delete_documents)
logger.info("Drop index")
return True
except: # noqa: E722
# Index not exist
return False
@classmethod
def from_existing_index(
cls,
embedding: Embeddings,
index_name: str,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
**kwargs: Any,
) -> Redis:
"""Connect to an existing Redis index."""
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
try:
import redis # noqa: F401
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# We need to first remove redis_url from kwargs,
# otherwise passing it to Redis will result in an error.
if "redis_url" in kwargs:
kwargs.pop("redis_url")
client = get_client(redis_url=redis_url, **kwargs)
# check if redis has redisearch module installed
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
# ensure that the index already exists
assert _check_index_exists(
client, index_name
), f"Index {index_name} does not exist"
except Exception as e:
raise ValueError(f"Redis failed to connect: {e}")
return cls(
redis_url,
index_name,
embedding.embed_query,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
**kwargs,
)
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class RedisVectorStoreRetriever(VectorStoreRetriever):
"""Retriever for `Redis` vector store."""
vectorstore: Redis
"""Redis VectorStore."""
search_type: str = "similarity"
"""Type of search to perform. Can be either 'similarity' or 'similarity_limit'."""
k: int = 4
"""Number of documents to return."""
score_threshold: float = 0.4
"""Score threshold for similarity_limit search."""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "similarity_limit"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
elif self.search_type == "similarity_limit":
docs = self.vectorstore.similarity_search_limit_score(
query, k=self.k, score_threshold=self.score_threshold
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
return self.vectorstore.add_documents(documents, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
return await self.vectorstore.aadd_documents(documents, **kwargs)

View File

@ -0,0 +1,9 @@
from .base import Redis
from .filters import (
RedisFilter,
RedisNum,
RedisTag,
RedisText,
)
__all__ = ["Redis", "RedisFilter", "RedisTag", "RedisText", "RedisNum"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,20 @@
from typing import Any, Dict, List
import numpy as np
# required modules
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20600},
{"name": "searchlight", "ver": 20600},
]
# distance metrics
REDIS_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"]
# supported vector datatypes
REDIS_VECTOR_DTYPE_MAP: Dict[str, Any] = {
"FLOAT32": np.float32,
"FLOAT64": np.float64,
}
REDIS_TAG_SEPARATOR = ","

View File

@ -0,0 +1,420 @@
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union
from langchain.utilities.redis import TokenEscaper
# disable mypy error for dunder method overrides
# mypy: disable-error-code="override"
class RedisFilterOperator(Enum):
EQ = 1
NE = 2
LT = 3
GT = 4
LE = 5
GE = 6
OR = 7
AND = 8
LIKE = 9
IN = 10
class RedisFilter:
@staticmethod
def text(field: str) -> "RedisText":
return RedisText(field)
@staticmethod
def num(field: str) -> "RedisNum":
return RedisNum(field)
@staticmethod
def tag(field: str) -> "RedisTag":
return RedisTag(field)
class RedisFilterField:
escaper: "TokenEscaper" = TokenEscaper()
OPERATORS: Dict[RedisFilterOperator, str] = {}
def __init__(self, field: str):
self._field = field
self._value: Any = None
self._operator: RedisFilterOperator = RedisFilterOperator.EQ
def equals(self, other: "RedisFilterField") -> bool:
if not isinstance(other, type(self)):
return False
return self._field == other._field and self._value == other._value
def _set_value(
self, val: Any, val_type: type, operator: RedisFilterOperator
) -> None:
# check that the operator is supported by this class
if operator not in self.OPERATORS:
raise ValueError(
f"Operator {operator} not supported by {self.__class__.__name__}. "
+ f"Supported operators are {self.OPERATORS.values()}"
)
if not isinstance(val, val_type):
raise TypeError(
f"Right side argument passed to operator {self.OPERATORS[operator]} "
f"with left side "
f"argument {self.__class__.__name__} must be of type {val_type}"
)
self._value = val
self._operator = operator
def check_operator_misuse(func: Callable) -> Callable:
@wraps(func)
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
# Extracting 'other' from positional arguments or keyword arguments
other = kwargs.get("other") if "other" in kwargs else None
if not other:
for arg in args:
if isinstance(arg, type(instance)):
other = arg
break
if isinstance(other, type(instance)):
raise ValueError(
"Equality operators are overridden for FilterExpression creation. Use "
".equals() for equality checks"
)
return func(instance, *args, **kwargs)
return wrapper
class RedisTag(RedisFilterField):
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
OPERATORS: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "==",
RedisFilterOperator.NE: "!=",
RedisFilterOperator.IN: "==",
}
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "@%s:{%s}",
RedisFilterOperator.NE: "(-@%s:{%s})",
RedisFilterOperator.IN: "@%s:{%s}",
}
def __init__(self, field: str):
"""Create a RedisTag FilterField
Args:
field (str): The name of the RedisTag field in the index to be queried
against.
"""
super().__init__(field)
def _set_tag_value(
self, other: Union[List[str], str], operator: RedisFilterOperator
) -> None:
if isinstance(other, list):
if not all(isinstance(tag, str) for tag in other):
raise ValueError("All tags must be strings")
else:
other = [other]
self._set_value(other, list, operator)
@check_operator_misuse
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
"""Create a RedisTag equality filter expression
Args:
other (Union[List[str], str]): The tag(s) to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisTag
>>> filter = RedisTag("brand") == "nike"
"""
self._set_tag_value(other, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
"""Create a RedisTag inequality filter expression
Args:
other (Union[List[str], str]): The tag(s) to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisTag
>>> filter = RedisTag("brand") != "nike"
"""
self._set_tag_value(other, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
@property
def _formatted_tag_value(self) -> str:
return "|".join([self.escaper.escape(tag) for tag in self._value])
def __str__(self) -> str:
if not self._value:
raise ValueError(
f"Operator must be used before calling __str__. Operators are "
f"{self.OPERATORS.values()}"
)
"""Return the Redis Query syntax for a RedisTag filter expression"""
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._formatted_tag_value,
)
class RedisNum(RedisFilterField):
"""A RedisFilterField representing a numeric field in a Redis index."""
OPERATORS: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "==",
RedisFilterOperator.NE: "!=",
RedisFilterOperator.LT: "<",
RedisFilterOperator.GT: ">",
RedisFilterOperator.LE: "<=",
RedisFilterOperator.GE: ">=",
}
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
RedisFilterOperator.EQ: "@%s:[%i %i]",
RedisFilterOperator.NE: "(-@%s:[%i %i])",
RedisFilterOperator.GT: "@%s:[(%i +inf]",
RedisFilterOperator.LT: "@%s:[-inf (%i]",
RedisFilterOperator.GE: "@%s:[%i +inf]",
RedisFilterOperator.LE: "@%s:[-inf %i]",
}
def __str__(self) -> str:
"""Return the Redis Query syntax for a Numeric filter expression"""
if not self._value:
raise ValueError(
f"Operator must be used before calling __str__. Operators are "
f"{self.OPERATORS.values()}"
)
if (
self._operator == RedisFilterOperator.EQ
or self._operator == RedisFilterOperator.NE
):
return self.OPERATOR_MAP[self._operator] % (
self._field,
self._value,
self._value,
)
else:
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
@check_operator_misuse
def __eq__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric equality filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") == 90210
"""
self._set_value(other, int, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric inequality filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("zipcode") != 90210
"""
self._set_value(other, int, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
def __gt__(self, other: int) -> "RedisFilterExpression":
"""Create a RedisNumeric greater than filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") > 18
"""
self._set_value(other, int, RedisFilterOperator.GT)
return RedisFilterExpression(str(self))
def __lt__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric less than filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") < 18
"""
self._set_value(other, int, RedisFilterOperator.LT)
return RedisFilterExpression(str(self))
def __ge__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric greater than or equal to filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") >= 18
"""
self._set_value(other, int, RedisFilterOperator.GE)
return RedisFilterExpression(str(self))
def __le__(self, other: int) -> "RedisFilterExpression":
"""Create a Numeric less than or equal to filter expression
Args:
other (int): The value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisNum
>>> filter = RedisNum("age") <= 18
"""
self._set_value(other, int, RedisFilterOperator.LE)
return RedisFilterExpression(str(self))
class RedisText(RedisFilterField):
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
OPERATORS = {
RedisFilterOperator.EQ: "==",
RedisFilterOperator.NE: "!=",
RedisFilterOperator.LIKE: "%",
}
OPERATOR_MAP = {
RedisFilterOperator.EQ: '@%s:"%s"',
RedisFilterOperator.NE: '(-@%s:"%s")',
RedisFilterOperator.LIKE: "@%s:%s",
}
@check_operator_misuse
def __eq__(self, other: str) -> "RedisFilterExpression":
"""Create a RedisText equality filter expression
Args:
other (str): The text value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisText
>>> filter = RedisText("job") == "engineer"
"""
self._set_value(other, str, RedisFilterOperator.EQ)
return RedisFilterExpression(str(self))
@check_operator_misuse
def __ne__(self, other: str) -> "RedisFilterExpression":
"""Create a RedisText inequality filter expression
Args:
other (str): The text value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisText
>>> filter = RedisText("job") != "engineer"
"""
self._set_value(other, str, RedisFilterOperator.NE)
return RedisFilterExpression(str(self))
def __mod__(self, other: str) -> "RedisFilterExpression":
"""Create a RedisText like filter expression
Args:
other (str): The text value to filter on.
Example:
>>> from langchain.vectorstores.redis import RedisText
>>> filter = RedisText("job") % "engineer"
"""
self._set_value(other, str, RedisFilterOperator.LIKE)
return RedisFilterExpression(str(self))
def __str__(self) -> str:
if not self._value:
raise ValueError(
f"Operator must be used before calling __str__. Operators are "
f"{self.OPERATORS.values()}"
)
try:
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
except KeyError:
raise Exception("Invalid operator")
class RedisFilterExpression:
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
RedisFilterExpressions can be combined using the & and | operators to create
complex logical expressions that evaluate to the Redis Query language.
This presents an interface by which users can create complex queries
without having to know the Redis Query language.
Filter expressions are not initialized directly. Instead they are built
by combining RedisFilterFields using the & and | operators.
Examples:
>>> from langchain.vectorstores.redis import RedisTag, RedisNum
>>> brand_is_nike = RedisTag("brand") == "nike"
>>> price_is_under_100 = RedisNum("price") < 100
>>> filter = brand_is_nike & price_is_under_100
>>> print(str(filter))
(@brand:{nike} @price:[-inf (100)])
"""
def __init__(
self,
_filter: Optional[str] = None,
operator: Optional[RedisFilterOperator] = None,
left: Optional["RedisFilterExpression"] = None,
right: Optional["RedisFilterExpression"] = None,
):
self._filter = _filter
self._operator = operator
self._left = left
self._right = right
def __and__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
return RedisFilterExpression(
operator=RedisFilterOperator.AND, left=self, right=other
)
def __or__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
return RedisFilterExpression(
operator=RedisFilterOperator.OR, left=self, right=other
)
def __str__(self) -> str:
# top level check that allows recursive calls to __str__
if not self._filter and not self._operator:
raise ValueError("Improperly initialized RedisFilterExpression")
# allow for single filter expression without operators as last
# expression in the chain might not have an operator
if self._operator:
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
return f"({str(self._left)}{operator_str}{str(self._right)})"
# check that base case, the filter is set
if not self._filter:
raise ValueError("Improperly initialized RedisFilterExpression")
return self._filter

View File

@ -0,0 +1,276 @@
import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import yaml
# ignore type error here as it's a redis-py type problem
from redis.commands.search.field import ( # type: ignore
NumericField,
TagField,
TextField,
VectorField,
)
from typing_extensions import Literal
from langchain.pydantic_v1 import BaseModel, Field, validator
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
class RedisDistanceMetric(str, Enum):
l2 = "L2"
cosine = "COSINE"
ip = "IP"
class RedisField(BaseModel):
name: str = Field(...)
class TextFieldSchema(RedisField):
weight: float = 1
no_stem: bool = False
phonetic_matcher: Optional[str] = None
withsuffixtrie: bool = False
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> TextField:
return TextField(
self.name,
weight=self.weight,
no_stem=self.no_stem,
phonetic_matcher=self.phonetic_matcher,
sortable=self.sortable,
no_index=self.no_index,
)
class TagFieldSchema(RedisField):
separator: str = ","
case_sensitive: bool = False
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> TagField:
return TagField(
self.name,
separator=self.separator,
case_sensitive=self.case_sensitive,
sortable=self.sortable,
no_index=self.no_index,
)
class NumericFieldSchema(RedisField):
no_index: bool = False
sortable: Optional[bool] = False
def as_field(self) -> NumericField:
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
class RedisVectorField(RedisField):
dims: int = Field(...)
algorithm: object = Field(...)
datatype: str = Field(default="FLOAT32")
distance_metric: RedisDistanceMetric = Field(default="COSINE")
initial_cap: int = Field(default=20000)
@validator("distance_metric", pre=True)
def uppercase_strings(cls, v: str) -> str:
return v.upper()
@validator("datatype", pre=True)
def uppercase_and_check_dtype(cls, v: str) -> str:
if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
raise ValueError(
f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
)
return v.upper()
class FlatVectorField(RedisVectorField):
algorithm: Literal["FLAT"] = "FLAT"
block_size: int = Field(default=1000)
def as_field(self) -> VectorField:
return VectorField(
self.name,
self.algorithm,
{
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
"INITIAL_CAP": self.initial_cap,
"BLOCK_SIZE": self.block_size,
},
)
class HNSWVectorField(RedisVectorField):
algorithm: Literal["HNSW"] = "HNSW"
m: int = Field(default=16)
ef_construction: int = Field(default=200)
ef_runtime: int = Field(default=10)
epsilon: float = Field(default=0.8)
def as_field(self) -> VectorField:
return VectorField(
self.name,
self.algorithm,
{
"TYPE": self.datatype,
"DIM": self.dims,
"DISTANCE_METRIC": self.distance_metric,
"INITIAL_CAP": self.initial_cap,
"M": self.m,
"EF_CONSTRUCTION": self.ef_construction,
"EF_RUNTIME": self.ef_runtime,
"EPSILON": self.epsilon,
},
)
class RedisModel(BaseModel):
# always have a content field for text
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
tag: Optional[List[TagFieldSchema]] = None
numeric: Optional[List[NumericFieldSchema]] = None
extra: Optional[List[RedisField]] = None
# filled by default_vector_schema
vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
content_key: str = "content"
content_vector_key: str = "content_vector"
def add_content_field(self) -> None:
if self.text is None:
self.text = []
for field in self.text:
if field.name == self.content_key:
return
self.text.append(TextFieldSchema(name=self.content_key))
def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
# catch case where user inputted no vector field spec
# in the index schema
if self.vector is None:
self.vector = []
# ignore types as pydantic is handling type validation and conversion
if vector_field["algorithm"] == "FLAT":
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
elif vector_field["algorithm"] == "HNSW":
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
else:
raise ValueError(
f"algorithm must be either FLAT or HNSW. Got "
f"{vector_field['algorithm']}"
)
def as_dict(self) -> Dict[str, List[Any]]:
schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
# iter over all class attributes
for attr, attr_value in self.__dict__.items():
# only non-empty lists
if isinstance(attr_value, list) and len(attr_value) > 0:
field_values: List[Dict[str, Any]] = []
# iterate over all fields in each category (tag, text, etc)
for val in attr_value:
value: Dict[str, Any] = {}
# iterate over values within each field to extract
# settings for that field (i.e. name, weight, etc)
for field, field_value in val.__dict__.items():
# make enums into strings
if isinstance(field_value, Enum):
value[field] = field_value.value
# don't write null values
elif field_value is not None:
value[field] = field_value
field_values.append(value)
schemas[attr] = field_values
schema: Dict[str, List[Any]] = {}
# only write non-empty lists from defaults
for k, v in schemas.items():
if len(v) > 0:
schema[k] = v
return schema
@property
def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
if not self.vector:
raise ValueError("No vector fields found")
for field in self.vector:
if field.name == self.content_vector_key:
return field
raise ValueError("No content_vector field found")
@property
def vector_dtype(self) -> np.dtype:
# should only ever be called after pydantic has validated the schema
return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
@property
def is_empty(self) -> bool:
return all(
field is None for field in [self.tag, self.text, self.numeric, self.vector]
)
def get_fields(self) -> List["RedisField"]:
redis_fields: List["RedisField"] = []
if self.is_empty:
return redis_fields
for field_name in self.__fields__.keys():
if field_name not in ["content_key", "content_vector_key", "extra"]:
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:
redis_fields.append(field.as_field())
return redis_fields
@property
def metadata_keys(self) -> List[str]:
keys: List[str] = []
if self.is_empty:
return keys
for field_name in self.__fields__.keys():
field_group = getattr(self, field_name)
if field_group is not None:
for field in field_group:
# check if it's a metadata field. exclude vector and content key
if not isinstance(field, str) and field.name not in [
self.content_key,
self.content_vector_key,
]:
keys.append(field.name)
return keys
def read_schema(
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
) -> Dict[str, Any]:
# check if its a dict and return RedisModel otherwise, check if it's a path and
# read in the file assuming it's a yaml file and return a RedisModel
if isinstance(index_schema, dict):
return index_schema
elif isinstance(index_schema, Path):
with open(index_schema, "rb") as f:
return yaml.safe_load(f)
elif isinstance(index_schema, str):
if Path(index_schema).resolve().is_file():
with open(index_schema, "rb") as f:
return yaml.safe_load(f)
else:
raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
else:
raise TypeError(
f"index_schema must be a dict, or path to a yaml file "
f"Got {type(index_schema)}"
)

View File

@ -1,16 +1,27 @@
"""Test Redis cache functionality."""
import uuid
from typing import List
import pytest
import langchain
from langchain.cache import RedisCache, RedisSemanticCache
from langchain.embeddings.base import Embeddings
from langchain.schema import Generation, LLMResult
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
REDIS_TEST_URL = "redis://localhost:6379"
def random_string() -> str:
return str(uuid.uuid4())
def test_redis_cache_ttl() -> None:
import redis
@ -30,12 +41,10 @@ def test_redis_cache() -> None:
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
print(output)
expected_output = LLMResult(
generations=[[Generation(text="fizz")]],
llm_output={},
)
print(expected_output)
assert output == expected_output
langchain.llm_cache.redis.flushall()
@ -80,14 +89,90 @@ def test_redis_semantic_cache() -> None:
langchain.llm_cache.clear(llm_string=llm_string)
def test_redis_semantic_cache_chat() -> None:
import redis
def test_redis_semantic_cache_multi() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
langchain.llm_cache.update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
output = llm.generate(
["bar"]
) # foo and bar will have the same embedding produced by FakeEmbeddings
expected_output = LLMResult(
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
llm_output={},
)
assert output == expected_output
# clear the cache
langchain.llm_cache.clear(llm_string=llm_string)
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
def test_redis_semantic_cache_chat() -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
)
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
with pytest.warns():
llm.predict("foo")
llm.predict("foo")
langchain.llm_cache.redis.flushall()
langchain.llm_cache.clear(llm_string=llm_string)
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
@pytest.mark.parametrize(
"prompts, generations",
[
# Single prompt, single generation
([random_string()], [[random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# Multiple prompts, multiple generations
(
[random_string(), random_string()],
[[random_string()], [random_string(), random_string()]],
),
],
ids=[
"single_prompt_single_generation",
"single_prompt_multiple_generations",
"single_prompt_multiple_generations",
"multiple_prompts_multiple_generations",
],
)
def test_redis_semantic_cache_hit(
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
) -> None:
langchain.llm_cache = RedisSemanticCache(
embedding=embedding, redis_url=REDIS_TEST_URL
)
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_generations = [
[
Generation(text=generation, generation_info=params)
for generation in prompt_i_generations
]
for prompt_i_generations in generations
]
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
print(prompt_i)
print(llm_generations_i)
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
llm.generate(prompts)
assert llm.generate(prompts) == LLMResult(
generations=llm_generations, llm_output={}
)

View File

@ -52,6 +52,7 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
if text not in self.known_texts:
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
return [float(1.0)] * (self.dimensionality - 1) + [

View File

@ -1,17 +1,28 @@
"""Test Redis functionality."""
from typing import List
import os
from typing import Any, Dict, List, Optional
import pytest
from langchain.docstore.document import Document
from langchain.vectorstores.redis import Redis
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from langchain.vectorstores.redis import (
Redis,
RedisFilter,
RedisNum,
RedisText,
)
from langchain.vectorstores.redis.filters import RedisFilterExpression
from tests.integration_tests.vectorstores.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
TEST_INDEX_NAME = "test"
TEST_REDIS_URL = "redis://localhost:6379"
TEST_SINGLE_RESULT = [Document(page_content="foo")]
TEST_SINGLE_WITH_METADATA_RESULT = [Document(page_content="foo", metadata={"a": "b"})]
TEST_SINGLE_WITH_METADATA = {"a": "b"}
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
RANGE_SCORE = pytest.approx(0.0513, abs=0.002)
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
IP_SCORE = -8.0
EUCLIDEAN_SCORE = 1.0
@ -23,6 +34,27 @@ def drop(index_name: str) -> bool:
)
def convert_bytes(data: Any) -> Any:
if isinstance(data, bytes):
return data.decode("ascii")
if isinstance(data, dict):
return dict(map(convert_bytes, data.items()))
if isinstance(data, list):
return list(map(convert_bytes, data))
if isinstance(data, tuple):
return map(convert_bytes, data)
return data
def make_dict(values: List[Any]) -> dict:
i = 0
di = {}
while i < len(values) - 1:
di[values[i]] = values[i + 1]
i += 2
return di
@pytest.fixture
def texts() -> List[str]:
return ["foo", "bar", "baz"]
@ -31,7 +63,7 @@ def texts() -> List[str]:
def test_redis(texts: List[str]) -> None:
"""Test end to end construction and search."""
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
output = docsearch.similarity_search("foo", k=1)
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
assert output == TEST_SINGLE_RESULT
assert drop(docsearch.index_name)
@ -40,30 +72,55 @@ def test_redis_new_vector(texts: List[str]) -> None:
"""Test adding a new document"""
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
assert output == TEST_RESULT
assert drop(docsearch.index_name)
def test_redis_from_existing(texts: List[str]) -> None:
"""Test adding a new document"""
Redis.from_texts(
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
)
schema: Dict = docsearch.schema
# write schema for the next test
docsearch.write_schema("test_schema.yml")
# Test creating from an existing
docsearch2 = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
FakeEmbeddings(),
index_name=TEST_INDEX_NAME,
redis_url=TEST_REDIS_URL,
schema=schema,
)
output = docsearch2.similarity_search("foo", k=1)
output = docsearch2.similarity_search("foo", k=1, return_metadata=False)
assert output == TEST_SINGLE_RESULT
def test_redis_add_texts_to_existing() -> None:
"""Test adding a new document"""
# Test creating from an existing with yaml from file
docsearch = Redis.from_existing_index(
FakeEmbeddings(),
index_name=TEST_INDEX_NAME,
redis_url=TEST_REDIS_URL,
schema="test_schema.yml",
)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
assert output == TEST_RESULT
assert drop(TEST_INDEX_NAME)
# remove the test_schema.yml file
os.remove("test_schema.yml")
def test_redis_from_texts_return_keys(texts: List[str]) -> None:
"""Test from_texts_return_keys constructor."""
docsearch, keys = Redis.from_texts_return_keys(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL
)
output = docsearch.similarity_search("foo", k=1)
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
assert output == TEST_SINGLE_RESULT
assert len(keys) == len(texts)
assert drop(docsearch.index_name)
@ -73,21 +130,124 @@ def test_redis_from_documents(texts: List[str]) -> None:
"""Test from_documents constructor."""
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
docsearch = Redis.from_documents(docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
output = docsearch.similarity_search("foo", k=1)
assert output == TEST_SINGLE_WITH_METADATA_RESULT
output = docsearch.similarity_search("foo", k=1, return_metadata=True)
assert "a" in output[0].metadata.keys()
assert "b" in output[0].metadata.values()
assert drop(docsearch.index_name)
def test_redis_add_texts_to_existing() -> None:
"""Test adding a new document"""
# Test creating from an existing
docsearch = Redis.from_existing_index(
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
# -- test filters -- #
@pytest.mark.parametrize(
"filter_expr, expected_length, expected_nums",
[
(RedisText("text") == "foo", 1, None),
(RedisFilter.text("text") == "foo", 1, None),
(RedisText("text") % "ba*", 2, ["bar", "baz"]),
(RedisNum("num") > 2, 1, [3]),
(RedisNum("num") < 2, 1, [1]),
(RedisNum("num") >= 2, 2, [2, 3]),
(RedisNum("num") <= 2, 2, [1, 2]),
(RedisNum("num") != 2, 2, [1, 3]),
(RedisFilter.num("num") != 2, 2, [1, 3]),
(RedisFilter.tag("category") == "a", 3, None),
(RedisFilter.tag("category") == "b", 2, None),
(RedisFilter.tag("category") == "c", 2, None),
(RedisFilter.tag("category") == ["b", "c"], 3, None),
],
ids=[
"text-filter-equals-foo",
"alternative-text-equals-foo",
"text-filter-fuzzy-match-ba",
"number-filter-greater-than-2",
"number-filter-less-than-2",
"number-filter-greater-equals-2",
"number-filter-less-equals-2",
"number-filter-not-equals-2",
"alternative-number-not-equals-2",
"tag-filter-equals-a",
"tag-filter-equals-b",
"tag-filter-equals-c",
"tag-filter-equals-b-or-c",
],
)
def test_redis_filters_1(
filter_expr: RedisFilterExpression,
expected_length: int,
expected_nums: Optional[list],
) -> None:
metadata = [
{"name": "joe", "num": 1, "text": "foo", "category": ["a", "b"]},
{"name": "john", "num": 2, "text": "bar", "category": ["a", "c"]},
{"name": "jane", "num": 3, "text": "baz", "category": ["b", "c", "a"]},
]
documents = [Document(page_content="foo", metadata=m) for m in metadata]
docsearch = Redis.from_documents(
documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL
)
docsearch.add_texts(["foo"])
output = docsearch.similarity_search("foo", k=2)
assert output == TEST_RESULT
assert drop(TEST_INDEX_NAME)
output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
assert len(output) == expected_length
if expected_nums is not None:
for out in output:
assert (
out.metadata["text"] in expected_nums
or int(out.metadata["num"]) in expected_nums
)
assert drop(docsearch.index_name)
# -- test index specification -- #
def test_index_specification_generation() -> None:
index_schema = {
"text": [{"name": "job"}, {"name": "title"}],
"numeric": [{"name": "salary"}],
}
text = ["foo"]
meta = {"job": "engineer", "title": "principal engineer", "salary": 100000}
docs = [Document(page_content=t, metadata=meta) for t in text]
r = Redis.from_documents(
docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, index_schema=index_schema
)
output = r.similarity_search("foo", k=1, return_metadata=True)
assert output[0].metadata["job"] == "engineer"
assert output[0].metadata["title"] == "principal engineer"
assert int(output[0].metadata["salary"]) == 100000
info = convert_bytes(r.client.ft(r.index_name).info())
attributes = info["attributes"]
assert len(attributes) == 5
for attr in attributes:
d = make_dict(attr)
if d["identifier"] == "job":
assert d["type"] == "TEXT"
elif d["identifier"] == "title":
assert d["type"] == "TEXT"
elif d["identifier"] == "salary":
assert d["type"] == "NUMERIC"
elif d["identifier"] == "content":
assert d["type"] == "TEXT"
elif d["identifier"] == "content_vector":
assert d["type"] == "VECTOR"
else:
raise ValueError("Unexpected attribute in index schema")
assert drop(r.index_name)
# -- test distance metrics -- #
cosine_schema: Dict = {"distance_metric": "cosine"}
ip_schema: Dict = {"distance_metric": "IP"}
l2_schema: Dict = {"distance_metric": "L2"}
def test_cosine(texts: List[str]) -> None:
@ -96,7 +256,7 @@ def test_cosine(texts: List[str]) -> None:
texts,
FakeEmbeddings(),
redis_url=TEST_REDIS_URL,
distance_metric="COSINE",
vector_schema=cosine_schema,
)
output = docsearch.similarity_search_with_score("far", k=2)
_, score = output[1]
@ -107,7 +267,7 @@ def test_cosine(texts: List[str]) -> None:
def test_l2(texts: List[str]) -> None:
"""Test Flat L2 distance."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=l2_schema
)
output = docsearch.similarity_search_with_score("far", k=2)
_, score = output[1]
@ -118,7 +278,7 @@ def test_l2(texts: List[str]) -> None:
def test_ip(texts: List[str]) -> None:
"""Test inner product distance."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=ip_schema
)
output = docsearch.similarity_search_with_score("far", k=2)
_, score = output[1]
@ -126,29 +286,34 @@ def test_ip(texts: List[str]) -> None:
assert drop(docsearch.index_name)
def test_similarity_search_limit_score(texts: List[str]) -> None:
def test_similarity_search_limit_distance(texts: List[str]) -> None:
"""Test similarity search limit score."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
texts,
FakeEmbeddings(),
redis_url=TEST_REDIS_URL,
)
output = docsearch.similarity_search_limit_score("far", k=2, score_threshold=0.1)
assert len(output) == 1
_, score = output[0]
assert score == COSINE_SCORE
output = docsearch.similarity_search(texts[0], k=3, distance_threshold=0.1)
# can't check score but length of output should be 2
assert len(output) == 2
assert drop(docsearch.index_name)
def test_similarity_search_with_score_with_limit_score(texts: List[str]) -> None:
def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None:
"""Test similarity search with score with limit score."""
docsearch = Redis.from_texts(
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
)
output = docsearch.similarity_search_with_relevance_scores(
"far", k=2, score_threshold=0.1
output = docsearch.similarity_search_with_score(
texts[0], k=3, distance_threshold=0.1, return_metadata=True
)
assert len(output) == 1
_, score = output[0]
assert score == COSINE_SCORE
assert len(output) == 2
for out, score in output:
if out.page_content == texts[1]:
score == COSINE_SCORE
assert drop(docsearch.index_name)
@ -156,6 +321,48 @@ def test_delete(texts: List[str]) -> None:
"""Test deleting a new document"""
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
ids = docsearch.add_texts(["foo"])
got = docsearch.delete(ids=ids)
got = docsearch.delete(ids=ids, redis_url=TEST_REDIS_URL)
assert got
assert drop(docsearch.index_name)
def test_redis_as_retriever() -> None:
texts = ["foo", "foo", "foo", "foo", "bar"]
docsearch = Redis.from_texts(
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
)
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
results = retriever.get_relevant_documents("foo")
assert len(results) == 3
assert all([d.page_content == "foo" for d in results])
assert drop(docsearch.index_name)
def test_redis_retriever_distance_threshold() -> None:
texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
retriever = docsearch.as_retriever(
search_type="similarity_distance_threshold",
search_kwargs={"k": 3, "distance_threshold": 0.1},
)
results = retriever.get_relevant_documents("foo")
assert len(results) == 2
assert drop(docsearch.index_name)
def test_redis_retriever_score_threshold() -> None:
texts = ["foo", "bar", "baz"]
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
retriever = docsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 3, "score_threshold": 0.91},
)
results = retriever.get_relevant_documents("foo")
assert len(results) == 2
assert drop(docsearch.index_name)