diff --git a/libs/partners/mistralai/Makefile b/libs/partners/mistralai/Makefile index e2365801c12..4a9b915e6b4 100644 --- a/libs/partners/mistralai/Makefile +++ b/libs/partners/mistralai/Makefile @@ -21,7 +21,7 @@ test_watch: integration_test integration_tests: - uv run --group test --group test_integration pytest -v --tb=short -n auto $(TEST_FILE) + uv run --group test --group test_integration pytest -v --tb=short -n auto --retries 3 --retry-delay 2 $(TEST_FILE) ###################### diff --git a/libs/partners/mistralai/pyproject.toml b/libs/partners/mistralai/pyproject.toml index 807d37e1d34..aa53bf2ecf9 100644 --- a/libs/partners/mistralai/pyproject.toml +++ b/libs/partners/mistralai/pyproject.toml @@ -44,6 +44,7 @@ Reddit = "https://www.reddit.com/r/LangChain/" test = [ "pytest>=9.0.3,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", + "pytest-retry>=1.7.0,<1.8.0", "pytest-watcher>=0.3.4,<1.0.0", "pytest-xdist>=3.6.1,<4.0.0", "langchain-core>=1.4.0,<2.0.0", diff --git a/libs/partners/mistralai/tests/integration_tests/_rate_limiter.py b/libs/partners/mistralai/tests/integration_tests/_rate_limiter.py new file mode 100644 index 00000000000..afefdbc35cf --- /dev/null +++ b/libs/partners/mistralai/tests/integration_tests/_rate_limiter.py @@ -0,0 +1,18 @@ +"""Shared rate limiter for Mistral integration tests. + +Scaled by ``PYTEST_XDIST_WORKER_COUNT`` so aggregate QPS across all xdist +workers stays bounded near the target rate. +""" + +from __future__ import annotations + +import os + +from langchain_core.rate_limiters import InMemoryRateLimiter + +_TARGET_REQUESTS_PER_SECOND = 0.5 +_WORKER_COUNT = max(1, int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1"))) + +rate_limiter = InMemoryRateLimiter( + requests_per_second=_TARGET_REQUESTS_PER_SECOND / _WORKER_COUNT, +) diff --git a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py index bef96ae3e14..a1d42a47ca2 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/integration_tests/test_chat_models.py @@ -13,11 +13,12 @@ from pydantic import BaseModel from typing_extensions import TypedDict from langchain_mistralai.chat_models import ChatMistralAI +from tests.integration_tests._rate_limiter import rate_limiter async def test_astream() -> None: """Test streaming tokens from ChatMistralAI.""" - llm = ChatMistralAI() + llm = ChatMistralAI(rate_limiter=rate_limiter) full: BaseMessageChunk | None = None chunks_with_token_counts = 0 @@ -70,7 +71,7 @@ def _check_parsed_result(result: Any, schema: Any) -> None: @pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()]) def test_structured_output_json_schema(schema: Any) -> None: - llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg] + llm = ChatMistralAI(model="ministral-8b-latest", rate_limiter=rate_limiter) # type: ignore[call-arg] structured_llm = llm.with_structured_output(schema, method="json_schema") messages = [ @@ -91,7 +92,7 @@ def test_structured_output_json_schema(schema: Any) -> None: @pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()]) async def test_structured_output_json_schema_async(schema: Any) -> None: - llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg] + llm = ChatMistralAI(model="ministral-8b-latest", rate_limiter=rate_limiter) # type: ignore[call-arg] structured_llm = llm.with_structured_output(schema, method="json_schema") messages = [ @@ -116,6 +117,7 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None: mistral = ChatMistralAI( timeout=1, # Very short timeout to trigger timeouts max_retries=3, # Should retry 3 times + rate_limiter=rate_limiter, ) # Simple test input that should take longer than 1 second to process @@ -148,7 +150,7 @@ def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None: def test_reasoning() -> None: - model = ChatMistralAI(model="magistral-medium-latest") # type: ignore[call-arg] + model = ChatMistralAI(model="magistral-medium-latest", rate_limiter=rate_limiter) # type: ignore[call-arg] input_message = { "role": "user", "content": "Hello, my name is Bob.", @@ -172,7 +174,11 @@ def test_reasoning() -> None: def test_reasoning_v1() -> None: - model = ChatMistralAI(model="magistral-medium-latest", output_version="v1") # type: ignore[call-arg] + model = ChatMistralAI( # type: ignore[call-arg] + model="magistral-medium-latest", + output_version="v1", + rate_limiter=rate_limiter, + ) input_message = { "role": "user", "content": "Hello, my name is Bob.", diff --git a/libs/partners/mistralai/tests/integration_tests/test_standard.py b/libs/partners/mistralai/tests/integration_tests/test_standard.py index ce67e0db4ed..c1281898ea9 100644 --- a/libs/partners/mistralai/tests/integration_tests/test_standard.py +++ b/libs/partners/mistralai/tests/integration_tests/test_standard.py @@ -7,6 +7,7 @@ from langchain_tests.integration_tests import ( # type: ignore[import-not-found ) from langchain_mistralai import ChatMistralAI +from tests.integration_tests._rate_limiter import rate_limiter class TestMistralStandard(ChatModelIntegrationTests): @@ -16,7 +17,11 @@ class TestMistralStandard(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: - return {"model": "mistral-large-latest", "temperature": 0} + return { + "model": "mistral-large-latest", + "temperature": 0, + "rate_limiter": rate_limiter, + } @property def supports_json_mode(self) -> bool: diff --git a/libs/partners/mistralai/uv.lock b/libs/partners/mistralai/uv.lock index 4abee098371..6d773ee3aae 100644 --- a/libs/partners/mistralai/uv.lock +++ b/libs/partners/mistralai/uv.lock @@ -454,6 +454,7 @@ test = [ { name = "langchain-tests" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-retry" }, { name = "pytest-watcher" }, { name = "pytest-xdist" }, ] @@ -479,6 +480,7 @@ test = [ { name = "langchain-tests", editable = "../../standard-tests" }, { name = "pytest", specifier = ">=9.0.3,<10.0.0" }, { name = "pytest-asyncio", specifier = ">=1.3.0,<2.0.0" }, + { name = "pytest-retry", specifier = ">=1.7.0,<1.8.0" }, { name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" }, { name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" }, ] @@ -1137,6 +1139,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/c2/ce34735972cc42d912173e79f200fe66530225190c06655c5632a9d88f1e/pytest_recording-0.13.4-py3-none-any.whl", hash = "sha256:ad49a434b51b1c4f78e85b1e6b74fdcc2a0a581ca16e52c798c6ace971f7f439", size = 13723, upload-time = "2025-05-08T10:41:09.684Z" }, ] +[[package]] +name = "pytest-retry" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/5b/607b017994cca28de3a1ad22a3eee8418e5d428dcd8ec25b26b18e995a73/pytest_retry-1.7.0.tar.gz", hash = "sha256:f8d52339f01e949df47c11ba9ee8d5b362f5824dff580d3870ec9ae0057df80f", size = 19977, upload-time = "2025-01-19T01:56:13.115Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/ff/3266c8a73b9b93c4b14160a7e2b31d1e1088e28ed29f4c2d93ae34093bfd/pytest_retry-1.7.0-py3-none-any.whl", hash = "sha256:a2dac85b79a4e2375943f1429479c65beb6c69553e7dae6b8332be47a60954f4", size = 13775, upload-time = "2025-01-19T01:56:11.199Z" }, +] + [[package]] name = "pytest-socket" version = "0.7.0"