Apply patch

This commit is contained in:
William FH 2025-06-12 18:09:02 +00:00
parent 16b47a8c17
commit 0c5ca73d9d

View File

@ -4,9 +4,19 @@ from __future__ import annotations
import abc
import asyncio
import logging
import random
import threading
import time
from typing import Optional
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Union
# Optional Redis imports - will be None if not available
coredis = None
try:
import coredis
except ImportError:
pass
class BaseRateLimiter(abc.ABC):
@ -253,7 +263,542 @@ class InMemoryRateLimiter(BaseRateLimiter):
return True
class RedisRateLimiter(BaseRateLimiter):
"""A Redis-based distributed rate limiter using sliding window algorithm.
This rate limiter provides distributed rate limiting across multiple processes,
pods, or containers using Redis as the shared state store. It uses a sliding
window counter algorithm implemented with Redis sorted sets and Lua scripts
for atomic operations.
Key features:
- Distributed rate limiting across multiple processes/pods
- Sliding window algorithm for accurate rate limiting
- Atomic operations using Lua scripts
- Connection pooling and Redis cluster support
- Comprehensive error handling with fallback options
- High performance with minimal Redis round trips
Example:
.. code-block:: python
from langchain_core.rate_limiters import RedisRateLimiter
# Basic usage
rate_limiter = RedisRateLimiter(
redis_url="redis://localhost:6379",
requests_per_second=10,
window_size=60, # 60 second sliding window
)
# With Redis cluster
rate_limiter = RedisRateLimiter(
redis_url="redis://node1:7000,redis://node2:7000,redis://node3:7000",
requests_per_second=100,
window_size=60,
redis_cluster=True,
)
# Use with chat model
from langchain_anthropic import ChatAnthropic
model = ChatAnthropic(
model_name="claude-3-opus-20240229",
rate_limiter=rate_limiter
)
.. versionadded:: 0.2.25
"""
# Lua script for atomic rate limit checking and updating
_CHECK_AND_UPDATE_SCRIPT = """
local key = KEYS[1]
local window_size = tonumber(ARGV[1])
local max_requests = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
local cleanup_probability = tonumber(ARGV[4])
-- Remove expired entries (sliding window cleanup)
local cutoff_time = current_time - window_size
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time)
-- Count current requests in window
local current_count = redis.call('ZCARD', key)
-- Check if we can proceed
if current_count < max_requests then
-- Add current request with unique score to handle concurrent requests
local score = current_time + math.random() * 0.001
redis.call('ZADD', key, score, score)
-- Set expiration for the key (window_size + buffer)
redis.call('EXPIRE', key, window_size + 60)
-- Probabilistic cleanup of old entries
if math.random() < cleanup_probability then
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time - window_size)
end
return {1, current_count + 1, max_requests}
else
return {0, current_count, max_requests}
end
"""
# Lua script for non-blocking rate limit check
_CHECK_ONLY_SCRIPT = """
local key = KEYS[1]
local window_size = tonumber(ARGV[1])
local max_requests = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
-- Remove expired entries
local cutoff_time = current_time - window_size
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time)
-- Count current requests in window
local current_count = redis.call('ZCARD', key)
-- Return availability without consuming
if current_count < max_requests then
return {1, current_count, max_requests}
else
return {0, current_count, max_requests}
end
"""
def __init__(
self,
*,
redis_url: str = "redis://localhost:6379",
requests_per_second: float = 1,
window_size: int = 60,
key_prefix: str = "langchain:ratelimit",
identifier: Optional[str] = None,
redis_cluster: bool = False,
connection_pool_size: int = 10,
socket_timeout: float = 5.0,
socket_connect_timeout: float = 5.0,
retry_on_timeout: bool = True,
health_check_interval: int = 30,
fallback_to_memory: bool = True,
cleanup_probability: float = 0.1,
check_every_n_seconds: float = 0.1,
**redis_kwargs: Any,
) -> None:
"""Initialize Redis-based distributed rate limiter.
Args:
redis_url: Redis connection URL. For clusters, provide comma-separated URLs.
requests_per_second: Maximum requests allowed per second.
window_size: Sliding window size in seconds.
key_prefix: Prefix for Redis keys to avoid collisions.
identifier: Unique identifier for this rate limiter instance.
If None, will be auto-generated.
redis_cluster: Whether to use Redis cluster mode.
connection_pool_size: Size of the Redis connection pool.
socket_timeout: Socket timeout for Redis operations.
socket_connect_timeout: Socket connection timeout.
retry_on_timeout: Whether to retry on timeout errors.
health_check_interval: Interval for Redis health checks.
fallback_to_memory: Whether to fallback to InMemoryRateLimiter on Redis errors.
cleanup_probability: Probability of performing cleanup on each request (0.0-1.0).
check_every_n_seconds: Sleep interval when blocking and waiting for tokens.
**redis_kwargs: Additional Redis client configuration.
"""
if coredis is None:
raise ImportError(
"RedisRateLimiter requires the 'coredis' package. "
"Install it with: pip install coredis"
)
self.redis_url = redis_url
self.requests_per_second = requests_per_second
self.window_size = window_size
self.max_requests = int(requests_per_second * window_size)
self.key_prefix = key_prefix
self.identifier = identifier or f"rl_{int(time.time())}_{random.randint(1000, 9999)}"
self.redis_cluster = redis_cluster
self.connection_pool_size = connection_pool_size
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout
self.retry_on_timeout = retry_on_timeout
self.health_check_interval = health_check_interval
self.fallback_to_memory = fallback_to_memory
self.cleanup_probability = cleanup_probability
self.check_every_n_seconds = check_every_n_seconds
self.redis_kwargs = redis_kwargs
# Redis client and connection management
self._redis_client: Optional[Union[coredis.Redis, coredis.RedisCluster]] = None
self._connection_lock = threading.Lock()
self._last_health_check = 0.0
self._redis_healthy = True
self._circuit_breaker_failures = 0
self._circuit_breaker_last_failure = 0.0
self._circuit_breaker_threshold = 5
self._circuit_breaker_timeout = 60.0
# Fallback rate limiter
self._fallback_limiter: Optional[InMemoryRateLimiter] = None
if self.fallback_to_memory:
self._fallback_limiter = InMemoryRateLimiter(
requests_per_second=requests_per_second,
check_every_n_seconds=check_every_n_seconds,
max_bucket_size=max(1, requests_per_second * 2), # Allow some burst
)
# Lua script SHA hashes (will be populated on first use)
self._check_and_update_sha: Optional[str] = None
self._check_only_sha: Optional[str] = None
# Logger
self._logger = logging.getLogger(__name__)
def _get_redis_key(self) -> str:
"""Generate Redis key for this rate limiter instance."""
return f"{self.key_prefix}:{self.identifier}"
def _create_redis_client(self) -> Union[coredis.Redis, coredis.RedisCluster]:
"""Create and configure Redis client."""
common_config = {
"socket_timeout": self.socket_timeout,
"socket_connect_timeout": self.socket_connect_timeout,
"retry_on_timeout": self.retry_on_timeout,
"health_check_interval": self.health_check_interval,
**self.redis_kwargs,
}
if self.redis_cluster:
# Parse cluster URLs
urls = [url.strip() for url in self.redis_url.split(",")]
startup_nodes = []
for url in urls:
if "://" in url:
# Parse redis://host:port format
parts = url.split("://")[1].split(":")
host = parts[0]
port = int(parts[1]) if len(parts) > 1 else 6379
else:
# Parse host:port format
parts = url.split(":")
host = parts[0]
port = int(parts[1]) if len(parts) > 1 else 6379
startup_nodes.append({"host": host, "port": port})
return coredis.RedisCluster(
startup_nodes=startup_nodes,
**common_config,
)
else:
return coredis.Redis.from_url(
self.redis_url,
max_connections=self.connection_pool_size,
**common_config,
)
def _get_redis_client(self) -> Union[coredis.Redis, coredis.RedisCluster]:
"""Get Redis client with connection management and health checking."""
current_time = time.time()
# Check circuit breaker
if (self._circuit_breaker_failures >= self._circuit_breaker_threshold and
current_time - self._circuit_breaker_last_failure < self._circuit_breaker_timeout):
raise ConnectionError("Circuit breaker is open - Redis unavailable")
with self._connection_lock:
# Create client if needed
if self._redis_client is None:
self._redis_client = self._create_redis_client()
# Periodic health check
if current_time - self._last_health_check > self.health_check_interval:
try:
# Simple ping to check health
asyncio.create_task(self._redis_client.ping())
self._redis_healthy = True
self._circuit_breaker_failures = 0
self._last_health_check = current_time
except Exception as e:
self._logger.warning(f"Redis health check failed: {e}")
self._redis_healthy = False
self._circuit_breaker_failures += 1
self._circuit_breaker_last_failure = current_time
# Recreate client on health check failure
try:
if hasattr(self._redis_client, 'close'):
asyncio.create_task(self._redis_client.close())
except Exception:
pass
self._redis_client = None
raise ConnectionError(f"Redis health check failed: {e}")
return self._redis_client
async def _ensure_scripts_loaded(self, redis_client: Union[coredis.Redis, coredis.RedisCluster]) -> None:
"""Ensure Lua scripts are loaded into Redis."""
try:
if self._check_and_update_sha is None:
self._check_and_update_sha = await redis_client.script_load(self._CHECK_AND_UPDATE_SCRIPT)
if self._check_only_sha is None:
self._check_only_sha = await redis_client.script_load(self._CHECK_ONLY_SCRIPT)
except Exception as e:
self._logger.warning(f"Failed to load Lua scripts: {e}")
# Reset SHAs to force reload on next attempt
self._check_and_update_sha = None
self._check_only_sha = None
raise
async def _execute_rate_limit_check(
self,
redis_client: Union[coredis.Redis, coredis.RedisCluster],
consume_token: bool = True
) -> tuple[bool, int, int]:
"""Execute rate limit check using Lua script.
Returns:
Tuple of (allowed, current_count, max_requests)
"""
await self._ensure_scripts_loaded(redis_client)
key = self._get_redis_key()
current_time = time.time()
try:
if consume_token:
result = await redis_client.evalsha(
self._check_and_update_sha,
1, # number of keys
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
str(self.cleanup_probability),
)
else:
result = await redis_client.evalsha(
self._check_only_sha,
1, # number of keys
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
)
allowed, current_count, max_requests = result
return bool(allowed), int(current_count), int(max_requests)
except Exception as e:
# If script execution fails, try to reload and retry once
self._logger.warning(f"Lua script execution failed, reloading: {e}")
self._check_and_update_sha = None
self._check_only_sha = None
await self._ensure_scripts_loaded(redis_client)
if consume_token:
result = await redis_client.evalsha(
self._check_and_update_sha,
1,
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
str(self.cleanup_probability),
)
else:
result = await redis_client.evalsha(
self._check_only_sha,
1,
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
)
allowed, current_count, max_requests = result
return bool(allowed), int(current_count), int(max_requests)
def _should_use_fallback(self, error: Exception) -> bool:
"""Determine if we should use fallback rate limiter based on error type."""
if not self.fallback_to_memory or self._fallback_limiter is None:
return False
# Use fallback for connection errors, timeouts, and Redis unavailability
return isinstance(error, (
ConnectionError,
TimeoutError,
OSError, # Network errors
)) or "connection" in str(error).lower()
def acquire(self, *, blocking: bool = True) -> bool:
"""Attempt to acquire a token from the rate limiter.
Args:
blocking: If True, block until token is available.
If False, return immediately with result.
Returns:
True if token was acquired, False otherwise.
"""
# Run async method in sync context
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, create a new event loop in a thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: asyncio.run(self.aacquire(blocking=blocking))
)
return future.result()
else:
return loop.run_until_complete(self.aacquire(blocking=blocking))
except RuntimeError:
# No event loop, create one
return asyncio.run(self.aacquire(blocking=blocking))
async def aacquire(self, *, blocking: bool = True) -> bool:
"""Attempt to acquire a token from the rate limiter. Async version.
Args:
blocking: If True, block until token is available.
If False, return immediately with result.
Returns:
True if token was acquired, False otherwise.
"""
max_retries = 3
retry_delay = 0.1
for attempt in range(max_retries):
try:
redis_client = self._get_redis_client()
if not blocking:
# Non-blocking: just check availability
allowed, current_count, max_requests = await self._execute_rate_limit_check(
redis_client, consume_token=False
)
if allowed:
# If available, consume the token
allowed, _, _ = await self._execute_rate_limit_check(
redis_client, consume_token=True
)
return allowed
else:
# Blocking: keep trying until we get a token
while True:
allowed, current_count, max_requests = await self._execute_rate_limit_check(
redis_client, consume_token=True
)
if allowed:
return True
# Wait before retrying
await asyncio.sleep(self.check_every_n_seconds)
except Exception as e:
self._logger.warning(f"Redis rate limiter error (attempt {attempt + 1}): {e}")
# Update circuit breaker
self._circuit_breaker_failures += 1
self._circuit_breaker_last_failure = time.time()
if self._should_use_fallback(e):
self._logger.info("Falling back to in-memory rate limiter")
return await self._fallback_limiter.aacquire(blocking=blocking)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
else:
# Final attempt failed
if self.fallback_to_memory and self._fallback_limiter:
self._logger.error(f"All Redis attempts failed, using fallback: {e}")
return await self._fallback_limiter.aacquire(blocking=blocking)
else:
raise ConnectionError(f"Redis rate limiter unavailable: {e}")
# Should not reach here, but just in case
if self.fallback_to_memory and self._fallback_limiter:
return await self._fallback_limiter.aacquire(blocking=blocking)
return False
async def get_current_usage(self) -> Dict[str, Union[int, float]]:
"""Get current rate limiter usage statistics.
Returns:
Dictionary with usage statistics including current count,
max requests, and utilization percentage.
"""
try:
redis_client = self._get_redis_client()
allowed, current_count, max_requests = await self._execute_rate_limit_check(
redis_client, consume_token=False
)
utilization = (current_count / max_requests) * 100 if max_requests > 0 else 0
return {
"current_count": current_count,
"max_requests": max_requests,
"utilization_percent": utilization,
"requests_per_second": self.requests_per_second,
"window_size": self.window_size,
"redis_healthy": self._redis_healthy,
}
except Exception as e:
self._logger.warning(f"Failed to get usage statistics: {e}")
return {
"current_count": -1,
"max_requests": self.max_requests,
"utilization_percent": -1,
"requests_per_second": self.requests_per_second,
"window_size": self.window_size,
"redis_healthy": False,
"error": str(e),
}
async def reset(self) -> bool:
"""Reset the rate limiter by clearing all stored data.
Returns:
True if reset was successful, False otherwise.
"""
try:
redis_client = self._get_redis_client()
key = self._get_redis_key()
await redis_client.delete(key)
self._logger.info(f"Rate limiter reset for key: {key}")
return True
except Exception as e:
self._logger.error(f"Failed to reset rate limiter: {e}")
return False
async def close(self) -> None:
"""Close Redis connections and cleanup resources."""
if self._redis_client:
try:
if hasattr(self._redis_client, 'close'):
await self._redis_client.close()
except Exception as e:
self._logger.warning(f"Error closing Redis client: {e}")
finally:
self._redis_client = None
def __del__(self) -> None:
"""Cleanup on object destruction."""
if self._redis_client:
try:
asyncio.create_task(self.close())
except Exception:
pass
__all__ = [
"BaseRateLimiter",
"InMemoryRateLimiter",
"RedisRateLimiter",
]