Source code for aioscraper.core.rate_limiter

import asyncio
import logging
import sys
from contextlib import suppress
from dataclasses import dataclass
from time import monotonic
from typing import Any, Awaitable, Callable, Hashable, Self

from yarl import URL

from aioscraper.config import RateLimitConfig, RequestRetryConfig
from aioscraper.types.session import PRequest, Request

logger = logging.getLogger(__name__)


[docs] @dataclass(slots=True) class AdaptiveMetrics: """Tracks metrics for adaptive rate limiting using EWMA + AIMD. Attributes: ewma_latency (float): Exponentially weighted moving average of request latency. ewma_alpha (float): Smoothing factor for EWMA (0 < alpha <= 1). success_count (int): Consecutive successful requests since last failure. failure_count (int): Consecutive failures since last success. last_outcome_time (float | None): Timestamp of last completed request. last_outcome_success (bool | None): Whether last request was successful. total_requests (int): Total number of completed requests in this group. """ ewma_latency: float = 0.0 ewma_alpha: float = 0.3 success_count: int = 0 failure_count: int = 0 last_outcome_time: float | None = None last_outcome_success: bool = True total_requests: int = 0
[docs] def update_latency(self, latency: float): """Update EWMA latency with new measurement.""" if self.total_requests == 0: self.ewma_latency = latency else: self.ewma_latency = (self.ewma_alpha * latency) + ((1 - self.ewma_alpha) * self.ewma_latency)
[docs] def record_success(self, latency: float): """Record a successful request outcome.""" self.update_latency(latency) self.success_count += 1 self.failure_count = 0 self.last_outcome_success = True self.last_outcome_time = monotonic() self.total_requests += 1
[docs] def record_failure(self, latency: float | None = None): """Record a failed request outcome (timeout, error status, etc).""" if latency is not None: self.update_latency(latency) self.failure_count += 1 self.success_count = 0 self.last_outcome_success = False self.last_outcome_time = monotonic() self.total_requests += 1
[docs] @dataclass(slots=True) class RequestOutcome: """Captures the result of a request execution. Attributes: group_key (Hashable): The RequestGroup key this outcome belongs to. latency (float): Request latency in seconds (start to finish). retry_after (float | None): Value from Retry-After header if present. status_code (int | None): HTTP status code if applicable. exception_type (type[BaseException] | None): Type of exception if one occurred. """ group_key: Hashable latency: float retry_after: float | None = None status_code: int | None = None exception_type: type[BaseException] | None = None
[docs] class AdaptiveStrategy: """EWMA + AIMD adaptive rate limiting strategy. Fast multiplicative increase on overload (server pushback). Slow additive decrease on sustained success (probing for capacity). Args: enabled (bool): Enable adaptive rate limiting. min_interval (float): Minimum allowed interval (seconds). max_interval (float): Maximum allowed interval (seconds). increase_factor (float): Multiplicative factor for interval increase on failure. decrease_step (float): Additive step for interval decrease on success. success_threshold (int): Number of consecutive successes before decreasing interval. ewma_alpha (float): Smoothing factor for latency EWMA (0 < alpha <= 1). trigger_statuses (tuple[int, ...]): HTTP statuses that trigger adaptive slowdown. trigger_exceptions (tuple[type[BaseException], ...]): Exception types that trigger adaptive slowdown. respect_retry_after (bool): Whether to use Retry-After header as override. """ def __init__( self, *, min_interval: float = 0.001, max_interval: float = 5.0, increase_factor: float = 2.0, decrease_step: float = 0.01, success_threshold: int = 5, ewma_alpha: float = 0.3, trigger_statuses: tuple[int, ...] = (429, 500, 502, 503, 504, 522, 524, 408), trigger_exceptions: tuple[type[BaseException], ...] = (asyncio.TimeoutError,), respect_retry_after: bool = True, ): self.min_interval = min_interval self.max_interval = max_interval self.increase_factor = increase_factor self.decrease_step = decrease_step self.success_threshold = success_threshold self.ewma_alpha = ewma_alpha self.trigger_statuses = set(trigger_statuses) self.trigger_exceptions = trigger_exceptions self.respect_retry_after = respect_retry_after self._metrics: dict[Hashable, AdaptiveMetrics] = {}
[docs] def get_or_create_metrics(self, group_key: Hashable) -> AdaptiveMetrics: """Get or create metrics for a group.""" if group_key not in self._metrics: self._metrics[group_key] = AdaptiveMetrics(ewma_alpha=self.ewma_alpha) return self._metrics[group_key]
[docs] def calculate_interval(self, group_key: Hashable, current_interval: float, outcome: RequestOutcome) -> float: """Calculate new interval based on request outcome. Algorithm: - On failure: interval = min(max_interval, interval * increase_factor) - On success: if success_count >= threshold: interval = max(min_interval, interval - decrease_step) - Retry-After override: Use header value if present and enabled Returns: New interval in seconds. """ metrics = self.get_or_create_metrics(group_key) success = not self._is_adaptive_failure(outcome.status_code, outcome.exception_type) if success: metrics.record_success(outcome.latency) else: metrics.record_failure(outcome.latency) # Priority 1: Retry-After override takes precedence if self.respect_retry_after and outcome.retry_after is not None and not success: new_interval = min(self.max_interval, outcome.retry_after) logger.info( "Adaptive rate limit: Retry-After header for group %r, setting interval to %.4f " "(status=%s, latency=%.4f)", group_key, new_interval, outcome.status_code, outcome.latency, ) return new_interval # Priority 2: Apply AIMD if not success: # Multiplicative increase on failure new_interval = current_interval * self.increase_factor logger.info( "Adaptive rate limit: failure for group %r, increasing interval %.4f -> %.4f " "(status=%s, latency=%.4f, failure_count=%d)", group_key, current_interval, new_interval, outcome.status_code or "exception", outcome.latency, metrics.failure_count, ) elif metrics.success_count >= self.success_threshold: # Additive decrease after sustained success new_interval = current_interval - self.decrease_step logger.debug( "Adaptive rate limit: sustained success for group %r, decreasing interval %.4f -> %.4f " "(latency=%.4f, success_count=%d)", group_key, current_interval, new_interval, outcome.latency, metrics.success_count, ) else: # Not enough successes yet, maintain current interval new_interval = current_interval return max(self.min_interval, min(self.max_interval, new_interval))
[docs] def reset_metrics(self, group_key: Hashable): """Reset metrics for a group (e.g., on cleanup).""" self._metrics.pop(group_key, None)
def _is_adaptive_failure(self, status_code: int | None, exception_type: type[BaseException] | None) -> bool: """Check if status/exception should trigger adaptive slowdown.""" if status_code and status_code in self.trigger_statuses: return True if exception_type and any(issubclass(exception_type, exc_type) for exc_type in self.trigger_exceptions): return True return False
def default_group_by_factory(default_interval: float) -> Callable[[Request], tuple[Hashable, float]]: "Creates a default grouping function that groups requests by hostname." def _group_by(request: Request) -> tuple[Hashable, float]: return URL(request.url).host or "unknown", default_interval return _group_by
[docs] class RequestGroup: """Manages a group of requests that share the same rate limit interval. Each group processes requests sequentially with a configured delay between them. Groups automatically clean up after a period of inactivity. Args: key (Hashable): Unique identifier for this request group. interval (float): Delay in seconds between processing requests in this group. cleanup_timeout (float): Timeout in seconds before cleaning up an idle group. schedule (Callable[[PRequest], Awaitable[None]]): Callback function to schedule request execution. on_finished (Callable[[Hashable, RequestGroup], None]): Callback invoked when the group finishes or becomes idle. """ def __init__( self, key: Hashable, interval: float, cleanup_timeout: float, schedule: Callable[[PRequest], Awaitable[None]], on_finished: Callable[[Hashable, "RequestGroup"], None], ): self._key = key self._interval = interval self._cleanup_timeout = max(cleanup_timeout, self._interval * 2) self._schedule = schedule self._on_finished = on_finished self._queue: asyncio.PriorityQueue[PRequest] = asyncio.PriorityQueue() self._task: asyncio.Task[None] | None = None @property def key(self) -> Hashable: return self._key @property def active(self) -> bool: "Check if the group has pending requests in its queue." return not self._queue.empty() @property def interval(self) -> float: "Get the current interval for this group." return self._interval @property def worker_alive(self) -> bool: if self._task is None: return False return not self._task.done() and not self._task.cancelled()
[docs] def set_intervals(self, interval: float, cleanup_timeout: float): "Update group interval and cleanup timeout." self._interval = interval self._cleanup_timeout = cleanup_timeout
[docs] async def put(self, pr: PRequest): "Add a request to this group's processing queue." await self._queue.put(pr)
def start_listening(self): if self._task is not None: return self._task = asyncio.create_task(self._listen_queue()) self._task.add_done_callback(self._on_task_done_factory())
[docs] async def close(self): "Cancel the worker task and wait for graceful shutdown." if self._task is None: return self._task.cancel() with suppress(asyncio.CancelledError): await self._task
async def _listen_queue(self): while True: try: # Wait for next request with timeout. If no requests arrive within # cleanup_timeout, the group is considered idle and will be cleaned up. pr = await asyncio.wait_for(self._queue.get(), timeout=self._cleanup_timeout) except asyncio.TimeoutError: # Race condition: item may have been added while timeout was firing if not self._queue.empty(): continue # Group is idle - trigger cleanup callback and exit worker loop self._on_finished(self._key, self) break if pr.request.url == "stub": break try: await asyncio.shield(self._schedule(pr)) except Exception: logger.exception("Rate limiter scheduler failed for %r", self._key) await asyncio.sleep(self._interval) def _on_task_done_factory(self) -> Callable[[asyncio.Task[None]], None]: def _on_task_done(task: asyncio.Task[None]): if task.cancelled(): logger.debug("Rate limiter group %r cancelled", self._key) return with suppress(asyncio.CancelledError): exc = task.exception() if exc is not None: logger.error("Rate limiter group %r crashed: %s", self._key, exc, exc_info=exc) self._on_finished(self._key, self) return _on_task_done
[docs] class RateLimitManager: """Manages rate limiting for requests using group-based throttling. Requests are grouped by a configurable key (default: hostname) and processed with a specified interval between requests in each group. Groups are created dynamically and cleaned up automatically after inactivity. Args: config (RateLimitConfig): Rate limiting configuration including grouping strategy and intervals. retry_config (RequestRetryConfig): Retry configuration for inheriting trigger conditions. schedule (Callable[[PRequest], Awaitable[Any]]): Callback function to schedule request execution. """ def __init__( self, config: RateLimitConfig, retry_config: RequestRetryConfig, schedule: Callable[[PRequest], Awaitable[Any]], ): self._schedule = schedule self._group_by = config.group_by or default_group_by_factory(config.default_interval) self._default_interval = config.default_interval self._cleanup_timeout = config.cleanup_timeout self._groups: dict[Hashable, RequestGroup] = {} self._enabled = config.enabled self._stopped = False self._adaptive_strategy: AdaptiveStrategy | None = None if config.enabled and config.adaptive: trigger_statuses = config.adaptive.custom_trigger_statuses trigger_exceptions = config.adaptive.custom_trigger_exceptions # Merge retry triggers if configured if config.adaptive.inherit_retry_triggers: trigger_statuses = tuple(set(trigger_statuses) | set(retry_config.statuses)) trigger_exceptions = tuple(set(trigger_exceptions) | set(retry_config.exceptions)) self._adaptive_strategy = AdaptiveStrategy( min_interval=config.adaptive.min_interval, max_interval=config.adaptive.max_interval, increase_factor=config.adaptive.increase_factor, decrease_step=config.adaptive.decrease_step, success_threshold=config.adaptive.success_threshold, ewma_alpha=config.adaptive.ewma_alpha, trigger_statuses=trigger_statuses, trigger_exceptions=trigger_exceptions, respect_retry_after=config.adaptive.respect_retry_after, ) if config.enabled: self._handle = self._handle_with_group logger.info( "Rate limiting enabled: grouping=%s, default_interval=%0.10g, cleanup_timeout=%0.10g", "custom" if config.group_by else "by hostname", self._default_interval, self._cleanup_timeout, ) else: self._handle = self._handle_without_group if self._default_interval > 0: logger.info( "Rate limiting disabled (no grouping), but default_interval=%0.10g will be applied", self._default_interval, ) if config.adaptive and self._adaptive_strategy: logger.info( "Adaptive rate limiting enabled: min_interval=%.3f, max_interval=%.3f, " "increase_factor=%.2f, decrease_step=%.3f, success_threshold=%d, ewma_alpha=%.2f", config.adaptive.min_interval, config.adaptive.max_interval, config.adaptive.increase_factor, config.adaptive.decrease_step, config.adaptive.success_threshold, config.adaptive.ewma_alpha, ) logger.info( "Adaptive rate limiting triggers (inherit_retry_triggers=%s): statuses=%s; exceptions=%s", config.adaptive.inherit_retry_triggers, ",".join(map(str, sorted(self._adaptive_strategy.trigger_statuses))), ",".join(exc.__module__ + "." + exc.__qualname__ for exc in self._adaptive_strategy.trigger_exceptions), ) @property def adaptive_strategy(self) -> AdaptiveStrategy | None: return self._adaptive_strategy @property def active(self) -> bool: "Check if any request groups have pending requests." return any(group.active for group in self._groups.values()) async def __call__(self, pr: PRequest): "Process a request through the rate limiter." await self._handle(pr) async def __aenter__(self) -> Self: return self async def __aexit__(self, *args: object) -> None: try: await self.shutdown() finally: await self.close() async def shutdown(self) -> bool: if not self._stopped: if groups := self._groups.values(): logger.info( "Rate limiter: shutting down %d active group(s): %s", len(groups), ",".join(str(group.key) for group in groups), ) for group in groups: await group.put(PRequest(priority=sys.maxsize, request=Request(url="stub"))) self._stopped = True return True return not self._stopped
[docs] async def close(self): "Close all request groups and clean up resources." groups = list(self._groups.values()) self._groups.clear() for group in groups: await group.close()
[docs] def get_group_key(self, request: Request) -> Hashable: """Get group key for a request.""" return self._group_by(request)[0]
[docs] def on_request_outcome(self, outcome: RequestOutcome): """Handle request outcome and adjust group interval adaptively.""" if not self._adaptive_strategy: return group = self._groups.get(outcome.group_key) if not group: return new_interval = self._adaptive_strategy.calculate_interval( group_key=outcome.group_key, current_interval=group.interval, outcome=outcome, ) if new_interval != group.interval: group.set_intervals(interval=new_interval, cleanup_timeout=max(self._cleanup_timeout, new_interval * 2))
async def _handle_with_group(self, pr: PRequest): group_key, interval = self._group_by(pr.request) # Ensure minimum interval to prevent busy-waiting. Custom group_by functions # may return zero or negative intervals, which we adjust to a safe minimum. if interval <= 0: logger.debug("Adjusting invalid interval %.3f to 0.01s for group %r", interval, group_key) interval = 0.01 if (group := self._groups.get(group_key)) is None: group = self._groups[group_key] = self._create_group(group_key, interval) logger.debug( "Created rate limit group %r: interval=%0.10g, cleanup_timeout=%0.10g", group_key, interval, self._cleanup_timeout, ) else: logger.debug("Queueing request to existing group %r (interval=%0.3fs)", group_key, group.interval) await group.put(pr) async def _handle_without_group(self, pr: PRequest): await self._schedule(pr) await asyncio.sleep(self._default_interval) def _create_group(self, key: Hashable, interval: float) -> RequestGroup: group = RequestGroup( key=key, interval=interval, cleanup_timeout=self._cleanup_timeout, schedule=self._schedule, on_finished=self._on_group_finished, ) group.start_listening() return group def _on_group_finished(self, key: Hashable, group: RequestGroup): current = self._groups.get(key) if current is group: self._groups.pop(key, None) if self._adaptive_strategy: self._adaptive_strategy.reset_metrics(key) logger.debug("Rate limit group %r finished and removed (idle timeout or shutdown)", key)