mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
ai21: apply rate limiter in integration tests (#24717)
Apply rate limiter in integration tests
This commit is contained in:
parent
03d62a737a
commit
3a5365a33e
@ -3,10 +3,13 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||||
from langchain_core.outputs import ChatGeneration
|
from langchain_core.outputs import ChatGeneration
|
||||||
|
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||||
|
|
||||||
from langchain_ai21.chat_models import ChatAI21
|
from langchain_ai21.chat_models import ChatAI21
|
||||||
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
||||||
|
|
||||||
|
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
ids=[
|
ids=[
|
||||||
@ -21,7 +24,7 @@ from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
|||||||
)
|
)
|
||||||
def test_invoke(model: str) -> None:
|
def test_invoke(model: str) -> None:
|
||||||
"""Test invoke tokens from AI21."""
|
"""Test invoke tokens from AI21."""
|
||||||
llm = ChatAI21(model=model) # type: ignore[call-arg]
|
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||||
|
|
||||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
@ -48,7 +51,7 @@ def test_generation(model: str, num_results: int) -> None:
|
|||||||
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
|
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
|
||||||
|
|
||||||
# Create the model instance using the appropriate key for the result count
|
# Create the model instance using the appropriate key for the result count
|
||||||
llm = ChatAI21(model=model, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||||
|
|
||||||
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
|
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
|
||||||
|
|
||||||
@ -75,7 +78,7 @@ def test_generation(model: str, num_results: int) -> None:
|
|||||||
)
|
)
|
||||||
async def test_ageneration(model: str) -> None:
|
async def test_ageneration(model: str) -> None:
|
||||||
"""Test invoke tokens from AI21."""
|
"""Test invoke tokens from AI21."""
|
||||||
llm = ChatAI21(model=model) # type: ignore[call-arg]
|
llm = ChatAI21(model=model, rate_limiter=rate_limiter) # type: ignore[call-arg]
|
||||||
message = HumanMessage(content="Hello")
|
message = HumanMessage(content="Hello")
|
||||||
|
|
||||||
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
||||||
|
@ -5,12 +5,13 @@ from typing import Type
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
|
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_ai21 import ChatAI21
|
from langchain_ai21 import ChatAI21
|
||||||
|
|
||||||
|
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||||
|
|
||||||
|
|
||||||
class BaseTestAI21(ChatModelIntegrationTests):
|
class BaseTestAI21(ChatModelIntegrationTests):
|
||||||
def teardown(self) -> None:
|
def teardown(self) -> None:
|
||||||
@ -31,6 +32,7 @@ class TestAI21J2(BaseTestAI21):
|
|||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"model": "j2-ultra",
|
"model": "j2-ultra",
|
||||||
|
"rate_limiter": rate_limiter,
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
|
@pytest.mark.xfail(reason="Streaming is not supported for Jurassic models.")
|
||||||
|
Loading…
Reference in New Issue
Block a user