chore: update branch with changes from master (#32277)

Co-authored-by: Maxime Grenu <69890511+cluster2600@users.noreply.github.com>
Co-authored-by: Claude <claude@anthropic.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: jmaillefaud <jonathan.maillefaud@evooq.ch>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: tanwirahmad <tanwirahmad@users.noreply.github.com>
Co-authored-by: Christophe Bornet <cbornet@hotmail.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: niceg <79145285+growmuye@users.noreply.github.com>
Co-authored-by: Chaitanya varma <varmac301@gmail.com>
Co-authored-by: dishaprakash <57954147+dishaprakash@users.noreply.github.com>
Co-authored-by: Chester Curme <chester.curme@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Kanav Bansal <13186335+bansalkanav@users.noreply.github.com>
Co-authored-by: Aleksandr Filippov <71711753+alex-feel@users.noreply.github.com>
Co-authored-by: Alex Feel <afilippov@spotware.com>
This commit is contained in:
Mason Daugherty
2025-07-28 10:39:41 -04:00
committed by GitHub
parent 3496e1739e
commit 5e9eb19a83
449 changed files with 16481 additions and 5327 deletions

View File

@@ -0,0 +1 @@
"""All tests for this package."""

View File

@@ -0,0 +1 @@
"""All integration tests (tests that call out to an external API)."""

View File

@@ -0,0 +1 @@
"""All integration tests for Cache objects."""

View File

@@ -0,0 +1,81 @@
"""Fake Embedding class for testing purposes."""
import math
from langchain_core.embeddings import Embeddings
fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[1.0] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> list[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [1.0] * 9 + [0.0]
async def aembed_query(self, text: str) -> list[float]:
return self.embed_query(text)
class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""
def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: list[str] = []
self.dimensionality = dimensionality
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [1.0] * (self.dimensionality - 1) + [
float(self.known_texts.index(text)),
]
out_vectors.append(vector)
return out_vectors
def embed_query(self, text: str) -> list[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
return self.embed_documents([text])[0]
class AngularTwoDimensionalEmbeddings(Embeddings):
"""
From angles (as strings in units of pi) to unit embedding vectors on a circle.
"""
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Make a list of texts into a list of embedding vectors.
"""
return [self.embed_query(text) for text in texts]
def embed_query(self, text: str) -> list[float]:
"""
Convert input text to a 'vector' (list of floats).
If the text is a number, use it as the angle for the
unit vector in units of pi.
Any other input text becomes the singular result [0, 0] !
"""
try:
angle = float(text)
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
except ValueError:
# Assume: just test string, no attention is paid to values.
return [0.0, 0.0]

View File

@@ -0,0 +1,59 @@
from typing import cast
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_tests.integration_tests import ChatModelIntegrationTests
from pydantic import BaseModel
from langchain.chat_models import init_chat_model
class multiply(BaseModel):
"""Product of two ints."""
x: int
y: int
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
async def test_init_chat_model_chain() -> None:
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
model_with_tools = model.bind_tools([multiply])
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"bar_model": "claude-3-sonnet-20240229"},
)
prompt = ChatPromptTemplate.from_messages([("system", "foo"), ("human", "{input}")])
chain = prompt | model_with_config
output = chain.invoke({"input": "bar"})
assert isinstance(output, AIMessage)
events = [
event async for event in chain.astream_events({"input": "bar"}, version="v2")
]
assert events
class TestStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> type[BaseChatModel]:
return cast("type[BaseChatModel]", init_chat_model)
@property
def chat_model_params(self) -> dict:
return {"model": "gpt-4o", "configurable_fields": "any"}
@property
def supports_image_inputs(self) -> bool:
return True
@property
def has_tool_calling(self) -> bool:
return True
@property
def has_structured_output(self) -> bool:
return True

View File

@@ -0,0 +1,34 @@
from pathlib import Path
import pytest
# Getting the absolute path of the current file's directory
ABS_PATH = Path(__file__).resolve().parent
# Getting the absolute path of the project's root directory
PROJECT_DIR = ABS_PATH.parent.parent
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = PROJECT_DIR / "tests" / "integration_tests" / ".env"
if dotenv_path.exists():
from dotenv import load_dotenv
load_dotenv(dotenv_path)
_load_env()
@pytest.fixture(scope="module")
def test_dir() -> Path:
return PROJECT_DIR / "tests" / "integration_tests"
# This fixture returns a string containing the path to the cassette directory for the
# current module
@pytest.fixture(scope="module")
def vcr_cassette_dir(request: pytest.FixtureRequest) -> str:
module = Path(request.module.__file__)
return str(module.parent / "cassettes" / module.stem)

View File

@@ -0,0 +1,44 @@
"""Test embeddings base module."""
import importlib
import pytest
from langchain_core.embeddings import Embeddings
from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings
@pytest.mark.parametrize(
("provider", "model"),
[
("openai", "text-embedding-3-large"),
("google_vertexai", "text-embedding-gecko@003"),
("bedrock", "amazon.titan-embed-text-v1"),
("cohere", "embed-english-v2.0"),
],
)
async def test_init_embedding_model(provider: str, model: str) -> None:
package = _SUPPORTED_PROVIDERS[provider]
try:
importlib.import_module(package)
except ImportError:
pytest.skip(f"Package {package} is not installed")
model_colon = init_embeddings(f"{provider}:{model}")
assert isinstance(model_colon, Embeddings)
model_explicit = init_embeddings(
model=model,
provider=provider,
)
assert isinstance(model_explicit, Embeddings)
text = "Hello world"
embedding_colon = await model_colon.aembed_query(text)
assert isinstance(embedding_colon, list)
assert all(isinstance(x, float) for x in embedding_colon)
embedding_explicit = await model_explicit.aembed_query(text)
assert isinstance(embedding_explicit, list)
assert all(isinstance(x, float) for x in embedding_explicit)

View File

@@ -0,0 +1,6 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""

View File

@@ -0,0 +1,237 @@
import os
from typing import TYPE_CHECKING, Optional
from unittest import mock
import pytest
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from pydantic import SecretStr
from langchain.chat_models import __all__, init_chat_model
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
EXPECTED_ALL = [
"init_chat_model",
"BaseChatModel",
]
def test_all_imports() -> None:
"""Test that all expected imports are present in the module's __all__."""
assert set(__all__) == set(EXPECTED_ALL)
@pytest.mark.requires(
"langchain_openai",
"langchain_anthropic",
"langchain_fireworks",
"langchain_groq",
)
@pytest.mark.parametrize(
("model_name", "model_provider"),
[
("gpt-4o", "openai"),
("claude-3-opus-20240229", "anthropic"),
("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
("mixtral-8x7b-32768", "groq"),
],
)
def test_init_chat_model(model_name: str, model_provider: Optional[str]) -> None:
llm1: BaseChatModel = init_chat_model(
model_name,
model_provider=model_provider,
api_key="foo",
)
llm2: BaseChatModel = init_chat_model(
f"{model_provider}:{model_name}",
api_key="foo",
)
assert llm1.dict() == llm2.dict()
def test_init_missing_dep() -> None:
with pytest.raises(ImportError):
init_chat_model("mixtral-8x7b-32768", model_provider="groq")
def test_init_unknown_provider() -> None:
with pytest.raises(ValueError, match="Unsupported model_provider='bar'."):
init_chat_model("foo", model_provider="bar")
@pytest.mark.requires("langchain_openai")
@mock.patch.dict(
os.environ,
{"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"},
clear=True,
)
def test_configurable() -> None:
model = init_chat_model()
for method in (
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"batch_as_completed",
"abatch_as_completed",
):
assert hasattr(model, method)
# Doesn't have access non-configurable, non-declarative methods until a config is
# provided.
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
with pytest.raises(AttributeError):
getattr(model, method)
# Can call declarative methods even without a default model.
model_with_tools = model.bind_tools(
[{"name": "foo", "description": "foo", "parameters": {}}],
)
# Check that original model wasn't mutated by declarative operation.
assert model._queued_declarative_operations == []
# Can iteratively call declarative methods.
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"model": "gpt-4o"},
)
assert model_with_config.model_name == "gpt-4o" # type: ignore[attr-defined]
for method in ("get_num_tokens", "get_num_tokens_from_messages"):
assert hasattr(model_with_config, method)
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
"name": None,
"bound": {
"name": None,
"disable_streaming": False,
"disabled_params": None,
"model_name": "gpt-4o",
"temperature": None,
"model_kwargs": {},
"openai_api_key": SecretStr("foo"),
"openai_api_base": None,
"openai_organization": None,
"openai_proxy": None,
"output_version": "v0",
"request_timeout": None,
"max_retries": None,
"presence_penalty": None,
"reasoning": None,
"reasoning_effort": None,
"frequency_penalty": None,
"include": None,
"seed": None,
"service_tier": None,
"logprobs": None,
"top_logprobs": None,
"logit_bias": None,
"streaming": False,
"n": None,
"top_p": None,
"truncation": None,
"max_tokens": None,
"tiktoken_model_name": None,
"default_headers": None,
"default_query": None,
"stop": None,
"store": None,
"extra_body": None,
"include_response_headers": False,
"stream_usage": False,
"use_previous_response_id": False,
"use_responses_api": None,
},
"kwargs": {
"tools": [
{
"type": "function",
"function": {"name": "foo", "description": "foo", "parameters": {}},
},
],
},
"config": {"tags": ["foo"], "configurable": {}},
"config_factories": [],
"custom_input_type": None,
"custom_output_type": None,
}
@pytest.mark.requires("langchain_openai", "langchain_anthropic")
@mock.patch.dict(
os.environ,
{"OPENAI_API_KEY": "foo", "ANTHROPIC_API_KEY": "bar"},
clear=True,
)
def test_configurable_with_default() -> None:
model = init_chat_model("gpt-4o", configurable_fields="any", config_prefix="bar")
for method in (
"invoke",
"ainvoke",
"batch",
"abatch",
"stream",
"astream",
"batch_as_completed",
"abatch_as_completed",
):
assert hasattr(model, method)
# Does have access non-configurable, non-declarative methods since default params
# are provided.
for method in ("get_num_tokens", "get_num_tokens_from_messages", "dict"):
assert hasattr(model, method)
assert model.model_name == "gpt-4o"
model_with_tools = model.bind_tools(
[{"name": "foo", "description": "foo", "parameters": {}}],
)
model_with_config = model_with_tools.with_config(
RunnableConfig(tags=["foo"]),
configurable={"bar_model": "claude-3-sonnet-20240229"},
)
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
"name": None,
"bound": {
"name": None,
"disable_streaming": False,
"model": "claude-3-sonnet-20240229",
"mcp_servers": None,
"max_tokens": 1024,
"temperature": None,
"thinking": None,
"top_k": None,
"top_p": None,
"default_request_timeout": None,
"max_retries": 2,
"stop_sequences": None,
"anthropic_api_url": "https://api.anthropic.com",
"anthropic_api_key": SecretStr("bar"),
"betas": None,
"default_headers": None,
"model_kwargs": {},
"streaming": False,
"stream_usage": True,
},
"kwargs": {
"tools": [{"name": "foo", "description": "foo", "input_schema": {}}],
},
"config": {"tags": ["foo"], "configurable": {}},
"config_factories": [],
"custom_input_type": None,
"custom_output_type": None,
}
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)

View File

@@ -0,0 +1,123 @@
"""Configuration for unit tests."""
from collections.abc import Iterator, Sequence
from importlib import util
import pytest
from blockbuster import blockbuster_ctx
@pytest.fixture(autouse=True)
def blockbuster() -> Iterator[None]:
with blockbuster_ctx("langchain") as bb:
bb.functions["io.TextIOWrapper.read"].can_block_in(
"langchain/__init__.py",
"<module>",
)
for func in ["os.stat", "os.path.abspath"]:
(
bb.functions[func]
.can_block_in("langchain_core/runnables/base.py", "__repr__")
.can_block_in(
"langchain_core/beta/runnables/context.py",
"aconfig_with_context",
)
)
for func in ["os.stat", "io.TextIOWrapper.read"]:
bb.functions[func].can_block_in(
"langsmith/client.py",
"_default_retry_config",
)
for bb_function in bb.functions.values():
bb_function.can_block_in(
"freezegun/api.py",
"_get_cached_module_attributes",
)
yield
def pytest_addoption(parser: pytest.Parser) -> None:
"""Add custom command line options to pytest."""
parser.addoption(
"--only-extended",
action="store_true",
help="Only run extended tests. Does not allow skipping any extended tests.",
)
parser.addoption(
"--only-core",
action="store_true",
help="Only run core tests. Never runs any extended tests.",
)
def pytest_collection_modifyitems(
config: pytest.Config, items: Sequence[pytest.Function]
) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: dict[str, bool] = {}
only_extended = config.getoption("--only-extended") or False
only_core = config.getoption("--only-core") or False
if only_extended and only_core:
msg = "Cannot specify both `--only-extended` and `--only-core`."
raise ValueError(msg)
for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
if only_core:
item.add_marker(pytest.mark.skip(reason="Skipping not a core test."))
continue
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
try:
installed = util.find_spec(pkg) is not None
except Exception:
installed = False
required_pkgs_info[pkg] = installed
if not required_pkgs_info[pkg]:
if only_extended:
pytest.fail(
f"Package `{pkg}` is not installed but is required for "
f"extended tests. Please install the given package and "
f"try again.",
)
else:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
)
break
elif only_extended:
item.add_marker(
pytest.mark.skip(reason="Skipping not an extended test."),
)

View File

@@ -0,0 +1,111 @@
"""Test embeddings base module."""
import pytest
from langchain.embeddings.base import (
_SUPPORTED_PROVIDERS,
_infer_model_and_provider,
_parse_model_string,
)
def test_parse_model_string() -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
)
def test_parse_model_string_errors() -> None:
"""Test error cases for model string parsing."""
with pytest.raises(ValueError, match="Model name must be"):
_parse_model_string("just-a-model-name")
with pytest.raises(ValueError, match="Invalid model format "):
_parse_model_string("")
with pytest.raises(ValueError, match="is not supported"):
_parse_model_string(":model-name")
with pytest.raises(ValueError, match="Model name cannot be empty"):
_parse_model_string("openai:")
with pytest.raises(
ValueError,
match="Provider 'invalid-provider' is not supported",
):
_parse_model_string("invalid-provider:model-name")
for provider in _SUPPORTED_PROVIDERS:
with pytest.raises(ValueError, match=f"{provider}"):
_parse_model_string("invalid-provider:model-name")
def test_infer_model_and_provider() -> None:
"""Test model and provider inference from different input formats."""
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _infer_model_and_provider(
model="text-embedding-3-small",
provider="openai",
) == ("openai", "text-embedding-3-small")
assert _infer_model_and_provider(
model="ft:text-embedding-3-small",
provider="openai",
) == ("openai", "ft:text-embedding-3-small")
assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
"openai",
"ft:text-embedding-3-small",
)
def test_infer_model_and_provider_errors() -> None:
"""Test error cases for model and provider inference."""
# Test missing provider
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("text-embedding-3-small")
# Test empty model
with pytest.raises(ValueError, match="Model name cannot be empty"):
_infer_model_and_provider("")
# Test empty provider with model
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("model", provider="")
# Test invalid provider
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
_infer_model_and_provider("model", provider="invalid")
# Test provider list is in error
for provider in _SUPPORTED_PROVIDERS:
assert provider in str(exc.value)
@pytest.mark.parametrize(
"provider",
sorted(_SUPPORTED_PROVIDERS.keys()),
)
def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider]
assert "-" not in package
assert package.startswith("langchain_")
assert package.islower()
def test_is_sorted() -> None:
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())

View File

@@ -0,0 +1,250 @@
"""Embeddings tests."""
import contextlib
import hashlib
import importlib
import warnings
import pytest
from langchain_core.embeddings import Embeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage.in_memory import InMemoryStore
class MockEmbeddings(Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]:
# Simulate embedding documents
embeddings: list[list[float]] = []
for text in texts:
if text == "RAISE_EXCEPTION":
msg = "Simulated embedding failure"
raise ValueError(msg)
embeddings.append([len(text), len(text) + 1])
return embeddings
def embed_query(self, text: str) -> list[float]:
# Simulate embedding a query
return [5.0, 6.0]
@pytest.fixture
def cache_embeddings() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace="test_namespace",
)
@pytest.fixture
def cache_embeddings_batch() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with a batch_size of 3."""
store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace="test_namespace",
batch_size=3,
)
@pytest.fixture
def cache_embeddings_with_query() -> CacheBackedEmbeddings:
"""Create a cache backed embeddings with query caching."""
doc_store = InMemoryStore()
query_store = InMemoryStore()
embeddings = MockEmbeddings()
return CacheBackedEmbeddings.from_bytes_store(
embeddings,
document_embedding_cache=doc_store,
namespace="test_namespace",
query_embedding_cache=query_store,
)
def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = cache_embeddings.embed_documents(texts)
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = list(cache_embeddings.document_embedding_store.yield_keys())
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
with contextlib.suppress(ValueError):
cache_embeddings_batch.embed_documents(texts)
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
def test_embed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
assert cache_embeddings.query_embedding_store is None
def test_embed_cached_query(cache_embeddings_with_query: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = cache_embeddings_with_query.embed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
async def test_aembed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
texts = ["1", "22", "a", "333"]
vectors = await cache_embeddings.aembed_documents(texts)
expected_vectors: list[list[float]] = [[1, 2.0], [2.0, 3.0], [1.0, 2.0], [3.0, 4.0]]
assert vectors == expected_vectors
keys = [
key async for key in cache_embeddings.document_embedding_store.ayield_keys()
]
assert len(keys) == 4
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_documents_batch(
cache_embeddings_batch: CacheBackedEmbeddings,
) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts)
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
]
# only the first batch of three embeddings should exist
assert len(keys) == 3
# UUID is expected to be the same for the same text
assert keys[0] == "test_namespace812b86c1-8ebf-5483-95c6-c95cf2b52d12"
async def test_aembed_query(cache_embeddings: CacheBackedEmbeddings) -> None:
text = "query_text"
vector = await cache_embeddings.aembed_query(text)
expected_vector = [5.0, 6.0]
assert vector == expected_vector
async def test_aembed_query_cached(
cache_embeddings_with_query: CacheBackedEmbeddings,
) -> None:
text = "query_text"
await cache_embeddings_with_query.aembed_query(text)
keys = list(cache_embeddings_with_query.query_embedding_store.yield_keys()) # type: ignore[union-attr]
assert len(keys) == 1
assert keys[0] == "test_namespace89ec3dae-a4d9-5636-a62e-ff3b56cdfa15"
def test_blake2b_encoder() -> None:
"""Test that the blake2b encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb,
store,
namespace="ns_",
key_encoder="blake2b",
)
text = "blake"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.blake2b(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha256_encoder() -> None:
"""Test that the sha256 encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb,
store,
namespace="ns_",
key_encoder="sha256",
)
text = "foo"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.sha256(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha512_encoder() -> None:
"""Test that the sha512 encoder is used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
cbe = CacheBackedEmbeddings.from_bytes_store(
emb,
store,
namespace="ns_",
key_encoder="sha512",
)
text = "foo"
cbe.embed_documents([text])
# rebuild the key exactly as the library does
expected_key = "ns_" + hashlib.sha512(text.encode()).hexdigest()
assert list(cbe.document_embedding_store.yield_keys()) == [expected_key]
def test_sha1_warning_emitted_once() -> None:
"""Test that a warning is emitted when using SHA-1 as the default key encoder."""
module = importlib.import_module(CacheBackedEmbeddings.__module__)
# Create a *temporary* MonkeyPatch object whose effects disappear
# automatically when the with-block exits.
with pytest.MonkeyPatch.context() as mp:
# We're monkey patching the module to reset the `_warned_about_sha1` flag
# which may have been set while testing other parts of the codebase.
mp.setattr(module, "_warned_about_sha1", False, raising=False)
store = InMemoryStore()
emb = MockEmbeddings()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
CacheBackedEmbeddings.from_bytes_store(emb, store) # triggers warning
CacheBackedEmbeddings.from_bytes_store(emb, store) # silent
sha1_msgs = [w for w in caught if "SHA-1" in str(w.message)]
assert len(sha1_msgs) == 1
def test_custom_encoder() -> None:
"""Test that a custom encoder can be used to encode keys in the cache store."""
store = InMemoryStore()
emb = MockEmbeddings()
def custom_upper(text: str) -> str: # very simple demo encoder
return "CUSTOM_" + text.upper()
cbe = CacheBackedEmbeddings.from_bytes_store(emb, store, key_encoder=custom_upper)
txt = "x"
cbe.embed_documents([txt])
assert list(cbe.document_embedding_store.yield_keys()) == ["CUSTOM_X"]

View File

@@ -0,0 +1,11 @@
from langchain import embeddings
EXPECTED_ALL = [
"CacheBackedEmbeddings",
"Embeddings",
"init_embeddings",
]
def test_all_imports() -> None:
assert set(embeddings.__all__) == set(EXPECTED_ALL)

View File

@@ -0,0 +1,12 @@
from langchain import storage
EXPECTED_ALL = [
"EncoderBackedStore",
"InMemoryStore",
"InMemoryByteStore",
"InvalidKeyException",
]
def test_all_imports() -> None:
assert set(storage.__all__) == set(EXPECTED_ALL)

View File

@@ -0,0 +1,46 @@
from typing import Any
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
class AnyStr(str):
__slots__ = ()
def __eq__(self, other: object) -> bool:
return isinstance(other, str)
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field
# Please note that the `id` field is assigned AFTER the model is created
# to workaround an issue with pydantic ignoring the __eq__ method on
# subclassed strings.
def _AnyIdDocument(**kwargs: Any) -> Document:
"""Create a document with an id field."""
message = Document(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessage(**kwargs: Any) -> AIMessage:
"""Create ai message with an any id field."""
message = AIMessage(**kwargs)
message.id = AnyStr()
return message
def _AnyIdAIMessageChunk(**kwargs: Any) -> AIMessageChunk:
"""Create ai message with an any id field."""
message = AIMessageChunk(**kwargs)
message.id = AnyStr()
return message
def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage:
"""Create a human with an any id field."""
message = HumanMessage(**kwargs)
message.id = AnyStr()
return message

View File

@@ -0,0 +1,35 @@
"""A unit test meant to catch accidental introduction of non-optional dependencies."""
from collections.abc import Mapping
from pathlib import Path
from typing import Any
import pytest
import toml
from packaging.requirements import Requirement
HERE = Path(__file__).parent
PYPROJECT_TOML = HERE / "../../pyproject.toml"
@pytest.fixture
def uv_conf() -> dict[str, Any]:
"""Load the pyproject.toml file."""
with PYPROJECT_TOML.open() as f:
return toml.load(f)
def test_required_dependencies(uv_conf: Mapping[str, Any]) -> None:
"""A test that checks if a new non-optional dependency is being introduced.
If this test is triggered, it means that a contributor is trying to introduce a new
required dependency. This should be avoided in most situations.
"""
# Get the dependencies from the [tool.poetry.dependencies] section
dependencies = uv_conf["project"]["dependencies"]
required_dependencies = {Requirement(dep).name for dep in dependencies}
assert sorted(required_dependencies) == sorted(
["langchain-core", "langchain-text-splitters", "langgraph", "pydantic"]
)

View File

@@ -0,0 +1,60 @@
import importlib
import warnings
from pathlib import Path
# Attempt to recursively import all modules in langchain
PKG_ROOT = Path(__file__).parent.parent.parent
def test_import_all() -> None:
"""Generate the public API for this package."""
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=UserWarning)
library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"):
# Calculate the relative path to the module
module_name = (
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
if module_name.endswith("__init__"):
# Without init
module_name = module_name.rsplit(".", 1)[0]
mod = importlib.import_module(module_name)
all_attrs = getattr(mod, "__all__", [])
for name in all_attrs:
# Attempt to import the name from the module
try:
obj = getattr(mod, name)
assert obj is not None
except Exception as e:
msg = f"Could not import {module_name}.{name}"
raise AssertionError(msg) from e
def test_import_all_using_dir() -> None:
"""Generate the public API for this package."""
library_code = PKG_ROOT / "langchain"
for path in library_code.rglob("*.py"):
# Calculate the relative path to the module
module_name = (
path.relative_to(PKG_ROOT).with_suffix("").as_posix().replace("/", ".")
)
if module_name.endswith("__init__"):
# Without init
module_name = module_name.rsplit(".", 1)[0]
try:
mod = importlib.import_module(module_name)
except ModuleNotFoundError as e:
msg = f"Could not import {module_name}"
raise ModuleNotFoundError(msg) from e
attributes = dir(mod)
for name in attributes:
if name.strip().startswith("_"):
continue
# Attempt to import the name from the module
getattr(mod, name)

View File

@@ -0,0 +1,11 @@
import pytest
import pytest_socket
import requests
def test_socket_disabled() -> None:
"""This test should fail."""
with pytest.raises(pytest_socket.SocketBlockedError):
# Ignore S113 since we don't need a timeout here as the request
# should fail immediately
requests.get("https://www.example.com") # noqa: S113

View File

@@ -0,0 +1,13 @@
from langchain import tools
EXPECTED_ALL = {
"BaseTool",
"InjectedToolArg",
"InjectedToolCallId",
"ToolException",
"tool",
}
def test_all_imports() -> None:
assert set(tools.__all__) == EXPECTED_ALL