From 6c07eb0c12cfa67dc5fbe13996b452914617f7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cahid=20Arda=20=C3=96z?= Date: Sat, 8 Jun 2024 00:02:06 +0300 Subject: [PATCH] community[minor]: Add UpstashRatelimitHandler (#21885) Adding `UpstashRatelimitHandler` callback for rate limiting based on number of chain invocations or LLM token usage. For more details, see [upstash/ratelimit-py repository](https://github.com/upstash/ratelimit-py) or the notebook guide included in this PR. Twitter handle: @cahidarda --------- Co-authored-by: Eugene Yurtsev --- .../callbacks/upstash_ratelimit.ipynb | 245 ++++++++++++++++++ libs/community/extended_testing_deps.txt | 3 +- .../langchain_community/callbacks/__init__.py | 8 + .../callbacks/upstash_ratelimit_callback.py | 206 +++++++++++++++ .../unit_tests/callbacks/test_imports.py | 2 + .../test_upstash_ratelimit_callback.py | 234 +++++++++++++++++ 6 files changed, 697 insertions(+), 1 deletion(-) create mode 100644 docs/docs/integrations/callbacks/upstash_ratelimit.ipynb create mode 100644 libs/community/langchain_community/callbacks/upstash_ratelimit_callback.py create mode 100644 libs/community/tests/unit_tests/callbacks/test_upstash_ratelimit_callback.py diff --git a/docs/docs/integrations/callbacks/upstash_ratelimit.ipynb b/docs/docs/integrations/callbacks/upstash_ratelimit.ipynb new file mode 100644 index 00000000000..78c5e15f832 --- /dev/null +++ b/docs/docs/integrations/callbacks/upstash_ratelimit.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Upstash Ratelimit Callback\n", + "\n", + "In this guide, we will go over how to add rate limiting based on number of requests or the number of tokens using `UpstashRatelimitHandler`. This handler uses [ratelimit library of Upstash](https://github.com/upstash/ratelimit-py/), which utilizes [Upstash Redis](https://upstash.com/docs/redis/overall/getstarted).\n", + "\n", + "Upstash Ratelimit works by sending an HTTP request to Upstash Redis everytime the `limit` method is called. Remaining tokens/requests of the user are checked and updated. Based on the remaining tokens, we can stop the execution of costly operations like invoking an LLM or querying a vector store:\n", + "\n", + "```py\n", + "response = ratelimit.limit()\n", + "if response.allowed:\n", + " execute_costly_operation()\n", + "```\n", + "\n", + "`UpstashRatelimitHandler` allows you to incorporate the ratelimit logic into your chain in a few minutes.\n", + "\n", + "First, you will need to go to [the Upstash Console](https://console.upstash.com/login) and create a redis database ([see our docs](https://upstash.com/docs/redis/overall/getstarted)). After creating a database, you will need to set the environment variables:\n", + "\n", + "```\n", + "UPSTASH_REDIS_REST_URL=\"****\"\n", + "UPSTASH_REDIS_REST_TOKEN=\"****\"\n", + "```\n", + "\n", + "Next, you will need to install Upstash Ratelimit and Redis library with:\n", + "\n", + "```\n", + "pip install upstash-ratelimit upstash-redis\n", + "```\n", + "\n", + "You are now ready to add rate limiting to your chain!\n", + "\n", + "## Ratelimiting Per Request\n", + "\n", + "Let's imagine that we want to allow our users to invoke our chain 10 times per minute. Achieving this is as simple as:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in UpstashRatelimitHandler.on_chain_start callback: UpstashRatelimitError('Request limit reached!')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Handling ratelimit. \n" + ] + } + ], + "source": [ + "# set env variables\n", + "import os\n", + "\n", + "os.environ[\"UPSTASH_REDIS_REST_URL\"] = \"****\"\n", + "os.environ[\"UPSTASH_REDIS_REST_TOKEN\"] = \"****\"\n", + "\n", + "from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler\n", + "from langchain_core.runnables import RunnableLambda\n", + "from upstash_ratelimit import FixedWindow, Ratelimit\n", + "from upstash_redis import Redis\n", + "\n", + "# create ratelimit\n", + "ratelimit = Ratelimit(\n", + " redis=Redis.from_env(),\n", + " # 10 requests per window, where window size is 60 seconds:\n", + " limiter=FixedWindow(max_requests=10, window=60),\n", + ")\n", + "\n", + "# create handler\n", + "user_id = \"user_id\" # should be a method which gets the user id\n", + "handler = UpstashRatelimitHandler(identifier=user_id, request_ratelimit=ratelimit)\n", + "\n", + "# create mock chain\n", + "chain = RunnableLambda(str)\n", + "\n", + "# invoke chain with handler:\n", + "try:\n", + " result = chain.invoke(\"Hello world!\", config={\"callbacks\": [handler]})\n", + "except UpstashRatelimitError:\n", + " print(\"Handling ratelimit.\", UpstashRatelimitError)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that we pass the handler to the `invoke` method instead of passing the handler when defining the chain.\n", + "\n", + "For rate limiting algorithms other than `FixedWindow`, see [upstash-ratelimit docs](https://github.com/upstash/ratelimit-py?tab=readme-ov-file#ratelimiting-algorithms).\n", + "\n", + "Before executing any steps in our pipeline, ratelimit will check whether the user has passed the request limit. If so, `UpstashRatelimitError` is raised.\n", + "\n", + "## Ratelimiting Per Token\n", + "\n", + "Another option is to rate limit chain invokations based on:\n", + "1. number of tokens in prompt\n", + "2. number of tokens in prompt and LLM completion\n", + "\n", + "This only works if you have an LLM in your chain. Another requirement is that the LLM you are using should return the token usage in it's `LLMOutput`.\n", + "\n", + "### How it works\n", + "\n", + "The handler will get the remaining tokens before calling the LLM. If the remaining tokens is more than 0, LLM will be called. Otherwise `UpstashRatelimitError` will be raised.\n", + "\n", + "After LLM is called, token usage information will be used to subtracted from the remaining tokens of the user. No error is raised at this stage of the chain.\n", + "\n", + "### Configuration\n", + "\n", + "For the first configuration, simply initialize the handler like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ratelimit = Ratelimit(\n", + " redis=Redis.from_env(),\n", + " # 1000 tokens per window, where window size is 60 seconds:\n", + " limiter=FixedWindow(max_requests=1000, window=60),\n", + ")\n", + "\n", + "handler = UpstashRatelimitHandler(identifier=user_id, token_ratelimit=ratelimit)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the second configuration, here is how to initialize the handler:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ratelimit = Ratelimit(\n", + " redis=Redis.from_env(),\n", + " # 1000 tokens per window, where window size is 60 seconds:\n", + " limiter=FixedWindow(max_requests=1000, window=60),\n", + ")\n", + "\n", + "handler = UpstashRatelimitHandler(\n", + " identifier=user_id,\n", + " token_ratelimit=ratelimit,\n", + " include_output_tokens=True, # set to True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also employ ratelimiting based on requests and tokens at the same time, simply by passing both `request_ratelimit` and `token_ratelimit` parameters.\n", + "\n", + "Here is an example with a chain utilizing an LLM:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in UpstashRatelimitHandler.on_llm_start callback: UpstashRatelimitError('Token limit reached!')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Handling ratelimit. \n" + ] + } + ], + "source": [ + "# set env variables\n", + "import os\n", + "\n", + "os.environ[\"UPSTASH_REDIS_REST_URL\"] = \"****\"\n", + "os.environ[\"UPSTASH_REDIS_REST_TOKEN\"] = \"****\"\n", + "os.environ[\"OPENAI_API_KEY\"] = \"****\"\n", + "\n", + "from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler\n", + "from langchain_core.runnables import RunnableLambda\n", + "from langchain_openai import ChatOpenAI\n", + "from upstash_ratelimit import FixedWindow, Ratelimit\n", + "from upstash_redis import Redis\n", + "\n", + "# create ratelimit\n", + "ratelimit = Ratelimit(\n", + " redis=Redis.from_env(),\n", + " # 500 tokens per window, where window size is 60 seconds:\n", + " limiter=FixedWindow(max_requests=500, window=60),\n", + ")\n", + "\n", + "# create handler\n", + "user_id = \"user_id\" # should be a method which gets the user id\n", + "handler = UpstashRatelimitHandler(identifier=user_id, token_ratelimit=ratelimit)\n", + "\n", + "# create mock chain\n", + "as_str = RunnableLambda(str)\n", + "model = ChatOpenAI()\n", + "\n", + "chain = as_str | model\n", + "\n", + "# invoke chain with handler:\n", + "try:\n", + " result = chain.invoke(\"Hello world!\", config={\"callbacks\": [handler]})\n", + "except UpstashRatelimitError:\n", + " print(\"Handling ratelimit.\", UpstashRatelimitError)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lc39", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index eee0b66e36a..9f5e8284af6 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -80,7 +80,8 @@ timescale-vector==0.0.1 tqdm>=4.48.0 tree-sitter>=0.20.2,<0.21 tree-sitter-languages>=1.8.0,<2 -upstash-redis>=0.15.0,<0.16 +upstash-redis>=1.1.0,<2 +upstash-ratelimit>=1.1.0,<2 vdms==0.0.20 xata>=1.0.0a7,<2 xmltodict>=0.13.0,<0.14 diff --git a/libs/community/langchain_community/callbacks/__init__.py b/libs/community/langchain_community/callbacks/__init__.py index 201f5ef670c..72a098a3527 100644 --- a/libs/community/langchain_community/callbacks/__init__.py +++ b/libs/community/langchain_community/callbacks/__init__.py @@ -72,6 +72,10 @@ if TYPE_CHECKING: from langchain_community.callbacks.trubrics_callback import ( TrubricsCallbackHandler, ) + from langchain_community.callbacks.upstash_ratelimit_callback import ( + UpstashRatelimitError, + UpstashRatelimitHandler, # noqa: F401 + ) from langchain_community.callbacks.uptrain_callback import ( UpTrainCallbackHandler, ) @@ -104,6 +108,8 @@ _module_lookup = { "SageMakerCallbackHandler": "langchain_community.callbacks.sagemaker_callback", "StreamlitCallbackHandler": "langchain_community.callbacks.streamlit", "TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback", + "UpstashRatelimitError": "langchain_community.callbacks.upstash_ratelimit_callback", + "UpstashRatelimitHandler": "langchain_community.callbacks.upstash_ratelimit_callback", # noqa "UpTrainCallbackHandler": "langchain_community.callbacks.uptrain_callback", "WandbCallbackHandler": "langchain_community.callbacks.wandb_callback", "WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback", @@ -140,6 +146,8 @@ __all__ = [ "SageMakerCallbackHandler", "StreamlitCallbackHandler", "TrubricsCallbackHandler", + "UpstashRatelimitError", + "UpstashRatelimitHandler", "UpTrainCallbackHandler", "WandbCallbackHandler", "WhyLabsCallbackHandler", diff --git a/libs/community/langchain_community/callbacks/upstash_ratelimit_callback.py b/libs/community/langchain_community/callbacks/upstash_ratelimit_callback.py new file mode 100644 index 00000000000..9bd984be4f1 --- /dev/null +++ b/libs/community/langchain_community/callbacks/upstash_ratelimit_callback.py @@ -0,0 +1,206 @@ +"""Ratelimiting Handler to limit requests or tokens""" + +import logging +from typing import Any, Dict, List, Literal, Optional + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +logger = logging.getLogger(__name__) +try: + from upstash_ratelimit import Ratelimit +except ImportError: + Ratelimit = None + + +class UpstashRatelimitError(Exception): + """ + Upstash Ratelimit Error + + Raised when the rate limit is reached in `UpstashRatelimitHandler` + """ + + def __init__( + self, + message: str, + type: Literal["token", "request"], + limit: Optional[int] = None, + reset: Optional[float] = None, + ): + """ + Args: + message (str): error message + type (str): The kind of the limit which was reached. One of + "token" or "request" + limit (Optional[int]): The limit which was reached. Passed when type + is request + reset (Optional[int]): unix timestamp in milliseconds when the limits + are reset. Passed when type is request + """ + # Call the base class constructor with the parameters it needs + super().__init__(message) + self.type = type + self.limit = limit + self.reset = reset + + +class UpstashRatelimitHandler(BaseCallbackHandler): + """ + Callback to handle rate limiting based on the number of requests + or the number of tokens in the input. + + It uses Upstash Ratelimit to track the ratelimit which utilizes + Upstash Redis to track the state. + + Should not be passed to the chain when initialising the chain. + This is because the handler has a state which should be fresh + every time invoke is called. Instead, initialise and pass a handler + every time you invoke. + """ + + raise_error = True + _checked = False + + def __init__( + self, + identifier: str, + *, + token_ratelimit: Optional[Ratelimit] = None, + request_ratelimit: Optional[Ratelimit] = None, + include_output_tokens: bool = False, + ): + """ + Creates UpstashRatelimitHandler. Must be passed an identifier to + ratelimit like a user id or an ip address. + + Additionally, it must be passed at least one of token_ratelimit + or request_ratelimit parameters. + + Args: + identifier Union[int, str]: the identifier + token_ratelimit Optional[Ratelimit]: Ratelimit to limit the + number of tokens. Only works with OpenAI models since only + these models provide the number of tokens as information + in their output. + request_ratelimit Optional[Ratelimit]: Ratelimit to limit the + number of requests + include_output_tokens bool: Whether to count output tokens when + rate limiting based on number of tokens. Only used when + `token_ratelimit` is passed. False by default. + + Example: + .. code-block:: python + + from upstash_redis import Redis + from upstash_ratelimit import Ratelimit, FixedWindow + + redis = Redis.from_env() + ratelimit = Ratelimit( + redis=redis, + # fixed window to allow 10 requests every 10 seconds: + limiter=FixedWindow(max_requests=10, window=10), + ) + + user_id = "foo" + handler = UpstashRatelimitHandler( + identifier=user_id, + request_ratelimit=ratelimit + ) + + # Initialize a simple runnable to test + chain = RunnableLambda(str) + + # pass handler as callback: + output = chain.invoke( + "input", + config={ + "callbacks": [handler] + } + ) + + """ + if not any([token_ratelimit, request_ratelimit]): + raise ValueError( + "You must pass at least one of input_token_ratelimit or" + " request_ratelimit parameters for handler to work." + ) + + self.identifier = identifier + self.token_ratelimit = token_ratelimit + self.request_ratelimit = request_ratelimit + self.include_output_tokens = include_output_tokens + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> Any: + """ + Run when chain starts running. + + on_chain_start runs multiple times during a chain execution. To make + sure that it's only called once, we keep a bool state `_checked`. If + not `self._checked`, we call limit with `request_ratelimit` and raise + `UpstashRatelimitError` if the identifier is rate limited. + """ + if self.request_ratelimit and not self._checked: + response = self.request_ratelimit.limit(self.identifier) + if not response.allowed: + raise UpstashRatelimitError( + "Request limit reached!", "request", response.limit, response.reset + ) + self._checked = True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """ + Run when LLM starts running + """ + if self.token_ratelimit: + remaining = self.token_ratelimit.get_remaining(self.identifier) + if remaining <= 0: + raise UpstashRatelimitError("Token limit reached!", "token") + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """ + Run when LLM ends running + + If the `include_output_tokens` is set to True, number of tokens + in LLM completion are counted for rate limiting + """ + if self.token_ratelimit: + try: + llm_output = response.llm_output or {} + token_usage = llm_output["token_usage"] + token_count = ( + token_usage["total_tokens"] + if self.include_output_tokens + else token_usage["prompt_tokens"] + ) + except KeyError: + raise ValueError( + "LLM response doesn't include" + " `token_usage: {total_tokens: int, prompt_tokens: int}`" + " field. To use UpstashRatelimitHandler with token_ratelimit," + " either use a model which returns token_usage (like " + " OpenAI models) or rate limit only with request_ratelimit." + ) + + # call limit to add the completion tokens to rate limit + # but don't raise exception since we already generated + # the tokens and would rather continue execution. + self.token_ratelimit.limit(self.identifier, rate=token_count) + + def reset(self, identifier: Optional[str] = None) -> "UpstashRatelimitHandler": + """ + Creates a new UpstashRatelimitHandler object with the same + ratelimit configurations but with a new identifier if it's + provided. + + Also resets the state of the handler. + """ + return UpstashRatelimitHandler( + identifier=identifier or self.identifier, + token_ratelimit=self.token_ratelimit, + request_ratelimit=self.request_ratelimit, + include_output_tokens=self.include_output_tokens, + ) diff --git a/libs/community/tests/unit_tests/callbacks/test_imports.py b/libs/community/tests/unit_tests/callbacks/test_imports.py index 26e6b7daaad..566099cbdd0 100644 --- a/libs/community/tests/unit_tests/callbacks/test_imports.py +++ b/libs/community/tests/unit_tests/callbacks/test_imports.py @@ -26,6 +26,8 @@ EXPECTED_ALL = [ "TrubricsCallbackHandler", "FiddlerCallbackHandler", "UpTrainCallbackHandler", + "UpstashRatelimitError", + "UpstashRatelimitHandler", ] diff --git a/libs/community/tests/unit_tests/callbacks/test_upstash_ratelimit_callback.py b/libs/community/tests/unit_tests/callbacks/test_upstash_ratelimit_callback.py new file mode 100644 index 00000000000..cf728c4c118 --- /dev/null +++ b/libs/community/tests/unit_tests/callbacks/test_upstash_ratelimit_callback.py @@ -0,0 +1,234 @@ +import logging +from typing import Any +from unittest.mock import create_autospec + +import pytest +from langchain_core.outputs import LLMResult + +from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler + +logger = logging.getLogger(__name__) + +try: + from upstash_ratelimit import Ratelimit, Response +except ImportError: + Ratelimit, Response = None, None + + +# Fixtures +@pytest.fixture +def request_ratelimit() -> Ratelimit: + ratelimit = create_autospec(Ratelimit) + response = Response(allowed=True, limit=10, remaining=10, reset=10000) + ratelimit.limit.return_value = response + return ratelimit + + +@pytest.fixture +def token_ratelimit() -> Ratelimit: + ratelimit = create_autospec(Ratelimit) + response = Response(allowed=True, limit=1000, remaining=1000, reset=10000) + ratelimit.limit.return_value = response + ratelimit.get_remaining.return_value = 1000 + return ratelimit + + +@pytest.fixture +def handler_with_both_limits( + request_ratelimit: Ratelimit, token_ratelimit: Ratelimit +) -> UpstashRatelimitHandler: + return UpstashRatelimitHandler( + identifier="user123", + token_ratelimit=token_ratelimit, + request_ratelimit=request_ratelimit, + include_output_tokens=False, + ) + + +# Tests +@pytest.mark.requires("upstash_ratelimit") +def test_init_no_limits() -> None: + with pytest.raises(ValueError): + UpstashRatelimitHandler(identifier="user123") + + +@pytest.mark.requires("upstash_ratelimit") +def test_init_request_limit_only(request_ratelimit: Ratelimit) -> None: + handler = UpstashRatelimitHandler( + identifier="user123", request_ratelimit=request_ratelimit + ) + assert handler.request_ratelimit is not None + assert handler.token_ratelimit is None + + +@pytest.mark.requires("upstash_ratelimit") +def test_init_token_limit_only(token_ratelimit: Ratelimit) -> None: + handler = UpstashRatelimitHandler( + identifier="user123", token_ratelimit=token_ratelimit + ) + assert handler.token_ratelimit is not None + assert handler.request_ratelimit is None + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_chain_start_request_limit(handler_with_both_limits: Any) -> None: + handler_with_both_limits.on_chain_start(serialized={}, inputs={}) + handler_with_both_limits.request_ratelimit.limit.assert_called_once_with("user123") + handler_with_both_limits.token_ratelimit.limit.assert_not_called() + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_chain_start_request_limit_reached(request_ratelimit: Any) -> None: + request_ratelimit.limit.return_value = Response( + allowed=False, limit=10, remaining=0, reset=10000 + ) + handler = UpstashRatelimitHandler( + identifier="user123", token_ratelimit=None, request_ratelimit=request_ratelimit + ) + with pytest.raises(UpstashRatelimitError): + handler.on_chain_start(serialized={}, inputs={}) + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_start_token_limit_reached(token_ratelimit: Any) -> None: + token_ratelimit.get_remaining.return_value = 0 + handler = UpstashRatelimitHandler( + identifier="user123", token_ratelimit=token_ratelimit, request_ratelimit=None + ) + with pytest.raises(UpstashRatelimitError): + handler.on_llm_start(serialized={}, prompts=["test"]) + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_start_token_limit_reached_negative(token_ratelimit: Any) -> None: + token_ratelimit.get_remaining.return_value = -10 + handler = UpstashRatelimitHandler( + identifier="user123", token_ratelimit=token_ratelimit, request_ratelimit=None + ) + with pytest.raises(UpstashRatelimitError): + handler.on_llm_start(serialized={}, prompts=["test"]) + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_end_with_token_limit(handler_with_both_limits: Any) -> None: + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 3, + "total_tokens": 5, + } + }, + ) + handler_with_both_limits.on_llm_end(response) + handler_with_both_limits.token_ratelimit.limit.assert_called_once_with("user123", 2) + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_end_with_token_limit_include_output_tokens( + token_ratelimit: Any, +) -> None: + handler = UpstashRatelimitHandler( + identifier="user123", + token_ratelimit=token_ratelimit, + request_ratelimit=None, + include_output_tokens=True, + ) + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 3, + "total_tokens": 5, + } + }, + ) + handler.on_llm_end(response) + token_ratelimit.limit.assert_called_once_with("user123", 5) + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_end_without_token_usage(handler_with_both_limits: Any) -> None: + response = LLMResult(generations=[], llm_output={}) + with pytest.raises(ValueError): + handler_with_both_limits.on_llm_end(response) + + +@pytest.mark.requires("upstash_ratelimit") +def test_reset_handler(handler_with_both_limits: Any) -> None: + new_handler = handler_with_both_limits.reset(identifier="user456") + assert new_handler.identifier == "user456" + assert not new_handler._checked + + +@pytest.mark.requires("upstash_ratelimit") +def test_reset_handler_no_new_identifier(handler_with_both_limits: Any) -> None: + new_handler = handler_with_both_limits.reset() + assert new_handler.identifier == "user123" + assert not new_handler._checked + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_chain_start_called_once(handler_with_both_limits: Any) -> None: + handler_with_both_limits.on_chain_start(serialized={}, inputs={}) + handler_with_both_limits.on_chain_start(serialized={}, inputs={}) + assert handler_with_both_limits.request_ratelimit.limit.call_count == 1 + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_chain_start_reset_checked(handler_with_both_limits: Any) -> None: + handler_with_both_limits.on_chain_start(serialized={}, inputs={}) + new_handler = handler_with_both_limits.reset(identifier="user456") + new_handler.on_chain_start(serialized={}, inputs={}) + + # becomes two because the mock object is kept in reset + assert new_handler.request_ratelimit.limit.call_count == 2 + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_start_no_token_limit(request_ratelimit: Any) -> None: + handler = UpstashRatelimitHandler( + identifier="user123", token_ratelimit=None, request_ratelimit=request_ratelimit + ) + handler.on_llm_start(serialized={}, prompts=["test"]) + assert request_ratelimit.limit.call_count == 0 + + +@pytest.mark.requires("upstash_ratelimit") +def test_on_llm_start_token_limit(handler_with_both_limits: Any) -> None: + handler_with_both_limits.on_llm_start(serialized={}, prompts=["test"]) + assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1 + + +@pytest.mark.requires("upstash_ratelimit") +def test_full_chain_with_both_limits(handler_with_both_limits: Any) -> None: + handler_with_both_limits.on_chain_start(serialized={}, inputs={}) + handler_with_both_limits.on_chain_start(serialized={}, inputs={}) + + assert handler_with_both_limits.request_ratelimit.limit.call_count == 1 + assert handler_with_both_limits.token_ratelimit.limit.call_count == 0 + assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 0 + + handler_with_both_limits.on_llm_start(serialized={}, prompts=["test"]) + + assert handler_with_both_limits.request_ratelimit.limit.call_count == 1 + assert handler_with_both_limits.token_ratelimit.limit.call_count == 0 + assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1 + + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 3, + "total_tokens": 5, + } + }, + ) + handler_with_both_limits.on_llm_end(response) + + assert handler_with_both_limits.request_ratelimit.limit.call_count == 1 + assert handler_with_both_limits.token_ratelimit.limit.call_count == 1 + assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1