Compare commits

...

9 Commits

Author SHA1 Message Date
Eugene Yurtsev
de339acbc1 x 2024-07-25 16:19:57 -04:00
Eugene Yurtsev
c929c0e38d x 2024-07-25 16:18:03 -04:00
Eugene Yurtsev
2ec1b96527 update 2024-07-25 16:14:50 -04:00
Eugene Yurtsev
0fd685fb7d x 2024-07-25 11:34:50 -04:00
Eugene Yurtsev
be89bf47e8 x 2024-07-25 11:34:34 -04:00
Eugene Yurtsev
b3419a3018 x 2024-07-25 10:27:47 -04:00
Eugene Yurtsev
98cbfb9643 x 2024-07-25 10:18:07 -04:00
Eugene Yurtsev
5776402f92 update 2024-07-25 10:13:41 -04:00
Eugene Yurtsev
f5e36246e3 qxqx 2024-07-25 09:41:21 -04:00
7 changed files with 386 additions and 232 deletions

View File

@@ -60,6 +60,7 @@ from langchain_core.pydantic_v1 import (
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
@@ -210,6 +211,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""
rate_limiter: Optional[BaseRateLimiter] = None
"""An optional rate limiter to use for limiting the number of requests."""
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.
@@ -341,6 +345,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
try:
for chunk in self._stream(messages, stop=stop, **kwargs):
if chunk.message.id is None:
@@ -412,6 +420,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=1,
)
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in self._astream(
@@ -742,6 +753,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
# Apply the rate limiter after checking the cache, since
# we usually don't want to rate limit cache lookups, but
# we do want to rate limit API requests.
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
@@ -822,6 +840,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
# Apply the rate limiter after checking the cache, since
# we usually don't want to rate limit cache lookups, but
# we do want to rate limit API requests.
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented
if (

View File

@@ -1,11 +1,4 @@
"""Interface and implementation for time based rate limiters.
This module defines an interface for rate limiting requests based on time.
The interface cannot account for the size of the request or any other factors.
The module also provides an in-memory implementation of the rate limiter.
"""
"""Interface for a rate limiter and an in-memory rate limiter."""
from __future__ import annotations
@@ -14,22 +7,14 @@ import asyncio
import threading
import time
from typing import (
Any,
Optional,
cast,
)
from langchain_core._api import beta
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.base import (
Input,
Output,
Runnable,
)
@beta(message="Introduced in 0.2.24. API subject to change.")
class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
class BaseRateLimiter(abc.ABC):
"""Base class for rate limiters.
Usage of the base limiter is through the acquire and aacquire methods depending
@@ -41,18 +26,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
Current limitations:
- The rate limiter is not designed to work across different processes. It is
an in-memory rate limiter, but it is thread safe.
- The rate limiter only supports time-based rate limiting. It does not take
into account the size of the request or any other factors.
- The current implementation does not handle streaming inputs well and will
consume all inputs even if the rate limit has not been reached. Better support
for streaming inputs will be added in the future.
- When the rate limiter is combined with another runnable via a RunnableSequence,
usage of .batch() or .abatch() will only respect the average rate limit.
There will be bursty behavior as .batch() and .abatch() wait for each step
to complete before starting the next step. One way to mitigate this is to
use batch_as_completed() or abatch_as_completed().
- Rate limiting information is not surfaced in tracing or callbacks. This means
that the total time it takes to invoke a chat model will encompass both
the time spent waiting for tokens and the time spent making the request.
.. versionadded:: 0.2.24
"""
@@ -95,55 +72,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
True if the tokens were successfully acquired, False otherwise.
"""
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the rate limiter.
This is a blocking call that waits until the given number of tokens are
available.
Args:
input: The input to the rate limiter.
config: The configuration for the rate limiter.
**kwargs: Additional keyword arguments.
Returns:
The output of the rate limiter.
"""
def _invoke(input: Input) -> Output:
"""Invoke the rate limiter. Internal function."""
self.acquire(blocking=True)
return cast(Output, input)
return self._call_with_config(_invoke, input, config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the rate limiter. Async version.
This is a blocking call that waits until the given number of tokens are
available.
Args:
input: The input to the rate limiter.
config: The configuration for the rate limiter.
**kwargs: Additional keyword arguments.
"""
async def _ainvoke(input: Input) -> Output:
"""Invoke the rate limiter. Internal function."""
await self.aacquire(blocking=True)
return cast(Output, input)
return await self._acall_with_config(_ainvoke, input, config, **kwargs)
@beta(message="Introduced in 0.2.24. API subject to change.")
class InMemoryRateLimiter(BaseRateLimiter):
"""An in memory rate limiter.
"""An in memory rate limiter based on a token bucket algorithm.
This is an in memory rate limiter, so it cannot rate limit across
different processes.
@@ -168,19 +100,13 @@ class InMemoryRateLimiter(BaseRateLimiter):
an in-memory rate limiter, but it is thread safe.
- The rate limiter only supports time-based rate limiting. It does not take
into account the size of the request or any other factors.
- The current implementation does not handle streaming inputs well and will
consume all inputs even if the rate limit has not been reached. Better support
for streaming inputs will be added in the future.
- When the rate limiter is combined with another runnable via a RunnableSequence,
usage of .batch() or .abatch() will only respect the average rate limit.
There will be bursty behavior as .batch() and .abatch() wait for each step
to complete before starting the next step. One way to mitigate this is to
use batch_as_completed() or abatch_as_completed().
Example:
.. code-block:: python
from langchain_core import InMemoryRateLimiter
from langchain_core.runnables import RunnableLambda, InMemoryRateLimiter
rate_limiter = InMemoryRateLimiter(
@@ -239,7 +165,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
self.check_every_n_seconds = check_every_n_seconds
def _consume(self) -> bool:
"""Consume the given amount of tokens if possible.
"""Try to consume a token.
Returns:
True means that the tokens were consumed, and the caller can proceed to
@@ -317,3 +243,9 @@ class InMemoryRateLimiter(BaseRateLimiter):
while not self._consume():
await asyncio.sleep(self.check_every_n_seconds)
return True
__all__ = [
"BaseRateLimiter",
"InMemoryRateLimiter",
]

View File

@@ -43,7 +43,6 @@ from langchain_core.runnables.passthrough import (
RunnablePassthrough,
RunnablePick,
)
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
from langchain_core.runnables.router import RouterInput, RouterRunnable
from langchain_core.runnables.utils import (
AddableDict,
@@ -65,7 +64,6 @@ __all__ = [
"ensure_config",
"run_in_executor",
"patch_config",
"InMemoryRateLimiter",
"RouterInput",
"RouterRunnable",
"Runnable",

View File

@@ -0,0 +1,345 @@
"""Test rate limiter."""
import time
import pytest
from freezegun import freeze_time
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter
@pytest.fixture
def rate_limiter() -> InMemoryRateLimiter:
"""Return an instance of InMemoryRateLimiter."""
return InMemoryRateLimiter(
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=2
)
def test_initial_state(rate_limiter: InMemoryRateLimiter) -> None:
"""Test the initial state of the rate limiter."""
assert rate_limiter.available_tokens == 0.0
def test_sync_wait(rate_limiter: InMemoryRateLimiter) -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter.last = time.time()
assert not rate_limiter.acquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not rate_limiter.acquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not rate_limiter.acquire(blocking=False)
frozen_time.tick(1.8)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1.0
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 0
frozen_time.tick(2.1)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1
frozen_time.tick(0.9)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1
# Check max bucket size
frozen_time.tick(100)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1
async def test_async_wait(rate_limiter: InMemoryRateLimiter) -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter.last = time.time()
assert not await rate_limiter.aacquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not await rate_limiter.aacquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not await rate_limiter.aacquire(blocking=False)
frozen_time.tick(1.8)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 1.0
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 0
frozen_time.tick(2.1)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 1
frozen_time.tick(0.9)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 1
def test_sync_wait_max_bucket_size() -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter = InMemoryRateLimiter(
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
)
rate_limiter.last = time.time()
frozen_time.tick(100) # Increment by 100 seconds
assert rate_limiter.acquire(blocking=False)
# After 100 seconds we manage to refill the bucket with 200 tokens
# After consuming 1 token, we should have 199 tokens left
assert rate_limiter.available_tokens == 199.0
frozen_time.tick(10000)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 499.0
# Assert that sync wait can proceed without blocking
# since we have enough tokens
rate_limiter.acquire(blocking=True)
async def test_async_wait_max_bucket_size() -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter = InMemoryRateLimiter(
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
)
rate_limiter.last = time.time()
frozen_time.tick(100) # Increment by 100 seconds
assert await rate_limiter.aacquire(blocking=False)
# After 100 seconds we manage to refill the bucket with 200 tokens
# After consuming 1 token, we should have 199 tokens left
assert rate_limiter.available_tokens == 199.0
frozen_time.tick(10000)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 499.0
# Assert that sync wait can proceed without blocking
# since we have enough tokens
await rate_limiter.aacquire(blocking=True)
def test_rate_limit_invoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
tic = time.time()
model.invoke("foo")
toc = time.time()
# The second time we call the model, we should have 1 extra token
# to proceed immediately.
assert toc - tic < 0.005
# The third time we call the model, we need to wait again for a token
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
async def test_rate_limit_ainvoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
),
)
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.1 < toc - tic < 0.2
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# The second time we call the model, we should have 1 extra token
# to proceed immediately.
assert toc - tic < 0.01
# The third time we call the model, we need to wait again for a token
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.1 < toc - tic < 0.2
def test_rate_limit_batch() -> None:
"""Test that batch and stream calls work with rate limiters."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
# Need 2 tokens to proceed
time_to_fill = 2 / 200.0
tic = time.time()
model.batch(["foo", "foo"])
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert time_to_fill < toc - tic < time_to_fill + 0.01
async def test_rate_limit_abatch() -> None:
"""Test that batch and stream calls work with rate limiters."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
# Need 2 tokens to proceed
time_to_fill = 2 / 200.0
tic = time.time()
await model.abatch(["foo", "foo"])
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert time_to_fill < toc - tic < time_to_fill + 0.01
def test_rate_limit_stream() -> None:
"""Test rate limit by stream."""
model = GenericFakeChatModel(
messages=iter(["hello world", "hello world", "hello world"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
# Check astream
tic = time.time()
response = list(model.stream("foo"))
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
# Second time around we should have 1 token left
tic = time.time()
response = list(model.stream("foo"))
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert toc - tic < 0.005 # Slightly smaller than check every n seconds
# Third time around we should have 0 tokens left
tic = time.time()
response = list(model.stream("foo"))
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
async def test_rate_limit_astream() -> None:
"""Test rate limiting astream."""
rate_limiter = InMemoryRateLimiter(
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
)
model = GenericFakeChatModel(
messages=iter(["hello world", "hello world", "hello world"]),
rate_limiter=rate_limiter,
)
# Check astream
tic = time.time()
response = [chunk async for chunk in model.astream("foo")]
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
assert 0.1 < toc - tic < 0.2
# Second time around we should have 1 token left
tic = time.time()
response = [chunk async for chunk in model.astream("foo")]
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert toc - tic < 0.01 # Slightly smaller than check every n seconds
# Third time around we should have 0 tokens left
tic = time.time()
response = [chunk async for chunk in model.astream("foo")]
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
assert 0.1 < toc - tic < 0.2
def test_rate_limit_skips_cache() -> None:
"""Test that rate limiting does not rate limit cache look ups."""
cache = InMemoryCache()
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
),
cache=cache,
)
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
for _ in range(2):
# Cache hits
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert toc - tic < 0.005
# Test verifies that there's only a single key
# Test also verifies that rate_limiter information is not part of the
# cache key
assert list(cache._cache) == [
(
'[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", '
'"messages", '
'"HumanMessage"], "kwargs": {"content": "foo", "type": "human"}}]',
"[('_type', 'generic-fake-chat-model'), ('stop', None)]",
)
]
async def test_rate_limit_skips_cache_async() -> None:
"""Test that rate limiting does not rate limit cache look ups."""
cache = InMemoryCache()
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
),
cache=cache,
)
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
for _ in range(2):
# Cache hits
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert toc - tic < 0.005

View File

@@ -11,7 +11,6 @@ EXPECTED_ALL = [
"run_in_executor",
"patch_config",
"RouterInput",
"InMemoryRateLimiter",
"RouterRunnable",
"Runnable",
"RunnableSerializable",

View File

@@ -1,145 +0,0 @@
"""Test rate limiter."""
import time
import pytest
from freezegun import freeze_time
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
@pytest.fixture
def rate_limiter() -> InMemoryRateLimiter:
"""Return an instance of InMemoryRateLimiter."""
return InMemoryRateLimiter(
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=2
)
def test_initial_state(rate_limiter: InMemoryRateLimiter) -> None:
"""Test the initial state of the rate limiter."""
assert rate_limiter.available_tokens == 0.0
def test_sync_wait(rate_limiter: InMemoryRateLimiter) -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter.last = time.time()
assert not rate_limiter.acquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not rate_limiter.acquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not rate_limiter.acquire(blocking=False)
frozen_time.tick(1.8)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1.0
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 0
frozen_time.tick(2.1)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1
frozen_time.tick(0.9)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1
# Check max bucket size
frozen_time.tick(100)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 1
async def test_async_wait(rate_limiter: InMemoryRateLimiter) -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter.last = time.time()
assert not await rate_limiter.aacquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not await rate_limiter.aacquire(blocking=False)
frozen_time.tick(0.1) # Increment by 0.1 seconds
assert rate_limiter.available_tokens == 0
assert not await rate_limiter.aacquire(blocking=False)
frozen_time.tick(1.8)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 1.0
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 0
frozen_time.tick(2.1)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 1
frozen_time.tick(0.9)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 1
def test_sync_wait_max_bucket_size() -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter = InMemoryRateLimiter(
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
)
rate_limiter.last = time.time()
frozen_time.tick(100) # Increment by 100 seconds
assert rate_limiter.acquire(blocking=False)
# After 100 seconds we manage to refill the bucket with 200 tokens
# After consuming 1 token, we should have 199 tokens left
assert rate_limiter.available_tokens == 199.0
frozen_time.tick(10000)
assert rate_limiter.acquire(blocking=False)
assert rate_limiter.available_tokens == 499.0
# Assert that sync wait can proceed without blocking
# since we have enough tokens
rate_limiter.acquire(blocking=True)
async def test_async_wait_max_bucket_size() -> None:
with freeze_time("2023-01-01 00:00:00") as frozen_time:
rate_limiter = InMemoryRateLimiter(
requests_per_second=2, check_every_n_seconds=0.1, max_bucket_size=500
)
rate_limiter.last = time.time()
frozen_time.tick(100) # Increment by 100 seconds
assert await rate_limiter.aacquire(blocking=False)
# After 100 seconds we manage to refill the bucket with 200 tokens
# After consuming 1 token, we should have 199 tokens left
assert rate_limiter.available_tokens == 199.0
frozen_time.tick(10000)
assert await rate_limiter.aacquire(blocking=False)
assert rate_limiter.available_tokens == 499.0
# Assert that sync wait can proceed without blocking
# since we have enough tokens
await rate_limiter.aacquire(blocking=True)
def test_add_rate_limiter() -> None:
"""Add rate limiter."""
def foo(x: int) -> int:
"""Return x."""
return x
rate_limiter = InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
)
foo_ = RunnableLambda(foo)
chain = rate_limiter | foo_
assert chain.invoke(1) == 1
async def test_async_add_rate_limiter() -> None:
"""Add rate limiter."""
async def foo(x: int) -> int:
"""Return x."""
return x
rate_limiter = InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
)
# mypy is unable to follow the type information when
# RunnableLambda is used with an async function
foo_ = RunnableLambda(foo) # type: ignore
chain = rate_limiter | foo_
assert (await chain.ainvoke(1)) == 1