From 0c5ca73d9d0cd2ba3c58d9f7ad98a2d13d6a73e3 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Thu, 12 Jun 2025 18:09:02 +0000 Subject: [PATCH] Apply patch --- libs/core/langchain_core/rate_limiters.py | 547 +++++++++++++++++++++- 1 file changed, 546 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/rate_limiters.py b/libs/core/langchain_core/rate_limiters.py index 952bdaaf8ad..b606d00371d 100644 --- a/libs/core/langchain_core/rate_limiters.py +++ b/libs/core/langchain_core/rate_limiters.py @@ -4,9 +4,19 @@ from __future__ import annotations import abc import asyncio +import logging +import random import threading import time -from typing import Optional +from contextlib import asynccontextmanager +from typing import Any, Dict, Optional, Union + +# Optional Redis imports - will be None if not available +coredis = None +try: + import coredis +except ImportError: + pass class BaseRateLimiter(abc.ABC): @@ -253,7 +263,542 @@ class InMemoryRateLimiter(BaseRateLimiter): return True +class RedisRateLimiter(BaseRateLimiter): + """A Redis-based distributed rate limiter using sliding window algorithm. + + This rate limiter provides distributed rate limiting across multiple processes, + pods, or containers using Redis as the shared state store. It uses a sliding + window counter algorithm implemented with Redis sorted sets and Lua scripts + for atomic operations. + + Key features: + - Distributed rate limiting across multiple processes/pods + - Sliding window algorithm for accurate rate limiting + - Atomic operations using Lua scripts + - Connection pooling and Redis cluster support + - Comprehensive error handling with fallback options + - High performance with minimal Redis round trips + + Example: + .. code-block:: python + + from langchain_core.rate_limiters import RedisRateLimiter + + # Basic usage + rate_limiter = RedisRateLimiter( + redis_url="redis://localhost:6379", + requests_per_second=10, + window_size=60, # 60 second sliding window + ) + + # With Redis cluster + rate_limiter = RedisRateLimiter( + redis_url="redis://node1:7000,redis://node2:7000,redis://node3:7000", + requests_per_second=100, + window_size=60, + redis_cluster=True, + ) + + # Use with chat model + from langchain_anthropic import ChatAnthropic + model = ChatAnthropic( + model_name="claude-3-opus-20240229", + rate_limiter=rate_limiter + ) + + .. versionadded:: 0.2.25 + """ + + # Lua script for atomic rate limit checking and updating + _CHECK_AND_UPDATE_SCRIPT = """ + local key = KEYS[1] + local window_size = tonumber(ARGV[1]) + local max_requests = tonumber(ARGV[2]) + local current_time = tonumber(ARGV[3]) + local cleanup_probability = tonumber(ARGV[4]) + + -- Remove expired entries (sliding window cleanup) + local cutoff_time = current_time - window_size + redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time) + + -- Count current requests in window + local current_count = redis.call('ZCARD', key) + + -- Check if we can proceed + if current_count < max_requests then + -- Add current request with unique score to handle concurrent requests + local score = current_time + math.random() * 0.001 + redis.call('ZADD', key, score, score) + + -- Set expiration for the key (window_size + buffer) + redis.call('EXPIRE', key, window_size + 60) + + -- Probabilistic cleanup of old entries + if math.random() < cleanup_probability then + redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time - window_size) + end + + return {1, current_count + 1, max_requests} + else + return {0, current_count, max_requests} + end + """ + + # Lua script for non-blocking rate limit check + _CHECK_ONLY_SCRIPT = """ + local key = KEYS[1] + local window_size = tonumber(ARGV[1]) + local max_requests = tonumber(ARGV[2]) + local current_time = tonumber(ARGV[3]) + + -- Remove expired entries + local cutoff_time = current_time - window_size + redis.call('ZREMRANGEBYSCORE', key, '-inf', cutoff_time) + + -- Count current requests in window + local current_count = redis.call('ZCARD', key) + + -- Return availability without consuming + if current_count < max_requests then + return {1, current_count, max_requests} + else + return {0, current_count, max_requests} + end + """ + + def __init__( + self, + *, + redis_url: str = "redis://localhost:6379", + requests_per_second: float = 1, + window_size: int = 60, + key_prefix: str = "langchain:ratelimit", + identifier: Optional[str] = None, + redis_cluster: bool = False, + connection_pool_size: int = 10, + socket_timeout: float = 5.0, + socket_connect_timeout: float = 5.0, + retry_on_timeout: bool = True, + health_check_interval: int = 30, + fallback_to_memory: bool = True, + cleanup_probability: float = 0.1, + check_every_n_seconds: float = 0.1, + **redis_kwargs: Any, + ) -> None: + """Initialize Redis-based distributed rate limiter. + + Args: + redis_url: Redis connection URL. For clusters, provide comma-separated URLs. + requests_per_second: Maximum requests allowed per second. + window_size: Sliding window size in seconds. + key_prefix: Prefix for Redis keys to avoid collisions. + identifier: Unique identifier for this rate limiter instance. + If None, will be auto-generated. + redis_cluster: Whether to use Redis cluster mode. + connection_pool_size: Size of the Redis connection pool. + socket_timeout: Socket timeout for Redis operations. + socket_connect_timeout: Socket connection timeout. + retry_on_timeout: Whether to retry on timeout errors. + health_check_interval: Interval for Redis health checks. + fallback_to_memory: Whether to fallback to InMemoryRateLimiter on Redis errors. + cleanup_probability: Probability of performing cleanup on each request (0.0-1.0). + check_every_n_seconds: Sleep interval when blocking and waiting for tokens. + **redis_kwargs: Additional Redis client configuration. + """ + if coredis is None: + raise ImportError( + "RedisRateLimiter requires the 'coredis' package. " + "Install it with: pip install coredis" + ) + + self.redis_url = redis_url + self.requests_per_second = requests_per_second + self.window_size = window_size + self.max_requests = int(requests_per_second * window_size) + self.key_prefix = key_prefix + self.identifier = identifier or f"rl_{int(time.time())}_{random.randint(1000, 9999)}" + self.redis_cluster = redis_cluster + self.connection_pool_size = connection_pool_size + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout + self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.fallback_to_memory = fallback_to_memory + self.cleanup_probability = cleanup_probability + self.check_every_n_seconds = check_every_n_seconds + self.redis_kwargs = redis_kwargs + + # Redis client and connection management + self._redis_client: Optional[Union[coredis.Redis, coredis.RedisCluster]] = None + self._connection_lock = threading.Lock() + self._last_health_check = 0.0 + self._redis_healthy = True + self._circuit_breaker_failures = 0 + self._circuit_breaker_last_failure = 0.0 + self._circuit_breaker_threshold = 5 + self._circuit_breaker_timeout = 60.0 + + # Fallback rate limiter + self._fallback_limiter: Optional[InMemoryRateLimiter] = None + if self.fallback_to_memory: + self._fallback_limiter = InMemoryRateLimiter( + requests_per_second=requests_per_second, + check_every_n_seconds=check_every_n_seconds, + max_bucket_size=max(1, requests_per_second * 2), # Allow some burst + ) + + # Lua script SHA hashes (will be populated on first use) + self._check_and_update_sha: Optional[str] = None + self._check_only_sha: Optional[str] = None + + # Logger + self._logger = logging.getLogger(__name__) + + def _get_redis_key(self) -> str: + """Generate Redis key for this rate limiter instance.""" + return f"{self.key_prefix}:{self.identifier}" + + def _create_redis_client(self) -> Union[coredis.Redis, coredis.RedisCluster]: + """Create and configure Redis client.""" + common_config = { + "socket_timeout": self.socket_timeout, + "socket_connect_timeout": self.socket_connect_timeout, + "retry_on_timeout": self.retry_on_timeout, + "health_check_interval": self.health_check_interval, + **self.redis_kwargs, + } + + if self.redis_cluster: + # Parse cluster URLs + urls = [url.strip() for url in self.redis_url.split(",")] + startup_nodes = [] + for url in urls: + if "://" in url: + # Parse redis://host:port format + parts = url.split("://")[1].split(":") + host = parts[0] + port = int(parts[1]) if len(parts) > 1 else 6379 + else: + # Parse host:port format + parts = url.split(":") + host = parts[0] + port = int(parts[1]) if len(parts) > 1 else 6379 + startup_nodes.append({"host": host, "port": port}) + + return coredis.RedisCluster( + startup_nodes=startup_nodes, + **common_config, + ) + else: + return coredis.Redis.from_url( + self.redis_url, + max_connections=self.connection_pool_size, + **common_config, + ) + + def _get_redis_client(self) -> Union[coredis.Redis, coredis.RedisCluster]: + """Get Redis client with connection management and health checking.""" + current_time = time.time() + + # Check circuit breaker + if (self._circuit_breaker_failures >= self._circuit_breaker_threshold and + current_time - self._circuit_breaker_last_failure < self._circuit_breaker_timeout): + raise ConnectionError("Circuit breaker is open - Redis unavailable") + + with self._connection_lock: + # Create client if needed + if self._redis_client is None: + self._redis_client = self._create_redis_client() + + # Periodic health check + if current_time - self._last_health_check > self.health_check_interval: + try: + # Simple ping to check health + asyncio.create_task(self._redis_client.ping()) + self._redis_healthy = True + self._circuit_breaker_failures = 0 + self._last_health_check = current_time + except Exception as e: + self._logger.warning(f"Redis health check failed: {e}") + self._redis_healthy = False + self._circuit_breaker_failures += 1 + self._circuit_breaker_last_failure = current_time + + # Recreate client on health check failure + try: + if hasattr(self._redis_client, 'close'): + asyncio.create_task(self._redis_client.close()) + except Exception: + pass + self._redis_client = None + raise ConnectionError(f"Redis health check failed: {e}") + + return self._redis_client + + async def _ensure_scripts_loaded(self, redis_client: Union[coredis.Redis, coredis.RedisCluster]) -> None: + """Ensure Lua scripts are loaded into Redis.""" + try: + if self._check_and_update_sha is None: + self._check_and_update_sha = await redis_client.script_load(self._CHECK_AND_UPDATE_SCRIPT) + + if self._check_only_sha is None: + self._check_only_sha = await redis_client.script_load(self._CHECK_ONLY_SCRIPT) + except Exception as e: + self._logger.warning(f"Failed to load Lua scripts: {e}") + # Reset SHAs to force reload on next attempt + self._check_and_update_sha = None + self._check_only_sha = None + raise + + async def _execute_rate_limit_check( + self, + redis_client: Union[coredis.Redis, coredis.RedisCluster], + consume_token: bool = True + ) -> tuple[bool, int, int]: + """Execute rate limit check using Lua script. + + Returns: + Tuple of (allowed, current_count, max_requests) + """ + await self._ensure_scripts_loaded(redis_client) + + key = self._get_redis_key() + current_time = time.time() + + try: + if consume_token: + result = await redis_client.evalsha( + self._check_and_update_sha, + 1, # number of keys + key, + str(self.window_size), + str(self.max_requests), + str(current_time), + str(self.cleanup_probability), + ) + else: + result = await redis_client.evalsha( + self._check_only_sha, + 1, # number of keys + key, + str(self.window_size), + str(self.max_requests), + str(current_time), + ) + + allowed, current_count, max_requests = result + return bool(allowed), int(current_count), int(max_requests) + + except Exception as e: + # If script execution fails, try to reload and retry once + self._logger.warning(f"Lua script execution failed, reloading: {e}") + self._check_and_update_sha = None + self._check_only_sha = None + + await self._ensure_scripts_loaded(redis_client) + + if consume_token: + result = await redis_client.evalsha( + self._check_and_update_sha, + 1, + key, + str(self.window_size), + str(self.max_requests), + str(current_time), + str(self.cleanup_probability), + ) + else: + result = await redis_client.evalsha( + self._check_only_sha, + 1, + key, + str(self.window_size), + str(self.max_requests), + str(current_time), + ) + + allowed, current_count, max_requests = result + return bool(allowed), int(current_count), int(max_requests) + + def _should_use_fallback(self, error: Exception) -> bool: + """Determine if we should use fallback rate limiter based on error type.""" + if not self.fallback_to_memory or self._fallback_limiter is None: + return False + + # Use fallback for connection errors, timeouts, and Redis unavailability + return isinstance(error, ( + ConnectionError, + TimeoutError, + OSError, # Network errors + )) or "connection" in str(error).lower() + + def acquire(self, *, blocking: bool = True) -> bool: + """Attempt to acquire a token from the rate limiter. + + Args: + blocking: If True, block until token is available. + If False, return immediately with result. + + Returns: + True if token was acquired, False otherwise. + """ + # Run async method in sync context + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # We're in an async context, create a new event loop in a thread + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + lambda: asyncio.run(self.aacquire(blocking=blocking)) + ) + return future.result() + else: + return loop.run_until_complete(self.aacquire(blocking=blocking)) + except RuntimeError: + # No event loop, create one + return asyncio.run(self.aacquire(blocking=blocking)) + + async def aacquire(self, *, blocking: bool = True) -> bool: + """Attempt to acquire a token from the rate limiter. Async version. + + Args: + blocking: If True, block until token is available. + If False, return immediately with result. + + Returns: + True if token was acquired, False otherwise. + """ + max_retries = 3 + retry_delay = 0.1 + + for attempt in range(max_retries): + try: + redis_client = self._get_redis_client() + + if not blocking: + # Non-blocking: just check availability + allowed, current_count, max_requests = await self._execute_rate_limit_check( + redis_client, consume_token=False + ) + if allowed: + # If available, consume the token + allowed, _, _ = await self._execute_rate_limit_check( + redis_client, consume_token=True + ) + return allowed + else: + # Blocking: keep trying until we get a token + while True: + allowed, current_count, max_requests = await self._execute_rate_limit_check( + redis_client, consume_token=True + ) + if allowed: + return True + + # Wait before retrying + await asyncio.sleep(self.check_every_n_seconds) + + except Exception as e: + self._logger.warning(f"Redis rate limiter error (attempt {attempt + 1}): {e}") + + # Update circuit breaker + self._circuit_breaker_failures += 1 + self._circuit_breaker_last_failure = time.time() + + if self._should_use_fallback(e): + self._logger.info("Falling back to in-memory rate limiter") + return await self._fallback_limiter.aacquire(blocking=blocking) + + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay * (2 ** attempt)) # Exponential backoff + else: + # Final attempt failed + if self.fallback_to_memory and self._fallback_limiter: + self._logger.error(f"All Redis attempts failed, using fallback: {e}") + return await self._fallback_limiter.aacquire(blocking=blocking) + else: + raise ConnectionError(f"Redis rate limiter unavailable: {e}") + + # Should not reach here, but just in case + if self.fallback_to_memory and self._fallback_limiter: + return await self._fallback_limiter.aacquire(blocking=blocking) + return False + + async def get_current_usage(self) -> Dict[str, Union[int, float]]: + """Get current rate limiter usage statistics. + + Returns: + Dictionary with usage statistics including current count, + max requests, and utilization percentage. + """ + try: + redis_client = self._get_redis_client() + allowed, current_count, max_requests = await self._execute_rate_limit_check( + redis_client, consume_token=False + ) + + utilization = (current_count / max_requests) * 100 if max_requests > 0 else 0 + + return { + "current_count": current_count, + "max_requests": max_requests, + "utilization_percent": utilization, + "requests_per_second": self.requests_per_second, + "window_size": self.window_size, + "redis_healthy": self._redis_healthy, + } + except Exception as e: + self._logger.warning(f"Failed to get usage statistics: {e}") + return { + "current_count": -1, + "max_requests": self.max_requests, + "utilization_percent": -1, + "requests_per_second": self.requests_per_second, + "window_size": self.window_size, + "redis_healthy": False, + "error": str(e), + } + + async def reset(self) -> bool: + """Reset the rate limiter by clearing all stored data. + + Returns: + True if reset was successful, False otherwise. + """ + try: + redis_client = self._get_redis_client() + key = self._get_redis_key() + await redis_client.delete(key) + self._logger.info(f"Rate limiter reset for key: {key}") + return True + except Exception as e: + self._logger.error(f"Failed to reset rate limiter: {e}") + return False + + async def close(self) -> None: + """Close Redis connections and cleanup resources.""" + if self._redis_client: + try: + if hasattr(self._redis_client, 'close'): + await self._redis_client.close() + except Exception as e: + self._logger.warning(f"Error closing Redis client: {e}") + finally: + self._redis_client = None + + def __del__(self) -> None: + """Cleanup on object destruction.""" + if self._redis_client: + try: + asyncio.create_task(self.close()) + except Exception: + pass + + __all__ = [ "BaseRateLimiter", "InMemoryRateLimiter", + "RedisRateLimiter", ] +