Apply patch

This commit is contained in:
William FH 2025-06-12 18:09:02 +00:00
parent 16b47a8c17
commit 0c5ca73d9d

View File

@ -4,9 +4,19 @@ from __future__ import annotations
import abc import abc
import asyncio import asyncio
import logging
import random
import threading import threading
import time import time
from typing import Optional from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Union
# Optional Redis imports - will be None if not available
coredis = None
try:
import coredis
except ImportError:
pass
class BaseRateLimiter(abc.ABC): class BaseRateLimiter(abc.ABC):
@ -253,7 +263,542 @@ class InMemoryRateLimiter(BaseRateLimiter):
return True return True
class RedisRateLimiter(BaseRateLimiter):
"""A Redis-based distributed rate limiter using sliding window algorithm.
This rate limiter provides distributed rate limiting across multiple processes,
pods, or containers using Redis as the shared state store. It uses a sliding
window counter algorithm implemented with Redis sorted sets and Lua scripts
for atomic operations.
Key features:
- Distributed rate limiting across multiple processes/pods
- Sliding window algorithm for accurate rate limiting
- Atomic operations using Lua scripts
- Connection pooling and Redis cluster support
- Comprehensive error handling with fallback options
- High performance with minimal Redis round trips
Example:
.. code-block:: python
from langchain_core.rate_limiters import RedisRateLimiter
# Basic usage
rate_limiter = RedisRateLimiter(
redis_url="redis://localhost:6379",
requests_per_second=10,
window_size=60, # 60 second sliding window
)
# With Redis cluster
rate_limiter = RedisRateLimiter(
redis_url="redis://node1:7000,redis://node2:7000,redis://node3:7000",
requests_per_second=100,
window_size=60,
redis_cluster=True,
)
# Use with chat model
from langchain_anthropic import ChatAnthropic
model = ChatAnthropic(
model_name="claude-3-opus-20240229",
rate_limiter=rate_limiter
)
.. versionadded:: 0.2.25
"""
# Lua script for atomic rate limit checking and updating
_CHECK_AND_UPDATE_SCRIPT = """
local key = KEYS[1]
local window_size = tonumber(ARGV[1])
local max_requests = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
local cleanup_probability = tonumber(ARGV[4])
-- Remove expired entries (sliding window cleanup)
local cutoff_time = current_time - window_size
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time)
-- Count current requests in window
local current_count = redis.call('ZCARD', key)
-- Check if we can proceed
if current_count < max_requests then
-- Add current request with unique score to handle concurrent requests
local score = current_time + math.random() * 0.001
redis.call('ZADD', key, score, score)
-- Set expiration for the key (window_size + buffer)
redis.call('EXPIRE', key, window_size + 60)
-- Probabilistic cleanup of old entries
if math.random() < cleanup_probability then
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time - window_size)
end
return {1, current_count + 1, max_requests}
else
return {0, current_count, max_requests}
end
"""
# Lua script for non-blocking rate limit check
_CHECK_ONLY_SCRIPT = """
local key = KEYS[1]
local window_size = tonumber(ARGV[1])
local max_requests = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
-- Remove expired entries
local cutoff_time = current_time - window_size
redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time)
-- Count current requests in window
local current_count = redis.call('ZCARD', key)
-- Return availability without consuming
if current_count < max_requests then
return {1, current_count, max_requests}
else
return {0, current_count, max_requests}
end
"""
def __init__(
self,
*,
redis_url: str = "redis://localhost:6379",
requests_per_second: float = 1,
window_size: int = 60,
key_prefix: str = "langchain:ratelimit",
identifier: Optional[str] = None,
redis_cluster: bool = False,
connection_pool_size: int = 10,
socket_timeout: float = 5.0,
socket_connect_timeout: float = 5.0,
retry_on_timeout: bool = True,
health_check_interval: int = 30,
fallback_to_memory: bool = True,
cleanup_probability: float = 0.1,
check_every_n_seconds: float = 0.1,
**redis_kwargs: Any,
) -> None:
"""Initialize Redis-based distributed rate limiter.
Args:
redis_url: Redis connection URL. For clusters, provide comma-separated URLs.
requests_per_second: Maximum requests allowed per second.
window_size: Sliding window size in seconds.
key_prefix: Prefix for Redis keys to avoid collisions.
identifier: Unique identifier for this rate limiter instance.
If None, will be auto-generated.
redis_cluster: Whether to use Redis cluster mode.
connection_pool_size: Size of the Redis connection pool.
socket_timeout: Socket timeout for Redis operations.
socket_connect_timeout: Socket connection timeout.
retry_on_timeout: Whether to retry on timeout errors.
health_check_interval: Interval for Redis health checks.
fallback_to_memory: Whether to fallback to InMemoryRateLimiter on Redis errors.
cleanup_probability: Probability of performing cleanup on each request (0.0-1.0).
check_every_n_seconds: Sleep interval when blocking and waiting for tokens.
**redis_kwargs: Additional Redis client configuration.
"""
if coredis is None:
raise ImportError(
"RedisRateLimiter requires the 'coredis' package. "
"Install it with: pip install coredis"
)
self.redis_url = redis_url
self.requests_per_second = requests_per_second
self.window_size = window_size
self.max_requests = int(requests_per_second * window_size)
self.key_prefix = key_prefix
self.identifier = identifier or f"rl_{int(time.time())}_{random.randint(1000, 9999)}"
self.redis_cluster = redis_cluster
self.connection_pool_size = connection_pool_size
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout
self.retry_on_timeout = retry_on_timeout
self.health_check_interval = health_check_interval
self.fallback_to_memory = fallback_to_memory
self.cleanup_probability = cleanup_probability
self.check_every_n_seconds = check_every_n_seconds
self.redis_kwargs = redis_kwargs
# Redis client and connection management
self._redis_client: Optional[Union[coredis.Redis, coredis.RedisCluster]] = None
self._connection_lock = threading.Lock()
self._last_health_check = 0.0
self._redis_healthy = True
self._circuit_breaker_failures = 0
self._circuit_breaker_last_failure = 0.0
self._circuit_breaker_threshold = 5
self._circuit_breaker_timeout = 60.0
# Fallback rate limiter
self._fallback_limiter: Optional[InMemoryRateLimiter] = None
if self.fallback_to_memory:
self._fallback_limiter = InMemoryRateLimiter(
requests_per_second=requests_per_second,
check_every_n_seconds=check_every_n_seconds,
max_bucket_size=max(1, requests_per_second * 2), # Allow some burst
)
# Lua script SHA hashes (will be populated on first use)
self._check_and_update_sha: Optional[str] = None
self._check_only_sha: Optional[str] = None
# Logger
self._logger = logging.getLogger(__name__)
def _get_redis_key(self) -> str:
"""Generate Redis key for this rate limiter instance."""
return f"{self.key_prefix}:{self.identifier}"
def _create_redis_client(self) -> Union[coredis.Redis, coredis.RedisCluster]:
"""Create and configure Redis client."""
common_config = {
"socket_timeout": self.socket_timeout,
"socket_connect_timeout": self.socket_connect_timeout,
"retry_on_timeout": self.retry_on_timeout,
"health_check_interval": self.health_check_interval,
**self.redis_kwargs,
}
if self.redis_cluster:
# Parse cluster URLs
urls = [url.strip() for url in self.redis_url.split(",")]
startup_nodes = []
for url in urls:
if "://" in url:
# Parse redis://host:port format
parts = url.split("://")[1].split(":")
host = parts[0]
port = int(parts[1]) if len(parts) > 1 else 6379
else:
# Parse host:port format
parts = url.split(":")
host = parts[0]
port = int(parts[1]) if len(parts) > 1 else 6379
startup_nodes.append({"host": host, "port": port})
return coredis.RedisCluster(
startup_nodes=startup_nodes,
**common_config,
)
else:
return coredis.Redis.from_url(
self.redis_url,
max_connections=self.connection_pool_size,
**common_config,
)
def _get_redis_client(self) -> Union[coredis.Redis, coredis.RedisCluster]:
"""Get Redis client with connection management and health checking."""
current_time = time.time()
# Check circuit breaker
if (self._circuit_breaker_failures >= self._circuit_breaker_threshold and
current_time - self._circuit_breaker_last_failure < self._circuit_breaker_timeout):
raise ConnectionError("Circuit breaker is open - Redis unavailable")
with self._connection_lock:
# Create client if needed
if self._redis_client is None:
self._redis_client = self._create_redis_client()
# Periodic health check
if current_time - self._last_health_check > self.health_check_interval:
try:
# Simple ping to check health
asyncio.create_task(self._redis_client.ping())
self._redis_healthy = True
self._circuit_breaker_failures = 0
self._last_health_check = current_time
except Exception as e:
self._logger.warning(f"Redis health check failed: {e}")
self._redis_healthy = False
self._circuit_breaker_failures += 1
self._circuit_breaker_last_failure = current_time
# Recreate client on health check failure
try:
if hasattr(self._redis_client, 'close'):
asyncio.create_task(self._redis_client.close())
except Exception:
pass
self._redis_client = None
raise ConnectionError(f"Redis health check failed: {e}")
return self._redis_client
async def _ensure_scripts_loaded(self, redis_client: Union[coredis.Redis, coredis.RedisCluster]) -> None:
"""Ensure Lua scripts are loaded into Redis."""
try:
if self._check_and_update_sha is None:
self._check_and_update_sha = await redis_client.script_load(self._CHECK_AND_UPDATE_SCRIPT)
if self._check_only_sha is None:
self._check_only_sha = await redis_client.script_load(self._CHECK_ONLY_SCRIPT)
except Exception as e:
self._logger.warning(f"Failed to load Lua scripts: {e}")
# Reset SHAs to force reload on next attempt
self._check_and_update_sha = None
self._check_only_sha = None
raise
async def _execute_rate_limit_check(
self,
redis_client: Union[coredis.Redis, coredis.RedisCluster],
consume_token: bool = True
) -> tuple[bool, int, int]:
"""Execute rate limit check using Lua script.
Returns:
Tuple of (allowed, current_count, max_requests)
"""
await self._ensure_scripts_loaded(redis_client)
key = self._get_redis_key()
current_time = time.time()
try:
if consume_token:
result = await redis_client.evalsha(
self._check_and_update_sha,
1, # number of keys
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
str(self.cleanup_probability),
)
else:
result = await redis_client.evalsha(
self._check_only_sha,
1, # number of keys
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
)
allowed, current_count, max_requests = result
return bool(allowed), int(current_count), int(max_requests)
except Exception as e:
# If script execution fails, try to reload and retry once
self._logger.warning(f"Lua script execution failed, reloading: {e}")
self._check_and_update_sha = None
self._check_only_sha = None
await self._ensure_scripts_loaded(redis_client)
if consume_token:
result = await redis_client.evalsha(
self._check_and_update_sha,
1,
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
str(self.cleanup_probability),
)
else:
result = await redis_client.evalsha(
self._check_only_sha,
1,
key,
str(self.window_size),
str(self.max_requests),
str(current_time),
)
allowed, current_count, max_requests = result
return bool(allowed), int(current_count), int(max_requests)
def _should_use_fallback(self, error: Exception) -> bool:
"""Determine if we should use fallback rate limiter based on error type."""
if not self.fallback_to_memory or self._fallback_limiter is None:
return False
# Use fallback for connection errors, timeouts, and Redis unavailability
return isinstance(error, (
ConnectionError,
TimeoutError,
OSError, # Network errors
)) or "connection" in str(error).lower()
def acquire(self, *, blocking: bool = True) -> bool:
"""Attempt to acquire a token from the rate limiter.
Args:
blocking: If True, block until token is available.
If False, return immediately with result.
Returns:
True if token was acquired, False otherwise.
"""
# Run async method in sync context
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context, create a new event loop in a thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
lambda: asyncio.run(self.aacquire(blocking=blocking))
)
return future.result()
else:
return loop.run_until_complete(self.aacquire(blocking=blocking))
except RuntimeError:
# No event loop, create one
return asyncio.run(self.aacquire(blocking=blocking))
async def aacquire(self, *, blocking: bool = True) -> bool:
"""Attempt to acquire a token from the rate limiter. Async version.
Args:
blocking: If True, block until token is available.
If False, return immediately with result.
Returns:
True if token was acquired, False otherwise.
"""
max_retries = 3
retry_delay = 0.1
for attempt in range(max_retries):
try:
redis_client = self._get_redis_client()
if not blocking:
# Non-blocking: just check availability
allowed, current_count, max_requests = await self._execute_rate_limit_check(
redis_client, consume_token=False
)
if allowed:
# If available, consume the token
allowed, _, _ = await self._execute_rate_limit_check(
redis_client, consume_token=True
)
return allowed
else:
# Blocking: keep trying until we get a token
while True:
allowed, current_count, max_requests = await self._execute_rate_limit_check(
redis_client, consume_token=True
)
if allowed:
return True
# Wait before retrying
await asyncio.sleep(self.check_every_n_seconds)
except Exception as e:
self._logger.warning(f"Redis rate limiter error (attempt {attempt + 1}): {e}")
# Update circuit breaker
self._circuit_breaker_failures += 1
self._circuit_breaker_last_failure = time.time()
if self._should_use_fallback(e):
self._logger.info("Falling back to in-memory rate limiter")
return await self._fallback_limiter.aacquire(blocking=blocking)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
else:
# Final attempt failed
if self.fallback_to_memory and self._fallback_limiter:
self._logger.error(f"All Redis attempts failed, using fallback: {e}")
return await self._fallback_limiter.aacquire(blocking=blocking)
else:
raise ConnectionError(f"Redis rate limiter unavailable: {e}")
# Should not reach here, but just in case
if self.fallback_to_memory and self._fallback_limiter:
return await self._fallback_limiter.aacquire(blocking=blocking)
return False
async def get_current_usage(self) -> Dict[str, Union[int, float]]:
"""Get current rate limiter usage statistics.
Returns:
Dictionary with usage statistics including current count,
max requests, and utilization percentage.
"""
try:
redis_client = self._get_redis_client()
allowed, current_count, max_requests = await self._execute_rate_limit_check(
redis_client, consume_token=False
)
utilization = (current_count / max_requests) * 100 if max_requests > 0 else 0
return {
"current_count": current_count,
"max_requests": max_requests,
"utilization_percent": utilization,
"requests_per_second": self.requests_per_second,
"window_size": self.window_size,
"redis_healthy": self._redis_healthy,
}
except Exception as e:
self._logger.warning(f"Failed to get usage statistics: {e}")
return {
"current_count": -1,
"max_requests": self.max_requests,
"utilization_percent": -1,
"requests_per_second": self.requests_per_second,
"window_size": self.window_size,
"redis_healthy": False,
"error": str(e),
}
async def reset(self) -> bool:
"""Reset the rate limiter by clearing all stored data.
Returns:
True if reset was successful, False otherwise.
"""
try:
redis_client = self._get_redis_client()
key = self._get_redis_key()
await redis_client.delete(key)
self._logger.info(f"Rate limiter reset for key: {key}")
return True
except Exception as e:
self._logger.error(f"Failed to reset rate limiter: {e}")
return False
async def close(self) -> None:
"""Close Redis connections and cleanup resources."""
if self._redis_client:
try:
if hasattr(self._redis_client, 'close'):
await self._redis_client.close()
except Exception as e:
self._logger.warning(f"Error closing Redis client: {e}")
finally:
self._redis_client = None
def __del__(self) -> None:
"""Cleanup on object destruction."""
if self._redis_client:
try:
asyncio.create_task(self.close())
except Exception:
pass
__all__ = [ __all__ = [
"BaseRateLimiter", "BaseRateLimiter",
"InMemoryRateLimiter", "InMemoryRateLimiter",
"RedisRateLimiter",
] ]