import ipaddress import logging import time from dataclasses import dataclass, field from ipaddress import IPv4Network, IPv6Network from threading import Lock from starlette.requests import Request logger = logging.getLogger(__name__) def get_client_ip( request: Request, trusted_networks: list[IPv4Network | IPv6Network], ) -> str: """Return the resolved client IP, honouring X-Forwarded-For when the TCP peer is a trusted upstream proxy. Falls back to the TCP peer address when no trusted networks are configured or the peer is not in the list.""" peer = request.client.host if request.client else "unknown" if trusted_networks and peer != "unknown": try: peer_addr = ipaddress.ip_address(peer) if any(peer_addr in net for net in trusted_networks): xff = request.headers.get("X-Forwarded-For", "").split(",")[0].strip() if xff: return xff real_ip = request.headers.get("X-Real-IP", "").strip() if real_ip: return real_ip except ValueError: pass return peer @dataclass class _Record: failures: int = 0 window_start: float = field(default_factory=time.time) blocked_until: float = 0.0 class LoginRateLimiter: def __init__( self, max_failures: int = 5, window_seconds: int = 300, cooldown_seconds: int = 900, ) -> None: self._max = max_failures self._window = window_seconds self._cooldown = cooldown_seconds self._store: dict[str, _Record] = {} self._lock = Lock() @property def cooldown_seconds(self) -> int: return self._cooldown def is_blocked(self, ip: str) -> bool: now = time.time() with self._lock: rec = self._store.get(ip) if rec is None: return False if rec.blocked_until > now: return True if rec.blocked_until > 0: del self._store[ip] return False def record_failure(self, ip: str) -> None: now = time.time() with self._lock: rec = self._store.get(ip) if rec is None: rec = _Record(window_start=now) self._store[ip] = rec if now - rec.window_start > self._window: rec.failures = 0 rec.window_start = now rec.failures += 1 if rec.failures >= self._max: rec.blocked_until = now + self._cooldown logger.warning( "Login blocked for %s after %d failures", ip, rec.failures ) def record_success(self, ip: str) -> None: with self._lock: self._store.pop(ip, None)