From d10fd02bb33d36d43d86ec6a3c7489b15edab8c8 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 11 Jun 2025 10:18:07 -0400 Subject: [PATCH] langchain[patch]: Allow specifying other hashing functions in embeddings (#31561) Allow specifying other hashing functions in embeddings --- libs/langchain/langchain/embeddings/cache.py | 114 +++++++++++++++--- .../unit_tests/embeddings/test_caching.py | 90 ++++++++++++++ 2 files changed, 190 insertions(+), 14 deletions(-) diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py index 163fd942683..401d17b6fb0 100644 --- a/libs/langchain/langchain/embeddings/cache.py +++ b/libs/langchain/langchain/embeddings/cache.py @@ -12,9 +12,9 @@ from __future__ import annotations import hashlib import json import uuid +import warnings from collections.abc import Sequence -from functools import partial -from typing import Callable, Optional, Union, cast +from typing import Callable, Literal, Optional, Union, cast from langchain_core.embeddings import Embeddings from langchain_core.stores import BaseStore, ByteStore @@ -25,20 +25,51 @@ from langchain.storage.encoder_backed import EncoderBackedStore NAMESPACE_UUID = uuid.UUID(int=1985) -def _hash_string_to_uuid(input_string: str) -> uuid.UUID: - """Hash a string and returns the corresponding UUID.""" - hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest() - return uuid.uuid5(NAMESPACE_UUID, hash_value) +def _sha1_hash_to_uuid(text: str) -> uuid.UUID: + """Return a UUID derived from *text* using SHA‑1 (deterministic). + + Deterministic and fast, **but not collision‑resistant**. + + A malicious attacker could try to create two different texts that hash to the same + UUID. This may not necessarily be an issue in the context of caching embeddings, + but new applications should swap this out for a stronger hash function like + xxHash, BLAKE2 or SHA‑256, which are collision-resistant. + """ + sha1_hex = hashlib.sha1(text.encode("utf-8")).hexdigest() + # Embed the hex string in `uuid5` to obtain a valid UUID. + return uuid.uuid5(NAMESPACE_UUID, sha1_hex) -def _key_encoder(key: str, namespace: str) -> str: - """Encode a key.""" - return namespace + str(_hash_string_to_uuid(key)) +def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str], str]: + """Create a default key encoder function. + Args: + namespace: Prefix that segregates keys from different embedding models. + algorithm: + * `sha1` - fast but not collision‑resistant + * `blake2b` - cryptographically strong, faster than SHA‑1 + * `sha256` - cryptographically strong, slower than SHA‑1 + * `sha512` - cryptographically strong, slower than SHA‑1 -def _create_key_encoder(namespace: str) -> Callable[[str], str]: - """Create an encoder for a key.""" - return partial(_key_encoder, namespace=namespace) + Returns: + A function that encodes a key using the specified algorithm. + """ + if algorithm == "sha1": + _warn_about_sha1_encoder() + + def _key_encoder(key: str) -> str: + """Encode a key using the specified algorithm.""" + if algorithm == "sha1": + return f"{namespace}{_sha1_hash_to_uuid(key)}" + if algorithm == "blake2b": + return f"{namespace}{hashlib.blake2b(key.encode('utf-8')).hexdigest()}" + if algorithm == "sha256": + return f"{namespace}{hashlib.sha256(key.encode('utf-8')).hexdigest()}" + if algorithm == "sha512": + return f"{namespace}{hashlib.sha512(key.encode('utf-8')).hexdigest()}" + raise ValueError(f"Unsupported algorithm: {algorithm}") + + return _key_encoder def _value_serializer(value: Sequence[float]) -> bytes: @@ -51,6 +82,28 @@ def _value_deserializer(serialized_value: bytes) -> list[float]: return cast(list[float], json.loads(serialized_value.decode())) +# The warning is global; track emission, so it appears only once. +_warned_about_sha1: bool = False + + +def _warn_about_sha1_encoder() -> None: + """Emit a one‑time warning about SHA‑1 collision weaknesses.""" + global _warned_about_sha1 + if not _warned_about_sha1: + warnings.warn( + "Using default key encoder: SHA‑1 is *not* collision‑resistant. " + "While acceptable for most cache scenarios, a motivated attacker " + "can craft two different payloads that map to the same cache key. " + "If that risk matters in your environment, supply a stronger " + "encoder (e.g. SHA‑256 or BLAKE2) via the `key_encoder` argument. " + "If you change the key encoder, consider also creating a new cache, " + "to avoid (the potential for) collisions with existing keys.", + category=UserWarning, + stacklevel=2, + ) + _warned_about_sha1 = True + + class CacheBackedEmbeddings(Embeddings): """Interface for caching results from embedding models. @@ -234,6 +287,9 @@ class CacheBackedEmbeddings(Embeddings): namespace: str = "", batch_size: Optional[int] = None, query_embedding_cache: Union[bool, ByteStore] = False, + key_encoder: Union[ + Callable[[str], str], Literal["sha1", "blake2b", "sha256", "sha512"] + ] = "sha1", ) -> CacheBackedEmbeddings: """On-ramp that adds the necessary serialization and encoding to the store. @@ -248,9 +304,39 @@ class CacheBackedEmbeddings(Embeddings): query_embedding_cache: The cache to use for storing query embeddings. True to use the same cache as document embeddings. False to not cache query embeddings. + key_encoder: Optional callable to encode keys. If not provided, + a default encoder using SHA‑1 will be used. SHA-1 is not + collision-resistant, and a motivated attacker could craft two + different texts that hash to the same cache key. + + New applications should use one of the alternative encoders + or provide a custom and strong key encoder function to avoid this risk. + + If you change a key encoder in an existing cache, consider + just creating a new cache, to avoid (the potential for) + collisions with existing keys or having duplicate keys + for the same text in the cache. + + Returns: + An instance of CacheBackedEmbeddings that uses the provided cache. """ - namespace = namespace - key_encoder = _create_key_encoder(namespace) + if isinstance(key_encoder, str): + key_encoder = _make_default_key_encoder(namespace, key_encoder) + elif callable(key_encoder): + # If a custom key encoder is provided, it should not be used with a + # namespace. + # A user can handle namespacing in directly their custom key encoder. + if namespace: + raise ValueError( + "Do not supply `namespace` when using a custom key_encoder; " + "add any prefixing inside the encoder itself." + ) + else: + raise ValueError( + "key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' " + "or a callable that encodes keys." + ) + document_embedding_store = EncoderBackedStore[str, list[float]]( document_embedding_cache, key_encoder, diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index b9545d052b8..7cb2d68f101 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -1,5 +1,9 @@ """Embeddings tests.""" +import hashlib +import importlib +import warnings + import pytest from langchain_core.embeddings import Embeddings @@ -146,3 +150,89 @@ async def test_aembed_query_cached( keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr] assert len(keys) == 1 assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15" + + +def test_blake2b_encoder() -> None: + """Test that the blake2b encoder is used to encode keys in the cache store.""" + store = InMemoryStore() + emb = MockEmbeddings() + cbe = CacheBackedEmbeddings.from_bytes_store( + emb, store, namespace="ns_", key_encoder="blake2b" + ) + + text = "blake" + cbe.embed_documents([text]) + + # rebuild the key exactly as the library does + expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest() + assert list(cbe.document_embedding_store.yield_keys()) == [expected_key] + + +def test_sha256_encoder() -> None: + """Test that the sha256 encoder is used to encode keys in the cache store.""" + store = InMemoryStore() + emb = MockEmbeddings() + cbe = CacheBackedEmbeddings.from_bytes_store( + emb, store, namespace="ns_", key_encoder="sha256" + ) + + text = "foo" + cbe.embed_documents([text]) + + # rebuild the key exactly as the library does + expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest() + assert list(cbe.document_embedding_store.yield_keys()) == [expected_key] + + +def test_sha512_encoder() -> None: + """Test that the sha512 encoder is used to encode keys in the cache store.""" + store = InMemoryStore() + emb = MockEmbeddings() + cbe = CacheBackedEmbeddings.from_bytes_store( + emb, store, namespace="ns_", key_encoder="sha512" + ) + + text = "foo" + cbe.embed_documents([text]) + + # rebuild the key exactly as the library does + expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest() + assert list(cbe.document_embedding_store.yield_keys()) == [expected_key] + + +def test_sha1_warning_emitted_once() -> None: + """Test that a warning is emitted when using SHA‑1 as the default key encoder.""" + module = importlib.import_module(CacheBackedEmbeddings.__module__) + + # Create a *temporary* MonkeyPatch object whose effects disappear + # automatically when the with‑block exits. + with pytest.MonkeyPatch.context() as mp: + # We're monkey patching the module to reset the `_warned_about_sha1` flag + # which may have been set while testing other parts of the codebase. + mp.setattr(module, "_warned_about_sha1", False, raising=False) + + store = InMemoryStore() + emb = MockEmbeddings() + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning + CacheBackedEmbeddings.from_bytes_store(emb, store) # silent + + sha1_msgs = [w for w in caught if "SHA‑1" in str(w.message)] + assert len(sha1_msgs) == 1 + + +def test_custom_encoder() -> None: + """Test that a custom encoder can be used to encode keys in the cache store.""" + store = InMemoryStore() + emb = MockEmbeddings() + + def custom_upper(text: str) -> str: # very simple demo encoder + return "CUSTOM_" + text.upper() + + cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper) + txt = "x" + cbe.embed_documents([txt]) + + assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]