mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
27 Commits
eugene/add
...
bagatur/re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93ec5a68fc | ||
|
|
6864df2404 | ||
|
|
dee06ff2d8 | ||
|
|
03d9a23c97 | ||
|
|
e0dcf56e28 | ||
|
|
a99288987e | ||
|
|
23a3705c6f | ||
|
|
6bbbe82088 | ||
|
|
b8060a621a | ||
|
|
48b7df56c5 | ||
|
|
8acc988c84 | ||
|
|
215d944f52 | ||
|
|
0fbff0238d | ||
|
|
8adaa7805e | ||
|
|
dce67b1251 | ||
|
|
07a2410aa6 | ||
|
|
07ba934136 | ||
|
|
eaba95c9ad | ||
|
|
7839f9fb78 | ||
|
|
90da194c30 | ||
|
|
508c0e9562 | ||
|
|
5e3937bd41 | ||
|
|
8b5b2ff2e6 | ||
|
|
7e12c63065 | ||
|
|
45935c00bd | ||
|
|
32ba2c8903 | ||
|
|
194515ff22 |
File diff suppressed because it is too large
Load Diff
@@ -33,6 +33,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -302,6 +303,14 @@ class RedisSemanticCache(BaseCache):
|
||||
|
||||
# TODO - implement a TTL policy in Redis
|
||||
|
||||
DEFAULT_SCHEMA = {
|
||||
"content_key": "prompt",
|
||||
"text": [
|
||||
{"name": "prompt"},
|
||||
],
|
||||
"extra": [{"name": "return_val"}, {"name": "llm_string"}],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, redis_url: str, embedding: Embeddings, score_threshold: float = 0.2
|
||||
):
|
||||
@@ -349,15 +358,18 @@ class RedisSemanticCache(BaseCache):
|
||||
embedding=self.embedding,
|
||||
index_name=index_name,
|
||||
redis_url=self.redis_url,
|
||||
schema=cast(Dict, self.DEFAULT_SCHEMA),
|
||||
)
|
||||
except ValueError:
|
||||
redis = RedisVectorstore(
|
||||
embedding_function=self.embedding.embed_query,
|
||||
embedding=self.embedding,
|
||||
index_name=index_name,
|
||||
redis_url=self.redis_url,
|
||||
index_schema=cast(Dict, self.DEFAULT_SCHEMA),
|
||||
)
|
||||
_embedding = self.embedding.embed_query(text="test")
|
||||
redis._create_index(dim=len(_embedding))
|
||||
print(redis.index_name)
|
||||
self._cache_dict[index_name] = redis
|
||||
|
||||
return self._cache_dict[index_name]
|
||||
@@ -374,17 +386,18 @@ class RedisSemanticCache(BaseCache):
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
generations = []
|
||||
generations: List = []
|
||||
# Read from a Hash
|
||||
results = llm_cache.similarity_search_limit_score(
|
||||
results = llm_cache.similarity_search(
|
||||
query=prompt,
|
||||
k=1,
|
||||
score_threshold=self.score_threshold,
|
||||
distance_threshold=self.score_threshold,
|
||||
)
|
||||
if results:
|
||||
for document in results:
|
||||
for text in document.metadata["return_val"]:
|
||||
generations.append(Generation(text=text))
|
||||
generations.extend(
|
||||
_load_generations_from_json(document.metadata["return_val"])
|
||||
)
|
||||
return generations if generations else None
|
||||
|
||||
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
|
||||
@@ -402,11 +415,10 @@ class RedisSemanticCache(BaseCache):
|
||||
)
|
||||
return
|
||||
llm_cache = self._get_llm_cache(llm_string)
|
||||
# Write to vectorstore
|
||||
metadata = {
|
||||
"llm_string": llm_string,
|
||||
"prompt": prompt,
|
||||
"return_val": [generation.text for generation in return_val],
|
||||
"return_val": _dump_generations_to_json([g for g in return_val]),
|
||||
}
|
||||
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.schema.messages import get_buffer_string # noqa: 401
|
||||
|
||||
|
||||
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Pattern
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.client import Redis as RedisType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes:
|
||||
return np.array(array).astype(dtype).tobytes()
|
||||
|
||||
|
||||
class TokenEscaper:
|
||||
"""
|
||||
Escape punctuation within an input string.
|
||||
"""
|
||||
|
||||
# Characters that RediSearch requires us to escape during queries.
|
||||
# Source: https://redis.io/docs/stack/search/reference/escaping/#the-rules-of-text-field-tokenization
|
||||
DEFAULT_ESCAPED_CHARS = r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/]"
|
||||
|
||||
def __init__(self, escape_chars_re: Optional[Pattern] = None):
|
||||
if escape_chars_re:
|
||||
self.escaped_chars_re = escape_chars_re
|
||||
else:
|
||||
self.escaped_chars_re = re.compile(self.DEFAULT_ESCAPED_CHARS)
|
||||
|
||||
def escape(self, value: str) -> str:
|
||||
def escape_symbol(match: re.Match) -> str:
|
||||
value = match.group(0)
|
||||
return f"\\{value}"
|
||||
|
||||
return self.escaped_chars_re.sub(escape_symbol, value)
|
||||
|
||||
|
||||
def check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||
"""Check if the correct Redis modules are installed."""
|
||||
installed_modules = client.module_list()
|
||||
installed_modules = {
|
||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||
}
|
||||
for module in required_modules:
|
||||
if module["name"] in installed_modules and int(
|
||||
installed_modules[module["name"]][b"ver"]
|
||||
) >= int(module["ver"]):
|
||||
return
|
||||
# otherwise raise error
|
||||
error_message = (
|
||||
"Redis cannot be used as a vector database without RediSearch >=2.4"
|
||||
"Please head to https://redis.io/docs/stack/search/quick_start/"
|
||||
"to know more about installing the RediSearch module within Redis Stack."
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def get_client(redis_url: str, **kwargs: Any) -> RedisType:
|
||||
|
||||
@@ -1,664 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.utilities.redis import get_client
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.client import Redis as RedisType
|
||||
from redis.commands.search.query import Query
|
||||
|
||||
|
||||
# required modules
|
||||
REDIS_REQUIRED_MODULES = [
|
||||
{"name": "search", "ver": 20400},
|
||||
{"name": "searchlight", "ver": 20400},
|
||||
]
|
||||
|
||||
# distance mmetrics
|
||||
REDIS_DISTANCE_METRICS = Literal["COSINE", "IP", "L2"]
|
||||
|
||||
|
||||
def _check_redis_module_exist(client: RedisType, required_modules: List[dict]) -> None:
|
||||
"""Check if the correct Redis modules are installed."""
|
||||
installed_modules = client.module_list()
|
||||
installed_modules = {
|
||||
module[b"name"].decode("utf-8"): module for module in installed_modules
|
||||
}
|
||||
for module in required_modules:
|
||||
if module["name"] in installed_modules and int(
|
||||
installed_modules[module["name"]][b"ver"]
|
||||
) >= int(module["ver"]):
|
||||
return
|
||||
# otherwise raise error
|
||||
error_message = (
|
||||
"Redis cannot be used as a vector database without RediSearch >=2.4"
|
||||
"Please head to https://redis.io/docs/stack/search/quick_start/"
|
||||
"to know more about installing the RediSearch module within Redis Stack."
|
||||
)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def _check_index_exists(client: RedisType, index_name: str) -> bool:
|
||||
"""Check if Redis index exists."""
|
||||
try:
|
||||
client.ft(index_name).info()
|
||||
except: # noqa: E722
|
||||
logger.info("Index does not exist")
|
||||
return False
|
||||
logger.info("Index already exists")
|
||||
return True
|
||||
|
||||
|
||||
def _redis_key(prefix: str) -> str:
|
||||
"""Redis key schema for a given prefix."""
|
||||
return f"{prefix}:{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def _redis_prefix(index_name: str) -> str:
|
||||
"""Redis key prefix for a given index."""
|
||||
return f"doc:{index_name}"
|
||||
|
||||
|
||||
def _default_relevance_score(val: float) -> float:
|
||||
return 1 - val
|
||||
|
||||
|
||||
class Redis(VectorStore):
|
||||
"""`Redis` vector store.
|
||||
|
||||
To use, you should have the ``redis`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Redis(
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
index_name="my-index",
|
||||
embedding_function=embeddings.embed_query,
|
||||
)
|
||||
|
||||
To use a redis replication setup with multiple redis server and redis sentinels
|
||||
set "redis_url" to "redis+sentinel://" scheme. With this url format a path is
|
||||
needed holding the name of the redis service within the sentinels to get the
|
||||
correct redis server connection. The default service name is "mymaster".
|
||||
|
||||
An optional username or password is used for booth connections to the rediserver
|
||||
and the sentinel, different passwords for server and sentinel are not supported.
|
||||
And as another constraint only one sentinel instance can be given:
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
vectorstore = Redis(
|
||||
redis_url="redis+sentinel://username:password@sentinelhost:26379/mymaster/0"
|
||||
index_name="my-index",
|
||||
embedding_function=embeddings.embed_query,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
index_name: str,
|
||||
embedding_function: Callable,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
try:
|
||||
redis_client = get_client(redis_url=redis_url, **kwargs)
|
||||
# check if redis has redisearch module installed
|
||||
_check_redis_module_exist(redis_client, REDIS_REQUIRED_MODULES)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Redis failed to connect: {e}")
|
||||
|
||||
self.client = redis_client
|
||||
self.content_key = content_key
|
||||
self.metadata_key = metadata_key
|
||||
self.vector_key = vector_key
|
||||
self.distance_metric = distance_metric
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
# TODO: Accept embedding object directly
|
||||
return None
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
if self.relevance_score_fn:
|
||||
return self.relevance_score_fn
|
||||
|
||||
if self.distance_metric == "COSINE":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance_metric == "IP":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_metric == "L2":
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
return _default_relevance_score
|
||||
|
||||
def _create_index(self, dim: int = 1536) -> None:
|
||||
try:
|
||||
from redis.commands.search.field import TextField, VectorField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
|
||||
# Check if index exists
|
||||
if not _check_index_exists(self.client, self.index_name):
|
||||
# Define schema
|
||||
schema = (
|
||||
TextField(name=self.content_key),
|
||||
TextField(name=self.metadata_key),
|
||||
VectorField(
|
||||
self.vector_key,
|
||||
"FLAT",
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": dim,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
},
|
||||
),
|
||||
)
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
|
||||
# Create Redis Index
|
||||
self.client.ft(self.index_name).create_index(
|
||||
fields=schema,
|
||||
definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH),
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embeddings: Optional[List[List[float]]] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add more texts to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Iterable of strings/text to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
Defaults to None.
|
||||
embeddings (Optional[List[List[float]]], optional): Optional pre-generated
|
||||
embeddings. Defaults to None.
|
||||
keys (List[str]) or ids (List[str]): Identifiers of entries.
|
||||
Defaults to None.
|
||||
batch_size (int, optional): Batch size to use for writes. Defaults to 1000.
|
||||
|
||||
Returns:
|
||||
List[str]: List of ids added to the vectorstore
|
||||
"""
|
||||
ids = []
|
||||
prefix = _redis_prefix(self.index_name)
|
||||
|
||||
# Get keys or ids from kwargs
|
||||
# Other vectorstores use ids
|
||||
keys_or_ids = kwargs.get("keys", kwargs.get("ids"))
|
||||
|
||||
# Write data to redis
|
||||
pipeline = self.client.pipeline(transaction=False)
|
||||
for i, text in enumerate(texts):
|
||||
# Use provided values by default or fallback
|
||||
key = keys_or_ids[i] if keys_or_ids else _redis_key(prefix)
|
||||
metadata = metadatas[i] if metadatas else {}
|
||||
embedding = embeddings[i] if embeddings else self.embedding_function(text)
|
||||
pipeline.hset(
|
||||
key,
|
||||
mapping={
|
||||
self.content_key: text,
|
||||
self.vector_key: np.array(embedding, dtype=np.float32).tobytes(),
|
||||
self.metadata_key: json.dumps(metadata),
|
||||
},
|
||||
)
|
||||
ids.append(key)
|
||||
|
||||
# Write batch
|
||||
if i % batch_size == 0:
|
||||
pipeline.execute()
|
||||
|
||||
# Cleanup final batch
|
||||
pipeline.execute()
|
||||
return ids
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Returns the most similar indexed documents to the query text.
|
||||
|
||||
Args:
|
||||
query (str): The query text for which to find similar documents.
|
||||
k (int): The number of documents to return. Default is 4.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query text.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search_limit_score(
|
||||
self, query: str, k: int = 4, score_threshold: float = 0.2, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Returns the most similar indexed documents to the query text within the
|
||||
score_threshold range.
|
||||
|
||||
Args:
|
||||
query (str): The query text for which to find similar documents.
|
||||
k (int): The number of documents to return. Default is 4.
|
||||
score_threshold (float): The minimum matching score required for a document
|
||||
to be considered a match. Defaults to 0.2.
|
||||
Because the similarity calculation algorithm is based on cosine
|
||||
similarity, the smaller the angle, the higher the similarity.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of documents that are most similar to the query text,
|
||||
including the match score for each document.
|
||||
|
||||
Note:
|
||||
If there are no documents that satisfy the score_threshold value,
|
||||
an empty list is returned.
|
||||
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
return [doc for doc, score in docs_and_scores if score < score_threshold]
|
||||
|
||||
def _prepare_query(self, k: int) -> Query:
|
||||
try:
|
||||
from redis.commands.search.query import Query
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
# Prepare the Query
|
||||
hybrid_fields = "*"
|
||||
base_query = (
|
||||
f"{hybrid_fields}=>[KNN {k} @{self.vector_key} $vector AS vector_score]"
|
||||
)
|
||||
return_fields = [self.metadata_key, self.content_key, "vector_score", "id"]
|
||||
return (
|
||||
Query(base_query)
|
||||
.return_fields(*return_fields)
|
||||
.sort_by("vector_score")
|
||||
.paging(0, k)
|
||||
.dialect(2)
|
||||
)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, query: str, k: int = 4
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
# Creates embedding vector from user query
|
||||
embedding = self.embedding_function(query)
|
||||
|
||||
# Creates Redis query
|
||||
redis_query = self._prepare_query(k)
|
||||
|
||||
params_dict: Mapping[str, str] = {
|
||||
"vector": np.array(embedding) # type: ignore
|
||||
.astype(dtype=np.float32)
|
||||
.tobytes()
|
||||
}
|
||||
|
||||
# Perform vector search
|
||||
results = self.client.ft(self.index_name).search(redis_query, params_dict)
|
||||
|
||||
# Prepare document results
|
||||
docs_and_scores: List[Tuple[Document, float]] = []
|
||||
for result in results.docs:
|
||||
metadata = {**json.loads(result.metadata), "id": result.id}
|
||||
doc = Document(page_content=result.content, metadata=metadata)
|
||||
docs_and_scores.append((doc, float(result.vector_score)))
|
||||
return docs_and_scores
|
||||
|
||||
@classmethod
|
||||
def from_texts_return_keys(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
index_name: Optional[str] = None,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Redis, List[str]]:
|
||||
"""Create a Redis vectorstore from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in Redis.
|
||||
3. Adds the documents to the newly created Redis index.
|
||||
4. Returns the keys of the newly created documents.
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redisearch, keys = RediSearch.from_texts_return_keys(
|
||||
texts,
|
||||
embeddings,
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
)
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
|
||||
# Name of the search index if not given
|
||||
if not index_name:
|
||||
index_name = uuid.uuid4().hex
|
||||
|
||||
# Create instance
|
||||
instance = cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
embedding.embed_query,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
distance_metric=distance_metric,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create embeddings over documents
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
|
||||
# Create the search index
|
||||
instance._create_index(dim=len(embeddings[0]))
|
||||
|
||||
# Add data to Redis
|
||||
keys = instance.add_texts(texts, metadatas, embeddings)
|
||||
return instance, keys
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Redis],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
index_name: Optional[str] = None,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
**kwargs: Any,
|
||||
) -> Redis:
|
||||
"""Create a Redis vectorstore from raw documents.
|
||||
This is a user-friendly interface that:
|
||||
1. Embeds documents.
|
||||
2. Creates a new index for the embeddings in Redis.
|
||||
3. Adds the documents to the newly created Redis index.
|
||||
|
||||
This is intended to be a quick way to get started.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Redis
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redisearch = RediSearch.from_texts(
|
||||
texts,
|
||||
embeddings,
|
||||
redis_url="redis://username:password@localhost:6379"
|
||||
)
|
||||
"""
|
||||
instance, _ = cls.from_texts_return_keys(
|
||||
texts,
|
||||
embedding,
|
||||
metadatas=metadatas,
|
||||
index_name=index_name,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
**kwargs,
|
||||
)
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a Redis entry.
|
||||
|
||||
Args:
|
||||
ids: List of ids (keys) to delete.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the deletions were successful.
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("'ids' (keys)() were not provided.")
|
||||
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Your redis connected error: {e}")
|
||||
# Check if index exists
|
||||
try:
|
||||
client.delete(*ids)
|
||||
logger.info("Entries deleted")
|
||||
return True
|
||||
except: # noqa: E722
|
||||
# ids does not exist
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def drop_index(
|
||||
index_name: str,
|
||||
delete_documents: bool,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
Drop a Redis search index.
|
||||
|
||||
Args:
|
||||
index_name (str): Name of the index to drop.
|
||||
delete_documents (bool): Whether to drop the associated documents.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the drop was successful.
|
||||
"""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Your redis connected error: {e}")
|
||||
# Check if index exists
|
||||
try:
|
||||
client.ft(index_name).dropindex(delete_documents)
|
||||
logger.info("Drop index")
|
||||
return True
|
||||
except: # noqa: E722
|
||||
# Index not exist
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_existing_index(
|
||||
cls,
|
||||
embedding: Embeddings,
|
||||
index_name: str,
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
**kwargs: Any,
|
||||
) -> Redis:
|
||||
"""Connect to an existing Redis index."""
|
||||
redis_url = get_from_dict_or_env(kwargs, "redis_url", "REDIS_URL")
|
||||
try:
|
||||
import redis # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import redis python package. "
|
||||
"Please install it with `pip install redis`."
|
||||
)
|
||||
try:
|
||||
# We need to first remove redis_url from kwargs,
|
||||
# otherwise passing it to Redis will result in an error.
|
||||
if "redis_url" in kwargs:
|
||||
kwargs.pop("redis_url")
|
||||
client = get_client(redis_url=redis_url, **kwargs)
|
||||
# check if redis has redisearch module installed
|
||||
_check_redis_module_exist(client, REDIS_REQUIRED_MODULES)
|
||||
# ensure that the index already exists
|
||||
assert _check_index_exists(
|
||||
client, index_name
|
||||
), f"Index {index_name} does not exist"
|
||||
except Exception as e:
|
||||
raise ValueError(f"Redis failed to connect: {e}")
|
||||
|
||||
return cls(
|
||||
redis_url,
|
||||
index_name,
|
||||
embedding.embed_query,
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def as_retriever(self, **kwargs: Any) -> RedisVectorStoreRetriever:
|
||||
tags = kwargs.pop("tags", None) or []
|
||||
tags.extend(self._get_retriever_tags())
|
||||
return RedisVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
|
||||
|
||||
|
||||
class RedisVectorStoreRetriever(VectorStoreRetriever):
|
||||
"""Retriever for `Redis` vector store."""
|
||||
|
||||
vectorstore: Redis
|
||||
"""Redis VectorStore."""
|
||||
search_type: str = "similarity"
|
||||
"""Type of search to perform. Can be either 'similarity' or 'similarity_limit'."""
|
||||
k: int = 4
|
||||
"""Number of documents to return."""
|
||||
score_threshold: float = 0.4
|
||||
"""Score threshold for similarity_limit search."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator()
|
||||
def validate_search_type(cls, values: Dict) -> Dict:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
if search_type not in ("similarity", "similarity_limit"):
|
||||
raise ValueError(f"search_type of {search_type} not allowed.")
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
elif self.search_type == "similarity_limit":
|
||||
docs = self.vectorstore.similarity_search_limit_score(
|
||||
query, k=self.k, score_threshold=self.score_threshold
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"search_type of {self.search_type} not allowed.")
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
|
||||
|
||||
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return self.vectorstore.add_documents(documents, **kwargs)
|
||||
|
||||
async def aadd_documents(
|
||||
self, documents: List[Document], **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Add documents to vectorstore."""
|
||||
return await self.vectorstore.aadd_documents(documents, **kwargs)
|
||||
9
libs/langchain/langchain/vectorstores/redis/__init__.py
Normal file
9
libs/langchain/langchain/vectorstores/redis/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .base import Redis
|
||||
from .filters import (
|
||||
RedisFilter,
|
||||
RedisNum,
|
||||
RedisTag,
|
||||
RedisText,
|
||||
)
|
||||
|
||||
__all__ = ["Redis", "RedisFilter", "RedisTag", "RedisText", "RedisNum"]
|
||||
1361
libs/langchain/langchain/vectorstores/redis/base.py
Normal file
1361
libs/langchain/langchain/vectorstores/redis/base.py
Normal file
File diff suppressed because it is too large
Load Diff
20
libs/langchain/langchain/vectorstores/redis/constants.py
Normal file
20
libs/langchain/langchain/vectorstores/redis/constants.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# required modules
|
||||
REDIS_REQUIRED_MODULES = [
|
||||
{"name": "search", "ver": 20600},
|
||||
{"name": "searchlight", "ver": 20600},
|
||||
]
|
||||
|
||||
# distance metrics
|
||||
REDIS_DISTANCE_METRICS: List[str] = ["COSINE", "IP", "L2"]
|
||||
|
||||
# supported vector datatypes
|
||||
REDIS_VECTOR_DTYPE_MAP: Dict[str, Any] = {
|
||||
"FLOAT32": np.float32,
|
||||
"FLOAT64": np.float64,
|
||||
}
|
||||
|
||||
REDIS_TAG_SEPARATOR = ","
|
||||
420
libs/langchain/langchain/vectorstores/redis/filters.py
Normal file
420
libs/langchain/langchain/vectorstores/redis/filters.py
Normal file
@@ -0,0 +1,420 @@
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from langchain.utilities.redis import TokenEscaper
|
||||
|
||||
# disable mypy error for dunder method overrides
|
||||
# mypy: disable-error-code="override"
|
||||
|
||||
|
||||
class RedisFilterOperator(Enum):
|
||||
EQ = 1
|
||||
NE = 2
|
||||
LT = 3
|
||||
GT = 4
|
||||
LE = 5
|
||||
GE = 6
|
||||
OR = 7
|
||||
AND = 8
|
||||
LIKE = 9
|
||||
IN = 10
|
||||
|
||||
|
||||
class RedisFilter:
|
||||
@staticmethod
|
||||
def text(field: str) -> "RedisText":
|
||||
return RedisText(field)
|
||||
|
||||
@staticmethod
|
||||
def num(field: str) -> "RedisNum":
|
||||
return RedisNum(field)
|
||||
|
||||
@staticmethod
|
||||
def tag(field: str) -> "RedisTag":
|
||||
return RedisTag(field)
|
||||
|
||||
|
||||
class RedisFilterField:
|
||||
escaper: "TokenEscaper" = TokenEscaper()
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {}
|
||||
|
||||
def __init__(self, field: str):
|
||||
self._field = field
|
||||
self._value: Any = None
|
||||
self._operator: RedisFilterOperator = RedisFilterOperator.EQ
|
||||
|
||||
def equals(self, other: "RedisFilterField") -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return False
|
||||
return self._field == other._field and self._value == other._value
|
||||
|
||||
def _set_value(
|
||||
self, val: Any, val_type: type, operator: RedisFilterOperator
|
||||
) -> None:
|
||||
# check that the operator is supported by this class
|
||||
if operator not in self.OPERATORS:
|
||||
raise ValueError(
|
||||
f"Operator {operator} not supported by {self.__class__.__name__}. "
|
||||
+ f"Supported operators are {self.OPERATORS.values()}"
|
||||
)
|
||||
|
||||
if not isinstance(val, val_type):
|
||||
raise TypeError(
|
||||
f"Right side argument passed to operator {self.OPERATORS[operator]} "
|
||||
f"with left side "
|
||||
f"argument {self.__class__.__name__} must be of type {val_type}"
|
||||
)
|
||||
self._value = val
|
||||
self._operator = operator
|
||||
|
||||
|
||||
def check_operator_misuse(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(instance: Any, *args: List[Any], **kwargs: Dict[str, Any]) -> Any:
|
||||
# Extracting 'other' from positional arguments or keyword arguments
|
||||
other = kwargs.get("other") if "other" in kwargs else None
|
||||
if not other:
|
||||
for arg in args:
|
||||
if isinstance(arg, type(instance)):
|
||||
other = arg
|
||||
break
|
||||
|
||||
if isinstance(other, type(instance)):
|
||||
raise ValueError(
|
||||
"Equality operators are overridden for FilterExpression creation. Use "
|
||||
".equals() for equality checks"
|
||||
)
|
||||
return func(instance, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RedisTag(RedisFilterField):
|
||||
"""A RedisTag is a RedisFilterField representing a tag in a Redis index."""
|
||||
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.IN: "==",
|
||||
}
|
||||
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:{%s}",
|
||||
RedisFilterOperator.NE: "(-@%s:{%s})",
|
||||
RedisFilterOperator.IN: "@%s:{%s}",
|
||||
}
|
||||
|
||||
def __init__(self, field: str):
|
||||
"""Create a RedisTag FilterField
|
||||
|
||||
Args:
|
||||
field (str): The name of the RedisTag field in the index to be queried
|
||||
against.
|
||||
"""
|
||||
super().__init__(field)
|
||||
|
||||
def _set_tag_value(
|
||||
self, other: Union[List[str], str], operator: RedisFilterOperator
|
||||
) -> None:
|
||||
if isinstance(other, list):
|
||||
if not all(isinstance(tag, str) for tag in other):
|
||||
raise ValueError("All tags must be strings")
|
||||
else:
|
||||
other = [other]
|
||||
self._set_value(other, list, operator)
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag equality filter expression
|
||||
|
||||
Args:
|
||||
other (Union[List[str], str]): The tag(s) to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisTag
|
||||
>>> filter = RedisTag("brand") == "nike"
|
||||
"""
|
||||
self._set_tag_value(other, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: Union[List[str], str]) -> "RedisFilterExpression":
|
||||
"""Create a RedisTag inequality filter expression
|
||||
|
||||
Args:
|
||||
other (Union[List[str], str]): The tag(s) to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisTag
|
||||
>>> filter = RedisTag("brand") != "nike"
|
||||
"""
|
||||
self._set_tag_value(other, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@property
|
||||
def _formatted_tag_value(self) -> str:
|
||||
return "|".join([self.escaper.escape(tag) for tag in self._value])
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
"""Return the Redis Query syntax for a RedisTag filter expression"""
|
||||
return self.OPERATOR_MAP[self._operator] % (
|
||||
self._field,
|
||||
self._formatted_tag_value,
|
||||
)
|
||||
|
||||
|
||||
class RedisNum(RedisFilterField):
|
||||
"""A RedisFilterField representing a numeric field in a Redis index."""
|
||||
|
||||
OPERATORS: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.LT: "<",
|
||||
RedisFilterOperator.GT: ">",
|
||||
RedisFilterOperator.LE: "<=",
|
||||
RedisFilterOperator.GE: ">=",
|
||||
}
|
||||
OPERATOR_MAP: Dict[RedisFilterOperator, str] = {
|
||||
RedisFilterOperator.EQ: "@%s:[%i %i]",
|
||||
RedisFilterOperator.NE: "(-@%s:[%i %i])",
|
||||
RedisFilterOperator.GT: "@%s:[(%i +inf]",
|
||||
RedisFilterOperator.LT: "@%s:[-inf (%i]",
|
||||
RedisFilterOperator.GE: "@%s:[%i +inf]",
|
||||
RedisFilterOperator.LE: "@%s:[-inf %i]",
|
||||
}
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the Redis Query syntax for a Numeric filter expression"""
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
|
||||
if (
|
||||
self._operator == RedisFilterOperator.EQ
|
||||
or self._operator == RedisFilterOperator.NE
|
||||
):
|
||||
return self.OPERATOR_MAP[self._operator] % (
|
||||
self._field,
|
||||
self._value,
|
||||
self._value,
|
||||
)
|
||||
else:
|
||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric equality filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") == 90210
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric inequality filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("zipcode") != 90210
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __gt__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a RedisNumeric greater than filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") > 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.GT)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __lt__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") < 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.LT)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __ge__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric greater than or equal to filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") >= 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.GE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __le__(self, other: int) -> "RedisFilterExpression":
|
||||
"""Create a Numeric less than or equal to filter expression
|
||||
|
||||
Args:
|
||||
other (int): The value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisNum
|
||||
>>> filter = RedisNum("age") <= 18
|
||||
"""
|
||||
self._set_value(other, int, RedisFilterOperator.LE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
|
||||
class RedisText(RedisFilterField):
|
||||
"""A RedisText is a RedisFilterField representing a text field in a Redis index."""
|
||||
|
||||
OPERATORS = {
|
||||
RedisFilterOperator.EQ: "==",
|
||||
RedisFilterOperator.NE: "!=",
|
||||
RedisFilterOperator.LIKE: "%",
|
||||
}
|
||||
OPERATOR_MAP = {
|
||||
RedisFilterOperator.EQ: '@%s:"%s"',
|
||||
RedisFilterOperator.NE: '(-@%s:"%s")',
|
||||
RedisFilterOperator.LIKE: "@%s:%s",
|
||||
}
|
||||
|
||||
@check_operator_misuse
|
||||
def __eq__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText equality filter expression
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") == "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.EQ)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
@check_operator_misuse
|
||||
def __ne__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText inequality filter expression
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") != "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.NE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __mod__(self, other: str) -> "RedisFilterExpression":
|
||||
"""Create a RedisText like filter expression
|
||||
|
||||
Args:
|
||||
other (str): The text value to filter on.
|
||||
|
||||
Example:
|
||||
>>> from langchain.vectorstores.redis import RedisText
|
||||
>>> filter = RedisText("job") % "engineer"
|
||||
"""
|
||||
self._set_value(other, str, RedisFilterOperator.LIKE)
|
||||
return RedisFilterExpression(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not self._value:
|
||||
raise ValueError(
|
||||
f"Operator must be used before calling __str__. Operators are "
|
||||
f"{self.OPERATORS.values()}"
|
||||
)
|
||||
|
||||
try:
|
||||
return self.OPERATOR_MAP[self._operator] % (self._field, self._value)
|
||||
except KeyError:
|
||||
raise Exception("Invalid operator")
|
||||
|
||||
|
||||
class RedisFilterExpression:
|
||||
"""A RedisFilterExpression is a logical expression of RedisFilterFields.
|
||||
|
||||
RedisFilterExpressions can be combined using the & and | operators to create
|
||||
complex logical expressions that evaluate to the Redis Query language.
|
||||
|
||||
This presents an interface by which users can create complex queries
|
||||
without having to know the Redis Query language.
|
||||
|
||||
Filter expressions are not initialized directly. Instead they are built
|
||||
by combining RedisFilterFields using the & and | operators.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from langchain.vectorstores.redis import RedisTag, RedisNum
|
||||
>>> brand_is_nike = RedisTag("brand") == "nike"
|
||||
>>> price_is_under_100 = RedisNum("price") < 100
|
||||
>>> filter = brand_is_nike & price_is_under_100
|
||||
>>> print(str(filter))
|
||||
(@brand:{nike} @price:[-inf (100)])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_filter: Optional[str] = None,
|
||||
operator: Optional[RedisFilterOperator] = None,
|
||||
left: Optional["RedisFilterExpression"] = None,
|
||||
right: Optional["RedisFilterExpression"] = None,
|
||||
):
|
||||
self._filter = _filter
|
||||
self._operator = operator
|
||||
self._left = left
|
||||
self._right = right
|
||||
|
||||
def __and__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
|
||||
return RedisFilterExpression(
|
||||
operator=RedisFilterOperator.AND, left=self, right=other
|
||||
)
|
||||
|
||||
def __or__(self, other: "RedisFilterExpression") -> "RedisFilterExpression":
|
||||
return RedisFilterExpression(
|
||||
operator=RedisFilterOperator.OR, left=self, right=other
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# top level check that allows recursive calls to __str__
|
||||
if not self._filter and not self._operator:
|
||||
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||
|
||||
# allow for single filter expression without operators as last
|
||||
# expression in the chain might not have an operator
|
||||
if self._operator:
|
||||
operator_str = " | " if self._operator == RedisFilterOperator.OR else " "
|
||||
return f"({str(self._left)}{operator_str}{str(self._right)})"
|
||||
|
||||
# check that base case, the filter is set
|
||||
if not self._filter:
|
||||
raise ValueError("Improperly initialized RedisFilterExpression")
|
||||
return self._filter
|
||||
276
libs/langchain/langchain/vectorstores/redis/schema.py
Normal file
276
libs/langchain/langchain/vectorstores/redis/schema.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
# ignore type error here as it's a redis-py type problem
|
||||
from redis.commands.search.field import ( # type: ignore
|
||||
NumericField,
|
||||
TagField,
|
||||
TextField,
|
||||
VectorField,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field, validator
|
||||
from langchain.vectorstores.redis.constants import REDIS_VECTOR_DTYPE_MAP
|
||||
|
||||
|
||||
class RedisDistanceMetric(str, Enum):
|
||||
l2 = "L2"
|
||||
cosine = "COSINE"
|
||||
ip = "IP"
|
||||
|
||||
|
||||
class RedisField(BaseModel):
|
||||
name: str = Field(...)
|
||||
|
||||
|
||||
class TextFieldSchema(RedisField):
|
||||
weight: float = 1
|
||||
no_stem: bool = False
|
||||
phonetic_matcher: Optional[str] = None
|
||||
withsuffixtrie: bool = False
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TextField:
|
||||
return TextField(
|
||||
self.name,
|
||||
weight=self.weight,
|
||||
no_stem=self.no_stem,
|
||||
phonetic_matcher=self.phonetic_matcher,
|
||||
sortable=self.sortable,
|
||||
no_index=self.no_index,
|
||||
)
|
||||
|
||||
|
||||
class TagFieldSchema(RedisField):
|
||||
separator: str = ","
|
||||
case_sensitive: bool = False
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> TagField:
|
||||
return TagField(
|
||||
self.name,
|
||||
separator=self.separator,
|
||||
case_sensitive=self.case_sensitive,
|
||||
sortable=self.sortable,
|
||||
no_index=self.no_index,
|
||||
)
|
||||
|
||||
|
||||
class NumericFieldSchema(RedisField):
|
||||
no_index: bool = False
|
||||
sortable: Optional[bool] = False
|
||||
|
||||
def as_field(self) -> NumericField:
|
||||
return NumericField(self.name, sortable=self.sortable, no_index=self.no_index)
|
||||
|
||||
|
||||
class RedisVectorField(RedisField):
|
||||
dims: int = Field(...)
|
||||
algorithm: object = Field(...)
|
||||
datatype: str = Field(default="FLOAT32")
|
||||
distance_metric: RedisDistanceMetric = Field(default="COSINE")
|
||||
initial_cap: int = Field(default=20000)
|
||||
|
||||
@validator("distance_metric", pre=True)
|
||||
def uppercase_strings(cls, v: str) -> str:
|
||||
return v.upper()
|
||||
|
||||
@validator("datatype", pre=True)
|
||||
def uppercase_and_check_dtype(cls, v: str) -> str:
|
||||
if v.upper() not in REDIS_VECTOR_DTYPE_MAP:
|
||||
raise ValueError(
|
||||
f"datatype must be one of {REDIS_VECTOR_DTYPE_MAP.keys()}. Got {v}"
|
||||
)
|
||||
return v.upper()
|
||||
|
||||
|
||||
class FlatVectorField(RedisVectorField):
|
||||
algorithm: Literal["FLAT"] = "FLAT"
|
||||
block_size: int = Field(default=1000)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
return VectorField(
|
||||
self.name,
|
||||
self.algorithm,
|
||||
{
|
||||
"TYPE": self.datatype,
|
||||
"DIM": self.dims,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
"INITIAL_CAP": self.initial_cap,
|
||||
"BLOCK_SIZE": self.block_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class HNSWVectorField(RedisVectorField):
|
||||
algorithm: Literal["HNSW"] = "HNSW"
|
||||
m: int = Field(default=16)
|
||||
ef_construction: int = Field(default=200)
|
||||
ef_runtime: int = Field(default=10)
|
||||
epsilon: float = Field(default=0.8)
|
||||
|
||||
def as_field(self) -> VectorField:
|
||||
return VectorField(
|
||||
self.name,
|
||||
self.algorithm,
|
||||
{
|
||||
"TYPE": self.datatype,
|
||||
"DIM": self.dims,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
"INITIAL_CAP": self.initial_cap,
|
||||
"M": self.m,
|
||||
"EF_CONSTRUCTION": self.ef_construction,
|
||||
"EF_RUNTIME": self.ef_runtime,
|
||||
"EPSILON": self.epsilon,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class RedisModel(BaseModel):
|
||||
# always have a content field for text
|
||||
text: List[TextFieldSchema] = [TextFieldSchema(name="content")]
|
||||
tag: Optional[List[TagFieldSchema]] = None
|
||||
numeric: Optional[List[NumericFieldSchema]] = None
|
||||
extra: Optional[List[RedisField]] = None
|
||||
|
||||
# filled by default_vector_schema
|
||||
vector: Optional[List[Union[FlatVectorField, HNSWVectorField]]] = None
|
||||
content_key: str = "content"
|
||||
content_vector_key: str = "content_vector"
|
||||
|
||||
def add_content_field(self) -> None:
|
||||
if self.text is None:
|
||||
self.text = []
|
||||
for field in self.text:
|
||||
if field.name == self.content_key:
|
||||
return
|
||||
self.text.append(TextFieldSchema(name=self.content_key))
|
||||
|
||||
def add_vector_field(self, vector_field: Dict[str, Any]) -> None:
|
||||
# catch case where user inputted no vector field spec
|
||||
# in the index schema
|
||||
if self.vector is None:
|
||||
self.vector = []
|
||||
|
||||
# ignore types as pydantic is handling type validation and conversion
|
||||
if vector_field["algorithm"] == "FLAT":
|
||||
self.vector.append(FlatVectorField(**vector_field)) # type: ignore
|
||||
elif vector_field["algorithm"] == "HNSW":
|
||||
self.vector.append(HNSWVectorField(**vector_field)) # type: ignore
|
||||
else:
|
||||
raise ValueError(
|
||||
f"algorithm must be either FLAT or HNSW. Got "
|
||||
f"{vector_field['algorithm']}"
|
||||
)
|
||||
|
||||
def as_dict(self) -> Dict[str, List[Any]]:
|
||||
schemas: Dict[str, List[Any]] = {"text": [], "tag": [], "numeric": []}
|
||||
# iter over all class attributes
|
||||
for attr, attr_value in self.__dict__.items():
|
||||
# only non-empty lists
|
||||
if isinstance(attr_value, list) and len(attr_value) > 0:
|
||||
field_values: List[Dict[str, Any]] = []
|
||||
# iterate over all fields in each category (tag, text, etc)
|
||||
for val in attr_value:
|
||||
value: Dict[str, Any] = {}
|
||||
# iterate over values within each field to extract
|
||||
# settings for that field (i.e. name, weight, etc)
|
||||
for field, field_value in val.__dict__.items():
|
||||
# make enums into strings
|
||||
if isinstance(field_value, Enum):
|
||||
value[field] = field_value.value
|
||||
# don't write null values
|
||||
elif field_value is not None:
|
||||
value[field] = field_value
|
||||
field_values.append(value)
|
||||
|
||||
schemas[attr] = field_values
|
||||
|
||||
schema: Dict[str, List[Any]] = {}
|
||||
# only write non-empty lists from defaults
|
||||
for k, v in schemas.items():
|
||||
if len(v) > 0:
|
||||
schema[k] = v
|
||||
return schema
|
||||
|
||||
@property
|
||||
def content_vector(self) -> Union[FlatVectorField, HNSWVectorField]:
|
||||
if not self.vector:
|
||||
raise ValueError("No vector fields found")
|
||||
for field in self.vector:
|
||||
if field.name == self.content_vector_key:
|
||||
return field
|
||||
raise ValueError("No content_vector field found")
|
||||
|
||||
@property
|
||||
def vector_dtype(self) -> np.dtype:
|
||||
# should only ever be called after pydantic has validated the schema
|
||||
return REDIS_VECTOR_DTYPE_MAP[self.content_vector.datatype]
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
return all(
|
||||
field is None for field in [self.tag, self.text, self.numeric, self.vector]
|
||||
)
|
||||
|
||||
def get_fields(self) -> List["RedisField"]:
|
||||
redis_fields: List["RedisField"] = []
|
||||
if self.is_empty:
|
||||
return redis_fields
|
||||
|
||||
for field_name in self.__fields__.keys():
|
||||
if field_name not in ["content_key", "content_vector_key", "extra"]:
|
||||
field_group = getattr(self, field_name)
|
||||
if field_group is not None:
|
||||
for field in field_group:
|
||||
redis_fields.append(field.as_field())
|
||||
return redis_fields
|
||||
|
||||
@property
|
||||
def metadata_keys(self) -> List[str]:
|
||||
keys: List[str] = []
|
||||
if self.is_empty:
|
||||
return keys
|
||||
|
||||
for field_name in self.__fields__.keys():
|
||||
field_group = getattr(self, field_name)
|
||||
if field_group is not None:
|
||||
for field in field_group:
|
||||
# check if it's a metadata field. exclude vector and content key
|
||||
if not isinstance(field, str) and field.name not in [
|
||||
self.content_key,
|
||||
self.content_vector_key,
|
||||
]:
|
||||
keys.append(field.name)
|
||||
return keys
|
||||
|
||||
|
||||
def read_schema(
|
||||
index_schema: Optional[Union[Dict[str, str], str, os.PathLike]]
|
||||
) -> Dict[str, Any]:
|
||||
# check if its a dict and return RedisModel otherwise, check if it's a path and
|
||||
# read in the file assuming it's a yaml file and return a RedisModel
|
||||
if isinstance(index_schema, dict):
|
||||
return index_schema
|
||||
elif isinstance(index_schema, Path):
|
||||
with open(index_schema, "rb") as f:
|
||||
return yaml.safe_load(f)
|
||||
elif isinstance(index_schema, str):
|
||||
if Path(index_schema).resolve().is_file():
|
||||
with open(index_schema, "rb") as f:
|
||||
return yaml.safe_load(f)
|
||||
else:
|
||||
raise FileNotFoundError(f"index_schema file {index_schema} does not exist")
|
||||
else:
|
||||
raise TypeError(
|
||||
f"index_schema must be a dict, or path to a yaml file "
|
||||
f"Got {type(index_schema)}"
|
||||
)
|
||||
@@ -1,16 +1,27 @@
|
||||
"""Test Redis cache functionality."""
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import langchain
|
||||
from langchain.cache import RedisCache, RedisSemanticCache
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
REDIS_TEST_URL = "redis://localhost:6379"
|
||||
|
||||
|
||||
def random_string() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def test_redis_cache_ttl() -> None:
|
||||
import redis
|
||||
|
||||
@@ -30,12 +41,10 @@ def test_redis_cache() -> None:
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo"])
|
||||
print(output)
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz")]],
|
||||
llm_output={},
|
||||
)
|
||||
print(expected_output)
|
||||
assert output == expected_output
|
||||
langchain.llm_cache.redis.flushall()
|
||||
|
||||
@@ -80,14 +89,84 @@ def test_redis_semantic_cache() -> None:
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
import redis
|
||||
def test_redis_semantic_cache_multi() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
langchain.llm_cache.update(
|
||||
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
|
||||
)
|
||||
output = llm.generate(
|
||||
["bar"]
|
||||
) # foo and bar will have the same embedding produced by FakeEmbeddings
|
||||
expected_output = LLMResult(
|
||||
generations=[[Generation(text="fizz"), Generation(text="Buzz")]],
|
||||
llm_output={},
|
||||
)
|
||||
assert output == expected_output
|
||||
# clear the cache
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
langchain.llm_cache = RedisCache(redis_=redis.Redis.from_url(REDIS_TEST_URL))
|
||||
|
||||
def test_redis_semantic_cache_chat() -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=FakeEmbeddings(), redis_url=REDIS_TEST_URL, score_threshold=0.1
|
||||
)
|
||||
llm = FakeChatModel()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
with pytest.warns():
|
||||
llm.predict("foo")
|
||||
llm.predict("foo")
|
||||
langchain.llm_cache.redis.flushall()
|
||||
langchain.llm_cache.clear(llm_string=llm_string)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("embedding", [ConsistentFakeEmbeddings()])
|
||||
@pytest.mark.parametrize(
|
||||
"prompts, generations",
|
||||
[
|
||||
# Single prompt, single generation
|
||||
([random_string()], [[random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string()]]),
|
||||
# Single prompt, multiple generations
|
||||
([random_string()], [[random_string(), random_string(), random_string()]]),
|
||||
# Multiple prompts, multiple generations
|
||||
(
|
||||
[random_string(), random_string()],
|
||||
[[random_string()], [random_string(), random_string()]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_redis_semantic_cache_hit(
|
||||
embedding: Embeddings, prompts: List[str], generations: List[List[str]]
|
||||
) -> None:
|
||||
langchain.llm_cache = RedisSemanticCache(
|
||||
embedding=embedding, redis_url=REDIS_TEST_URL
|
||||
)
|
||||
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
|
||||
llm_generations = [
|
||||
[
|
||||
Generation(text=generation, generation_info=params)
|
||||
for generation in prompt_i_generations
|
||||
]
|
||||
for prompt_i_generations in generations
|
||||
]
|
||||
for prompt_i, llm_generations_i in zip(prompts, llm_generations):
|
||||
print(prompt_i)
|
||||
print(llm_generations_i)
|
||||
langchain.llm_cache.update(prompt_i, llm_string, llm_generations_i)
|
||||
|
||||
assert llm.generate(prompts) == LLMResult(
|
||||
generations=llm_generations, llm_output={}
|
||||
)
|
||||
|
||||
@@ -52,6 +52,7 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return consistent embeddings for the text, if seen before, or a constant
|
||||
one if the text is unknown."""
|
||||
return self.embed_documents([text])[0]
|
||||
if text not in self.known_texts:
|
||||
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
|
||||
return [float(1.0)] * (self.dimensionality - 1) + [
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
"""Test Redis functionality."""
|
||||
from typing import List
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.redis import Redis
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
from langchain.vectorstores.redis import (
|
||||
Redis,
|
||||
RedisFilter,
|
||||
RedisNum,
|
||||
RedisText,
|
||||
)
|
||||
from langchain.vectorstores.redis.filters import RedisFilterExpression
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
TEST_INDEX_NAME = "test"
|
||||
TEST_REDIS_URL = "redis://localhost:6379"
|
||||
TEST_SINGLE_RESULT = [Document(page_content="foo")]
|
||||
TEST_SINGLE_WITH_METADATA_RESULT = [Document(page_content="foo", metadata={"a": "b"})]
|
||||
TEST_SINGLE_WITH_METADATA = {"a": "b"}
|
||||
TEST_RESULT = [Document(page_content="foo"), Document(page_content="foo")]
|
||||
RANGE_SCORE = pytest.approx(0.0513, abs=0.002)
|
||||
COSINE_SCORE = pytest.approx(0.05, abs=0.002)
|
||||
IP_SCORE = -8.0
|
||||
EUCLIDEAN_SCORE = 1.0
|
||||
@@ -23,6 +34,27 @@ def drop(index_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def convert_bytes(data: Any) -> Any:
|
||||
if isinstance(data, bytes):
|
||||
return data.decode("ascii")
|
||||
if isinstance(data, dict):
|
||||
return dict(map(convert_bytes, data.items()))
|
||||
if isinstance(data, list):
|
||||
return list(map(convert_bytes, data))
|
||||
if isinstance(data, tuple):
|
||||
return map(convert_bytes, data)
|
||||
return data
|
||||
|
||||
|
||||
def make_dict(values: List[Any]) -> dict:
|
||||
i = 0
|
||||
di = {}
|
||||
while i < len(values) - 1:
|
||||
di[values[i]] = values[i + 1]
|
||||
i += 2
|
||||
return di
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def texts() -> List[str]:
|
||||
return ["foo", "bar", "baz"]
|
||||
@@ -31,7 +63,7 @@ def texts() -> List[str]:
|
||||
def test_redis(texts: List[str]) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
@@ -40,30 +72,55 @@ def test_redis_new_vector(texts: List[str]) -> None:
|
||||
"""Test adding a new document"""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_from_existing(texts: List[str]) -> None:
|
||||
"""Test adding a new document"""
|
||||
Redis.from_texts(
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
)
|
||||
schema: Dict = docsearch.schema
|
||||
|
||||
# write schema for the next test
|
||||
docsearch.write_schema("test_schema.yml")
|
||||
|
||||
# Test creating from an existing
|
||||
docsearch2 = Redis.from_existing_index(
|
||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
FakeEmbeddings(),
|
||||
index_name=TEST_INDEX_NAME,
|
||||
redis_url=TEST_REDIS_URL,
|
||||
schema=schema,
|
||||
)
|
||||
output = docsearch2.similarity_search("foo", k=1)
|
||||
output = docsearch2.similarity_search("foo", k=1, return_metadata=False)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
|
||||
|
||||
def test_redis_add_texts_to_existing() -> None:
|
||||
"""Test adding a new document"""
|
||||
# Test creating from an existing with yaml from file
|
||||
docsearch = Redis.from_existing_index(
|
||||
FakeEmbeddings(),
|
||||
index_name=TEST_INDEX_NAME,
|
||||
redis_url=TEST_REDIS_URL,
|
||||
schema="test_schema.yml",
|
||||
)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2, return_metadata=False)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(TEST_INDEX_NAME)
|
||||
# remove the test_schema.yml file
|
||||
os.remove("test_schema.yml")
|
||||
|
||||
|
||||
def test_redis_from_texts_return_keys(texts: List[str]) -> None:
|
||||
"""Test from_texts_return_keys constructor."""
|
||||
docsearch, keys = Redis.from_texts_return_keys(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
output = docsearch.similarity_search("foo", k=1, return_metadata=False)
|
||||
assert output == TEST_SINGLE_RESULT
|
||||
assert len(keys) == len(texts)
|
||||
assert drop(docsearch.index_name)
|
||||
@@ -73,21 +130,124 @@ def test_redis_from_documents(texts: List[str]) -> None:
|
||||
"""Test from_documents constructor."""
|
||||
docs = [Document(page_content=t, metadata={"a": "b"}) for t in texts]
|
||||
docsearch = Redis.from_documents(docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == TEST_SINGLE_WITH_METADATA_RESULT
|
||||
output = docsearch.similarity_search("foo", k=1, return_metadata=True)
|
||||
assert "a" in output[0].metadata.keys()
|
||||
assert "b" in output[0].metadata.values()
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_add_texts_to_existing() -> None:
|
||||
"""Test adding a new document"""
|
||||
# Test creating from an existing
|
||||
docsearch = Redis.from_existing_index(
|
||||
FakeEmbeddings(), index_name=TEST_INDEX_NAME, redis_url=TEST_REDIS_URL
|
||||
# -- test filters -- #
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filter_expr, expected_length, expected_nums",
|
||||
[
|
||||
(RedisText("text") == "foo", 1, None),
|
||||
(RedisFilter.text("text") == "foo", 1, None),
|
||||
(RedisText("text") % "ba*", 2, ["bar", "baz"]),
|
||||
(RedisNum("num") > 2, 1, [3]),
|
||||
(RedisNum("num") < 2, 1, [1]),
|
||||
(RedisNum("num") >= 2, 2, [2, 3]),
|
||||
(RedisNum("num") <= 2, 2, [1, 2]),
|
||||
(RedisNum("num") != 2, 2, [1, 3]),
|
||||
(RedisFilter.num("num") != 2, 2, [1, 3]),
|
||||
(RedisFilter.tag("category") == "a", 3, None),
|
||||
(RedisFilter.tag("category") == "b", 2, None),
|
||||
(RedisFilter.tag("category") == "c", 2, None),
|
||||
(RedisFilter.tag("category") == ["b", "c"], 3, None),
|
||||
],
|
||||
ids=[
|
||||
"text-filter-equals-foo",
|
||||
"alternative-text-equals-foo",
|
||||
"text-filter-fuzzy-match-ba",
|
||||
"number-filter-greater-than-2",
|
||||
"number-filter-less-than-2",
|
||||
"number-filter-greater-equals-2",
|
||||
"number-filter-less-equals-2",
|
||||
"number-filter-not-equals-2",
|
||||
"alternative-number-not-equals-2",
|
||||
"tag-filter-equals-a",
|
||||
"tag-filter-equals-b",
|
||||
"tag-filter-equals-c",
|
||||
"tag-filter-equals-b-or-c",
|
||||
],
|
||||
)
|
||||
def test_redis_filters_1(
|
||||
filter_expr: RedisFilterExpression,
|
||||
expected_length: int,
|
||||
expected_nums: Optional[list],
|
||||
) -> None:
|
||||
metadata = [
|
||||
{"name": "joe", "num": 1, "text": "foo", "category": ["a", "b"]},
|
||||
{"name": "john", "num": 2, "text": "bar", "category": ["a", "c"]},
|
||||
{"name": "jane", "num": 3, "text": "baz", "category": ["b", "c", "a"]},
|
||||
]
|
||||
documents = [Document(page_content="foo", metadata=m) for m in metadata]
|
||||
docsearch = Redis.from_documents(
|
||||
documents, FakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
docsearch.add_texts(["foo"])
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == TEST_RESULT
|
||||
assert drop(TEST_INDEX_NAME)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=3, filter=filter_expr)
|
||||
|
||||
assert len(output) == expected_length
|
||||
|
||||
if expected_nums is not None:
|
||||
for out in output:
|
||||
assert (
|
||||
out.metadata["text"] in expected_nums
|
||||
or int(out.metadata["num"]) in expected_nums
|
||||
)
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
# -- test index specification -- #
|
||||
|
||||
|
||||
def test_index_specification_generation() -> None:
|
||||
index_schema = {
|
||||
"text": [{"name": "job"}, {"name": "title"}],
|
||||
"numeric": [{"name": "salary"}],
|
||||
}
|
||||
|
||||
text = ["foo"]
|
||||
meta = {"job": "engineer", "title": "principal engineer", "salary": 100000}
|
||||
docs = [Document(page_content=t, metadata=meta) for t in text]
|
||||
r = Redis.from_documents(
|
||||
docs, FakeEmbeddings(), redis_url=TEST_REDIS_URL, index_schema=index_schema
|
||||
)
|
||||
|
||||
output = r.similarity_search("foo", k=1, return_metadata=True)
|
||||
assert output[0].metadata["job"] == "engineer"
|
||||
assert output[0].metadata["title"] == "principal engineer"
|
||||
assert int(output[0].metadata["salary"]) == 100000
|
||||
|
||||
info = convert_bytes(r.client.ft(r.index_name).info())
|
||||
attributes = info["attributes"]
|
||||
assert len(attributes) == 5
|
||||
for attr in attributes:
|
||||
d = make_dict(attr)
|
||||
if d["identifier"] == "job":
|
||||
assert d["type"] == "TEXT"
|
||||
elif d["identifier"] == "title":
|
||||
assert d["type"] == "TEXT"
|
||||
elif d["identifier"] == "salary":
|
||||
assert d["type"] == "NUMERIC"
|
||||
elif d["identifier"] == "content":
|
||||
assert d["type"] == "TEXT"
|
||||
elif d["identifier"] == "content_vector":
|
||||
assert d["type"] == "VECTOR"
|
||||
else:
|
||||
raise ValueError("Unexpected attribute in index schema")
|
||||
|
||||
assert drop(r.index_name)
|
||||
|
||||
|
||||
# -- test distance metrics -- #
|
||||
|
||||
cosine_schema: Dict = {"distance_metric": "cosine"}
|
||||
ip_schema: Dict = {"distance_metric": "IP"}
|
||||
l2_schema: Dict = {"distance_metric": "L2"}
|
||||
|
||||
|
||||
def test_cosine(texts: List[str]) -> None:
|
||||
@@ -96,7 +256,7 @@ def test_cosine(texts: List[str]) -> None:
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
redis_url=TEST_REDIS_URL,
|
||||
distance_metric="COSINE",
|
||||
vector_schema=cosine_schema,
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
@@ -107,7 +267,7 @@ def test_cosine(texts: List[str]) -> None:
|
||||
def test_l2(texts: List[str]) -> None:
|
||||
"""Test Flat L2 distance."""
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="L2"
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=l2_schema
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
@@ -118,7 +278,7 @@ def test_l2(texts: List[str]) -> None:
|
||||
def test_ip(texts: List[str]) -> None:
|
||||
"""Test inner product distance."""
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="IP"
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, vector_schema=ip_schema
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("far", k=2)
|
||||
_, score = output[1]
|
||||
@@ -126,29 +286,34 @@ def test_ip(texts: List[str]) -> None:
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_similarity_search_limit_score(texts: List[str]) -> None:
|
||||
def test_similarity_search_limit_distance(texts: List[str]) -> None:
|
||||
"""Test similarity search limit score."""
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
redis_url=TEST_REDIS_URL,
|
||||
)
|
||||
output = docsearch.similarity_search_limit_score("far", k=2, score_threshold=0.1)
|
||||
assert len(output) == 1
|
||||
_, score = output[0]
|
||||
assert score == COSINE_SCORE
|
||||
output = docsearch.similarity_search(texts[0], k=3, distance_threshold=0.1)
|
||||
|
||||
# can't check score but length of output should be 2
|
||||
assert len(output) == 2
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_similarity_search_with_score_with_limit_score(texts: List[str]) -> None:
|
||||
def test_similarity_search_with_score_with_limit_distance(texts: List[str]) -> None:
|
||||
"""Test similarity search with score with limit score."""
|
||||
|
||||
docsearch = Redis.from_texts(
|
||||
texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL, distance_metric="COSINE"
|
||||
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores(
|
||||
"far", k=2, score_threshold=0.1
|
||||
output = docsearch.similarity_search_with_score(
|
||||
texts[0], k=3, distance_threshold=0.1, return_metadata=True
|
||||
)
|
||||
assert len(output) == 1
|
||||
_, score = output[0]
|
||||
assert score == COSINE_SCORE
|
||||
|
||||
assert len(output) == 2
|
||||
for out, score in output:
|
||||
if out.page_content == texts[1]:
|
||||
score == COSINE_SCORE
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
@@ -156,6 +321,48 @@ def test_delete(texts: List[str]) -> None:
|
||||
"""Test deleting a new document"""
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
ids = docsearch.add_texts(["foo"])
|
||||
got = docsearch.delete(ids=ids)
|
||||
got = docsearch.delete(ids=ids, redis_url=TEST_REDIS_URL)
|
||||
assert got
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_as_retriever() -> None:
|
||||
texts = ["foo", "foo", "foo", "foo", "bar"]
|
||||
docsearch = Redis.from_texts(
|
||||
texts, ConsistentFakeEmbeddings(), redis_url=TEST_REDIS_URL
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
||||
results = retriever.get_relevant_documents("foo")
|
||||
assert len(results) == 3
|
||||
assert all([d.page_content == "foo" for d in results])
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_retriever_distance_threshold() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_distance_threshold",
|
||||
search_kwargs={"k": 3, "distance_threshold": 0.1},
|
||||
)
|
||||
results = retriever.get_relevant_documents("foo")
|
||||
assert len(results) == 2
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
|
||||
def test_redis_retriever_score_threshold() -> None:
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Redis.from_texts(texts, FakeEmbeddings(), redis_url=TEST_REDIS_URL)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.91},
|
||||
)
|
||||
results = retriever.get_relevant_documents("foo")
|
||||
assert len(results) == 2
|
||||
|
||||
assert drop(docsearch.index_name)
|
||||
|
||||
Reference in New Issue
Block a user