From 3a5365a33ea60bbc6accfad164f8e043933c0b1e Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 26 Jul 2024 11:15:36 -0400 Subject: [PATCH] ai21: apply rate limiter in integration tests (#24717) Apply rate limiter in integration tests --- .../ai21/tests/integration_tests/test_chat_models.py | 9 ++++++--- .../ai21/tests/integration_tests/test_standard.py | 8 +++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/libs/partners/ai21/tests/integration_tests/test_chat_models.py b/libs/partners/ai21/tests/integration_tests/test_chat_models.py index f34341c6a98..d0e6a7e09bb 100644 --- a/libs/partners/ai21/tests/integration_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/integration_tests/test_chat_models.py @@ -3,10 +3,13 @@ import pytest from langchain_core.messages import AIMessageChunk, HumanMessage from langchain_core.outputs import ChatGeneration +from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_ai21.chat_models import ChatAI21 from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME +rate_limiter = InMemoryRateLimiter(requests_per_second=0.5) + @pytest.mark.parametrize( ids=[ @@ -21,7 +24,7 @@ from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME ) def test_invoke(model: str) -> None: """Test invoke tokens from AI21.""" - llm = ChatAI21(model=model) # type: ignore[call-arg] + llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg] result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) @@ -48,7 +51,7 @@ def test_generation(model: str, num_results: int) -> None: config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results" # Create the model instance using the appropriate key for the result count - llm = ChatAI21(model=model, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type] + llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type] message = HumanMessage(content="Hello, this is a test. Can you help me please?") @@ -75,7 +78,7 @@ def test_generation(model: str, num_results: int) -> None: ) async def test_ageneration(model: str) -> None: """Test invoke tokens from AI21.""" - llm = ChatAI21(model=model) # type: ignore[call-arg] + llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg] message = HumanMessage(content="Hello") result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"])) diff --git a/libs/partners/ai21/tests/integration_tests/test_standard.py b/libs/partners/ai21/tests/integration_tests/test_standard.py index 0774faf5e26..5896573102d 100644 --- a/libs/partners/ai21/tests/integration_tests/test_standard.py +++ b/libs/partners/ai21/tests/integration_tests/test_standard.py @@ -5,12 +5,13 @@ from typing import Type import pytest from langchain_core.language_models import BaseChatModel -from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found] - ChatModelIntegrationTests, # type: ignore[import-not-found] -) +from langchain_core.rate_limiters import InMemoryRateLimiter +from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_ai21 import ChatAI21 +rate_limiter = InMemoryRateLimiter(requests_per_second=0.5) + class BaseTestAI21(ChatModelIntegrationTests): def teardown(self) -> None: @@ -31,6 +32,7 @@ class TestAI21J2(BaseTestAI21): def chat_model_params(self) -> dict: return { "model": "j2-ultra", + "rate_limiter": rate_limiter, } @pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")