mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 20:35:43 +00:00
**description:** the ChatModel[Integration]Tests classes are powerful and helpful, this change allows sub-classes to add additional tests. for instance, ``` class TestChatMyServiceIntegration(ChatModelIntegrationTests): ... def test_myservice(self, model: BaseChatModel) -> None: ... ``` --------- Co-authored-by: ccurme <chester.curme@gmail.com>
71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
from abc import ABC
|
|
from typing import Type
|
|
|
|
|
|
class BaseStandardTests(ABC):
|
|
"""
|
|
:private:
|
|
"""
|
|
|
|
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None:
|
|
"""
|
|
Test that no standard tests are overridden.
|
|
|
|
:private:
|
|
"""
|
|
# find path to standard test implementations
|
|
comparison_class = None
|
|
|
|
def explore_bases(cls: Type) -> None:
|
|
nonlocal comparison_class
|
|
for base in cls.__bases__:
|
|
if base.__module__.startswith("langchain_tests."):
|
|
if comparison_class is None:
|
|
comparison_class = base
|
|
else:
|
|
raise ValueError(
|
|
"Multiple standard test base classes found: "
|
|
f"{comparison_class}, {base}"
|
|
)
|
|
else:
|
|
explore_bases(base)
|
|
|
|
explore_bases(self.__class__)
|
|
assert comparison_class is not None, "No standard test base class found."
|
|
|
|
print(f"Comparing {self.__class__} to {comparison_class}") # noqa: T201
|
|
|
|
running_tests = set(
|
|
[method for method in dir(self) if method.startswith("test_")]
|
|
)
|
|
base_tests = set(
|
|
[method for method in dir(comparison_class) if method.startswith("test_")]
|
|
)
|
|
deleted_tests = base_tests - running_tests
|
|
assert not deleted_tests, f"Standard tests deleted: {deleted_tests}"
|
|
|
|
overridden_tests = [
|
|
method
|
|
for method in base_tests
|
|
if getattr(self.__class__, method) is not getattr(comparison_class, method)
|
|
]
|
|
|
|
def is_xfail(method: str) -> bool:
|
|
m = getattr(self.__class__, method)
|
|
if not hasattr(m, "pytestmark"):
|
|
return False
|
|
marks = m.pytestmark
|
|
return any(
|
|
mark.name == "xfail" and mark.kwargs.get("reason") for mark in marks
|
|
)
|
|
|
|
overridden_not_xfail = [
|
|
method for method in overridden_tests if not is_xfail(method)
|
|
]
|
|
assert not overridden_not_xfail, (
|
|
"Standard tests overridden without "
|
|
f'@pytest.mark.xfail(reason="..."): {overridden_not_xfail}\n'
|
|
"Note: reason is required to explain why the standard test has an expected "
|
|
"failure."
|
|
)
|