mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
community[minor]: Add UpstashRatelimitHandler (#21885)
Adding `UpstashRatelimitHandler` callback for rate limiting based on number of chain invocations or LLM token usage. For more details, see [upstash/ratelimit-py repository](https://github.com/upstash/ratelimit-py) or the notebook guide included in this PR. Twitter handle: @cahidarda --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
9b3ce16982
commit
6c07eb0c12
245
docs/docs/integrations/callbacks/upstash_ratelimit.ipynb
Normal file
245
docs/docs/integrations/callbacks/upstash_ratelimit.ipynb
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Upstash Ratelimit Callback\n",
|
||||||
|
"\n",
|
||||||
|
"In this guide, we will go over how to add rate limiting based on number of requests or the number of tokens using `UpstashRatelimitHandler`. This handler uses [ratelimit library of Upstash](https://github.com/upstash/ratelimit-py/), which utilizes [Upstash Redis](https://upstash.com/docs/redis/overall/getstarted).\n",
|
||||||
|
"\n",
|
||||||
|
"Upstash Ratelimit works by sending an HTTP request to Upstash Redis everytime the `limit` method is called. Remaining tokens/requests of the user are checked and updated. Based on the remaining tokens, we can stop the execution of costly operations like invoking an LLM or querying a vector store:\n",
|
||||||
|
"\n",
|
||||||
|
"```py\n",
|
||||||
|
"response = ratelimit.limit()\n",
|
||||||
|
"if response.allowed:\n",
|
||||||
|
" execute_costly_operation()\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"`UpstashRatelimitHandler` allows you to incorporate the ratelimit logic into your chain in a few minutes.\n",
|
||||||
|
"\n",
|
||||||
|
"First, you will need to go to [the Upstash Console](https://console.upstash.com/login) and create a redis database ([see our docs](https://upstash.com/docs/redis/overall/getstarted)). After creating a database, you will need to set the environment variables:\n",
|
||||||
|
"\n",
|
||||||
|
"```\n",
|
||||||
|
"UPSTASH_REDIS_REST_URL=\"****\"\n",
|
||||||
|
"UPSTASH_REDIS_REST_TOKEN=\"****\"\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Next, you will need to install Upstash Ratelimit and Redis library with:\n",
|
||||||
|
"\n",
|
||||||
|
"```\n",
|
||||||
|
"pip install upstash-ratelimit upstash-redis\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"You are now ready to add rate limiting to your chain!\n",
|
||||||
|
"\n",
|
||||||
|
"## Ratelimiting Per Request\n",
|
||||||
|
"\n",
|
||||||
|
"Let's imagine that we want to allow our users to invoke our chain 10 times per minute. Achieving this is as simple as:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error in UpstashRatelimitHandler.on_chain_start callback: UpstashRatelimitError('Request limit reached!')\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Handling ratelimit. <class 'langchain_community.callbacks.upstash_ratelimit_callback.UpstashRatelimitError'>\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# set env variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"UPSTASH_REDIS_REST_URL\"] = \"****\"\n",
|
||||||
|
"os.environ[\"UPSTASH_REDIS_REST_TOKEN\"] = \"****\"\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler\n",
|
||||||
|
"from langchain_core.runnables import RunnableLambda\n",
|
||||||
|
"from upstash_ratelimit import FixedWindow, Ratelimit\n",
|
||||||
|
"from upstash_redis import Redis\n",
|
||||||
|
"\n",
|
||||||
|
"# create ratelimit\n",
|
||||||
|
"ratelimit = Ratelimit(\n",
|
||||||
|
" redis=Redis.from_env(),\n",
|
||||||
|
" # 10 requests per window, where window size is 60 seconds:\n",
|
||||||
|
" limiter=FixedWindow(max_requests=10, window=60),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# create handler\n",
|
||||||
|
"user_id = \"user_id\" # should be a method which gets the user id\n",
|
||||||
|
"handler = UpstashRatelimitHandler(identifier=user_id, request_ratelimit=ratelimit)\n",
|
||||||
|
"\n",
|
||||||
|
"# create mock chain\n",
|
||||||
|
"chain = RunnableLambda(str)\n",
|
||||||
|
"\n",
|
||||||
|
"# invoke chain with handler:\n",
|
||||||
|
"try:\n",
|
||||||
|
" result = chain.invoke(\"Hello world!\", config={\"callbacks\": [handler]})\n",
|
||||||
|
"except UpstashRatelimitError:\n",
|
||||||
|
" print(\"Handling ratelimit.\", UpstashRatelimitError)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Note that we pass the handler to the `invoke` method instead of passing the handler when defining the chain.\n",
|
||||||
|
"\n",
|
||||||
|
"For rate limiting algorithms other than `FixedWindow`, see [upstash-ratelimit docs](https://github.com/upstash/ratelimit-py?tab=readme-ov-file#ratelimiting-algorithms).\n",
|
||||||
|
"\n",
|
||||||
|
"Before executing any steps in our pipeline, ratelimit will check whether the user has passed the request limit. If so, `UpstashRatelimitError` is raised.\n",
|
||||||
|
"\n",
|
||||||
|
"## Ratelimiting Per Token\n",
|
||||||
|
"\n",
|
||||||
|
"Another option is to rate limit chain invokations based on:\n",
|
||||||
|
"1. number of tokens in prompt\n",
|
||||||
|
"2. number of tokens in prompt and LLM completion\n",
|
||||||
|
"\n",
|
||||||
|
"This only works if you have an LLM in your chain. Another requirement is that the LLM you are using should return the token usage in it's `LLMOutput`.\n",
|
||||||
|
"\n",
|
||||||
|
"### How it works\n",
|
||||||
|
"\n",
|
||||||
|
"The handler will get the remaining tokens before calling the LLM. If the remaining tokens is more than 0, LLM will be called. Otherwise `UpstashRatelimitError` will be raised.\n",
|
||||||
|
"\n",
|
||||||
|
"After LLM is called, token usage information will be used to subtracted from the remaining tokens of the user. No error is raised at this stage of the chain.\n",
|
||||||
|
"\n",
|
||||||
|
"### Configuration\n",
|
||||||
|
"\n",
|
||||||
|
"For the first configuration, simply initialize the handler like this:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ratelimit = Ratelimit(\n",
|
||||||
|
" redis=Redis.from_env(),\n",
|
||||||
|
" # 1000 tokens per window, where window size is 60 seconds:\n",
|
||||||
|
" limiter=FixedWindow(max_requests=1000, window=60),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"handler = UpstashRatelimitHandler(identifier=user_id, token_ratelimit=ratelimit)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"For the second configuration, here is how to initialize the handler:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ratelimit = Ratelimit(\n",
|
||||||
|
" redis=Redis.from_env(),\n",
|
||||||
|
" # 1000 tokens per window, where window size is 60 seconds:\n",
|
||||||
|
" limiter=FixedWindow(max_requests=1000, window=60),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"handler = UpstashRatelimitHandler(\n",
|
||||||
|
" identifier=user_id,\n",
|
||||||
|
" token_ratelimit=ratelimit,\n",
|
||||||
|
" include_output_tokens=True, # set to True\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also employ ratelimiting based on requests and tokens at the same time, simply by passing both `request_ratelimit` and `token_ratelimit` parameters.\n",
|
||||||
|
"\n",
|
||||||
|
"Here is an example with a chain utilizing an LLM:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Error in UpstashRatelimitHandler.on_llm_start callback: UpstashRatelimitError('Token limit reached!')\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Handling ratelimit. <class 'langchain_community.callbacks.upstash_ratelimit_callback.UpstashRatelimitError'>\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# set env variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"UPSTASH_REDIS_REST_URL\"] = \"****\"\n",
|
||||||
|
"os.environ[\"UPSTASH_REDIS_REST_TOKEN\"] = \"****\"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = \"****\"\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler\n",
|
||||||
|
"from langchain_core.runnables import RunnableLambda\n",
|
||||||
|
"from langchain_openai import ChatOpenAI\n",
|
||||||
|
"from upstash_ratelimit import FixedWindow, Ratelimit\n",
|
||||||
|
"from upstash_redis import Redis\n",
|
||||||
|
"\n",
|
||||||
|
"# create ratelimit\n",
|
||||||
|
"ratelimit = Ratelimit(\n",
|
||||||
|
" redis=Redis.from_env(),\n",
|
||||||
|
" # 500 tokens per window, where window size is 60 seconds:\n",
|
||||||
|
" limiter=FixedWindow(max_requests=500, window=60),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# create handler\n",
|
||||||
|
"user_id = \"user_id\" # should be a method which gets the user id\n",
|
||||||
|
"handler = UpstashRatelimitHandler(identifier=user_id, token_ratelimit=ratelimit)\n",
|
||||||
|
"\n",
|
||||||
|
"# create mock chain\n",
|
||||||
|
"as_str = RunnableLambda(str)\n",
|
||||||
|
"model = ChatOpenAI()\n",
|
||||||
|
"\n",
|
||||||
|
"chain = as_str | model\n",
|
||||||
|
"\n",
|
||||||
|
"# invoke chain with handler:\n",
|
||||||
|
"try:\n",
|
||||||
|
" result = chain.invoke(\"Hello world!\", config={\"callbacks\": [handler]})\n",
|
||||||
|
"except UpstashRatelimitError:\n",
|
||||||
|
" print(\"Handling ratelimit.\", UpstashRatelimitError)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "lc39",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -80,7 +80,8 @@ timescale-vector==0.0.1
|
|||||||
tqdm>=4.48.0
|
tqdm>=4.48.0
|
||||||
tree-sitter>=0.20.2,<0.21
|
tree-sitter>=0.20.2,<0.21
|
||||||
tree-sitter-languages>=1.8.0,<2
|
tree-sitter-languages>=1.8.0,<2
|
||||||
upstash-redis>=0.15.0,<0.16
|
upstash-redis>=1.1.0,<2
|
||||||
|
upstash-ratelimit>=1.1.0,<2
|
||||||
vdms==0.0.20
|
vdms==0.0.20
|
||||||
xata>=1.0.0a7,<2
|
xata>=1.0.0a7,<2
|
||||||
xmltodict>=0.13.0,<0.14
|
xmltodict>=0.13.0,<0.14
|
||||||
|
@ -72,6 +72,10 @@ if TYPE_CHECKING:
|
|||||||
from langchain_community.callbacks.trubrics_callback import (
|
from langchain_community.callbacks.trubrics_callback import (
|
||||||
TrubricsCallbackHandler,
|
TrubricsCallbackHandler,
|
||||||
)
|
)
|
||||||
|
from langchain_community.callbacks.upstash_ratelimit_callback import (
|
||||||
|
UpstashRatelimitError,
|
||||||
|
UpstashRatelimitHandler, # noqa: F401
|
||||||
|
)
|
||||||
from langchain_community.callbacks.uptrain_callback import (
|
from langchain_community.callbacks.uptrain_callback import (
|
||||||
UpTrainCallbackHandler,
|
UpTrainCallbackHandler,
|
||||||
)
|
)
|
||||||
@ -104,6 +108,8 @@ _module_lookup = {
|
|||||||
"SageMakerCallbackHandler": "langchain_community.callbacks.sagemaker_callback",
|
"SageMakerCallbackHandler": "langchain_community.callbacks.sagemaker_callback",
|
||||||
"StreamlitCallbackHandler": "langchain_community.callbacks.streamlit",
|
"StreamlitCallbackHandler": "langchain_community.callbacks.streamlit",
|
||||||
"TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback",
|
"TrubricsCallbackHandler": "langchain_community.callbacks.trubrics_callback",
|
||||||
|
"UpstashRatelimitError": "langchain_community.callbacks.upstash_ratelimit_callback",
|
||||||
|
"UpstashRatelimitHandler": "langchain_community.callbacks.upstash_ratelimit_callback", # noqa
|
||||||
"UpTrainCallbackHandler": "langchain_community.callbacks.uptrain_callback",
|
"UpTrainCallbackHandler": "langchain_community.callbacks.uptrain_callback",
|
||||||
"WandbCallbackHandler": "langchain_community.callbacks.wandb_callback",
|
"WandbCallbackHandler": "langchain_community.callbacks.wandb_callback",
|
||||||
"WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback",
|
"WhyLabsCallbackHandler": "langchain_community.callbacks.whylabs_callback",
|
||||||
@ -140,6 +146,8 @@ __all__ = [
|
|||||||
"SageMakerCallbackHandler",
|
"SageMakerCallbackHandler",
|
||||||
"StreamlitCallbackHandler",
|
"StreamlitCallbackHandler",
|
||||||
"TrubricsCallbackHandler",
|
"TrubricsCallbackHandler",
|
||||||
|
"UpstashRatelimitError",
|
||||||
|
"UpstashRatelimitHandler",
|
||||||
"UpTrainCallbackHandler",
|
"UpTrainCallbackHandler",
|
||||||
"WandbCallbackHandler",
|
"WandbCallbackHandler",
|
||||||
"WhyLabsCallbackHandler",
|
"WhyLabsCallbackHandler",
|
||||||
|
@ -0,0 +1,206 @@
|
|||||||
|
"""Ratelimiting Handler to limit requests or tokens"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
try:
|
||||||
|
from upstash_ratelimit import Ratelimit
|
||||||
|
except ImportError:
|
||||||
|
Ratelimit = None
|
||||||
|
|
||||||
|
|
||||||
|
class UpstashRatelimitError(Exception):
|
||||||
|
"""
|
||||||
|
Upstash Ratelimit Error
|
||||||
|
|
||||||
|
Raised when the rate limit is reached in `UpstashRatelimitHandler`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
type: Literal["token", "request"],
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
reset: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
message (str): error message
|
||||||
|
type (str): The kind of the limit which was reached. One of
|
||||||
|
"token" or "request"
|
||||||
|
limit (Optional[int]): The limit which was reached. Passed when type
|
||||||
|
is request
|
||||||
|
reset (Optional[int]): unix timestamp in milliseconds when the limits
|
||||||
|
are reset. Passed when type is request
|
||||||
|
"""
|
||||||
|
# Call the base class constructor with the parameters it needs
|
||||||
|
super().__init__(message)
|
||||||
|
self.type = type
|
||||||
|
self.limit = limit
|
||||||
|
self.reset = reset
|
||||||
|
|
||||||
|
|
||||||
|
class UpstashRatelimitHandler(BaseCallbackHandler):
|
||||||
|
"""
|
||||||
|
Callback to handle rate limiting based on the number of requests
|
||||||
|
or the number of tokens in the input.
|
||||||
|
|
||||||
|
It uses Upstash Ratelimit to track the ratelimit which utilizes
|
||||||
|
Upstash Redis to track the state.
|
||||||
|
|
||||||
|
Should not be passed to the chain when initialising the chain.
|
||||||
|
This is because the handler has a state which should be fresh
|
||||||
|
every time invoke is called. Instead, initialise and pass a handler
|
||||||
|
every time you invoke.
|
||||||
|
"""
|
||||||
|
|
||||||
|
raise_error = True
|
||||||
|
_checked = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
identifier: str,
|
||||||
|
*,
|
||||||
|
token_ratelimit: Optional[Ratelimit] = None,
|
||||||
|
request_ratelimit: Optional[Ratelimit] = None,
|
||||||
|
include_output_tokens: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates UpstashRatelimitHandler. Must be passed an identifier to
|
||||||
|
ratelimit like a user id or an ip address.
|
||||||
|
|
||||||
|
Additionally, it must be passed at least one of token_ratelimit
|
||||||
|
or request_ratelimit parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier Union[int, str]: the identifier
|
||||||
|
token_ratelimit Optional[Ratelimit]: Ratelimit to limit the
|
||||||
|
number of tokens. Only works with OpenAI models since only
|
||||||
|
these models provide the number of tokens as information
|
||||||
|
in their output.
|
||||||
|
request_ratelimit Optional[Ratelimit]: Ratelimit to limit the
|
||||||
|
number of requests
|
||||||
|
include_output_tokens bool: Whether to count output tokens when
|
||||||
|
rate limiting based on number of tokens. Only used when
|
||||||
|
`token_ratelimit` is passed. False by default.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from upstash_redis import Redis
|
||||||
|
from upstash_ratelimit import Ratelimit, FixedWindow
|
||||||
|
|
||||||
|
redis = Redis.from_env()
|
||||||
|
ratelimit = Ratelimit(
|
||||||
|
redis=redis,
|
||||||
|
# fixed window to allow 10 requests every 10 seconds:
|
||||||
|
limiter=FixedWindow(max_requests=10, window=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = "foo"
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier=user_id,
|
||||||
|
request_ratelimit=ratelimit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize a simple runnable to test
|
||||||
|
chain = RunnableLambda(str)
|
||||||
|
|
||||||
|
# pass handler as callback:
|
||||||
|
output = chain.invoke(
|
||||||
|
"input",
|
||||||
|
config={
|
||||||
|
"callbacks": [handler]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not any([token_ratelimit, request_ratelimit]):
|
||||||
|
raise ValueError(
|
||||||
|
"You must pass at least one of input_token_ratelimit or"
|
||||||
|
" request_ratelimit parameters for handler to work."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.identifier = identifier
|
||||||
|
self.token_ratelimit = token_ratelimit
|
||||||
|
self.request_ratelimit = request_ratelimit
|
||||||
|
self.include_output_tokens = include_output_tokens
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Run when chain starts running.
|
||||||
|
|
||||||
|
on_chain_start runs multiple times during a chain execution. To make
|
||||||
|
sure that it's only called once, we keep a bool state `_checked`. If
|
||||||
|
not `self._checked`, we call limit with `request_ratelimit` and raise
|
||||||
|
`UpstashRatelimitError` if the identifier is rate limited.
|
||||||
|
"""
|
||||||
|
if self.request_ratelimit and not self._checked:
|
||||||
|
response = self.request_ratelimit.limit(self.identifier)
|
||||||
|
if not response.allowed:
|
||||||
|
raise UpstashRatelimitError(
|
||||||
|
"Request limit reached!", "request", response.limit, response.reset
|
||||||
|
)
|
||||||
|
self._checked = True
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run when LLM starts running
|
||||||
|
"""
|
||||||
|
if self.token_ratelimit:
|
||||||
|
remaining = self.token_ratelimit.get_remaining(self.identifier)
|
||||||
|
if remaining <= 0:
|
||||||
|
raise UpstashRatelimitError("Token limit reached!", "token")
|
||||||
|
|
||||||
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Run when LLM ends running
|
||||||
|
|
||||||
|
If the `include_output_tokens` is set to True, number of tokens
|
||||||
|
in LLM completion are counted for rate limiting
|
||||||
|
"""
|
||||||
|
if self.token_ratelimit:
|
||||||
|
try:
|
||||||
|
llm_output = response.llm_output or {}
|
||||||
|
token_usage = llm_output["token_usage"]
|
||||||
|
token_count = (
|
||||||
|
token_usage["total_tokens"]
|
||||||
|
if self.include_output_tokens
|
||||||
|
else token_usage["prompt_tokens"]
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
"LLM response doesn't include"
|
||||||
|
" `token_usage: {total_tokens: int, prompt_tokens: int}`"
|
||||||
|
" field. To use UpstashRatelimitHandler with token_ratelimit,"
|
||||||
|
" either use a model which returns token_usage (like "
|
||||||
|
" OpenAI models) or rate limit only with request_ratelimit."
|
||||||
|
)
|
||||||
|
|
||||||
|
# call limit to add the completion tokens to rate limit
|
||||||
|
# but don't raise exception since we already generated
|
||||||
|
# the tokens and would rather continue execution.
|
||||||
|
self.token_ratelimit.limit(self.identifier, rate=token_count)
|
||||||
|
|
||||||
|
def reset(self, identifier: Optional[str] = None) -> "UpstashRatelimitHandler":
|
||||||
|
"""
|
||||||
|
Creates a new UpstashRatelimitHandler object with the same
|
||||||
|
ratelimit configurations but with a new identifier if it's
|
||||||
|
provided.
|
||||||
|
|
||||||
|
Also resets the state of the handler.
|
||||||
|
"""
|
||||||
|
return UpstashRatelimitHandler(
|
||||||
|
identifier=identifier or self.identifier,
|
||||||
|
token_ratelimit=self.token_ratelimit,
|
||||||
|
request_ratelimit=self.request_ratelimit,
|
||||||
|
include_output_tokens=self.include_output_tokens,
|
||||||
|
)
|
@ -26,6 +26,8 @@ EXPECTED_ALL = [
|
|||||||
"TrubricsCallbackHandler",
|
"TrubricsCallbackHandler",
|
||||||
"FiddlerCallbackHandler",
|
"FiddlerCallbackHandler",
|
||||||
"UpTrainCallbackHandler",
|
"UpTrainCallbackHandler",
|
||||||
|
"UpstashRatelimitError",
|
||||||
|
"UpstashRatelimitHandler",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,234 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import create_autospec
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
from langchain_community.callbacks import UpstashRatelimitError, UpstashRatelimitHandler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from upstash_ratelimit import Ratelimit, Response
|
||||||
|
except ImportError:
|
||||||
|
Ratelimit, Response = None, None
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def request_ratelimit() -> Ratelimit:
|
||||||
|
ratelimit = create_autospec(Ratelimit)
|
||||||
|
response = Response(allowed=True, limit=10, remaining=10, reset=10000)
|
||||||
|
ratelimit.limit.return_value = response
|
||||||
|
return ratelimit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def token_ratelimit() -> Ratelimit:
|
||||||
|
ratelimit = create_autospec(Ratelimit)
|
||||||
|
response = Response(allowed=True, limit=1000, remaining=1000, reset=10000)
|
||||||
|
ratelimit.limit.return_value = response
|
||||||
|
ratelimit.get_remaining.return_value = 1000
|
||||||
|
return ratelimit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler_with_both_limits(
|
||||||
|
request_ratelimit: Ratelimit, token_ratelimit: Ratelimit
|
||||||
|
) -> UpstashRatelimitHandler:
|
||||||
|
return UpstashRatelimitHandler(
|
||||||
|
identifier="user123",
|
||||||
|
token_ratelimit=token_ratelimit,
|
||||||
|
request_ratelimit=request_ratelimit,
|
||||||
|
include_output_tokens=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_init_no_limits() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
UpstashRatelimitHandler(identifier="user123")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_init_request_limit_only(request_ratelimit: Ratelimit) -> None:
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123", request_ratelimit=request_ratelimit
|
||||||
|
)
|
||||||
|
assert handler.request_ratelimit is not None
|
||||||
|
assert handler.token_ratelimit is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_init_token_limit_only(token_ratelimit: Ratelimit) -> None:
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123", token_ratelimit=token_ratelimit
|
||||||
|
)
|
||||||
|
assert handler.token_ratelimit is not None
|
||||||
|
assert handler.request_ratelimit is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_chain_start_request_limit(handler_with_both_limits: Any) -> None:
|
||||||
|
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||||
|
handler_with_both_limits.request_ratelimit.limit.assert_called_once_with("user123")
|
||||||
|
handler_with_both_limits.token_ratelimit.limit.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_chain_start_request_limit_reached(request_ratelimit: Any) -> None:
|
||||||
|
request_ratelimit.limit.return_value = Response(
|
||||||
|
allowed=False, limit=10, remaining=0, reset=10000
|
||||||
|
)
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123", token_ratelimit=None, request_ratelimit=request_ratelimit
|
||||||
|
)
|
||||||
|
with pytest.raises(UpstashRatelimitError):
|
||||||
|
handler.on_chain_start(serialized={}, inputs={})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_start_token_limit_reached(token_ratelimit: Any) -> None:
|
||||||
|
token_ratelimit.get_remaining.return_value = 0
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123", token_ratelimit=token_ratelimit, request_ratelimit=None
|
||||||
|
)
|
||||||
|
with pytest.raises(UpstashRatelimitError):
|
||||||
|
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_start_token_limit_reached_negative(token_ratelimit: Any) -> None:
|
||||||
|
token_ratelimit.get_remaining.return_value = -10
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123", token_ratelimit=token_ratelimit, request_ratelimit=None
|
||||||
|
)
|
||||||
|
with pytest.raises(UpstashRatelimitError):
|
||||||
|
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_end_with_token_limit(handler_with_both_limits: Any) -> None:
|
||||||
|
response = LLMResult(
|
||||||
|
generations=[],
|
||||||
|
llm_output={
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"completion_tokens": 3,
|
||||||
|
"total_tokens": 5,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler_with_both_limits.on_llm_end(response)
|
||||||
|
handler_with_both_limits.token_ratelimit.limit.assert_called_once_with("user123", 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_end_with_token_limit_include_output_tokens(
|
||||||
|
token_ratelimit: Any,
|
||||||
|
) -> None:
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123",
|
||||||
|
token_ratelimit=token_ratelimit,
|
||||||
|
request_ratelimit=None,
|
||||||
|
include_output_tokens=True,
|
||||||
|
)
|
||||||
|
response = LLMResult(
|
||||||
|
generations=[],
|
||||||
|
llm_output={
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"completion_tokens": 3,
|
||||||
|
"total_tokens": 5,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler.on_llm_end(response)
|
||||||
|
token_ratelimit.limit.assert_called_once_with("user123", 5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_end_without_token_usage(handler_with_both_limits: Any) -> None:
|
||||||
|
response = LLMResult(generations=[], llm_output={})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
handler_with_both_limits.on_llm_end(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_reset_handler(handler_with_both_limits: Any) -> None:
|
||||||
|
new_handler = handler_with_both_limits.reset(identifier="user456")
|
||||||
|
assert new_handler.identifier == "user456"
|
||||||
|
assert not new_handler._checked
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_reset_handler_no_new_identifier(handler_with_both_limits: Any) -> None:
|
||||||
|
new_handler = handler_with_both_limits.reset()
|
||||||
|
assert new_handler.identifier == "user123"
|
||||||
|
assert not new_handler._checked
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_chain_start_called_once(handler_with_both_limits: Any) -> None:
|
||||||
|
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||||
|
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||||
|
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_chain_start_reset_checked(handler_with_both_limits: Any) -> None:
|
||||||
|
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||||
|
new_handler = handler_with_both_limits.reset(identifier="user456")
|
||||||
|
new_handler.on_chain_start(serialized={}, inputs={})
|
||||||
|
|
||||||
|
# becomes two because the mock object is kept in reset
|
||||||
|
assert new_handler.request_ratelimit.limit.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_start_no_token_limit(request_ratelimit: Any) -> None:
|
||||||
|
handler = UpstashRatelimitHandler(
|
||||||
|
identifier="user123", token_ratelimit=None, request_ratelimit=request_ratelimit
|
||||||
|
)
|
||||||
|
handler.on_llm_start(serialized={}, prompts=["test"])
|
||||||
|
assert request_ratelimit.limit.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_on_llm_start_token_limit(handler_with_both_limits: Any) -> None:
|
||||||
|
handler_with_both_limits.on_llm_start(serialized={}, prompts=["test"])
|
||||||
|
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("upstash_ratelimit")
|
||||||
|
def test_full_chain_with_both_limits(handler_with_both_limits: Any) -> None:
|
||||||
|
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||||
|
handler_with_both_limits.on_chain_start(serialized={}, inputs={})
|
||||||
|
|
||||||
|
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||||
|
assert handler_with_both_limits.token_ratelimit.limit.call_count == 0
|
||||||
|
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 0
|
||||||
|
|
||||||
|
handler_with_both_limits.on_llm_start(serialized={}, prompts=["test"])
|
||||||
|
|
||||||
|
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||||
|
assert handler_with_both_limits.token_ratelimit.limit.call_count == 0
|
||||||
|
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1
|
||||||
|
|
||||||
|
response = LLMResult(
|
||||||
|
generations=[],
|
||||||
|
llm_output={
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 2,
|
||||||
|
"completion_tokens": 3,
|
||||||
|
"total_tokens": 5,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
handler_with_both_limits.on_llm_end(response)
|
||||||
|
|
||||||
|
assert handler_with_both_limits.request_ratelimit.limit.call_count == 1
|
||||||
|
assert handler_with_both_limits.token_ratelimit.limit.call_count == 1
|
||||||
|
assert handler_with_both_limits.token_ratelimit.get_remaining.call_count == 1
|
Loading…
Reference in New Issue
Block a user