mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
standard-tests: add override check (#24407)
This commit is contained in:
parent
1639ccfd15
commit
2c6b9e8771
@ -2,12 +2,10 @@
|
|||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.runnables import RunnableBinding
|
from langchain_standard_tests.unit_tests.chat_models import (
|
||||||
from langchain_standard_tests.unit_tests.chat_models import ( # type: ignore[import-not-found]
|
|
||||||
ChatModelUnitTests,
|
ChatModelUnitTests,
|
||||||
Person,
|
|
||||||
my_adder_tool,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_groq import ChatGroq
|
from langchain_groq import ChatGroq
|
||||||
@ -18,10 +16,6 @@ class TestGroqStandard(ChatModelUnitTests):
|
|||||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
return ChatGroq
|
return ChatGroq
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Groq does not support tool_choice='any'")
|
||||||
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
||||||
"""Does not currently support tool_choice='any'."""
|
super().test_bind_tool_pydantic(model)
|
||||||
if not self.has_tool_calling:
|
|
||||||
return
|
|
||||||
|
|
||||||
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
|
|
||||||
assert isinstance(tool_model, RunnableBinding)
|
|
||||||
|
@ -2,10 +2,9 @@
|
|||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.runnables import RunnableBinding
|
|
||||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||||
from langchain_standard_tests.unit_tests.chat_models import Person, my_adder_tool
|
|
||||||
|
|
||||||
from langchain_openai import AzureChatOpenAI
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
|
||||||
@ -23,10 +22,6 @@ class TestOpenAIStandard(ChatModelUnitTests):
|
|||||||
"azure_endpoint": "https://test.azure.com",
|
"azure_endpoint": "https://test.azure.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="AzureOpenAI does not support tool_choice='any'")
|
||||||
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
def test_bind_tool_pydantic(self, model: BaseChatModel) -> None:
|
||||||
"""Does not currently support tool_choice='any'."""
|
super().test_bind_tool_pydantic(model)
|
||||||
if not self.has_tool_calling:
|
|
||||||
return
|
|
||||||
|
|
||||||
tool_model = model.bind_tools([Person, Person.schema(), my_adder_tool])
|
|
||||||
assert isinstance(tool_model, RunnableBinding)
|
|
||||||
|
66
libs/standard-tests/langchain_standard_tests/base.py
Normal file
66
libs/standard-tests/langchain_standard_tests/base.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStandardTests(ABC):
|
||||||
|
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that no standard tests are overridden.
|
||||||
|
"""
|
||||||
|
# find path to standard test implementations
|
||||||
|
comparison_class = None
|
||||||
|
|
||||||
|
def explore_bases(cls: Type) -> None:
|
||||||
|
nonlocal comparison_class
|
||||||
|
for base in cls.__bases__:
|
||||||
|
if base.__module__.startswith("langchain_standard_tests."):
|
||||||
|
if comparison_class is None:
|
||||||
|
comparison_class = base
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple standard test base classes found: "
|
||||||
|
f"{comparison_class}, {base}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
explore_bases(base)
|
||||||
|
|
||||||
|
explore_bases(self.__class__)
|
||||||
|
assert comparison_class is not None, "No standard test base class found."
|
||||||
|
|
||||||
|
print(f"Comparing {self.__class__} to {comparison_class}") # noqa: T201
|
||||||
|
|
||||||
|
running_tests = set(
|
||||||
|
[method for method in dir(self) if method.startswith("test_")]
|
||||||
|
)
|
||||||
|
base_tests = set(
|
||||||
|
[method for method in dir(comparison_class) if method.startswith("test_")]
|
||||||
|
)
|
||||||
|
non_standard_tests = running_tests - base_tests
|
||||||
|
assert not non_standard_tests, f"Non-standard tests found: {non_standard_tests}"
|
||||||
|
deleted_tests = base_tests - running_tests
|
||||||
|
assert not deleted_tests, f"Standard tests deleted: {deleted_tests}"
|
||||||
|
|
||||||
|
overridden_tests = [
|
||||||
|
method
|
||||||
|
for method in running_tests
|
||||||
|
if getattr(self.__class__, method) is not getattr(comparison_class, method)
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_xfail(method: str) -> bool:
|
||||||
|
m = getattr(self.__class__, method)
|
||||||
|
if not hasattr(m, "pytestmark"):
|
||||||
|
return False
|
||||||
|
marks = m.pytestmark
|
||||||
|
return any(
|
||||||
|
mark.name == "xfail" and mark.kwargs.get("reason") for mark in marks
|
||||||
|
)
|
||||||
|
|
||||||
|
overridden_not_xfail = [
|
||||||
|
method for method in overridden_tests if not is_xfail(method)
|
||||||
|
]
|
||||||
|
assert not overridden_not_xfail, (
|
||||||
|
"Standard tests overridden without "
|
||||||
|
f'@pytest.mark.xfail(reason="..."): {overridden_not_xfail}\n'
|
||||||
|
"Note: reason is required to explain why the standard test has an expected "
|
||||||
|
"failure."
|
||||||
|
)
|
@ -1,13 +1,15 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import AsyncGenerator, Generator, Generic, Tuple, TypeVar
|
from typing import AsyncGenerator, Generator, Generic, Tuple, TypeVar
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.stores import BaseStore
|
from langchain_core.stores import BaseStore
|
||||||
|
|
||||||
|
from langchain_standard_tests.base import BaseStandardTests
|
||||||
|
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
|
|
||||||
|
|
||||||
class BaseStoreSyncTests(ABC, Generic[V]):
|
class BaseStoreSyncTests(BaseStandardTests, Generic[V]):
|
||||||
"""Test suite for checking the key-value API of a BaseStore.
|
"""Test suite for checking the key-value API of a BaseStore.
|
||||||
|
|
||||||
This test suite verifies the basic key-value API of a BaseStore.
|
This test suite verifies the basic key-value API of a BaseStore.
|
||||||
@ -138,7 +140,7 @@ class BaseStoreSyncTests(ABC, Generic[V]):
|
|||||||
assert sorted(kv_store.yield_keys(prefix="foo")) == ["foo"]
|
assert sorted(kv_store.yield_keys(prefix="foo")) == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
class BaseStoreAsyncTests(ABC):
|
class BaseStoreAsyncTests(BaseStandardTests):
|
||||||
"""Test suite for checking the key-value API of a BaseStore.
|
"""Test suite for checking the key-value API of a BaseStore.
|
||||||
|
|
||||||
This test suite verifies the basic key-value API of a BaseStore.
|
This test suite verifies the basic key-value API of a BaseStore.
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.caches import BaseCache
|
from langchain_core.caches import BaseCache
|
||||||
from langchain_core.outputs import Generation
|
from langchain_core.outputs import Generation
|
||||||
|
|
||||||
|
from langchain_standard_tests.base import BaseStandardTests
|
||||||
|
|
||||||
class SyncCacheTestSuite(ABC):
|
|
||||||
|
class SyncCacheTestSuite(BaseStandardTests):
|
||||||
"""Test suite for checking the BaseCache API of a caching layer for LLMs.
|
"""Test suite for checking the BaseCache API of a caching layer for LLMs.
|
||||||
|
|
||||||
This test suite verifies the basic caching API of a caching layer for LLMs.
|
This test suite verifies the basic caching API of a caching layer for LLMs.
|
||||||
@ -95,7 +97,7 @@ class SyncCacheTestSuite(ABC):
|
|||||||
assert cache.lookup(prompt, llm_string) == generations
|
assert cache.lookup(prompt, llm_string) == generations
|
||||||
|
|
||||||
|
|
||||||
class AsyncCacheTestSuite(ABC):
|
class AsyncCacheTestSuite(BaseStandardTests):
|
||||||
"""Test suite for checking the BaseCache API of a caching layer for LLMs.
|
"""Test suite for checking the BaseCache API of a caching layer for LLMs.
|
||||||
|
|
||||||
This test suite verifies the basic caching API of a caching layer for LLMs.
|
This test suite verifies the basic caching API of a caching layer for LLMs.
|
||||||
|
@ -1,19 +1,21 @@
|
|||||||
"""Test suite to test vectostores."""
|
"""Test suite to test vectostores."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding, Embeddings
|
from langchain_core.embeddings.fake import DeterministicFakeEmbedding, Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
|
from langchain_standard_tests.base import BaseStandardTests
|
||||||
|
|
||||||
# Arbitrarily chosen. Using a small embedding size
|
# Arbitrarily chosen. Using a small embedding size
|
||||||
# so tests are faster and easier to debug.
|
# so tests are faster and easier to debug.
|
||||||
EMBEDDING_SIZE = 6
|
EMBEDDING_SIZE = 6
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteTestSuite(ABC):
|
class ReadWriteTestSuite(BaseStandardTests):
|
||||||
"""Test suite for checking the read-write API of a vectorstore.
|
"""Test suite for checking the read-write API of a vectorstore.
|
||||||
|
|
||||||
This test suite verifies the basic read-write API of a vectorstore.
|
This test suite verifies the basic read-write API of a vectorstore.
|
||||||
@ -201,7 +203,7 @@ class ReadWriteTestSuite(ABC):
|
|||||||
assert "ids" not in signature.parameters
|
assert "ids" not in signature.parameters
|
||||||
|
|
||||||
|
|
||||||
class AsyncReadWriteTestSuite(ABC):
|
class AsyncReadWriteTestSuite(BaseStandardTests):
|
||||||
"""Test suite for checking the **async** read-write API of a vectorstore.
|
"""Test suite for checking the **async** read-write API of a vectorstore.
|
||||||
|
|
||||||
This test suite verifies the basic read-write API of a vectorstore.
|
This test suite verifies the basic read-write API of a vectorstore.
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Unit tests for chat models."""
|
"""Unit tests for chat models."""
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
from abc import abstractmethod
|
||||||
from typing import Any, List, Literal, Optional, Type
|
from typing import Any, List, Literal, Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,6 +9,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field
|
|||||||
from langchain_core.runnables import RunnableBinding
|
from langchain_core.runnables import RunnableBinding
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from langchain_standard_tests.base import BaseStandardTests
|
||||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +66,7 @@ def my_adder(a: int, b: int) -> int:
|
|||||||
return a + b
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
class ChatModelTests(ABC):
|
class ChatModelTests(BaseStandardTests):
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
|
Loading…
Reference in New Issue
Block a user