mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-13 10:26:26 +00:00
langchain[patch]: Allow specifying other hashing functions in embeddings (#31561)
Allow specifying other hashing functions in embeddings
This commit is contained in:
parent
4071670f56
commit
d10fd02bb3
@ -12,9 +12,9 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union, cast
|
||||
from typing import Callable, Literal, Optional, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
@ -25,20 +25,51 @@ from langchain.storage.encoder_backed import EncoderBackedStore
|
||||
NAMESPACE_UUID = uuid.UUID(int=1985)
|
||||
|
||||
|
||||
def _hash_string_to_uuid(input_string: str) -> uuid.UUID:
|
||||
"""Hash a string and returns the corresponding UUID."""
|
||||
hash_value = hashlib.sha1(input_string.encode("utf-8")).hexdigest()
|
||||
return uuid.uuid5(NAMESPACE_UUID, hash_value)
|
||||
def _sha1_hash_to_uuid(text: str) -> uuid.UUID:
|
||||
"""Return a UUID derived from *text* using SHA‑1 (deterministic).
|
||||
|
||||
Deterministic and fast, **but not collision‑resistant**.
|
||||
|
||||
A malicious attacker could try to create two different texts that hash to the same
|
||||
UUID. This may not necessarily be an issue in the context of caching embeddings,
|
||||
but new applications should swap this out for a stronger hash function like
|
||||
xxHash, BLAKE2 or SHA‑256, which are collision-resistant.
|
||||
"""
|
||||
sha1_hex = hashlib.sha1(text.encode("utf-8")).hexdigest()
|
||||
# Embed the hex string in `uuid5` to obtain a valid UUID.
|
||||
return uuid.uuid5(NAMESPACE_UUID, sha1_hex)
|
||||
|
||||
|
||||
def _key_encoder(key: str, namespace: str) -> str:
|
||||
"""Encode a key."""
|
||||
return namespace + str(_hash_string_to_uuid(key))
|
||||
def _make_default_key_encoder(namespace: str, algorithm: str) -> Callable[[str], str]:
|
||||
"""Create a default key encoder function.
|
||||
|
||||
Args:
|
||||
namespace: Prefix that segregates keys from different embedding models.
|
||||
algorithm:
|
||||
* `sha1` - fast but not collision‑resistant
|
||||
* `blake2b` - cryptographically strong, faster than SHA‑1
|
||||
* `sha256` - cryptographically strong, slower than SHA‑1
|
||||
* `sha512` - cryptographically strong, slower than SHA‑1
|
||||
|
||||
def _create_key_encoder(namespace: str) -> Callable[[str], str]:
|
||||
"""Create an encoder for a key."""
|
||||
return partial(_key_encoder, namespace=namespace)
|
||||
Returns:
|
||||
A function that encodes a key using the specified algorithm.
|
||||
"""
|
||||
if algorithm == "sha1":
|
||||
_warn_about_sha1_encoder()
|
||||
|
||||
def _key_encoder(key: str) -> str:
|
||||
"""Encode a key using the specified algorithm."""
|
||||
if algorithm == "sha1":
|
||||
return f"{namespace}{_sha1_hash_to_uuid(key)}"
|
||||
if algorithm == "blake2b":
|
||||
return f"{namespace}{hashlib.blake2b(key.encode('utf-8')).hexdigest()}"
|
||||
if algorithm == "sha256":
|
||||
return f"{namespace}{hashlib.sha256(key.encode('utf-8')).hexdigest()}"
|
||||
if algorithm == "sha512":
|
||||
return f"{namespace}{hashlib.sha512(key.encode('utf-8')).hexdigest()}"
|
||||
raise ValueError(f"Unsupported algorithm: {algorithm}")
|
||||
|
||||
return _key_encoder
|
||||
|
||||
|
||||
def _value_serializer(value: Sequence[float]) -> bytes:
|
||||
@ -51,6 +82,28 @@ def _value_deserializer(serialized_value: bytes) -> list[float]:
|
||||
return cast(list[float], json.loads(serialized_value.decode()))
|
||||
|
||||
|
||||
# The warning is global; track emission, so it appears only once.
|
||||
_warned_about_sha1: bool = False
|
||||
|
||||
|
||||
def _warn_about_sha1_encoder() -> None:
|
||||
"""Emit a one‑time warning about SHA‑1 collision weaknesses."""
|
||||
global _warned_about_sha1
|
||||
if not _warned_about_sha1:
|
||||
warnings.warn(
|
||||
"Using default key encoder: SHA‑1 is *not* collision‑resistant. "
|
||||
"While acceptable for most cache scenarios, a motivated attacker "
|
||||
"can craft two different payloads that map to the same cache key. "
|
||||
"If that risk matters in your environment, supply a stronger "
|
||||
"encoder (e.g. SHA‑256 or BLAKE2) via the `key_encoder` argument. "
|
||||
"If you change the key encoder, consider also creating a new cache, "
|
||||
"to avoid (the potential for) collisions with existing keys.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
_warned_about_sha1 = True
|
||||
|
||||
|
||||
class CacheBackedEmbeddings(Embeddings):
|
||||
"""Interface for caching results from embedding models.
|
||||
|
||||
@ -234,6 +287,9 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
namespace: str = "",
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_cache: Union[bool, ByteStore] = False,
|
||||
key_encoder: Union[
|
||||
Callable[[str], str], Literal["sha1", "blake2b", "sha256", "sha512"]
|
||||
] = "sha1",
|
||||
) -> CacheBackedEmbeddings:
|
||||
"""On-ramp that adds the necessary serialization and encoding to the store.
|
||||
|
||||
@ -248,9 +304,39 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
query_embedding_cache: The cache to use for storing query embeddings.
|
||||
True to use the same cache as document embeddings.
|
||||
False to not cache query embeddings.
|
||||
key_encoder: Optional callable to encode keys. If not provided,
|
||||
a default encoder using SHA‑1 will be used. SHA-1 is not
|
||||
collision-resistant, and a motivated attacker could craft two
|
||||
different texts that hash to the same cache key.
|
||||
|
||||
New applications should use one of the alternative encoders
|
||||
or provide a custom and strong key encoder function to avoid this risk.
|
||||
|
||||
If you change a key encoder in an existing cache, consider
|
||||
just creating a new cache, to avoid (the potential for)
|
||||
collisions with existing keys or having duplicate keys
|
||||
for the same text in the cache.
|
||||
|
||||
Returns:
|
||||
An instance of CacheBackedEmbeddings that uses the provided cache.
|
||||
"""
|
||||
namespace = namespace
|
||||
key_encoder = _create_key_encoder(namespace)
|
||||
if isinstance(key_encoder, str):
|
||||
key_encoder = _make_default_key_encoder(namespace, key_encoder)
|
||||
elif callable(key_encoder):
|
||||
# If a custom key encoder is provided, it should not be used with a
|
||||
# namespace.
|
||||
# A user can handle namespacing in directly their custom key encoder.
|
||||
if namespace:
|
||||
raise ValueError(
|
||||
"Do not supply `namespace` when using a custom key_encoder; "
|
||||
"add any prefixing inside the encoder itself."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"key_encoder must be either 'blake2b', 'sha1', 'sha256', 'sha512' "
|
||||
"or a callable that encodes keys."
|
||||
)
|
||||
|
||||
document_embedding_store = EncoderBackedStore[str, list[float]](
|
||||
document_embedding_cache,
|
||||
key_encoder,
|
||||
|
@ -1,5 +1,9 @@
|
||||
"""Embeddings tests."""
|
||||
|
||||
import hashlib
|
||||
import importlib
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
@ -146,3 +150,89 @@ async def test_aembed_query_cached(
|
||||
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
|
||||
assert len(keys) == 1
|
||||
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
|
||||
|
||||
|
||||
def test_blake2b_encoder() -> None:
|
||||
"""Test that the blake2b encoder is used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(
|
||||
emb, store, namespace="ns_", key_encoder="blake2b"
|
||||
)
|
||||
|
||||
text = "blake"
|
||||
cbe.embed_documents([text])
|
||||
|
||||
# rebuild the key exactly as the library does
|
||||
expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest()
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
|
||||
|
||||
|
||||
def test_sha256_encoder() -> None:
|
||||
"""Test that the sha256 encoder is used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(
|
||||
emb, store, namespace="ns_", key_encoder="sha256"
|
||||
)
|
||||
|
||||
text = "foo"
|
||||
cbe.embed_documents([text])
|
||||
|
||||
# rebuild the key exactly as the library does
|
||||
expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest()
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
|
||||
|
||||
|
||||
def test_sha512_encoder() -> None:
|
||||
"""Test that the sha512 encoder is used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(
|
||||
emb, store, namespace="ns_", key_encoder="sha512"
|
||||
)
|
||||
|
||||
text = "foo"
|
||||
cbe.embed_documents([text])
|
||||
|
||||
# rebuild the key exactly as the library does
|
||||
expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest()
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
|
||||
|
||||
|
||||
def test_sha1_warning_emitted_once() -> None:
|
||||
"""Test that a warning is emitted when using SHA‑1 as the default key encoder."""
|
||||
module = importlib.import_module(CacheBackedEmbeddings.__module__)
|
||||
|
||||
# Create a *temporary* MonkeyPatch object whose effects disappear
|
||||
# automatically when the with‑block exits.
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
# We're monkey patching the module to reset the `_warned_about_sha1` flag
|
||||
# which may have been set while testing other parts of the codebase.
|
||||
mp.setattr(module, "_warned_about_sha1", False, raising=False)
|
||||
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
|
||||
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
|
||||
|
||||
sha1_msgs = [w for w in caught if "SHA‑1" in str(w.message)]
|
||||
assert len(sha1_msgs) == 1
|
||||
|
||||
|
||||
def test_custom_encoder() -> None:
|
||||
"""Test that a custom encoder can be used to encode keys in the cache store."""
|
||||
store = InMemoryStore()
|
||||
emb = MockEmbeddings()
|
||||
|
||||
def custom_upper(text: str) -> str: # very simple demo encoder
|
||||
return "CUSTOM_" + text.upper()
|
||||
|
||||
cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper)
|
||||
txt = "x"
|
||||
cbe.embed_documents([txt])
|
||||
|
||||
assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]
|
||||
|
Loading…
Reference in New Issue
Block a user