mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
chore(standard-tests): select ALL rules with exclusions (#31937)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
323729915a
commit
c672590f42
@@ -1,15 +1,13 @@
|
||||
"""Standard tests."""
|
||||
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class BaseStandardTests(ABC):
|
||||
class BaseStandardTests:
|
||||
"""Base class for standard tests.
|
||||
|
||||
:private:
|
||||
"""
|
||||
|
||||
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None:
|
||||
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: # noqa: N802
|
||||
"""Test that no standard tests are overridden.
|
||||
|
||||
:private:
|
||||
|
@@ -3,7 +3,7 @@
|
||||
import gzip
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -27,7 +27,13 @@ class CustomSerializer:
|
||||
def serialize(cassette_dict: dict) -> bytes:
|
||||
"""Convert cassette to YAML and compress it."""
|
||||
cassette_dict["requests"] = [
|
||||
request._to_dict() for request in cassette_dict["requests"]
|
||||
{
|
||||
"method": request.method,
|
||||
"uri": request.uri,
|
||||
"body": request.body,
|
||||
"headers": {k: [v] for k, v in request.headers.items()},
|
||||
}
|
||||
for request in cassette_dict["requests"]
|
||||
]
|
||||
yml = yaml.safe_dump(cassette_dict)
|
||||
return gzip.compress(yml.encode("utf-8"))
|
||||
@@ -35,11 +41,9 @@ class CustomSerializer:
|
||||
@staticmethod
|
||||
def deserialize(data: bytes) -> dict:
|
||||
"""Decompress data and convert it from YAML."""
|
||||
text = gzip.decompress(data).decode("utf-8")
|
||||
cassette: dict[str, Any] = yaml.safe_load(text)
|
||||
cassette["requests"] = [
|
||||
Request._from_dict(request) for request in cassette["requests"]
|
||||
]
|
||||
decoded_yaml = gzip.decompress(data).decode("utf-8")
|
||||
cassette = cast("dict[str, Any]", yaml.safe_load(decoded_yaml))
|
||||
cassette["requests"] = [Request(**request) for request in cassette["requests"]]
|
||||
return cassette
|
||||
|
||||
|
||||
|
@@ -30,14 +30,14 @@ from pydantic import BaseModel, Field
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import Field as FieldV1
|
||||
from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped]
|
||||
from typing_extensions import TypedDict
|
||||
from typing_extensions import TypedDict, override
|
||||
from vcr.cassette import Cassette
|
||||
|
||||
from langchain_tests.unit_tests.chat_models import ChatModelTests
|
||||
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
|
||||
def _get_joke_class(
|
||||
def _get_joke_class( # noqa: RET503
|
||||
schema_type: Literal["pydantic", "typeddict", "json_schema"],
|
||||
) -> Any:
|
||||
class Joke(BaseModel):
|
||||
@@ -56,7 +56,7 @@ def _get_joke_class(
|
||||
punchline: Annotated[str, ..., "answer to resolve the joke"]
|
||||
|
||||
def validate_joke_dict(result: Any) -> bool:
|
||||
return all(key in ["setup", "punchline"] for key in result)
|
||||
return all(key in {"setup", "punchline"} for key in result)
|
||||
|
||||
if schema_type == "pydantic":
|
||||
return Joke, validate_joke
|
||||
@@ -75,6 +75,7 @@ class _TestCallbackHandler(BaseCallbackHandler):
|
||||
super().__init__()
|
||||
self.options = []
|
||||
|
||||
@override
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Any,
|
||||
@@ -1042,7 +1043,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
# Needed for langchain_core.callbacks.usage
|
||||
model_name = result.response_metadata.get("model_name")
|
||||
assert isinstance(model_name, str)
|
||||
assert model_name != "", "model_name is empty"
|
||||
assert model_name, "model_name is empty"
|
||||
|
||||
# `input_tokens` is the total, possibly including other unclassified or
|
||||
# system-level tokens.
|
||||
@@ -1056,10 +1057,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
) is not None
|
||||
assert isinstance(input_token_details.get("audio"), int)
|
||||
# Asserts that total input tokens are at least the sum of the token counts
|
||||
total_detailed_tokens = sum(
|
||||
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int)
|
||||
)
|
||||
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
|
||||
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_audio_output()
|
||||
assert (usage_metadata := msg.usage_metadata) is not None
|
||||
@@ -1068,10 +1068,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
) is not None
|
||||
assert isinstance(output_token_details.get("audio"), int)
|
||||
# Asserts that total output tokens are at least the sum of the token counts
|
||||
total_detailed_tokens = sum(
|
||||
assert usage_metadata.get("output_tokens", 0) >= sum(
|
||||
v for v in output_token_details.values() if isinstance(v, int)
|
||||
)
|
||||
assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens
|
||||
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_reasoning_output()
|
||||
assert (usage_metadata := msg.usage_metadata) is not None
|
||||
@@ -1080,10 +1079,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
) is not None
|
||||
assert isinstance(output_token_details.get("reasoning"), int)
|
||||
# Asserts that total output tokens are at least the sum of the token counts
|
||||
total_detailed_tokens = sum(
|
||||
assert usage_metadata.get("output_tokens", 0) >= sum(
|
||||
v for v in output_token_details.values() if isinstance(v, int)
|
||||
)
|
||||
assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens
|
||||
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_cache_read_input()
|
||||
assert (usage_metadata := msg.usage_metadata) is not None
|
||||
@@ -1092,10 +1090,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
) is not None
|
||||
assert isinstance(input_token_details.get("cache_read"), int)
|
||||
# Asserts that total input tokens are at least the sum of the token counts
|
||||
total_detailed_tokens = sum(
|
||||
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int)
|
||||
)
|
||||
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
|
||||
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_cache_creation_input()
|
||||
assert (usage_metadata := msg.usage_metadata) is not None
|
||||
@@ -1104,10 +1101,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
) is not None
|
||||
assert isinstance(input_token_details.get("cache_creation"), int)
|
||||
# Asserts that total input tokens are at least the sum of the token counts
|
||||
total_detailed_tokens = sum(
|
||||
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int)
|
||||
)
|
||||
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
|
||||
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
"""Test usage metadata in streaming mode.
|
||||
@@ -1235,7 +1231,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
# Needed for langchain_core.callbacks.usage
|
||||
model_name = full.response_metadata.get("model_name")
|
||||
assert isinstance(model_name, str)
|
||||
assert model_name != "", "model_name is empty"
|
||||
assert model_name, "model_name is empty"
|
||||
|
||||
if "audio_input" in self.supported_usage_metadata_details["stream"]:
|
||||
msg = self.invoke_with_audio_input(stream=True)
|
||||
@@ -1720,7 +1716,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
pytest.skip("Test requires tool choice.")
|
||||
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
def get_weather(location: str) -> str: # noqa: ARG001
|
||||
"""Get weather at a location."""
|
||||
return "It's sunny."
|
||||
|
||||
@@ -2130,7 +2126,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
See `example implementation <https://python.langchain.com/api_reference/_modules/langchain_openai/chat_models/base.html#BaseChatOpenAI.with_structured_output>`__
|
||||
of ``with_structured_output``.
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
if not self.has_structured_output:
|
||||
pytest.skip("Test requires structured output.")
|
||||
|
||||
@@ -2262,8 +2258,8 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.supports_json_mode:
|
||||
pytest.skip("Test requires json mode support.")
|
||||
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
from pydantic import BaseModel as BaseModelProper # noqa: PLC0415
|
||||
from pydantic import Field as FieldProper # noqa: PLC0415
|
||||
|
||||
class Joke(BaseModelProper):
|
||||
"""Joke to tell user."""
|
||||
@@ -2912,7 +2908,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
def get_weather(location: str) -> str: # noqa: ARG001
|
||||
"""Call to surf the web."""
|
||||
return "It's sunny."
|
||||
|
||||
|
@@ -175,7 +175,7 @@ class DocumentIndexerTestSuite(ABC):
|
||||
|
||||
def test_delete_no_args(self, index: DocumentIndex) -> None:
|
||||
"""Test delete with no args raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError): # noqa: PT011
|
||||
index.delete()
|
||||
|
||||
def test_delete_missing_content(self, index: DocumentIndex) -> None:
|
||||
@@ -367,7 +367,7 @@ class AsyncDocumentIndexTestSuite(ABC):
|
||||
|
||||
async def test_delete_no_args(self, index: DocumentIndex) -> None:
|
||||
"""Test delete with no args raises ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError): # noqa: PT011
|
||||
await index.adelete()
|
||||
|
||||
async def test_delete_missing_content(self, index: DocumentIndex) -> None:
|
||||
|
@@ -7,7 +7,7 @@
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""Get the major version of Pydantic."""
|
||||
try:
|
||||
import pydantic
|
||||
import pydantic # noqa: PLC0415
|
||||
|
||||
return int(pydantic.__version__.split(".")[0])
|
||||
except ImportError:
|
||||
|
@@ -59,14 +59,45 @@ ignore_missing_imports = true
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["D", "E", "F", "I", "PGH", "T201", "UP",]
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
select = [ "ALL",]
|
||||
ignore = [
|
||||
"C90", # McCabe complexity
|
||||
"COM812", # Messes with the formatter
|
||||
"FA100", # Can't activate since we exclude UP007 for now
|
||||
"FIX002", # Line contains TODO
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"PLR2004", # Magic numbers
|
||||
"PLR09", # Too many something (arg, statements, etc)
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"S101", # Asserts allowed in tests
|
||||
"S311", # No need for strong crypto in tests
|
||||
"SLF001", # Tests may call private methods
|
||||
"TC001", # Doesn't play well with Pydantic
|
||||
"TC002", # Doesn't play well with Pydantic
|
||||
"TC003", # Doesn't play well with Pydantic
|
||||
"TD002", # Missing author in TODO
|
||||
"TD003", # Missing issue link in TODO
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
# TODO rules
|
||||
"ANN401",
|
||||
"BLE",
|
||||
]
|
||||
unfixable = [
|
||||
"B028", # People should intentionally tune the stacklevel
|
||||
"PLW1510", # People should intentionally set the check argument
|
||||
]
|
||||
|
||||
flake8-annotations.allow-star-arg-any = true
|
||||
flake8-annotations.mypy-init-return = true
|
||||
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
|
||||
pydocstyle.convention = "google"
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = [ "D1",]
|
||||
"scripts/**" = [ "INP",]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
@@ -7,6 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class ChatParrotLink(BaseChatModel):
|
||||
@@ -41,6 +42,7 @@ class ChatParrotLink(BaseChatModel):
|
||||
stop: Optional[list[str]] = None
|
||||
max_retries: int = 2
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@@ -92,6 +94,7 @@ class ChatParrotLink(BaseChatModel):
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_tests.integration_tests import ToolsIntegrationTests
|
||||
from langchain_tests.unit_tests import ToolsUnitTests
|
||||
@@ -12,6 +13,7 @@ class ParrotMultiplyTool(BaseTool):
|
||||
"Multiply two numbers like a parrot. Parrots always add eighty for their matey."
|
||||
)
|
||||
|
||||
@override
|
||||
def _run(self, a: int, b: int) -> int:
|
||||
return a * b + 80
|
||||
|
||||
@@ -23,6 +25,7 @@ class ParrotMultiplyArtifactTool(BaseTool):
|
||||
)
|
||||
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
||||
|
||||
@override
|
||||
def _run(self, a: int, b: int) -> tuple[int, str]:
|
||||
return a * b + 80, "parrot artifact"
|
||||
|
||||
|
@@ -1,9 +1,8 @@
|
||||
"""Test the standard tests on the custom chat model in the docs."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from typing_extensions import Any
|
||||
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||
@@ -34,7 +33,6 @@ class TestChatParrotLinkIntegration(ChatModelIntegrationTests):
|
||||
def test_unicode_tool_call_integration(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
tool_choice: Optional[str] = None,
|
||||
force_tool_call: bool = True,
|
||||
**_: Any,
|
||||
) -> None:
|
||||
"""Expected failure as ChatParrotLink doesn't support tool calling yet."""
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
import pytest
|
||||
from langchain_core.stores import InMemoryStore
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_tests.integration_tests.base_store import (
|
||||
BaseStoreAsyncTests,
|
||||
@@ -11,19 +12,23 @@ from langchain_tests.integration_tests.base_store import (
|
||||
|
||||
class TestInMemoryStore(BaseStoreSyncTests[str]):
|
||||
@pytest.fixture
|
||||
@override
|
||||
def three_values(self) -> tuple[str, str, str]:
|
||||
return "foo", "bar", "buzz"
|
||||
|
||||
@pytest.fixture
|
||||
@override
|
||||
def kv_store(self) -> InMemoryStore:
|
||||
return InMemoryStore()
|
||||
|
||||
|
||||
class TestInMemoryStoreAsync(BaseStoreAsyncTests[str]):
|
||||
@pytest.fixture
|
||||
@override
|
||||
def three_values(self) -> tuple[str, str, str]:
|
||||
return "foo", "bar", "buzz"
|
||||
|
||||
@pytest.fixture
|
||||
@override
|
||||
async def kv_store(self) -> InMemoryStore:
|
||||
return InMemoryStore()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_tests.integration_tests.cache import (
|
||||
AsyncCacheTestSuite,
|
||||
@@ -9,11 +10,13 @@ from langchain_tests.integration_tests.cache import (
|
||||
|
||||
class TestInMemoryCache(SyncCacheTestSuite):
|
||||
@pytest.fixture
|
||||
@override
|
||||
def cache(self) -> InMemoryCache:
|
||||
return InMemoryCache()
|
||||
|
||||
|
||||
class TestInMemoryCacheAsync(AsyncCacheTestSuite):
|
||||
@pytest.fixture
|
||||
@override
|
||||
async def cache(self) -> InMemoryCache:
|
||||
return InMemoryCache()
|
||||
|
Reference in New Issue
Block a user