diff --git a/libs/standard-tests/langchain_tests/base.py b/libs/standard-tests/langchain_tests/base.py index 0f9156673c1..1afd1e379ae 100644 --- a/libs/standard-tests/langchain_tests/base.py +++ b/libs/standard-tests/langchain_tests/base.py @@ -1,15 +1,13 @@ """Standard tests.""" -from abc import ABC - -class BaseStandardTests(ABC): +class BaseStandardTests: """Base class for standard tests. :private: """ - def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: + def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: # noqa: N802 """Test that no standard tests are overridden. :private: diff --git a/libs/standard-tests/langchain_tests/conftest.py b/libs/standard-tests/langchain_tests/conftest.py index fb441f65f36..8b6b61cda24 100644 --- a/libs/standard-tests/langchain_tests/conftest.py +++ b/libs/standard-tests/langchain_tests/conftest.py @@ -3,7 +3,7 @@ import gzip from os import PathLike from pathlib import Path -from typing import Any, Union +from typing import Any, Union, cast import pytest import yaml @@ -27,7 +27,13 @@ class CustomSerializer: def serialize(cassette_dict: dict) -> bytes: """Convert cassette to YAML and compress it.""" cassette_dict["requests"] = [ - request._to_dict() for request in cassette_dict["requests"] + { + "method": request.method, + "uri": request.uri, + "body": request.body, + "headers": {k: [v] for k, v in request.headers.items()}, + } + for request in cassette_dict["requests"] ] yml = yaml.safe_dump(cassette_dict) return gzip.compress(yml.encode("utf-8")) @@ -35,11 +41,9 @@ class CustomSerializer: @staticmethod def deserialize(data: bytes) -> dict: """Decompress data and convert it from YAML.""" - text = gzip.decompress(data).decode("utf-8") - cassette: dict[str, Any] = yaml.safe_load(text) - cassette["requests"] = [ - Request._from_dict(request) for request in cassette["requests"] - ] + decoded_yaml = gzip.decompress(data).decode("utf-8") + cassette = cast("dict[str, Any]", yaml.safe_load(decoded_yaml)) + cassette["requests"] = [Request(**request) for request in cassette["requests"]] return cassette diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index dce989b3a6a..bf7240c2971 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -30,14 +30,14 @@ from pydantic import BaseModel, Field from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import Field as FieldV1 from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped] -from typing_extensions import TypedDict +from typing_extensions import TypedDict, override from vcr.cassette import Cassette from langchain_tests.unit_tests.chat_models import ChatModelTests from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION -def _get_joke_class( +def _get_joke_class( # noqa: RET503 schema_type: Literal["pydantic", "typeddict", "json_schema"], ) -> Any: class Joke(BaseModel): @@ -56,7 +56,7 @@ def _get_joke_class( punchline: Annotated[str, ..., "answer to resolve the joke"] def validate_joke_dict(result: Any) -> bool: - return all(key in ["setup", "punchline"] for key in result) + return all(key in {"setup", "punchline"} for key in result) if schema_type == "pydantic": return Joke, validate_joke @@ -75,6 +75,7 @@ class _TestCallbackHandler(BaseCallbackHandler): super().__init__() self.options = [] + @override def on_chat_model_start( self, serialized: Any, @@ -1042,7 +1043,7 @@ class ChatModelIntegrationTests(ChatModelTests): # Needed for langchain_core.callbacks.usage model_name = result.response_metadata.get("model_name") assert isinstance(model_name, str) - assert model_name != "", "model_name is empty" + assert model_name, "model_name is empty" # `input_tokens` is the total, possibly including other unclassified or # system-level tokens. @@ -1056,10 +1057,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) is not None assert isinstance(input_token_details.get("audio"), int) # Asserts that total input tokens are at least the sum of the token counts - total_detailed_tokens = sum( + assert usage_metadata.get("input_tokens", 0) >= sum( v for v in input_token_details.values() if isinstance(v, int) ) - assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens if "audio_output" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_audio_output() assert (usage_metadata := msg.usage_metadata) is not None @@ -1068,10 +1068,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) is not None assert isinstance(output_token_details.get("audio"), int) # Asserts that total output tokens are at least the sum of the token counts - total_detailed_tokens = sum( + assert usage_metadata.get("output_tokens", 0) >= sum( v for v in output_token_details.values() if isinstance(v, int) ) - assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens if "reasoning_output" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_reasoning_output() assert (usage_metadata := msg.usage_metadata) is not None @@ -1080,10 +1079,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) is not None assert isinstance(output_token_details.get("reasoning"), int) # Asserts that total output tokens are at least the sum of the token counts - total_detailed_tokens = sum( + assert usage_metadata.get("output_tokens", 0) >= sum( v for v in output_token_details.values() if isinstance(v, int) ) - assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens if "cache_read_input" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_cache_read_input() assert (usage_metadata := msg.usage_metadata) is not None @@ -1092,10 +1090,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) is not None assert isinstance(input_token_details.get("cache_read"), int) # Asserts that total input tokens are at least the sum of the token counts - total_detailed_tokens = sum( + assert usage_metadata.get("input_tokens", 0) >= sum( v for v in input_token_details.values() if isinstance(v, int) ) - assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_cache_creation_input() assert (usage_metadata := msg.usage_metadata) is not None @@ -1104,10 +1101,9 @@ class ChatModelIntegrationTests(ChatModelTests): ) is not None assert isinstance(input_token_details.get("cache_creation"), int) # Asserts that total input tokens are at least the sum of the token counts - total_detailed_tokens = sum( + assert usage_metadata.get("input_tokens", 0) >= sum( v for v in input_token_details.values() if isinstance(v, int) ) - assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: """Test usage metadata in streaming mode. @@ -1235,7 +1231,7 @@ class ChatModelIntegrationTests(ChatModelTests): # Needed for langchain_core.callbacks.usage model_name = full.response_metadata.get("model_name") assert isinstance(model_name, str) - assert model_name != "", "model_name is empty" + assert model_name, "model_name is empty" if "audio_input" in self.supported_usage_metadata_details["stream"]: msg = self.invoke_with_audio_input(stream=True) @@ -1720,7 +1716,7 @@ class ChatModelIntegrationTests(ChatModelTests): pytest.skip("Test requires tool choice.") @tool - def get_weather(location: str) -> str: + def get_weather(location: str) -> str: # noqa: ARG001 """Get weather at a location.""" return "It's sunny." @@ -2130,7 +2126,7 @@ class ChatModelIntegrationTests(ChatModelTests): See `example implementation `__ of ``with_structured_output``. - """ # noqa: E501 + """ if not self.has_structured_output: pytest.skip("Test requires structured output.") @@ -2262,8 +2258,8 @@ class ChatModelIntegrationTests(ChatModelTests): if not self.supports_json_mode: pytest.skip("Test requires json mode support.") - from pydantic import BaseModel as BaseModelProper - from pydantic import Field as FieldProper + from pydantic import BaseModel as BaseModelProper # noqa: PLC0415 + from pydantic import Field as FieldProper # noqa: PLC0415 class Joke(BaseModelProper): """Joke to tell user.""" @@ -2912,7 +2908,7 @@ class ChatModelIntegrationTests(ChatModelTests): pytest.skip("Test requires tool calling.") @tool - def get_weather(location: str) -> str: + def get_weather(location: str) -> str: # noqa: ARG001 """Call to surf the web.""" return "It's sunny." diff --git a/libs/standard-tests/langchain_tests/integration_tests/indexer.py b/libs/standard-tests/langchain_tests/integration_tests/indexer.py index 4a887b5c8ad..94378a45d6c 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/indexer.py +++ b/libs/standard-tests/langchain_tests/integration_tests/indexer.py @@ -175,7 +175,7 @@ class DocumentIndexerTestSuite(ABC): def test_delete_no_args(self, index: DocumentIndex) -> None: """Test delete with no args raises ValueError.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 index.delete() def test_delete_missing_content(self, index: DocumentIndex) -> None: @@ -367,7 +367,7 @@ class AsyncDocumentIndexTestSuite(ABC): async def test_delete_no_args(self, index: DocumentIndex) -> None: """Test delete with no args raises ValueError.""" - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 await index.adelete() async def test_delete_missing_content(self, index: DocumentIndex) -> None: diff --git a/libs/standard-tests/langchain_tests/utils/pydantic.py b/libs/standard-tests/langchain_tests/utils/pydantic.py index 34cfce3c2ab..6a6e7ba0ed2 100644 --- a/libs/standard-tests/langchain_tests/utils/pydantic.py +++ b/libs/standard-tests/langchain_tests/utils/pydantic.py @@ -7,7 +7,7 @@ def get_pydantic_major_version() -> int: """Get the major version of Pydantic.""" try: - import pydantic + import pydantic # noqa: PLC0415 return int(pydantic.__version__.split(".")[0]) except ImportError: diff --git a/libs/standard-tests/pyproject.toml b/libs/standard-tests/pyproject.toml index 613ab8c315f..53c33258664 100644 --- a/libs/standard-tests/pyproject.toml +++ b/libs/standard-tests/pyproject.toml @@ -59,14 +59,45 @@ ignore_missing_imports = true target-version = "py39" [tool.ruff.lint] -select = ["D", "E", "F", "I", "PGH", "T201", "UP",] -pyupgrade.keep-runtime-typing = true +select = [ "ALL",] +ignore = [ + "C90", # McCabe complexity + "COM812", # Messes with the formatter + "FA100", # Can't activate since we exclude UP007 for now + "FIX002", # Line contains TODO + "ISC001", # Messes with the formatter + "PERF203", # Rarely useful + "PLR2004", # Magic numbers + "PLR09", # Too many something (arg, statements, etc) + "RUF012", # Doesn't play well with Pydantic + "S101", # Asserts allowed in tests + "S311", # No need for strong crypto in tests + "SLF001", # Tests may call private methods + "TC001", # Doesn't play well with Pydantic + "TC002", # Doesn't play well with Pydantic + "TC003", # Doesn't play well with Pydantic + "TD002", # Missing author in TODO + "TD003", # Missing issue link in TODO -[tool.ruff.lint.pydocstyle] -convention = "google" + # TODO rules + "ANN401", + "BLE", +] +unfixable = [ + "B028", # People should intentionally tune the stacklevel + "PLW1510", # People should intentionally set the check argument +] + +flake8-annotations.allow-star-arg-any = true +flake8-annotations.mypy-init-return = true +flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"] +pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",] +pydocstyle.convention = "google" +pyupgrade.keep-runtime-typing = true [tool.ruff.lint.per-file-ignores] "tests/**" = [ "D1",] +"scripts/**" = [ "INP",] [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/standard-tests/tests/integration_tests/__init__.py b/libs/standard-tests/tests/integration_tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/standard-tests/tests/unit_tests/custom_chat_model.py b/libs/standard-tests/tests/unit_tests/custom_chat_model.py index d8f9f7c668f..f91786aa62c 100644 --- a/libs/standard-tests/tests/unit_tests/custom_chat_model.py +++ b/libs/standard-tests/tests/unit_tests/custom_chat_model.py @@ -7,6 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from pydantic import Field +from typing_extensions import override class ChatParrotLink(BaseChatModel): @@ -41,6 +42,7 @@ class ChatParrotLink(BaseChatModel): stop: Optional[list[str]] = None max_retries: int = 2 + @override def _generate( self, messages: list[BaseMessage], @@ -92,6 +94,7 @@ class ChatParrotLink(BaseChatModel): generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) + @override def _stream( self, messages: list[BaseMessage], diff --git a/libs/standard-tests/tests/unit_tests/test_basic_tool.py b/libs/standard-tests/tests/unit_tests/test_basic_tool.py index 1439e414316..d7ab342f043 100644 --- a/libs/standard-tests/tests/unit_tests/test_basic_tool.py +++ b/libs/standard-tests/tests/unit_tests/test_basic_tool.py @@ -1,6 +1,7 @@ from typing import Literal from langchain_core.tools import BaseTool +from typing_extensions import override from langchain_tests.integration_tests import ToolsIntegrationTests from langchain_tests.unit_tests import ToolsUnitTests @@ -12,6 +13,7 @@ class ParrotMultiplyTool(BaseTool): "Multiply two numbers like a parrot. Parrots always add eighty for their matey." ) + @override def _run(self, a: int, b: int) -> int: return a * b + 80 @@ -23,6 +25,7 @@ class ParrotMultiplyArtifactTool(BaseTool): ) response_format: Literal["content_and_artifact"] = "content_and_artifact" + @override def _run(self, a: int, b: int) -> tuple[int, str]: return a * b + 80, "parrot artifact" diff --git a/libs/standard-tests/tests/unit_tests/test_custom_chat_model.py b/libs/standard-tests/tests/unit_tests/test_custom_chat_model.py index ac821bd77dc..36ccb1d256a 100644 --- a/libs/standard-tests/tests/unit_tests/test_custom_chat_model.py +++ b/libs/standard-tests/tests/unit_tests/test_custom_chat_model.py @@ -1,9 +1,8 @@ """Test the standard tests on the custom chat model in the docs.""" -from typing import Optional - import pytest from langchain_core.language_models.chat_models import BaseChatModel +from typing_extensions import Any from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.unit_tests import ChatModelUnitTests @@ -34,7 +33,6 @@ class TestChatParrotLinkIntegration(ChatModelIntegrationTests): def test_unicode_tool_call_integration( self, model: BaseChatModel, - tool_choice: Optional[str] = None, - force_tool_call: bool = True, + **_: Any, ) -> None: """Expected failure as ChatParrotLink doesn't support tool calling yet.""" diff --git a/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py b/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py index 7349d3ffc5e..42fe15cbf90 100644 --- a/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py +++ b/libs/standard-tests/tests/unit_tests/test_in_memory_base_store.py @@ -2,6 +2,7 @@ import pytest from langchain_core.stores import InMemoryStore +from typing_extensions import override from langchain_tests.integration_tests.base_store import ( BaseStoreAsyncTests, @@ -11,19 +12,23 @@ from langchain_tests.integration_tests.base_store import ( class TestInMemoryStore(BaseStoreSyncTests[str]): @pytest.fixture + @override def three_values(self) -> tuple[str, str, str]: return "foo", "bar", "buzz" @pytest.fixture + @override def kv_store(self) -> InMemoryStore: return InMemoryStore() class TestInMemoryStoreAsync(BaseStoreAsyncTests[str]): @pytest.fixture + @override def three_values(self) -> tuple[str, str, str]: return "foo", "bar", "buzz" @pytest.fixture + @override async def kv_store(self) -> InMemoryStore: return InMemoryStore() diff --git a/libs/standard-tests/tests/unit_tests/test_in_memory_cache.py b/libs/standard-tests/tests/unit_tests/test_in_memory_cache.py index 6c1a1647ade..5d5e67df04d 100644 --- a/libs/standard-tests/tests/unit_tests/test_in_memory_cache.py +++ b/libs/standard-tests/tests/unit_tests/test_in_memory_cache.py @@ -1,5 +1,6 @@ import pytest from langchain_core.caches import InMemoryCache +from typing_extensions import override from langchain_tests.integration_tests.cache import ( AsyncCacheTestSuite, @@ -9,11 +10,13 @@ from langchain_tests.integration_tests.cache import ( class TestInMemoryCache(SyncCacheTestSuite): @pytest.fixture + @override def cache(self) -> InMemoryCache: return InMemoryCache() class TestInMemoryCacheAsync(AsyncCacheTestSuite): @pytest.fixture + @override async def cache(self) -> InMemoryCache: return InMemoryCache()