diff --git a/libs/langchain/langchain/cache.py b/libs/langchain/langchain/cache.py index 49a23db801c..a45c4328211 100644 --- a/libs/langchain/langchain/cache.py +++ b/libs/langchain/langchain/cache.py @@ -506,7 +506,7 @@ class RedisSemanticCache(BaseCache): index_schema=cast(Dict, self.DEFAULT_SCHEMA), ) _embedding = self.embedding.embed_query(text="test") - redis._create_index(dim=len(_embedding)) + redis._create_index_if_not_exist(dim=len(_embedding)) self._cache_dict[index_name] = redis return self._cache_dict[index_name] diff --git a/libs/langchain/langchain/vectorstores/redis/base.py b/libs/langchain/langchain/vectorstores/redis/base.py index 1306ba0177a..726f571c8cf 100644 --- a/libs/langchain/langchain/vectorstores/redis/base.py +++ b/libs/langchain/langchain/vectorstores/redis/base.py @@ -54,16 +54,6 @@ if TYPE_CHECKING: from langchain.vectorstores.redis.schema import RedisModel -def _redis_key(prefix: str) -> str: - """Redis key schema for a given prefix.""" - return f"{prefix}:{uuid.uuid4().hex}" - - -def _redis_prefix(index_name: str) -> str: - """Redis key prefix for a given index.""" - return f"doc:{index_name}" - - def _default_relevance_score(val: float) -> float: return 1 - val @@ -94,6 +84,7 @@ class Redis(VectorStore): search API available. .. code-block:: bash + # to run redis stack in docker locally docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest @@ -258,6 +249,7 @@ class Redis(VectorStore): index_schema: Optional[Union[Dict[str, str], str, os.PathLike]] = None, vector_schema: Optional[Dict[str, Union[str, int]]] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, + key_prefix: Optional[str] = None, **kwargs: Any, ): """Initialize with necessary components.""" @@ -284,6 +276,7 @@ class Redis(VectorStore): self.client = redis_client self.relevance_score_fn = relevance_score_fn self._schema = self._get_schema_with_defaults(index_schema, vector_schema) + self.key_prefix = key_prefix if key_prefix is not None else f"doc:{index_name}" @property def embeddings(self) -> Optional[Embeddings]: @@ -420,14 +413,8 @@ class Redis(VectorStore): **kwargs, ) - # Create embeddings over documents - embeddings = embedding.embed_documents(texts) - - # Create the search index - instance._create_index(dim=len(embeddings[0])) - # Add data to Redis - keys = instance.add_texts(texts, metadatas, embeddings, keys=keys) + keys = instance.add_texts(texts, metadatas, keys=keys) return instance, keys @classmethod @@ -692,7 +679,6 @@ class Redis(VectorStore): List[str]: List of ids added to the vectorstore """ ids = [] - prefix = _redis_prefix(self.index_name) # Get keys or ids from kwargs # Other vectorstores use ids @@ -705,22 +691,24 @@ class Redis(VectorStore): if not (isinstance(metadatas, list) and isinstance(metadatas[0], dict)): raise ValueError("Metadatas must be a list of dicts") + embeddings = embeddings or self._embeddings.embed_documents(list(texts)) + self._create_index_if_not_exist(dim=len(embeddings[0])) + # Write data to redis pipeline = self.client.pipeline(transaction=False) for i, text in enumerate(texts): # Use provided values by default or fallback - key = keys_or_ids[i] if keys_or_ids else _redis_key(prefix) + key = keys_or_ids[i] if keys_or_ids else str(uuid.uuid4().hex) + if not key.startswith(self.key_prefix + ":"): + key = self.key_prefix + ":" + key metadata = metadatas[i] if metadatas else {} metadata = _prepare_metadata(metadata) if clean_metadata else metadata - embedding = ( - embeddings[i] if embeddings else self._embeddings.embed_query(text) - ) pipeline.hset( key, mapping={ self._schema.content_key: text, self._schema.content_vector_key: _array_to_buffer( - embedding, self._schema.vector_dtype + embeddings[i], self._schema.vector_dtype ), **metadata, }, @@ -1212,7 +1200,7 @@ class Redis(VectorStore): schema.add_vector_field(vector_field) return schema - def _create_index(self, dim: int = 1536) -> None: + def _create_index_if_not_exist(self, dim: int = 1536) -> None: try: from redis.commands.search.indexDefinition import ( # type: ignore IndexDefinition, @@ -1232,12 +1220,12 @@ class Redis(VectorStore): # Check if index exists if not check_index_exists(self.client, self.index_name): - prefix = _redis_prefix(self.index_name) - # Create Redis Index self.client.ft(self.index_name).create_index( fields=self._schema.get_fields(), - definition=IndexDefinition(prefix=[prefix], index_type=IndexType.HASH), + definition=IndexDefinition( + prefix=[self.key_prefix], index_type=IndexType.HASH + ), ) def _calculate_fp_distance(self, distance: str) -> float: