langchain[patch]: Allow specifying other hashing functions in embeddings (#31561)

Allow specifying other hashing functions in embeddings
This commit is contained in:
Eugene Yurtsev 2025-06-11 10:18:07 -04:00 committed by GitHub
parent 4071670f56
commit d10fd02bb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 190 additions and 14 deletions

View File

@ -12,9 +12,9 @@ from __future__ import annotations
import hashlib
import json
import uuid
import warnings
from collections.abc import Sequence
from functools import partial
from typing import Callable, Optional, Union, cast
from typing import Callable, Literal, Optional, Union, cast
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore
@ -25,20 +25,51 @@ from langchain.storage.encoder_backed import EncoderBackedStore
NAMESPACE_UUID = uuid.UUID(int=1985)
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
"""Hash a string and returns the corresponding UUID."""
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
return uuid.uuid5(NAMESPACE_UUID, hash_value)
def _sha1_hash_to_uuid(text: str) -> uuid.UUID:
"""Return a UUID derived from *text* using SHA1 (deterministic).
Deterministic and fast, **but not collisionresistant**.
A malicious attacker could try to create two different texts that hash to the same
UUID. This may not necessarily be an issue in the context of caching embeddings,
but new applications should swap this out for a stronger hash function like
xxHash, BLAKE2 or SHA256, which are collision-resistant.
"""
sha1_hex = hashlib.sha1(text.encode("utf-8")).hexdigest()
# Embed the hex string in `uuid5` to obtain a valid UUID.
return uuid.uuid5(NAMESPACE_UUID, sha1_hex)
def _key_encoder(key: str, namespace: str) -> str:
"""Encode a key."""
return namespace + str(_hash_string_to_uuid(key))
def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str], str]:
"""Create a default key encoder function.
Args:
namespace: Prefix that segregates keys from different embedding models.
algorithm:
* `sha1` - fast but not collisionresistant
* `blake2b` - cryptographically strong, faster than SHA1
* `sha256` - cryptographically strong, slower than SHA1
* `sha512` - cryptographically strong, slower than SHA1
def _create_key_encoder(namespace: str) -> Callable[[str], str]:
"""Create an encoder for a key."""
return partial(_key_encoder, namespace=namespace)
Returns:
A function that encodes a key using the specified algorithm.
"""
if algorithm == "sha1":
_warn_about_sha1_encoder()
def _key_encoder(key: str) -> str:
"""Encode a key using the specified algorithm."""
if algorithm == "sha1":
return f"{namespace}{_sha1_hash_to_uuid(key)}"
if algorithm == "blake2b":
return f"{namespace}{hashlib.blake2b(key.encode('utf-8')).hexdigest()}"
if algorithm == "sha256":
return f"{namespace}{hashlib.sha256(key.encode('utf-8')).hexdigest()}"
if algorithm == "sha512":
return f"{namespace}{hashlib.sha512(key.encode('utf-8')).hexdigest()}"
raise ValueError(f"Unsupported algorithm: {algorithm}")
return _key_encoder
def _value_serializer(value: Sequence[float]) -> bytes:
@ -51,6 +82,28 @@ def _value_deserializer(serialized_value: bytes) -> list[float]:
return cast(list[float], json.loads(serialized_value.decode()))
# The warning is global; track emission, so it appears only once.
_warned_about_sha1: bool = False
def _warn_about_sha1_encoder() -> None:
"""Emit a onetime warning about SHA1 collision weaknesses."""
global _warned_about_sha1
if not _warned_about_sha1:
warnings.warn(
"Using default key encoder: SHA1 is *not* collisionresistant. "
"While acceptable for most cache scenarios, a motivated attacker "
"can craft two different payloads that map to the same cache key. "
"If that risk matters in your environment, supply a stronger "
"encoder (e.g. SHA256 or BLAKE2) via the `key_encoder` argument. "
"If you change the key encoder, consider also creating a new cache, "
"to avoid (the potential for) collisions with existing keys.",
category=UserWarning,
stacklevel=2,
)
_warned_about_sha1 = True
class CacheBackedEmbeddings(Embeddings):
"""Interface for caching results from embedding models.
@ -234,6 +287,9 @@ class CacheBackedEmbeddings(Embeddings):
namespace: str = "",
batch_size: Optional[int] = None,
query_embedding_cache: Union[bool, ByteStore] = False,
key_encoder: Union[
Callable[[str], str], Literal["sha1", "blake2b", "sha256", "sha512"]
] = "sha1",
) -> CacheBackedEmbeddings:
"""On-ramp that adds the necessary serialization and encoding to the store.
@ -248,9 +304,39 @@ class CacheBackedEmbeddings(Embeddings):
query_embedding_cache: The cache to use for storing query embeddings.
True to use the same cache as document embeddings.
False to not cache query embeddings.
key_encoder: Optional callable to encode keys. If not provided,
a default encoder using SHA1 will be used. SHA-1 is not
collision-resistant, and a motivated attacker could craft two
different texts that hash to the same cache key.
New applications should use one of the alternative encoders
or provide a custom and strong key encoder function to avoid this risk.
If you change a key encoder in an existing cache, consider
just creating a new cache, to avoid (the potential for)
collisions with existing keys or having duplicate keys
for the same text in the cache.
Returns:
An instance of CacheBackedEmbeddings that uses the provided cache.
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
if isinstance(key_encoder, str):
key_encoder = _make_default_key_encoder(namespace, key_encoder)
elif callable(key_encoder):
# If a custom key encoder is provided, it should not be used with a
# namespace.
# A user can handle namespacing in directly their custom key encoder.
if namespace:
raise ValueError(
"Do not supply `namespace` when using a custom key_encoder; "
"add any prefixing inside the encoder itself."
)
else:
raise ValueError(
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
"or a callable that encodes keys."
)
document_embedding_store = EncoderBackedStore[str, list[float]](
document_embedding_cache,
key_encoder,

View File

@ -1,5 +1,9 @@
"""Embeddings tests."""
import hashlib
import importlib
import warnings
import pytest
from langchain_core.embeddings import Embeddings
@ -146,3 +150,89 @@ async def test_aembed_query_cached(
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
def test_blake2b_encoder() -> None:
"""Test that the blake2b encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb, store, namespace="ns_", key_encoder="blake2b"
)
text = "blake"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha256_encoder() -> None:
"""Test that the sha256 encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb, store, namespace="ns_", key_encoder="sha256"
)
text = "foo"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha512_encoder() -> None:
"""Test that the sha512 encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb, store, namespace="ns_", key_encoder="sha512"
)
text = "foo"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha1_warning_emitted_once() -> None:
"""Test that a warning is emitted when using SHA1 as the default key encoder."""
module = importlib.import_module(CacheBackedEmbeddings.__module__)
# Create a *temporary* MonkeyPatch object whose effects disappear
# automatically when the withblock exits.
with pytest.MonkeyPatch.context() as mp:
# We're monkey patching the module to reset the `_warned_about_sha1` flag
# which may have been set while testing other parts of the codebase.
mp.setattr(module, "_warned_about_sha1", False, raising=False)
store = InMemoryStore()
emb = MockEmbeddings()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
sha1_msgs = [w for w in caught if "SHA1" in str(w.message)]
assert len(sha1_msgs) == 1
def test_custom_encoder() -> None:
"""Test that a custom encoder can be used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
def custom_upper(text: str) -> str: # very simple demo encoder
return "CUSTOM_" + text.upper()
cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper)
txt = "x"
cbe.embed_documents([txt])
assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]