mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 04:01:05 +00:00
Compare commits
9 Commits
langchain=
...
eugene/rat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de339acbc1 | ||
|
|
c929c0e38d | ||
|
|
2ec1b96527 | ||
|
|
0fd685fb7d | ||
|
|
be89bf47e8 | ||
|
|
b3419a3018 | ||
|
|
98cbfb9643 | ||
|
|
5776402f92 | ||
|
|
f5e36246e3 |
@@ -60,6 +60,7 @@ from langchain_core.pydantic_v1 import (
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
@@ -210,6 +211,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||
|
||||
rate_limiter: Optional[BaseRateLimiter] = None
|
||||
"""An optional rate limiter to use for limiting the number of requests."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
@@ -341,6 +345,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
try:
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
@@ -412,6 +420,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
async for chunk in self._astream(
|
||||
@@ -742,6 +753,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
|
||||
# Apply the rate limiter after checking the cache, since
|
||||
# we usually don't want to rate limit cache lookups, but
|
||||
# we do want to rate limit API requests.
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
||||
@@ -822,6 +840,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
|
||||
# Apply the rate limiter after checking the cache, since
|
||||
# we usually don't want to rate limit cache lookups, but
|
||||
# we do want to rate limit API requests.
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
if (
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
"""Interface and implementation for time based rate limiters.
|
||||
|
||||
This module defines an interface for rate limiting requests based on time.
|
||||
|
||||
The interface cannot account for the size of the request or any other factors.
|
||||
|
||||
The module also provides an in-memory implementation of the rate limiter.
|
||||
"""
|
||||
"""Interface for a rate limiter and an in-memory rate limiter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -14,22 +7,14 @@ import asyncio
|
||||
import threading
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.base import (
|
||||
Input,
|
||||
Output,
|
||||
Runnable,
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Introduced in 0.2.24. API subject to change.")
|
||||
class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
class BaseRateLimiter(abc.ABC):
|
||||
"""Base class for rate limiters.
|
||||
|
||||
Usage of the base limiter is through the acquire and aacquire methods depending
|
||||
@@ -41,18 +26,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
|
||||
Current limitations:
|
||||
|
||||
- The rate limiter is not designed to work across different processes. It is
|
||||
an in-memory rate limiter, but it is thread safe.
|
||||
- The rate limiter only supports time-based rate limiting. It does not take
|
||||
into account the size of the request or any other factors.
|
||||
- The current implementation does not handle streaming inputs well and will
|
||||
consume all inputs even if the rate limit has not been reached. Better support
|
||||
for streaming inputs will be added in the future.
|
||||
- When the rate limiter is combined with another runnable via a RunnableSequence,
|
||||
usage of .batch() or .abatch() will only respect the average rate limit.
|
||||
There will be bursty behavior as .batch() and .abatch() wait for each step
|
||||
to complete before starting the next step. One way to mitigate this is to
|
||||
use batch_as_completed() or abatch_as_completed().
|
||||
- Rate limiting information is not surfaced in tracing or callbacks. This means
|
||||
that the total time it takes to invoke a chat model will encompass both
|
||||
the time spent waiting for tokens and the time spent making the request.
|
||||
|
||||
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
@@ -95,55 +72,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
True if the tokens were successfully acquired, False otherwise.
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Invoke the rate limiter.
|
||||
|
||||
This is a blocking call that waits until the given number of tokens are
|
||||
available.
|
||||
|
||||
Args:
|
||||
input: The input to the rate limiter.
|
||||
config: The configuration for the rate limiter.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The output of the rate limiter.
|
||||
"""
|
||||
|
||||
def _invoke(input: Input) -> Output:
|
||||
"""Invoke the rate limiter. Internal function."""
|
||||
self.acquire(blocking=True)
|
||||
return cast(Output, input)
|
||||
|
||||
return self._call_with_config(_invoke, input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Invoke the rate limiter. Async version.
|
||||
|
||||
This is a blocking call that waits until the given number of tokens are
|
||||
available.
|
||||
|
||||
Args:
|
||||
input: The input to the rate limiter.
|
||||
config: The configuration for the rate limiter.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def _ainvoke(input: Input) -> Output:
|
||||
"""Invoke the rate limiter. Internal function."""
|
||||
await self.aacquire(blocking=True)
|
||||
return cast(Output, input)
|
||||
|
||||
return await self._acall_with_config(_ainvoke, input, config, **kwargs)
|
||||
|
||||
|
||||
@beta(message="Introduced in 0.2.24. API subject to change.")
|
||||
class InMemoryRateLimiter(BaseRateLimiter):
|
||||
"""An in memory rate limiter.
|
||||
"""An in memory rate limiter based on a token bucket algorithm.
|
||||
|
||||
This is an in memory rate limiter, so it cannot rate limit across
|
||||
different processes.
|
||||
@@ -168,19 +100,13 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
an in-memory rate limiter, but it is thread safe.
|
||||
- The rate limiter only supports time-based rate limiting. It does not take
|
||||
into account the size of the request or any other factors.
|
||||
- The current implementation does not handle streaming inputs well and will
|
||||
consume all inputs even if the rate limit has not been reached. Better support
|
||||
for streaming inputs will be added in the future.
|
||||
- When the rate limiter is combined with another runnable via a RunnableSequence,
|
||||
usage of .batch() or .abatch() will only respect the average rate limit.
|
||||
There will be bursty behavior as .batch() and .abatch() wait for each step
|
||||
to complete before starting the next step. One way to mitigate this is to
|
||||
use batch_as_completed() or abatch_as_completed().
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core import InMemoryRateLimiter
|
||||
|
||||
from langchain_core.runnables import RunnableLambda, InMemoryRateLimiter
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
@@ -239,7 +165,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
self.check_every_n_seconds = check_every_n_seconds
|
||||
|
||||
def _consume(self) -> bool:
|
||||
"""Consume the given amount of tokens if possible.
|
||||
"""Try to consume a token.
|
||||
|
||||
Returns:
|
||||
True means that the tokens were consumed, and the caller can proceed to
|
||||
@@ -317,3 +243,9 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
while not self._consume():
|
||||
await asyncio.sleep(self.check_every_n_seconds)
|
||||
return True
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseRateLimiter",
|
||||
"InMemoryRateLimiter",
|
||||
]
|
||||
@@ -43,7 +43,6 @@ from langchain_core.runnables.passthrough import (
|
||||
RunnablePassthrough,
|
||||
RunnablePick,
|
||||
)
|
||||
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
|
||||
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
||||
from langchain_core.runnables.utils import (
|
||||
AddableDict,
|
||||
@@ -65,7 +64,6 @@ __all__ = [
|
||||
"ensure_config",
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"InMemoryRateLimiter",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
"""Test rate limiter."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter() -> InMemoryRateLimiter:
|
||||
"""Return an instance of InMemoryRateLimiter."""
|
||||
return InMemoryRateLimiter(
|
||||
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=2
|
||||
)
|
||||
|
||||
|
||||
def test_initial_state(rate_limiter: InMemoryRateLimiter) -> None:
|
||||
"""Test the initial state of the rate limiter."""
|
||||
assert rate_limiter.available_tokens == 0.0
|
||||
|
||||
|
||||
def test_sync_wait(rate_limiter: InMemoryRateLimiter) -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter.last = time.time()
|
||||
assert not rate_limiter.acquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not rate_limiter.acquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not rate_limiter.acquire(blocking=False)
|
||||
frozen_time.tick(1.8)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1.0
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 0
|
||||
frozen_time.tick(2.1)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
frozen_time.tick(0.9)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
|
||||
# Check max bucket size
|
||||
frozen_time.tick(100)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
|
||||
|
||||
async def test_async_wait(rate_limiter: InMemoryRateLimiter) -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter.last = time.time()
|
||||
assert not await rate_limiter.aacquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not await rate_limiter.aacquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not await rate_limiter.aacquire(blocking=False)
|
||||
frozen_time.tick(1.8)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1.0
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 0
|
||||
frozen_time.tick(2.1)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
frozen_time.tick(0.9)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
|
||||
|
||||
def test_sync_wait_max_bucket_size() -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
|
||||
)
|
||||
rate_limiter.last = time.time()
|
||||
frozen_time.tick(100) # Increment by 100 seconds
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
# After 100 seconds we manage to refill the bucket with 200 tokens
|
||||
# After consuming 1 token, we should have 199 tokens left
|
||||
assert rate_limiter.available_tokens == 199.0
|
||||
frozen_time.tick(10000)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 499.0
|
||||
# Assert that sync wait can proceed without blocking
|
||||
# since we have enough tokens
|
||||
rate_limiter.acquire(blocking=True)
|
||||
|
||||
|
||||
async def test_async_wait_max_bucket_size() -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
|
||||
)
|
||||
rate_limiter.last = time.time()
|
||||
frozen_time.tick(100) # Increment by 100 seconds
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
# After 100 seconds we manage to refill the bucket with 200 tokens
|
||||
# After consuming 1 token, we should have 199 tokens left
|
||||
assert rate_limiter.available_tokens == 199.0
|
||||
frozen_time.tick(10000)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 499.0
|
||||
# Assert that sync wait can proceed without blocking
|
||||
# since we have enough tokens
|
||||
await rate_limiter.aacquire(blocking=True)
|
||||
|
||||
|
||||
def test_rate_limit_invoke() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# The second time we call the model, we should have 1 extra token
|
||||
# to proceed immediately.
|
||||
assert toc - tic < 0.005
|
||||
|
||||
# The third time we call the model, we need to wait again for a token
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
|
||||
async def test_rate_limit_ainvoke() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# The second time we call the model, we should have 1 extra token
|
||||
# to proceed immediately.
|
||||
assert toc - tic < 0.01
|
||||
|
||||
# The third time we call the model, we need to wait again for a token
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
|
||||
def test_rate_limit_batch() -> None:
|
||||
"""Test that batch and stream calls work with rate limiters."""
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
# Need 2 tokens to proceed
|
||||
time_to_fill = 2 / 200.0
|
||||
tic = time.time()
|
||||
model.batch(["foo", "foo"])
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert time_to_fill < toc - tic < time_to_fill + 0.01
|
||||
|
||||
|
||||
async def test_rate_limit_abatch() -> None:
|
||||
"""Test that batch and stream calls work with rate limiters."""
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
# Need 2 tokens to proceed
|
||||
time_to_fill = 2 / 200.0
|
||||
tic = time.time()
|
||||
await model.abatch(["foo", "foo"])
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert time_to_fill < toc - tic < time_to_fill + 0.01
|
||||
|
||||
|
||||
def test_rate_limit_stream() -> None:
|
||||
"""Test rate limit by stream."""
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello world", "hello world", "hello world"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
# Check astream
|
||||
tic = time.time()
|
||||
response = list(model.stream("foo"))
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
|
||||
|
||||
# Second time around we should have 1 token left
|
||||
tic = time.time()
|
||||
response = list(model.stream("foo"))
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert toc - tic < 0.005 # Slightly smaller than check every n seconds
|
||||
|
||||
# Third time around we should have 0 tokens left
|
||||
tic = time.time()
|
||||
response = list(model.stream("foo"))
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
|
||||
|
||||
|
||||
async def test_rate_limit_astream() -> None:
|
||||
"""Test rate limiting astream."""
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
)
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello world", "hello world", "hello world"]),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
# Check astream
|
||||
tic = time.time()
|
||||
response = [chunk async for chunk in model.astream("foo")]
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
# Second time around we should have 1 token left
|
||||
tic = time.time()
|
||||
response = [chunk async for chunk in model.astream("foo")]
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert toc - tic < 0.01 # Slightly smaller than check every n seconds
|
||||
|
||||
# Third time around we should have 0 tokens left
|
||||
tic = time.time()
|
||||
response = [chunk async for chunk in model.astream("foo")]
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
|
||||
def test_rate_limit_skips_cache() -> None:
|
||||
"""Test that rate limiting does not rate limit cache look ups."""
|
||||
cache = InMemoryCache()
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
for _ in range(2):
|
||||
# Cache hits
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert toc - tic < 0.005
|
||||
|
||||
# Test verifies that there's only a single key
|
||||
# Test also verifies that rate_limiter information is not part of the
|
||||
# cache key
|
||||
assert list(cache._cache) == [
|
||||
(
|
||||
'[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", '
|
||||
'"messages", '
|
||||
'"HumanMessage"], "kwargs": {"content": "foo", "type": "human"}}]',
|
||||
"[('_type', 'generic-fake-chat-model'), ('stop', None)]",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def test_rate_limit_skips_cache_async() -> None:
|
||||
"""Test that rate limiting does not rate limit cache look ups."""
|
||||
cache = InMemoryCache()
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
for _ in range(2):
|
||||
# Cache hits
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert toc - tic < 0.005
|
||||
@@ -11,7 +11,6 @@ EXPECTED_ALL = [
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"InMemoryRateLimiter",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Test rate limiter."""
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter() -> InMemoryRateLimiter:
|
||||
"""Return an instance of InMemoryRateLimiter."""
|
||||
return InMemoryRateLimiter(
|
||||
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=2
|
||||
)
|
||||
|
||||
|
||||
def test_initial_state(rate_limiter: InMemoryRateLimiter) -> None:
|
||||
"""Test the initial state of the rate limiter."""
|
||||
assert rate_limiter.available_tokens == 0.0
|
||||
|
||||
|
||||
def test_sync_wait(rate_limiter: InMemoryRateLimiter) -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter.last = time.time()
|
||||
assert not rate_limiter.acquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not rate_limiter.acquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not rate_limiter.acquire(blocking=False)
|
||||
frozen_time.tick(1.8)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1.0
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 0
|
||||
frozen_time.tick(2.1)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
frozen_time.tick(0.9)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
|
||||
# Check max bucket size
|
||||
frozen_time.tick(100)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
|
||||
|
||||
async def test_async_wait(rate_limiter: InMemoryRateLimiter) -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter.last = time.time()
|
||||
assert not await rate_limiter.aacquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not await rate_limiter.aacquire(blocking=False)
|
||||
frozen_time.tick(0.1) # Increment by 0.1 seconds
|
||||
assert rate_limiter.available_tokens == 0
|
||||
assert not await rate_limiter.aacquire(blocking=False)
|
||||
frozen_time.tick(1.8)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1.0
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 0
|
||||
frozen_time.tick(2.1)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
frozen_time.tick(0.9)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 1
|
||||
|
||||
|
||||
def test_sync_wait_max_bucket_size() -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
|
||||
)
|
||||
rate_limiter.last = time.time()
|
||||
frozen_time.tick(100) # Increment by 100 seconds
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
# After 100 seconds we manage to refill the bucket with 200 tokens
|
||||
# After consuming 1 token, we should have 199 tokens left
|
||||
assert rate_limiter.available_tokens == 199.0
|
||||
frozen_time.tick(10000)
|
||||
assert rate_limiter.acquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 499.0
|
||||
# Assert that sync wait can proceed without blocking
|
||||
# since we have enough tokens
|
||||
rate_limiter.acquire(blocking=True)
|
||||
|
||||
|
||||
async def test_async_wait_max_bucket_size() -> None:
|
||||
with freeze_time("2023-01-01 00:00:00") as frozen_time:
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
|
||||
)
|
||||
rate_limiter.last = time.time()
|
||||
frozen_time.tick(100) # Increment by 100 seconds
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
# After 100 seconds we manage to refill the bucket with 200 tokens
|
||||
# After consuming 1 token, we should have 199 tokens left
|
||||
assert rate_limiter.available_tokens == 199.0
|
||||
frozen_time.tick(10000)
|
||||
assert await rate_limiter.aacquire(blocking=False)
|
||||
assert rate_limiter.available_tokens == 499.0
|
||||
# Assert that sync wait can proceed without blocking
|
||||
# since we have enough tokens
|
||||
await rate_limiter.aacquire(blocking=True)
|
||||
|
||||
|
||||
def test_add_rate_limiter() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
def foo(x: int) -> int:
|
||||
"""Return x."""
|
||||
return x
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
)
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
chain = rate_limiter | foo_
|
||||
assert chain.invoke(1) == 1
|
||||
|
||||
|
||||
async def test_async_add_rate_limiter() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
async def foo(x: int) -> int:
|
||||
"""Return x."""
|
||||
return x
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
)
|
||||
|
||||
# mypy is unable to follow the type information when
|
||||
# RunnableLambda is used with an async function
|
||||
foo_ = RunnableLambda(foo) # type: ignore
|
||||
chain = rate_limiter | foo_
|
||||
assert (await chain.ainvoke(1)) == 1
|
||||
Reference in New Issue
Block a user