chore(standard-tests): select ALL rules with exclusions (#31937)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-09-08 16:57:47 +02:00
committed by GitHub
parent 323729915a
commit c672590f42
12 changed files with 83 additions and 42 deletions

View File

@@ -1,15 +1,13 @@
"""Standard tests.""" """Standard tests."""
from abc import ABC
class BaseStandardTests:
class BaseStandardTests(ABC):
"""Base class for standard tests. """Base class for standard tests.
:private: :private:
""" """
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: # noqa: N802
"""Test that no standard tests are overridden. """Test that no standard tests are overridden.
:private: :private:

View File

@@ -3,7 +3,7 @@
import gzip import gzip
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, Union, cast
import pytest import pytest
import yaml import yaml
@@ -27,7 +27,13 @@ class CustomSerializer:
def serialize(cassette_dict: dict) -> bytes: def serialize(cassette_dict: dict) -> bytes:
"""Convert cassette to YAML and compress it.""" """Convert cassette to YAML and compress it."""
cassette_dict["requests"] = [ cassette_dict["requests"] = [
request._to_dict() for request in cassette_dict["requests"] {
"method": request.method,
"uri": request.uri,
"body": request.body,
"headers": {k: [v] for k, v in request.headers.items()},
}
for request in cassette_dict["requests"]
] ]
yml = yaml.safe_dump(cassette_dict) yml = yaml.safe_dump(cassette_dict)
return gzip.compress(yml.encode("utf-8")) return gzip.compress(yml.encode("utf-8"))
@@ -35,11 +41,9 @@ class CustomSerializer:
@staticmethod @staticmethod
def deserialize(data: bytes) -> dict: def deserialize(data: bytes) -> dict:
"""Decompress data and convert it from YAML.""" """Decompress data and convert it from YAML."""
text = gzip.decompress(data).decode("utf-8") decoded_yaml = gzip.decompress(data).decode("utf-8")
cassette: dict[str, Any] = yaml.safe_load(text) cassette = cast("dict[str, Any]", yaml.safe_load(decoded_yaml))
cassette["requests"] = [ cassette["requests"] = [Request(**request) for request in cassette["requests"]]
Request._from_dict(request) for request in cassette["requests"]
]
return cassette return cassette

View File

@@ -30,14 +30,14 @@ from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1 from pydantic.v1 import Field as FieldV1
from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped] from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped]
from typing_extensions import TypedDict from typing_extensions import TypedDict, override
from vcr.cassette import Cassette from vcr.cassette import Cassette
from langchain_tests.unit_tests.chat_models import ChatModelTests from langchain_tests.unit_tests.chat_models import ChatModelTests
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
def _get_joke_class( def _get_joke_class( # noqa: RET503
schema_type: Literal["pydantic", "typeddict", "json_schema"], schema_type: Literal["pydantic", "typeddict", "json_schema"],
) -> Any: ) -> Any:
class Joke(BaseModel): class Joke(BaseModel):
@@ -56,7 +56,7 @@ def _get_joke_class(
punchline: Annotated[str, ..., "answer to resolve the joke"] punchline: Annotated[str, ..., "answer to resolve the joke"]
def validate_joke_dict(result: Any) -> bool: def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result) return all(key in {"setup", "punchline"} for key in result)
if schema_type == "pydantic": if schema_type == "pydantic":
return Joke, validate_joke return Joke, validate_joke
@@ -75,6 +75,7 @@ class _TestCallbackHandler(BaseCallbackHandler):
super().__init__() super().__init__()
self.options = [] self.options = []
@override
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: Any, serialized: Any,
@@ -1042,7 +1043,7 @@ class ChatModelIntegrationTests(ChatModelTests):
# Needed for langchain_core.callbacks.usage # Needed for langchain_core.callbacks.usage
model_name = result.response_metadata.get("model_name") model_name = result.response_metadata.get("model_name")
assert isinstance(model_name, str) assert isinstance(model_name, str)
assert model_name != "", "model_name is empty" assert model_name, "model_name is empty"
# `input_tokens` is the total, possibly including other unclassified or # `input_tokens` is the total, possibly including other unclassified or
# system-level tokens. # system-level tokens.
@@ -1056,10 +1057,9 @@ class ChatModelIntegrationTests(ChatModelTests):
) is not None ) is not None
assert isinstance(input_token_details.get("audio"), int) assert isinstance(input_token_details.get("audio"), int)
# Asserts that total input tokens are at least the sum of the token counts # Asserts that total input tokens are at least the sum of the token counts
total_detailed_tokens = sum( assert usage_metadata.get("input_tokens", 0) >= sum(
v for v in input_token_details.values() if isinstance(v, int) v for v in input_token_details.values() if isinstance(v, int)
) )
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
if "audio_output" in self.supported_usage_metadata_details["invoke"]: if "audio_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_output() msg = self.invoke_with_audio_output()
assert (usage_metadata := msg.usage_metadata) is not None assert (usage_metadata := msg.usage_metadata) is not None
@@ -1068,10 +1068,9 @@ class ChatModelIntegrationTests(ChatModelTests):
) is not None ) is not None
assert isinstance(output_token_details.get("audio"), int) assert isinstance(output_token_details.get("audio"), int)
# Asserts that total output tokens are at least the sum of the token counts # Asserts that total output tokens are at least the sum of the token counts
total_detailed_tokens = sum( assert usage_metadata.get("output_tokens", 0) >= sum(
v for v in output_token_details.values() if isinstance(v, int) v for v in output_token_details.values() if isinstance(v, int)
) )
assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]: if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_reasoning_output() msg = self.invoke_with_reasoning_output()
assert (usage_metadata := msg.usage_metadata) is not None assert (usage_metadata := msg.usage_metadata) is not None
@@ -1080,10 +1079,9 @@ class ChatModelIntegrationTests(ChatModelTests):
) is not None ) is not None
assert isinstance(output_token_details.get("reasoning"), int) assert isinstance(output_token_details.get("reasoning"), int)
# Asserts that total output tokens are at least the sum of the token counts # Asserts that total output tokens are at least the sum of the token counts
total_detailed_tokens = sum( assert usage_metadata.get("output_tokens", 0) >= sum(
v for v in output_token_details.values() if isinstance(v, int) v for v in output_token_details.values() if isinstance(v, int)
) )
assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]: if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input() msg = self.invoke_with_cache_read_input()
assert (usage_metadata := msg.usage_metadata) is not None assert (usage_metadata := msg.usage_metadata) is not None
@@ -1092,10 +1090,9 @@ class ChatModelIntegrationTests(ChatModelTests):
) is not None ) is not None
assert isinstance(input_token_details.get("cache_read"), int) assert isinstance(input_token_details.get("cache_read"), int)
# Asserts that total input tokens are at least the sum of the token counts # Asserts that total input tokens are at least the sum of the token counts
total_detailed_tokens = sum( assert usage_metadata.get("input_tokens", 0) >= sum(
v for v in input_token_details.values() if isinstance(v, int) v for v in input_token_details.values() if isinstance(v, int)
) )
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]: if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input() msg = self.invoke_with_cache_creation_input()
assert (usage_metadata := msg.usage_metadata) is not None assert (usage_metadata := msg.usage_metadata) is not None
@@ -1104,10 +1101,9 @@ class ChatModelIntegrationTests(ChatModelTests):
) is not None ) is not None
assert isinstance(input_token_details.get("cache_creation"), int) assert isinstance(input_token_details.get("cache_creation"), int)
# Asserts that total input tokens are at least the sum of the token counts # Asserts that total input tokens are at least the sum of the token counts
total_detailed_tokens = sum( assert usage_metadata.get("input_tokens", 0) >= sum(
v for v in input_token_details.values() if isinstance(v, int) v for v in input_token_details.values() if isinstance(v, int)
) )
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
"""Test usage metadata in streaming mode. """Test usage metadata in streaming mode.
@@ -1235,7 +1231,7 @@ class ChatModelIntegrationTests(ChatModelTests):
# Needed for langchain_core.callbacks.usage # Needed for langchain_core.callbacks.usage
model_name = full.response_metadata.get("model_name") model_name = full.response_metadata.get("model_name")
assert isinstance(model_name, str) assert isinstance(model_name, str)
assert model_name != "", "model_name is empty" assert model_name, "model_name is empty"
if "audio_input" in self.supported_usage_metadata_details["stream"]: if "audio_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_audio_input(stream=True) msg = self.invoke_with_audio_input(stream=True)
@@ -1720,7 +1716,7 @@ class ChatModelIntegrationTests(ChatModelTests):
pytest.skip("Test requires tool choice.") pytest.skip("Test requires tool choice.")
@tool @tool
def get_weather(location: str) -> str: def get_weather(location: str) -> str: # noqa: ARG001
"""Get weather at a location.""" """Get weather at a location."""
return "It's sunny." return "It's sunny."
@@ -2130,7 +2126,7 @@ class ChatModelIntegrationTests(ChatModelTests):
See `example implementation <https://python.langchain.com/api_reference/_modules/langchain_openai/chat_models/base.html#BaseChatOpenAI.with_structured_output>`__ See `example implementation <https://python.langchain.com/api_reference/_modules/langchain_openai/chat_models/base.html#BaseChatOpenAI.with_structured_output>`__
of ``with_structured_output``. of ``with_structured_output``.
""" # noqa: E501 """
if not self.has_structured_output: if not self.has_structured_output:
pytest.skip("Test requires structured output.") pytest.skip("Test requires structured output.")
@@ -2262,8 +2258,8 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.supports_json_mode: if not self.supports_json_mode:
pytest.skip("Test requires json mode support.") pytest.skip("Test requires json mode support.")
from pydantic import BaseModel as BaseModelProper from pydantic import BaseModel as BaseModelProper # noqa: PLC0415
from pydantic import Field as FieldProper from pydantic import Field as FieldProper # noqa: PLC0415
class Joke(BaseModelProper): class Joke(BaseModelProper):
"""Joke to tell user.""" """Joke to tell user."""
@@ -2912,7 +2908,7 @@ class ChatModelIntegrationTests(ChatModelTests):
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
@tool @tool
def get_weather(location: str) -> str: def get_weather(location: str) -> str: # noqa: ARG001
"""Call to surf the web.""" """Call to surf the web."""
return "It's sunny." return "It's sunny."

View File

@@ -175,7 +175,7 @@ class DocumentIndexerTestSuite(ABC):
def test_delete_no_args(self, index: DocumentIndex) -> None: def test_delete_no_args(self, index: DocumentIndex) -> None:
"""Test delete with no args raises ValueError.""" """Test delete with no args raises ValueError."""
with pytest.raises(ValueError): with pytest.raises(ValueError): # noqa: PT011
index.delete() index.delete()
def test_delete_missing_content(self, index: DocumentIndex) -> None: def test_delete_missing_content(self, index: DocumentIndex) -> None:
@@ -367,7 +367,7 @@ class AsyncDocumentIndexTestSuite(ABC):
async def test_delete_no_args(self, index: DocumentIndex) -> None: async def test_delete_no_args(self, index: DocumentIndex) -> None:
"""Test delete with no args raises ValueError.""" """Test delete with no args raises ValueError."""
with pytest.raises(ValueError): with pytest.raises(ValueError): # noqa: PT011
await index.adelete() await index.adelete()
async def test_delete_missing_content(self, index: DocumentIndex) -> None: async def test_delete_missing_content(self, index: DocumentIndex) -> None:

View File

@@ -7,7 +7,7 @@
def get_pydantic_major_version() -> int: def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic.""" """Get the major version of Pydantic."""
try: try:
import pydantic import pydantic # noqa: PLC0415
return int(pydantic.__version__.split(".")[0]) return int(pydantic.__version__.split(".")[0])
except ImportError: except ImportError:

View File

@@ -59,14 +59,45 @@ ignore_missing_imports = true
target-version = "py39" target-version = "py39"
[tool.ruff.lint] [tool.ruff.lint]
select = ["D", "E", "F", "I", "PGH", "T201", "UP",] select = [ "ALL",]
pyupgrade.keep-runtime-typing = true ignore = [
"C90", # McCabe complexity
"COM812", # Messes with the formatter
"FA100", # Can't activate since we exclude UP007 for now
"FIX002", # Line contains TODO
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLR2004", # Magic numbers
"PLR09", # Too many something (arg, statements, etc)
"RUF012", # Doesn't play well with Pydantic
"S101", # Asserts allowed in tests
"S311", # No need for strong crypto in tests
"SLF001", # Tests may call private methods
"TC001", # Doesn't play well with Pydantic
"TC002", # Doesn't play well with Pydantic
"TC003", # Doesn't play well with Pydantic
"TD002", # Missing author in TODO
"TD003", # Missing issue link in TODO
[tool.ruff.lint.pydocstyle] # TODO rules
convention = "google" "ANN401",
"BLE",
]
unfixable = [
"B028", # People should intentionally tune the stacklevel
"PLW1510", # People should intentionally set the check argument
]
flake8-annotations.allow-star-arg-any = true
flake8-annotations.mypy-init-return = true
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"tests/**" = [ "D1",] "tests/**" = [ "D1",]
"scripts/**" = [ "INP",]
[tool.coverage.run] [tool.coverage.run]
omit = ["tests/*"] omit = ["tests/*"]

View File

@@ -7,6 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field from pydantic import Field
from typing_extensions import override
class ChatParrotLink(BaseChatModel): class ChatParrotLink(BaseChatModel):
@@ -41,6 +42,7 @@ class ChatParrotLink(BaseChatModel):
stop: Optional[list[str]] = None stop: Optional[list[str]] = None
max_retries: int = 2 max_retries: int = 2
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@@ -92,6 +94,7 @@ class ChatParrotLink(BaseChatModel):
generation = ChatGeneration(message=message) generation = ChatGeneration(message=message)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
@override
def _stream( def _stream(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@@ -1,6 +1,7 @@
from typing import Literal from typing import Literal
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from typing_extensions import override
from langchain_tests.integration_tests import ToolsIntegrationTests from langchain_tests.integration_tests import ToolsIntegrationTests
from langchain_tests.unit_tests import ToolsUnitTests from langchain_tests.unit_tests import ToolsUnitTests
@@ -12,6 +13,7 @@ class ParrotMultiplyTool(BaseTool):
"Multiply two numbers like a parrot. Parrots always add eighty for their matey." "Multiply two numbers like a parrot. Parrots always add eighty for their matey."
) )
@override
def _run(self, a: int, b: int) -> int: def _run(self, a: int, b: int) -> int:
return a * b + 80 return a * b + 80
@@ -23,6 +25,7 @@ class ParrotMultiplyArtifactTool(BaseTool):
) )
response_format: Literal["content_and_artifact"] = "content_and_artifact" response_format: Literal["content_and_artifact"] = "content_and_artifact"
@override
def _run(self, a: int, b: int) -> tuple[int, str]: def _run(self, a: int, b: int) -> tuple[int, str]:
return a * b + 80, "parrot artifact" return a * b + 80, "parrot artifact"

View File

@@ -1,9 +1,8 @@
"""Test the standard tests on the custom chat model in the docs.""" """Test the standard tests on the custom chat model in the docs."""
from typing import Optional
import pytest import pytest
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from typing_extensions import Any
from langchain_tests.integration_tests import ChatModelIntegrationTests from langchain_tests.integration_tests import ChatModelIntegrationTests
from langchain_tests.unit_tests import ChatModelUnitTests from langchain_tests.unit_tests import ChatModelUnitTests
@@ -34,7 +33,6 @@ class TestChatParrotLinkIntegration(ChatModelIntegrationTests):
def test_unicode_tool_call_integration( def test_unicode_tool_call_integration(
self, self,
model: BaseChatModel, model: BaseChatModel,
tool_choice: Optional[str] = None, **_: Any,
force_tool_call: bool = True,
) -> None: ) -> None:
"""Expected failure as ChatParrotLink doesn't support tool calling yet.""" """Expected failure as ChatParrotLink doesn't support tool calling yet."""

View File

@@ -2,6 +2,7 @@
import pytest import pytest
from langchain_core.stores import InMemoryStore from langchain_core.stores import InMemoryStore
from typing_extensions import override
from langchain_tests.integration_tests.base_store import ( from langchain_tests.integration_tests.base_store import (
BaseStoreAsyncTests, BaseStoreAsyncTests,
@@ -11,19 +12,23 @@ from langchain_tests.integration_tests.base_store import (
class TestInMemoryStore(BaseStoreSyncTests[str]): class TestInMemoryStore(BaseStoreSyncTests[str]):
@pytest.fixture @pytest.fixture
@override
def three_values(self) -> tuple[str, str, str]: def three_values(self) -> tuple[str, str, str]:
return "foo", "bar", "buzz" return "foo", "bar", "buzz"
@pytest.fixture @pytest.fixture
@override
def kv_store(self) -> InMemoryStore: def kv_store(self) -> InMemoryStore:
return InMemoryStore() return InMemoryStore()
class TestInMemoryStoreAsync(BaseStoreAsyncTests[str]): class TestInMemoryStoreAsync(BaseStoreAsyncTests[str]):
@pytest.fixture @pytest.fixture
@override
def three_values(self) -> tuple[str, str, str]: def three_values(self) -> tuple[str, str, str]:
return "foo", "bar", "buzz" return "foo", "bar", "buzz"
@pytest.fixture @pytest.fixture
@override
async def kv_store(self) -> InMemoryStore: async def kv_store(self) -> InMemoryStore:
return InMemoryStore() return InMemoryStore()

View File

@@ -1,5 +1,6 @@
import pytest import pytest
from langchain_core.caches import InMemoryCache from langchain_core.caches import InMemoryCache
from typing_extensions import override
from langchain_tests.integration_tests.cache import ( from langchain_tests.integration_tests.cache import (
AsyncCacheTestSuite, AsyncCacheTestSuite,
@@ -9,11 +10,13 @@ from langchain_tests.integration_tests.cache import (
class TestInMemoryCache(SyncCacheTestSuite): class TestInMemoryCache(SyncCacheTestSuite):
@pytest.fixture @pytest.fixture
@override
def cache(self) -> InMemoryCache: def cache(self) -> InMemoryCache:
return InMemoryCache() return InMemoryCache()
class TestInMemoryCacheAsync(AsyncCacheTestSuite): class TestInMemoryCacheAsync(AsyncCacheTestSuite):
@pytest.fixture @pytest.fixture
@override
async def cache(self) -> InMemoryCache: async def cache(self) -> InMemoryCache:
return InMemoryCache() return InMemoryCache()