standard-tests: add override check (#24407)

This commit is contained in:
Erick Friis 2024-07-22 16:38:01 -07:00 committed by GitHub
parent 1639ccfd15
commit 2c6b9e8771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 92 additions and 29 deletions

View File

@ -2,12 +2,10 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import RunnableBinding from langchain_standard_tests.unit_tests.chat_models import (
from langchain_standard_tests.unit_tests.chat_models import ( # type: ignore[import-not-found]
ChatModelUnitTests, ChatModelUnitTests,
Person,
my_adder_tool,
) )
from langchain_groq import ChatGroq from langchain_groq import ChatGroq
@ -18,10 +16,6 @@ class TestGroqStandard(ChatModelUnitTests):
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq return ChatGroq
@pytest.mark.xfail(reason="Groq does not support tool_choice='any'")
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
"""Does not currently support tool_choice='any'.""" super().test_bind_tool_pydantic(model)
if not self.has_tool_calling:
return
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
assert isinstance(tool_model, RunnableBinding)

View File

@ -2,10 +2,9 @@
from typing import Type from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import RunnableBinding
from langchain_standard_tests.unit_tests import ChatModelUnitTests from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
@ -23,10 +22,6 @@ class TestOpenAIStandard(ChatModelUnitTests):
"azure_endpoint": "https://test.azure.com", "azure_endpoint": "https://test.azure.com",
} }
@pytest.mark.xfail(reason="AzureOpenAI does not support tool_choice='any'")
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
"""Does not currently support tool_choice='any'.""" super().test_bind_tool_pydantic(model)
if not self.has_tool_calling:
return
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
assert isinstance(tool_model, RunnableBinding)

View File

@ -0,0 +1,66 @@
from abc import ABC
from typing import Type
class BaseStandardTests(ABC):
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None:
"""
Test that no standard tests are overridden.
"""
# find path to standard test implementations
comparison_class = None
def explore_bases(cls: Type) -> None:
nonlocal comparison_class
for base in cls.__bases__:
if base.__module__.startswith("langchain_standard_tests."):
if comparison_class is None:
comparison_class = base
else:
raise ValueError(
"Multiple standard test base classes found: "
f"{comparison_class}, {base}"
)
else:
explore_bases(base)
explore_bases(self.__class__)
assert comparison_class is not None, "No standard test base class found."
print(f"Comparing {self.__class__} to {comparison_class}") # noqa: T201
running_tests = set(
[method for method in dir(self) if method.startswith("test_")]
)
base_tests = set(
[method for method in dir(comparison_class) if method.startswith("test_")]
)
non_standard_tests = running_tests - base_tests
assert not non_standard_tests, f"Non-standard tests found: {non_standard_tests}"
deleted_tests = base_tests - running_tests
assert not deleted_tests, f"Standard tests deleted: {deleted_tests}"
overridden_tests = [
method
for method in running_tests
if getattr(self.__class__, method) is not getattr(comparison_class, method)
]
def is_xfail(method: str) -> bool:
m = getattr(self.__class__, method)
if not hasattr(m, "pytestmark"):
return False
marks = m.pytestmark
return any(
mark.name == "xfail" and mark.kwargs.get("reason") for mark in marks
)
overridden_not_xfail = [
method for method in overridden_tests if not is_xfail(method)
]
assert not overridden_not_xfail, (
"Standard tests overridden without "
f'@pytest.mark.xfail(reason="..."): {overridden_not_xfail}\n'
"Note: reason is required to explain why the standard test has an expected "
"failure."
)

View File

@ -1,13 +1,15 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import AsyncGenerator, Generator, Generic, Tuple, TypeVar from typing import AsyncGenerator, Generator, Generic, Tuple, TypeVar
import pytest import pytest
from langchain_core.stores import BaseStore from langchain_core.stores import BaseStore
from langchain_standard_tests.base import BaseStandardTests
V = TypeVar("V") V = TypeVar("V")
class BaseStoreSyncTests(ABC, Generic[V]): class BaseStoreSyncTests(BaseStandardTests, Generic[V]):
"""Test suite for checking the key-value API of a BaseStore. """Test suite for checking the key-value API of a BaseStore.
This test suite verifies the basic key-value API of a BaseStore. This test suite verifies the basic key-value API of a BaseStore.
@ -138,7 +140,7 @@ class BaseStoreSyncTests(ABC, Generic[V]):
assert sorted(kv_store.yield_keys(prefix="foo")) == ["foo"] assert sorted(kv_store.yield_keys(prefix="foo")) == ["foo"]
class BaseStoreAsyncTests(ABC): class BaseStoreAsyncTests(BaseStandardTests):
"""Test suite for checking the key-value API of a BaseStore. """Test suite for checking the key-value API of a BaseStore.
This test suite verifies the basic key-value API of a BaseStore. This test suite verifies the basic key-value API of a BaseStore.

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod from abc import abstractmethod
import pytest import pytest
from langchain_core.caches import BaseCache from langchain_core.caches import BaseCache
from langchain_core.outputs import Generation from langchain_core.outputs import Generation
from langchain_standard_tests.base import BaseStandardTests
class SyncCacheTestSuite(ABC):
class SyncCacheTestSuite(BaseStandardTests):
"""Test suite for checking the BaseCache API of a caching layer for LLMs. """Test suite for checking the BaseCache API of a caching layer for LLMs.
This test suite verifies the basic caching API of a caching layer for LLMs. This test suite verifies the basic caching API of a caching layer for LLMs.
@ -95,7 +97,7 @@ class SyncCacheTestSuite(ABC):
assert cache.lookup(prompt, llm_string) == generations assert cache.lookup(prompt, llm_string) == generations
class AsyncCacheTestSuite(ABC): class AsyncCacheTestSuite(BaseStandardTests):
"""Test suite for checking the BaseCache API of a caching layer for LLMs. """Test suite for checking the BaseCache API of a caching layer for LLMs.
This test suite verifies the basic caching API of a caching layer for LLMs. This test suite verifies the basic caching API of a caching layer for LLMs.

View File

@ -1,19 +1,21 @@
"""Test suite to test vectostores.""" """Test suite to test vectostores."""
import inspect import inspect
from abc import ABC, abstractmethod from abc import abstractmethod
import pytest import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings.fake import DeterministicFakeEmbedding, Embeddings from langchain_core.embeddings.fake import DeterministicFakeEmbedding, Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from langchain_standard_tests.base import BaseStandardTests
# Arbitrarily chosen. Using a small embedding size # Arbitrarily chosen. Using a small embedding size
# so tests are faster and easier to debug. # so tests are faster and easier to debug.
EMBEDDING_SIZE = 6 EMBEDDING_SIZE = 6
class ReadWriteTestSuite(ABC): class ReadWriteTestSuite(BaseStandardTests):
"""Test suite for checking the read-write API of a vectorstore. """Test suite for checking the read-write API of a vectorstore.
This test suite verifies the basic read-write API of a vectorstore. This test suite verifies the basic read-write API of a vectorstore.
@ -201,7 +203,7 @@ class ReadWriteTestSuite(ABC):
assert "ids" not in signature.parameters assert "ids" not in signature.parameters
class AsyncReadWriteTestSuite(ABC): class AsyncReadWriteTestSuite(BaseStandardTests):
"""Test suite for checking the **async** read-write API of a vectorstore. """Test suite for checking the **async** read-write API of a vectorstore.
This test suite verifies the basic read-write API of a vectorstore. This test suite verifies the basic read-write API of a vectorstore.

View File

@ -1,5 +1,6 @@
"""Unit tests for chat models.""" """Unit tests for chat models."""
from abc import ABC, abstractmethod
from abc import abstractmethod
from typing import Any, List, Literal, Optional, Type from typing import Any, List, Literal, Optional, Type
import pytest import pytest
@ -8,6 +9,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableBinding from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_standard_tests.base import BaseStandardTests
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
@ -64,7 +66,7 @@ def my_adder(a: int, b: int) -> int:
return a + b return a + b
class ChatModelTests(ABC): class ChatModelTests(BaseStandardTests):
@property @property
@abstractmethod @abstractmethod
def chat_model_class(self) -> Type[BaseChatModel]: def chat_model_class(self) -> Type[BaseChatModel]: