mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
ai21: apply rate limiter in integration tests (#24717)
Apply rate limiter in integration tests
This commit is contained in:
parent
03d62a737a
commit
3a5365a33e
@ -3,10 +3,13 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
@ -21,7 +24,7 @@ from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
||||
)
|
||||
def test_invoke(model: str) -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model=model) # type: ignore[call-arg]
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
@ -48,7 +51,7 @@ def test_generation(model: str, num_results: int) -> None:
|
||||
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
|
||||
|
||||
# Create the model instance using the appropriate key for the result count
|
||||
llm = ChatAI21(model=model, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
|
||||
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
|
||||
|
||||
@ -75,7 +78,7 @@ def test_generation(model: str, num_results: int) -> None:
|
||||
)
|
||||
async def test_ageneration(model: str) -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model=model) # type: ignore[call-arg]
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
|
||||
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
||||
|
@ -5,12 +5,13 @@ from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
||||
)
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_ai21 import ChatAI21
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||
|
||||
|
||||
class BaseTestAI21(ChatModelIntegrationTests):
|
||||
def teardown(self) -> None:
|
||||
@ -31,6 +32,7 @@ class TestAI21J2(BaseTestAI21):
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "j2-ultra",
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
|
||||
|
Loading…
Reference in New Issue
Block a user