mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
chore(standard-tests): select ALL rules with exclusions (#31937)
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
323729915a
commit
c672590f42
@@ -1,15 +1,13 @@
|
|||||||
"""Standard tests."""
|
"""Standard tests."""
|
||||||
|
|
||||||
from abc import ABC
|
|
||||||
|
|
||||||
|
class BaseStandardTests:
|
||||||
class BaseStandardTests(ABC):
|
|
||||||
"""Base class for standard tests.
|
"""Base class for standard tests.
|
||||||
|
|
||||||
:private:
|
:private:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None:
|
def test_no_overrides_DO_NOT_OVERRIDE(self) -> None: # noqa: N802
|
||||||
"""Test that no standard tests are overridden.
|
"""Test that no standard tests are overridden.
|
||||||
|
|
||||||
:private:
|
:private:
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
import gzip
|
import gzip
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, Union, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
@@ -27,7 +27,13 @@ class CustomSerializer:
|
|||||||
def serialize(cassette_dict: dict) -> bytes:
|
def serialize(cassette_dict: dict) -> bytes:
|
||||||
"""Convert cassette to YAML and compress it."""
|
"""Convert cassette to YAML and compress it."""
|
||||||
cassette_dict["requests"] = [
|
cassette_dict["requests"] = [
|
||||||
request._to_dict() for request in cassette_dict["requests"]
|
{
|
||||||
|
"method": request.method,
|
||||||
|
"uri": request.uri,
|
||||||
|
"body": request.body,
|
||||||
|
"headers": {k: [v] for k, v in request.headers.items()},
|
||||||
|
}
|
||||||
|
for request in cassette_dict["requests"]
|
||||||
]
|
]
|
||||||
yml = yaml.safe_dump(cassette_dict)
|
yml = yaml.safe_dump(cassette_dict)
|
||||||
return gzip.compress(yml.encode("utf-8"))
|
return gzip.compress(yml.encode("utf-8"))
|
||||||
@@ -35,11 +41,9 @@ class CustomSerializer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(data: bytes) -> dict:
|
def deserialize(data: bytes) -> dict:
|
||||||
"""Decompress data and convert it from YAML."""
|
"""Decompress data and convert it from YAML."""
|
||||||
text = gzip.decompress(data).decode("utf-8")
|
decoded_yaml = gzip.decompress(data).decode("utf-8")
|
||||||
cassette: dict[str, Any] = yaml.safe_load(text)
|
cassette = cast("dict[str, Any]", yaml.safe_load(decoded_yaml))
|
||||||
cassette["requests"] = [
|
cassette["requests"] = [Request(**request) for request in cassette["requests"]]
|
||||||
Request._from_dict(request) for request in cassette["requests"]
|
|
||||||
]
|
|
||||||
return cassette
|
return cassette
|
||||||
|
|
||||||
|
|
||||||
|
@@ -30,14 +30,14 @@ from pydantic import BaseModel, Field
|
|||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from pydantic.v1 import Field as FieldV1
|
from pydantic.v1 import Field as FieldV1
|
||||||
from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped]
|
from pytest_benchmark.fixture import BenchmarkFixture # type: ignore[import-untyped]
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict, override
|
||||||
from vcr.cassette import Cassette
|
from vcr.cassette import Cassette
|
||||||
|
|
||||||
from langchain_tests.unit_tests.chat_models import ChatModelTests
|
from langchain_tests.unit_tests.chat_models import ChatModelTests
|
||||||
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
from langchain_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
|
|
||||||
def _get_joke_class(
|
def _get_joke_class( # noqa: RET503
|
||||||
schema_type: Literal["pydantic", "typeddict", "json_schema"],
|
schema_type: Literal["pydantic", "typeddict", "json_schema"],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
class Joke(BaseModel):
|
class Joke(BaseModel):
|
||||||
@@ -56,7 +56,7 @@ def _get_joke_class(
|
|||||||
punchline: Annotated[str, ..., "answer to resolve the joke"]
|
punchline: Annotated[str, ..., "answer to resolve the joke"]
|
||||||
|
|
||||||
def validate_joke_dict(result: Any) -> bool:
|
def validate_joke_dict(result: Any) -> bool:
|
||||||
return all(key in ["setup", "punchline"] for key in result)
|
return all(key in {"setup", "punchline"} for key in result)
|
||||||
|
|
||||||
if schema_type == "pydantic":
|
if schema_type == "pydantic":
|
||||||
return Joke, validate_joke
|
return Joke, validate_joke
|
||||||
@@ -75,6 +75,7 @@ class _TestCallbackHandler(BaseCallbackHandler):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.options = []
|
self.options = []
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: Any,
|
serialized: Any,
|
||||||
@@ -1042,7 +1043,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
# Needed for langchain_core.callbacks.usage
|
# Needed for langchain_core.callbacks.usage
|
||||||
model_name = result.response_metadata.get("model_name")
|
model_name = result.response_metadata.get("model_name")
|
||||||
assert isinstance(model_name, str)
|
assert isinstance(model_name, str)
|
||||||
assert model_name != "", "model_name is empty"
|
assert model_name, "model_name is empty"
|
||||||
|
|
||||||
# `input_tokens` is the total, possibly including other unclassified or
|
# `input_tokens` is the total, possibly including other unclassified or
|
||||||
# system-level tokens.
|
# system-level tokens.
|
||||||
@@ -1056,10 +1057,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
) is not None
|
) is not None
|
||||||
assert isinstance(input_token_details.get("audio"), int)
|
assert isinstance(input_token_details.get("audio"), int)
|
||||||
# Asserts that total input tokens are at least the sum of the token counts
|
# Asserts that total input tokens are at least the sum of the token counts
|
||||||
total_detailed_tokens = sum(
|
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||||
v for v in input_token_details.values() if isinstance(v, int)
|
v for v in input_token_details.values() if isinstance(v, int)
|
||||||
)
|
)
|
||||||
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
|
|
||||||
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
|
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
|
||||||
msg = self.invoke_with_audio_output()
|
msg = self.invoke_with_audio_output()
|
||||||
assert (usage_metadata := msg.usage_metadata) is not None
|
assert (usage_metadata := msg.usage_metadata) is not None
|
||||||
@@ -1068,10 +1068,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
) is not None
|
) is not None
|
||||||
assert isinstance(output_token_details.get("audio"), int)
|
assert isinstance(output_token_details.get("audio"), int)
|
||||||
# Asserts that total output tokens are at least the sum of the token counts
|
# Asserts that total output tokens are at least the sum of the token counts
|
||||||
total_detailed_tokens = sum(
|
assert usage_metadata.get("output_tokens", 0) >= sum(
|
||||||
v for v in output_token_details.values() if isinstance(v, int)
|
v for v in output_token_details.values() if isinstance(v, int)
|
||||||
)
|
)
|
||||||
assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens
|
|
||||||
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
|
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
|
||||||
msg = self.invoke_with_reasoning_output()
|
msg = self.invoke_with_reasoning_output()
|
||||||
assert (usage_metadata := msg.usage_metadata) is not None
|
assert (usage_metadata := msg.usage_metadata) is not None
|
||||||
@@ -1080,10 +1079,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
) is not None
|
) is not None
|
||||||
assert isinstance(output_token_details.get("reasoning"), int)
|
assert isinstance(output_token_details.get("reasoning"), int)
|
||||||
# Asserts that total output tokens are at least the sum of the token counts
|
# Asserts that total output tokens are at least the sum of the token counts
|
||||||
total_detailed_tokens = sum(
|
assert usage_metadata.get("output_tokens", 0) >= sum(
|
||||||
v for v in output_token_details.values() if isinstance(v, int)
|
v for v in output_token_details.values() if isinstance(v, int)
|
||||||
)
|
)
|
||||||
assert usage_metadata.get("output_tokens", 0) >= total_detailed_tokens
|
|
||||||
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
|
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
|
||||||
msg = self.invoke_with_cache_read_input()
|
msg = self.invoke_with_cache_read_input()
|
||||||
assert (usage_metadata := msg.usage_metadata) is not None
|
assert (usage_metadata := msg.usage_metadata) is not None
|
||||||
@@ -1092,10 +1090,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
) is not None
|
) is not None
|
||||||
assert isinstance(input_token_details.get("cache_read"), int)
|
assert isinstance(input_token_details.get("cache_read"), int)
|
||||||
# Asserts that total input tokens are at least the sum of the token counts
|
# Asserts that total input tokens are at least the sum of the token counts
|
||||||
total_detailed_tokens = sum(
|
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||||
v for v in input_token_details.values() if isinstance(v, int)
|
v for v in input_token_details.values() if isinstance(v, int)
|
||||||
)
|
)
|
||||||
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
|
|
||||||
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
|
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
|
||||||
msg = self.invoke_with_cache_creation_input()
|
msg = self.invoke_with_cache_creation_input()
|
||||||
assert (usage_metadata := msg.usage_metadata) is not None
|
assert (usage_metadata := msg.usage_metadata) is not None
|
||||||
@@ -1104,10 +1101,9 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
) is not None
|
) is not None
|
||||||
assert isinstance(input_token_details.get("cache_creation"), int)
|
assert isinstance(input_token_details.get("cache_creation"), int)
|
||||||
# Asserts that total input tokens are at least the sum of the token counts
|
# Asserts that total input tokens are at least the sum of the token counts
|
||||||
total_detailed_tokens = sum(
|
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||||
v for v in input_token_details.values() if isinstance(v, int)
|
v for v in input_token_details.values() if isinstance(v, int)
|
||||||
)
|
)
|
||||||
assert usage_metadata.get("input_tokens", 0) >= total_detailed_tokens
|
|
||||||
|
|
||||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||||
"""Test usage metadata in streaming mode.
|
"""Test usage metadata in streaming mode.
|
||||||
@@ -1235,7 +1231,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
# Needed for langchain_core.callbacks.usage
|
# Needed for langchain_core.callbacks.usage
|
||||||
model_name = full.response_metadata.get("model_name")
|
model_name = full.response_metadata.get("model_name")
|
||||||
assert isinstance(model_name, str)
|
assert isinstance(model_name, str)
|
||||||
assert model_name != "", "model_name is empty"
|
assert model_name, "model_name is empty"
|
||||||
|
|
||||||
if "audio_input" in self.supported_usage_metadata_details["stream"]:
|
if "audio_input" in self.supported_usage_metadata_details["stream"]:
|
||||||
msg = self.invoke_with_audio_input(stream=True)
|
msg = self.invoke_with_audio_input(stream=True)
|
||||||
@@ -1720,7 +1716,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
pytest.skip("Test requires tool choice.")
|
pytest.skip("Test requires tool choice.")
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_weather(location: str) -> str:
|
def get_weather(location: str) -> str: # noqa: ARG001
|
||||||
"""Get weather at a location."""
|
"""Get weather at a location."""
|
||||||
return "It's sunny."
|
return "It's sunny."
|
||||||
|
|
||||||
@@ -2130,7 +2126,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
See `example implementation <https://python.langchain.com/api_reference/_modules/langchain_openai/chat_models/base.html#BaseChatOpenAI.with_structured_output>`__
|
See `example implementation <https://python.langchain.com/api_reference/_modules/langchain_openai/chat_models/base.html#BaseChatOpenAI.with_structured_output>`__
|
||||||
of ``with_structured_output``.
|
of ``with_structured_output``.
|
||||||
|
|
||||||
""" # noqa: E501
|
"""
|
||||||
if not self.has_structured_output:
|
if not self.has_structured_output:
|
||||||
pytest.skip("Test requires structured output.")
|
pytest.skip("Test requires structured output.")
|
||||||
|
|
||||||
@@ -2262,8 +2258,8 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
if not self.supports_json_mode:
|
if not self.supports_json_mode:
|
||||||
pytest.skip("Test requires json mode support.")
|
pytest.skip("Test requires json mode support.")
|
||||||
|
|
||||||
from pydantic import BaseModel as BaseModelProper
|
from pydantic import BaseModel as BaseModelProper # noqa: PLC0415
|
||||||
from pydantic import Field as FieldProper
|
from pydantic import Field as FieldProper # noqa: PLC0415
|
||||||
|
|
||||||
class Joke(BaseModelProper):
|
class Joke(BaseModelProper):
|
||||||
"""Joke to tell user."""
|
"""Joke to tell user."""
|
||||||
@@ -2912,7 +2908,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
pytest.skip("Test requires tool calling.")
|
pytest.skip("Test requires tool calling.")
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def get_weather(location: str) -> str:
|
def get_weather(location: str) -> str: # noqa: ARG001
|
||||||
"""Call to surf the web."""
|
"""Call to surf the web."""
|
||||||
return "It's sunny."
|
return "It's sunny."
|
||||||
|
|
||||||
|
@@ -175,7 +175,7 @@ class DocumentIndexerTestSuite(ABC):
|
|||||||
|
|
||||||
def test_delete_no_args(self, index: DocumentIndex) -> None:
|
def test_delete_no_args(self, index: DocumentIndex) -> None:
|
||||||
"""Test delete with no args raises ValueError."""
|
"""Test delete with no args raises ValueError."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError): # noqa: PT011
|
||||||
index.delete()
|
index.delete()
|
||||||
|
|
||||||
def test_delete_missing_content(self, index: DocumentIndex) -> None:
|
def test_delete_missing_content(self, index: DocumentIndex) -> None:
|
||||||
@@ -367,7 +367,7 @@ class AsyncDocumentIndexTestSuite(ABC):
|
|||||||
|
|
||||||
async def test_delete_no_args(self, index: DocumentIndex) -> None:
|
async def test_delete_no_args(self, index: DocumentIndex) -> None:
|
||||||
"""Test delete with no args raises ValueError."""
|
"""Test delete with no args raises ValueError."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError): # noqa: PT011
|
||||||
await index.adelete()
|
await index.adelete()
|
||||||
|
|
||||||
async def test_delete_missing_content(self, index: DocumentIndex) -> None:
|
async def test_delete_missing_content(self, index: DocumentIndex) -> None:
|
||||||
|
@@ -7,7 +7,7 @@
|
|||||||
def get_pydantic_major_version() -> int:
|
def get_pydantic_major_version() -> int:
|
||||||
"""Get the major version of Pydantic."""
|
"""Get the major version of Pydantic."""
|
||||||
try:
|
try:
|
||||||
import pydantic
|
import pydantic # noqa: PLC0415
|
||||||
|
|
||||||
return int(pydantic.__version__.split(".")[0])
|
return int(pydantic.__version__.split(".")[0])
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@@ -59,14 +59,45 @@ ignore_missing_imports = true
|
|||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["D", "E", "F", "I", "PGH", "T201", "UP",]
|
select = [ "ALL",]
|
||||||
pyupgrade.keep-runtime-typing = true
|
ignore = [
|
||||||
|
"C90", # McCabe complexity
|
||||||
|
"COM812", # Messes with the formatter
|
||||||
|
"FA100", # Can't activate since we exclude UP007 for now
|
||||||
|
"FIX002", # Line contains TODO
|
||||||
|
"ISC001", # Messes with the formatter
|
||||||
|
"PERF203", # Rarely useful
|
||||||
|
"PLR2004", # Magic numbers
|
||||||
|
"PLR09", # Too many something (arg, statements, etc)
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"S101", # Asserts allowed in tests
|
||||||
|
"S311", # No need for strong crypto in tests
|
||||||
|
"SLF001", # Tests may call private methods
|
||||||
|
"TC001", # Doesn't play well with Pydantic
|
||||||
|
"TC002", # Doesn't play well with Pydantic
|
||||||
|
"TC003", # Doesn't play well with Pydantic
|
||||||
|
"TD002", # Missing author in TODO
|
||||||
|
"TD003", # Missing issue link in TODO
|
||||||
|
|
||||||
[tool.ruff.lint.pydocstyle]
|
# TODO rules
|
||||||
convention = "google"
|
"ANN401",
|
||||||
|
"BLE",
|
||||||
|
]
|
||||||
|
unfixable = [
|
||||||
|
"B028", # People should intentionally tune the stacklevel
|
||||||
|
"PLW1510", # People should intentionally set the check argument
|
||||||
|
]
|
||||||
|
|
||||||
|
flake8-annotations.allow-star-arg-any = true
|
||||||
|
flake8-annotations.mypy-init-return = true
|
||||||
|
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||||
|
pep8-naming.classmethod-decorators = [ "classmethod", "langchain_core.utils.pydantic.pre_init", "pydantic.field_validator", "pydantic.v1.root_validator",]
|
||||||
|
pydocstyle.convention = "google"
|
||||||
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"tests/**" = [ "D1",]
|
"tests/**" = [ "D1",]
|
||||||
|
"scripts/**" = [ "INP",]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = ["tests/*"]
|
omit = ["tests/*"]
|
||||||
|
@@ -7,6 +7,7 @@ from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class ChatParrotLink(BaseChatModel):
|
class ChatParrotLink(BaseChatModel):
|
||||||
@@ -41,6 +42,7 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
stop: Optional[list[str]] = None
|
stop: Optional[list[str]] = None
|
||||||
max_retries: int = 2
|
max_retries: int = 2
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@@ -92,6 +94,7 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
generation = ChatGeneration(message=message)
|
generation = ChatGeneration(message=message)
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
@override
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_tests.integration_tests import ToolsIntegrationTests
|
from langchain_tests.integration_tests import ToolsIntegrationTests
|
||||||
from langchain_tests.unit_tests import ToolsUnitTests
|
from langchain_tests.unit_tests import ToolsUnitTests
|
||||||
@@ -12,6 +13,7 @@ class ParrotMultiplyTool(BaseTool):
|
|||||||
"Multiply two numbers like a parrot. Parrots always add eighty for their matey."
|
"Multiply two numbers like a parrot. Parrots always add eighty for their matey."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def _run(self, a: int, b: int) -> int:
|
def _run(self, a: int, b: int) -> int:
|
||||||
return a * b + 80
|
return a * b + 80
|
||||||
|
|
||||||
@@ -23,6 +25,7 @@ class ParrotMultiplyArtifactTool(BaseTool):
|
|||||||
)
|
)
|
||||||
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
response_format: Literal["content_and_artifact"] = "content_and_artifact"
|
||||||
|
|
||||||
|
@override
|
||||||
def _run(self, a: int, b: int) -> tuple[int, str]:
|
def _run(self, a: int, b: int) -> tuple[int, str]:
|
||||||
return a * b + 80, "parrot artifact"
|
return a * b + 80, "parrot artifact"
|
||||||
|
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
"""Test the standard tests on the custom chat model in the docs."""
|
"""Test the standard tests on the custom chat model in the docs."""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
from typing_extensions import Any
|
||||||
|
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||||
@@ -34,7 +33,6 @@ class TestChatParrotLinkIntegration(ChatModelIntegrationTests):
|
|||||||
def test_unicode_tool_call_integration(
|
def test_unicode_tool_call_integration(
|
||||||
self,
|
self,
|
||||||
model: BaseChatModel,
|
model: BaseChatModel,
|
||||||
tool_choice: Optional[str] = None,
|
**_: Any,
|
||||||
force_tool_call: bool = True,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Expected failure as ChatParrotLink doesn't support tool calling yet."""
|
"""Expected failure as ChatParrotLink doesn't support tool calling yet."""
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.stores import InMemoryStore
|
from langchain_core.stores import InMemoryStore
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_tests.integration_tests.base_store import (
|
from langchain_tests.integration_tests.base_store import (
|
||||||
BaseStoreAsyncTests,
|
BaseStoreAsyncTests,
|
||||||
@@ -11,19 +12,23 @@ from langchain_tests.integration_tests.base_store import (
|
|||||||
|
|
||||||
class TestInMemoryStore(BaseStoreSyncTests[str]):
|
class TestInMemoryStore(BaseStoreSyncTests[str]):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@override
|
||||||
def three_values(self) -> tuple[str, str, str]:
|
def three_values(self) -> tuple[str, str, str]:
|
||||||
return "foo", "bar", "buzz"
|
return "foo", "bar", "buzz"
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@override
|
||||||
def kv_store(self) -> InMemoryStore:
|
def kv_store(self) -> InMemoryStore:
|
||||||
return InMemoryStore()
|
return InMemoryStore()
|
||||||
|
|
||||||
|
|
||||||
class TestInMemoryStoreAsync(BaseStoreAsyncTests[str]):
|
class TestInMemoryStoreAsync(BaseStoreAsyncTests[str]):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@override
|
||||||
def three_values(self) -> tuple[str, str, str]:
|
def three_values(self) -> tuple[str, str, str]:
|
||||||
return "foo", "bar", "buzz"
|
return "foo", "bar", "buzz"
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@override
|
||||||
async def kv_store(self) -> InMemoryStore:
|
async def kv_store(self) -> InMemoryStore:
|
||||||
return InMemoryStore()
|
return InMemoryStore()
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.caches import InMemoryCache
|
from langchain_core.caches import InMemoryCache
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain_tests.integration_tests.cache import (
|
from langchain_tests.integration_tests.cache import (
|
||||||
AsyncCacheTestSuite,
|
AsyncCacheTestSuite,
|
||||||
@@ -9,11 +10,13 @@ from langchain_tests.integration_tests.cache import (
|
|||||||
|
|
||||||
class TestInMemoryCache(SyncCacheTestSuite):
|
class TestInMemoryCache(SyncCacheTestSuite):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@override
|
||||||
def cache(self) -> InMemoryCache:
|
def cache(self) -> InMemoryCache:
|
||||||
return InMemoryCache()
|
return InMemoryCache()
|
||||||
|
|
||||||
|
|
||||||
class TestInMemoryCacheAsync(AsyncCacheTestSuite):
|
class TestInMemoryCacheAsync(AsyncCacheTestSuite):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@override
|
||||||
async def cache(self) -> InMemoryCache:
|
async def cache(self) -> InMemoryCache:
|
||||||
return InMemoryCache()
|
return InMemoryCache()
|
||||||
|
Reference in New Issue
Block a user