diff --git a/libs/partners/groq/tests/unit_tests/test_standard.py b/libs/partners/groq/tests/unit_tests/test_standard.py index 5301d5abe4a..db84602711f 100644 --- a/libs/partners/groq/tests/unit_tests/test_standard.py +++ b/libs/partners/groq/tests/unit_tests/test_standard.py @@ -2,12 +2,10 @@ from typing import Type +import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.runnables import RunnableBinding -from langchain_standard_tests.unit_tests.chat_models import ( # type: ignore[import-not-found] +from langchain_standard_tests.unit_tests.chat_models import ( ChatModelUnitTests, - Person, - my_adder_tool, ) from langchain_groq import ChatGroq @@ -18,10 +16,6 @@ class TestGroqStandard(ChatModelUnitTests): def chat_model_class(self) -> Type[BaseChatModel]: return ChatGroq + @pytest.mark.xfail(reason="Groq does not support tool_choice='any'") def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: - """Does not currently support tool_choice='any'.""" - if not self.has_tool_calling: - return - - tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool]) - assert isinstance(tool_model, RunnableBinding) + super().test_bind_tool_pydantic(model) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py index d75cf98bc81..0465fcbc7a4 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py @@ -2,10 +2,9 @@ from typing import Type +import pytest from langchain_core.language_models import BaseChatModel -from langchain_core.runnables import RunnableBinding from langchain_standard_tests.unit_tests import ChatModelUnitTests -from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool from langchain_openai import AzureChatOpenAI @@ -23,10 +22,6 @@ class TestOpenAIStandard(ChatModelUnitTests): "azure_endpoint": "https://test.azure.com", } + @pytest.mark.xfail(reason="AzureOpenAI does not support tool_choice='any'") def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: - """Does not currently support tool_choice='any'.""" - if not self.has_tool_calling: - return - - tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool]) - assert isinstance(tool_model, RunnableBinding) + super().test_bind_tool_pydantic(model) diff --git a/libs/standard-tests/langchain_standard_tests/base.py b/libs/standard-tests/langchain_standard_tests/base.py new file mode 100644 index 00000000000..e9f71802737 --- /dev/null +++ b/libs/standard-tests/langchain_standard_tests/base.py @@ -0,0 +1,66 @@ +from abc import ABC +from typing import Type + + +class BaseStandardTests(ABC): + def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: + """ + Test that no standard tests are overridden. + """ + # find path to standard test implementations + comparison_class = None + + def explore_bases(cls: Type) -> None: + nonlocal comparison_class + for base in cls.__bases__: + if base.__module__.startswith("langchain_standard_tests."): + if comparison_class is None: + comparison_class = base + else: + raise ValueError( + "Multiple standard test base classes found: " + f"{comparison_class}, {base}" + ) + else: + explore_bases(base) + + explore_bases(self.__class__) + assert comparison_class is not None, "No standard test base class found." + + print(f"Comparing {self.__class__} to {comparison_class}") # noqa: T201 + + running_tests = set( + [method for method in dir(self) if method.startswith("test_")] + ) + base_tests = set( + [method for method in dir(comparison_class) if method.startswith("test_")] + ) + non_standard_tests = running_tests - base_tests + assert not non_standard_tests, f"Non-standard tests found: {non_standard_tests}" + deleted_tests = base_tests - running_tests + assert not deleted_tests, f"Standard tests deleted: {deleted_tests}" + + overridden_tests = [ + method + for method in running_tests + if getattr(self.__class__, method) is not getattr(comparison_class, method) + ] + + def is_xfail(method: str) -> bool: + m = getattr(self.__class__, method) + if not hasattr(m, "pytestmark"): + return False + marks = m.pytestmark + return any( + mark.name == "xfail" and mark.kwargs.get("reason") for mark in marks + ) + + overridden_not_xfail = [ + method for method in overridden_tests if not is_xfail(method) + ] + assert not overridden_not_xfail, ( + "Standard tests overridden without " + f'@pytest.mark.xfail(reason="..."): {overridden_not_xfail}\n' + "Note: reason is required to explain why the standard test has an expected " + "failure." + ) diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/base_store.py b/libs/standard-tests/langchain_standard_tests/integration_tests/base_store.py index 8f74d066a45..e4b461d9822 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/base_store.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/base_store.py @@ -1,13 +1,15 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import AsyncGenerator, Generator, Generic, Tuple, TypeVar import pytest from langchain_core.stores import BaseStore +from langchain_standard_tests.base import BaseStandardTests + V = TypeVar("V") -class BaseStoreSyncTests(ABC, Generic[V]): +class BaseStoreSyncTests(BaseStandardTests, Generic[V]): """Test suite for checking the key-value API of a BaseStore. This test suite verifies the basic key-value API of a BaseStore. @@ -138,7 +140,7 @@ class BaseStoreSyncTests(ABC, Generic[V]): assert sorted(kv_store.yield_keys(prefix="foo")) == ["foo"] -class BaseStoreAsyncTests(ABC): +class BaseStoreAsyncTests(BaseStandardTests): """Test suite for checking the key-value API of a BaseStore. This test suite verifies the basic key-value API of a BaseStore. diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/cache.py b/libs/standard-tests/langchain_standard_tests/integration_tests/cache.py index fe84d8450cf..7d1359f5154 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/cache.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/cache.py @@ -1,11 +1,13 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod import pytest from langchain_core.caches import BaseCache from langchain_core.outputs import Generation +from langchain_standard_tests.base import BaseStandardTests -class SyncCacheTestSuite(ABC): + +class SyncCacheTestSuite(BaseStandardTests): """Test suite for checking the BaseCache API of a caching layer for LLMs. This test suite verifies the basic caching API of a caching layer for LLMs. @@ -95,7 +97,7 @@ class SyncCacheTestSuite(ABC): assert cache.lookup(prompt, llm_string) == generations -class AsyncCacheTestSuite(ABC): +class AsyncCacheTestSuite(BaseStandardTests): """Test suite for checking the BaseCache API of a caching layer for LLMs. This test suite verifies the basic caching API of a caching layer for LLMs. diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py b/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py index 83770099ab3..83e76aaff80 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/vectorstores.py @@ -1,19 +1,21 @@ """Test suite to test vectostores.""" import inspect -from abc import ABC, abstractmethod +from abc import abstractmethod import pytest from langchain_core.documents import Document from langchain_core.embeddings.fake import DeterministicFakeEmbedding, Embeddings from langchain_core.vectorstores import VectorStore +from langchain_standard_tests.base import BaseStandardTests + # Arbitrarily chosen. Using a small embedding size # so tests are faster and easier to debug. EMBEDDING_SIZE = 6 -class ReadWriteTestSuite(ABC): +class ReadWriteTestSuite(BaseStandardTests): """Test suite for checking the read-write API of a vectorstore. This test suite verifies the basic read-write API of a vectorstore. @@ -201,7 +203,7 @@ class ReadWriteTestSuite(ABC): assert "ids" not in signature.parameters -class AsyncReadWriteTestSuite(ABC): +class AsyncReadWriteTestSuite(BaseStandardTests): """Test suite for checking the **async** read-write API of a vectorstore. This test suite verifies the basic read-write API of a vectorstore. diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index 138b5f54e5c..ed73771dbda 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -1,5 +1,6 @@ """Unit tests for chat models.""" -from abc import ABC, abstractmethod + +from abc import abstractmethod from typing import Any, List, Literal, Optional, Type import pytest @@ -8,6 +9,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import RunnableBinding from langchain_core.tools import tool +from langchain_standard_tests.base import BaseStandardTests from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION @@ -64,7 +66,7 @@ def my_adder(a: int, b: int) -> int: return a + b -class ChatModelTests(ABC): +class ChatModelTests(BaseStandardTests): @property @abstractmethod def chat_model_class(self) -> Type[BaseChatModel]: