ai21: apply rate limiter in integration tests (#24717)

Apply rate limiter in integration tests
This commit is contained in:
Eugene Yurtsev 2024-07-26 11:15:36 -04:00 committed by GitHub
parent 03d62a737a
commit 3a5365a33e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 6 deletions

View File

@ -3,10 +3,13 @@
import pytest
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_ai21.chat_models import ChatAI21
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
@pytest.mark.parametrize(
ids=[
@ -21,7 +24,7 @@ from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
)
def test_invoke(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model) # type: ignore[call-arg]
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
@ -48,7 +51,7 @@ def test_generation(model: str, num_results: int) -> None:
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
# Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
@ -75,7 +78,7 @@ def test_generation(model: str, num_results: int) -> None:
)
async def test_ageneration(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=model) # type: ignore[call-arg]
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
message = HumanMessage(content="Hello")
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))

View File

@ -5,12 +5,13 @@ from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
)
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_ai21 import ChatAI21
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
class BaseTestAI21(ChatModelIntegrationTests):
def teardown(self) -> None:
@ -31,6 +32,7 @@ class TestAI21J2(BaseTestAI21):
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
"rate_limiter": rate_limiter,
}
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")