diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index d6f2cd7bef8..c2485dbe432 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -60,6 +60,7 @@ from langchain_core.pydantic_v1 import ( Field, root_validator, ) +from langchain_core.rate_limiters import BaseRateLimiter from langchain_core.runnables import RunnableMap, RunnablePassthrough from langchain_core.runnables.config import ensure_config, run_in_executor from langchain_core.tracers._streaming import _StreamingCallbackHandler @@ -210,6 +211,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """[DEPRECATED] Callback manager to add to the run trace.""" + rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True) + """An optional rate limiter to use for limiting the number of requests.""" + @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used. @@ -341,6 +345,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): batch_size=1, ) generation: Optional[ChatGenerationChunk] = None + + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + try: for chunk in self._stream(messages, stop=stop, **kwargs): if chunk.message.id is None: @@ -412,6 +420,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): batch_size=1, ) + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + generation: Optional[ChatGenerationChunk] = None try: async for chunk in self._astream( @@ -742,6 +753,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) + + # Apply the rate limiter after checking the cache, since + # we usually don't want to rate limit cache lookups, but + # we do want to rate limit API requests. + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _stream not implemented if type(self)._stream != BaseChatModel._stream and kwargs.pop( @@ -822,6 +840,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) + + # Apply the rate limiter after checking the cache, since + # we usually don't want to rate limit cache lookups, but + # we do want to rate limit API requests. + if self.rate_limiter: + self.rate_limiter.acquire(blocking=True) + # If stream is not explicitly set, check if implicitly requested by # astream_events() or astream_log(). Bail out if _astream not implemented if ( diff --git a/libs/core/langchain_core/runnables/rate_limiter.py b/libs/core/langchain_core/rate_limiters.py similarity index 70% rename from libs/core/langchain_core/runnables/rate_limiter.py rename to libs/core/langchain_core/rate_limiters.py index 378d73affb2..02a88535329 100644 --- a/libs/core/langchain_core/runnables/rate_limiter.py +++ b/libs/core/langchain_core/rate_limiters.py @@ -1,11 +1,4 @@ -"""Interface and implementation for time based rate limiters. - -This module defines an interface for rate limiting requests based on time. - -The interface cannot account for the size of the request or any other factors. - -The module also provides an in-memory implementation of the rate limiter. -""" +"""Interface for a rate limiter and an in-memory rate limiter.""" from __future__ import annotations @@ -14,22 +7,14 @@ import asyncio import threading import time from typing import ( - Any, Optional, - cast, ) from langchain_core._api import beta -from langchain_core.runnables import RunnableConfig -from langchain_core.runnables.base import ( - Input, - Output, - Runnable, -) @beta(message="Introduced in 0.2.24. API subject to change.") -class BaseRateLimiter(Runnable[Input, Output], abc.ABC): +class BaseRateLimiter(abc.ABC): """Base class for rate limiters. Usage of the base limiter is through the acquire and aacquire methods depending @@ -41,18 +26,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC): Current limitations: - - The rate limiter is not designed to work across different processes. It is - an in-memory rate limiter, but it is thread safe. - - The rate limiter only supports time-based rate limiting. It does not take - into account the size of the request or any other factors. - - The current implementation does not handle streaming inputs well and will - consume all inputs even if the rate limit has not been reached. Better support - for streaming inputs will be added in the future. - - When the rate limiter is combined with another runnable via a RunnableSequence, - usage of .batch() or .abatch() will only respect the average rate limit. - There will be bursty behavior as .batch() and .abatch() wait for each step - to complete before starting the next step. One way to mitigate this is to - use batch_as_completed() or abatch_as_completed(). + - Rate limiting information is not surfaced in tracing or callbacks. This means + that the total time it takes to invoke a chat model will encompass both + the time spent waiting for tokens and the time spent making the request. + .. versionadded:: 0.2.24 """ @@ -95,55 +72,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC): True if the tokens were successfully acquired, False otherwise. """ - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """Invoke the rate limiter. - - This is a blocking call that waits until the given number of tokens are - available. - - Args: - input: The input to the rate limiter. - config: The configuration for the rate limiter. - **kwargs: Additional keyword arguments. - - Returns: - The output of the rate limiter. - """ - - def _invoke(input: Input) -> Output: - """Invoke the rate limiter. Internal function.""" - self.acquire(blocking=True) - return cast(Output, input) - - return self._call_with_config(_invoke, input, config, **kwargs) - - async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> Output: - """Invoke the rate limiter. Async version. - - This is a blocking call that waits until the given number of tokens are - available. - - Args: - input: The input to the rate limiter. - config: The configuration for the rate limiter. - **kwargs: Additional keyword arguments. - """ - - async def _ainvoke(input: Input) -> Output: - """Invoke the rate limiter. Internal function.""" - await self.aacquire(blocking=True) - return cast(Output, input) - - return await self._acall_with_config(_ainvoke, input, config, **kwargs) - @beta(message="Introduced in 0.2.24. API subject to change.") class InMemoryRateLimiter(BaseRateLimiter): - """An in memory rate limiter. + """An in memory rate limiter based on a token bucket algorithm. This is an in memory rate limiter, so it cannot rate limit across different processes. @@ -168,19 +100,13 @@ class InMemoryRateLimiter(BaseRateLimiter): an in-memory rate limiter, but it is thread safe. - The rate limiter only supports time-based rate limiting. It does not take into account the size of the request or any other factors. - - The current implementation does not handle streaming inputs well and will - consume all inputs even if the rate limit has not been reached. Better support - for streaming inputs will be added in the future. - - When the rate limiter is combined with another runnable via a RunnableSequence, - usage of .batch() or .abatch() will only respect the average rate limit. - There will be bursty behavior as .batch() and .abatch() wait for each step - to complete before starting the next step. One way to mitigate this is to - use batch_as_completed() or abatch_as_completed(). Example: .. code-block:: python + from langchain_core import InMemoryRateLimiter + from langchain_core.runnables import RunnableLambda, InMemoryRateLimiter rate_limiter = InMemoryRateLimiter( @@ -239,7 +165,7 @@ class InMemoryRateLimiter(BaseRateLimiter): self.check_every_n_seconds = check_every_n_seconds def _consume(self) -> bool: - """Consume the given amount of tokens if possible. + """Try to consume a token. Returns: True means that the tokens were consumed, and the caller can proceed to @@ -317,3 +243,9 @@ class InMemoryRateLimiter(BaseRateLimiter): while not self._consume(): await asyncio.sleep(self.check_every_n_seconds) return True + + +__all__ = [ + "BaseRateLimiter", + "InMemoryRateLimiter", +] diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index 5ec88752bc1..44c95519c08 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -43,7 +43,6 @@ from langchain_core.runnables.passthrough import ( RunnablePassthrough, RunnablePick, ) -from langchain_core.runnables.rate_limiter import InMemoryRateLimiter from langchain_core.runnables.router import RouterInput, RouterRunnable from langchain_core.runnables.utils import ( AddableDict, @@ -65,7 +64,6 @@ __all__ = [ "ensure_config", "run_in_executor", "patch_config", - "InMemoryRateLimiter", "RouterInput", "RouterRunnable", "Runnable", diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py new file mode 100644 index 00000000000..b15b202f484 --- /dev/null +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py @@ -0,0 +1,258 @@ +import time + +from langchain_core.caches import InMemoryCache +from langchain_core.language_models import GenericFakeChatModel +from langchain_core.rate_limiters import InMemoryRateLimiter + + +def test_rate_limit_invoke() -> None: + """Add rate limiter.""" + + model = GenericFakeChatModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10 + ), + ) + tic = time.time() + model.invoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert 0.01 < toc - tic < 0.02 + + tic = time.time() + model.invoke("foo") + toc = time.time() + # The second time we call the model, we should have 1 extra token + # to proceed immediately. + assert toc - tic < 0.005 + + # The third time we call the model, we need to wait again for a token + tic = time.time() + model.invoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert 0.01 < toc - tic < 0.02 + + +async def test_rate_limit_ainvoke() -> None: + """Add rate limiter.""" + + model = GenericFakeChatModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10 + ), + ) + tic = time.time() + await model.ainvoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert 0.1 < toc - tic < 0.2 + + tic = time.time() + await model.ainvoke("foo") + toc = time.time() + # The second time we call the model, we should have 1 extra token + # to proceed immediately. + assert toc - tic < 0.01 + + # The third time we call the model, we need to wait again for a token + tic = time.time() + await model.ainvoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert 0.1 < toc - tic < 0.2 + + +def test_rate_limit_batch() -> None: + """Test that batch and stream calls work with rate limiters.""" + model = GenericFakeChatModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10 + ), + ) + # Need 2 tokens to proceed + time_to_fill = 2 / 200.0 + tic = time.time() + model.batch(["foo", "foo"]) + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert time_to_fill < toc - tic < time_to_fill + 0.01 + + +async def test_rate_limit_abatch() -> None: + """Test that batch and stream calls work with rate limiters.""" + model = GenericFakeChatModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10 + ), + ) + # Need 2 tokens to proceed + time_to_fill = 2 / 200.0 + tic = time.time() + await model.abatch(["foo", "foo"]) + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert time_to_fill < toc - tic < time_to_fill + 0.01 + + +def test_rate_limit_stream() -> None: + """Test rate limit by stream.""" + model = GenericFakeChatModel( + messages=iter(["hello world", "hello world", "hello world"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10 + ), + ) + # Check astream + tic = time.time() + response = list(model.stream("foo")) + assert [msg.content for msg in response] == ["hello", " ", "world"] + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds + + # Second time around we should have 1 token left + tic = time.time() + response = list(model.stream("foo")) + assert [msg.content for msg in response] == ["hello", " ", "world"] + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + assert toc - tic < 0.005 # Slightly smaller than check every n seconds + + # Third time around we should have 0 tokens left + tic = time.time() + response = list(model.stream("foo")) + assert [msg.content for msg in response] == ["hello", " ", "world"] + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds + + +async def test_rate_limit_astream() -> None: + """Test rate limiting astream.""" + rate_limiter = InMemoryRateLimiter( + requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10 + ) + model = GenericFakeChatModel( + messages=iter(["hello world", "hello world", "hello world"]), + rate_limiter=rate_limiter, + ) + # Check astream + tic = time.time() + response = [chunk async for chunk in model.astream("foo")] + assert [msg.content for msg in response] == ["hello", " ", "world"] + toc = time.time() + assert 0.1 < toc - tic < 0.2 + + # Second time around we should have 1 token left + tic = time.time() + response = [chunk async for chunk in model.astream("foo")] + assert [msg.content for msg in response] == ["hello", " ", "world"] + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + assert toc - tic < 0.01 # Slightly smaller than check every n seconds + + # Third time around we should have 0 tokens left + tic = time.time() + response = [chunk async for chunk in model.astream("foo")] + assert [msg.content for msg in response] == ["hello", " ", "world"] + toc = time.time() + assert 0.1 < toc - tic < 0.2 + + +def test_rate_limit_skips_cache() -> None: + """Test that rate limiting does not rate limit cache look ups.""" + cache = InMemoryCache() + model = GenericFakeChatModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1 + ), + cache=cache, + ) + + tic = time.time() + model.invoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert 0.01 < toc - tic < 0.02 + + for _ in range(2): + # Cache hits + tic = time.time() + model.invoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert toc - tic < 0.005 + + # Test verifies that there's only a single key + # Test also verifies that rate_limiter information is not part of the + # cache key + assert list(cache._cache) == [ + ( + '[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", ' + '"messages", ' + '"HumanMessage"], "kwargs": {"content": "foo", "type": "human"}}]', + "[('_type', 'generic-fake-chat-model'), ('stop', None)]", + ) + ] + + +class SerializableModel(GenericFakeChatModel): + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + +def test_serialization_with_rate_limiter() -> None: + """Test model serialization with rate limiter.""" + from langchain_core.load import dumps + + model = SerializableModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1 + ), + ) + serialized_model = dumps(model) + assert InMemoryRateLimiter.__name__ not in serialized_model + + +async def test_rate_limit_skips_cache_async() -> None: + """Test that rate limiting does not rate limit cache look ups.""" + cache = InMemoryCache() + model = GenericFakeChatModel( + messages=iter(["hello", "world", "!"]), + rate_limiter=InMemoryRateLimiter( + requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1 + ), + cache=cache, + ) + + tic = time.time() + await model.ainvoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert 0.01 < toc - tic < 0.02 + + for _ in range(2): + # Cache hits + tic = time.time() + await model.ainvoke("foo") + toc = time.time() + # Should be larger than check every n seconds since the token bucket starts + # with 0 tokens. + assert toc - tic < 0.005 diff --git a/libs/core/tests/unit_tests/rate_limiters/__init__.py b/libs/core/tests/unit_tests/rate_limiters/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/core/tests/unit_tests/runnables/test_rate_limiter.py b/libs/core/tests/unit_tests/rate_limiters/test_in_memory_rate_limiter.py similarity index 82% rename from libs/core/tests/unit_tests/runnables/test_rate_limiter.py rename to libs/core/tests/unit_tests/rate_limiters/test_in_memory_rate_limiter.py index b54a47e92a0..914b9d94262 100644 --- a/libs/core/tests/unit_tests/runnables/test_rate_limiter.py +++ b/libs/core/tests/unit_tests/rate_limiters/test_in_memory_rate_limiter.py @@ -5,8 +5,7 @@ import time import pytest from freezegun import freeze_time -from langchain_core.runnables import RunnableLambda -from langchain_core.runnables.rate_limiter import InMemoryRateLimiter +from langchain_core.rate_limiters import InMemoryRateLimiter @pytest.fixture @@ -109,37 +108,3 @@ async def test_async_wait_max_bucket_size() -> None: # Assert that sync wait can proceed without blocking # since we have enough tokens await rate_limiter.aacquire(blocking=True) - - -def test_add_rate_limiter() -> None: - """Add rate limiter.""" - - def foo(x: int) -> int: - """Return x.""" - return x - - rate_limiter = InMemoryRateLimiter( - requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10 - ) - - foo_ = RunnableLambda(foo) - chain = rate_limiter | foo_ - assert chain.invoke(1) == 1 - - -async def test_async_add_rate_limiter() -> None: - """Add rate limiter.""" - - async def foo(x: int) -> int: - """Return x.""" - return x - - rate_limiter = InMemoryRateLimiter( - requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10 - ) - - # mypy is unable to follow the type information when - # RunnableLambda is used with an async function - foo_ = RunnableLambda(foo) # type: ignore - chain = rate_limiter | foo_ - assert (await chain.ainvoke(1)) == 1 diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index 09e733a257e..12b1a80d1bf 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -11,7 +11,6 @@ EXPECTED_ALL = [ "run_in_executor", "patch_config", "RouterInput", - "InMemoryRateLimiter", "RouterRunnable", "Runnable", "RunnableSerializable",