diff --git a/src/proxy_pool/worker/tasks_validate.py b/src/proxy_pool/worker/tasks_validate.py new file mode 100644 index 0000000..a3dc13e --- /dev/null +++ b/src/proxy_pool/worker/tasks_validate.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timedelta +from itertools import groupby +from operator import attrgetter +from uuid import UUID + +import httpx +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import async_sessionmaker + +from proxy_pool.plugins.protocols import CheckContext, CheckResult, Event +from proxy_pool.plugins.registry import PluginRegistry +from proxy_pool.proxy.models import Proxy, ProxyCheck, ProxyStatus + +logger = logging.getLogger(__name__) + + +async def _run_single_check(checker, proxy, context) -> CheckResult: + """Run one checker with timeout and error isolation.""" + try: + return await asyncio.wait_for( + checker.check( + proxy_ip=str(proxy.ip), + proxy_port=proxy.port, + proxy_protocol=proxy.protocol.value, + context=context, + ), + timeout=checker.timeout, + ) + except TimeoutError: + return CheckResult( + passed=False, + detail=f"{checker.name} timed out after {checker.timeout}s", + ) + except Exception as err: + return CheckResult( + passed=False, + detail=f"{checker.name} raised {type(err).__name__}: {err}", + ) + + +async def validate_proxy(ctx: dict, proxy_id: str) -> dict: + """Run the full checker pipeline for a single proxy.""" + session_factory: async_sessionmaker = ctx["session_factory"] + registry: PluginRegistry = ctx["registry"] + settings = ctx["settings"] + + async with session_factory() as db: + proxy = await db.get(Proxy, UUID(proxy_id)) + if proxy is None: + logger.warning("Proxy %s not found", proxy_id) + return {"status": "skipped", "reason": "not_found"} + + checkers = registry.get_checker_pipeline() + if not checkers: + logger.warning("No checkers registered, skipping validation") + return {"status": "skipped", "reason": "no_checkers"} + + # Create shared context + async with httpx.AsyncClient( + timeout=settings.proxy.check_http_timeout, + ) as http_client: + context = CheckContext( + started_at=datetime.now(), + http_client=http_client, + ) + + all_results: list[tuple[object, CheckResult]] = [] + final_status = ProxyStatus.ACTIVE + + # Group by stage, run stages sequentially + for _stage_num, stage_group in groupby(checkers, key=attrgetter("stage")): + stage_list = [ + c for c in stage_group if not c.should_skip(proxy.protocol.value) + ] + if not stage_list: + continue + + # Run checkers within a stage concurrently + stage_results = await asyncio.gather( + *(_run_single_check(c, proxy, context) for c in stage_list), + ) + + for checker, result in zip(stage_list, stage_results, strict=False): + all_results.append((checker, result)) + + # Log check to database + check_record = ProxyCheck( + proxy_id=proxy.id, + checker_name=checker.name, + stage=checker.stage, + passed=result.passed, + latency_ms=result.latency_ms, + detail=result.detail, + exit_ip=context.exit_ip, + ) + db.add(check_record) + + # If any check in this stage failed, abort pipeline + if any(not r.passed for r in stage_results): + final_status = ProxyStatus.DEAD + break + + # Update proxy record + old_status = proxy.status + proxy.status = final_status + proxy.last_checked_at = datetime.now() + + if context.exit_ip: + proxy.exit_ip = context.exit_ip + if context.anonymity_level: + proxy.anonymity = context.anonymity_level + if context.country: + proxy.country = context.country + if context.tcp_latency_ms: + proxy.avg_latency_ms = context.tcp_latency_ms + + # Simple score: 1.0 if all passed, 0.0 if any failed + # We'll refine this later with historical data + passed_count = sum(1 for _, r in all_results if r.passed) + total_count = len(all_results) + proxy.score = passed_count / total_count if total_count > 0 else 0.0 + + await db.commit() + + # Emit events on status transitions + if old_status == ProxyStatus.ACTIVE and final_status == ProxyStatus.DEAD: + # Check if pool is running low + count_result = await db.execute( + select(func.count()) + .select_from(Proxy) + .where(Proxy.status == ProxyStatus.ACTIVE) + ) + active_count = count_result.scalar_one() + + if active_count < settings.proxy.pool_low_threshold: + await registry.emit( + Event( + type="proxy.pool_low", + payload={ + "active_count": active_count, + "threshold": settings.proxy.pool_low_threshold, + }, + ) + ) + + logger.info( + "Validated proxy %s: %s (score=%.2f)", + proxy_id, + final_status.value, + proxy.score, + ) + + return { + "status": final_status.value, + "score": proxy.score, + "checks": total_count, + "passed": passed_count, + } + + +async def revalidate_sweep(ctx: dict) -> dict: + """Select proxies due for revalidation and validate them.""" + session_factory: async_sessionmaker = ctx["session_factory"] + settings = ctx["settings"] + + now = datetime.now() + active_cutoff = now - timedelta(minutes=settings.proxy.revalidate_active_minutes) + dead_cutoff = now - timedelta(hours=settings.proxy.revalidate_dead_hours) + batch_size = settings.proxy.revalidate_batch_size + + async with session_factory() as db: + # Priority 1: Never checked + unchecked = await db.execute( + select(Proxy.id) + .where(Proxy.status == ProxyStatus.UNCHECKED) + .limit(batch_size) + ) + unchecked_ids = [str(row[0]) for row in unchecked.all()] + + remaining = batch_size - len(unchecked_ids) + + # Priority 2: Stale active proxies + stale_active_ids = [] + if remaining > 0: + stale_active = await db.execute( + select(Proxy.id) + .where( + Proxy.status == ProxyStatus.ACTIVE, + (Proxy.last_checked_at.is_(None)) + | (Proxy.last_checked_at < active_cutoff), + ) + .order_by(Proxy.last_checked_at.asc().nulls_first()) + .limit(remaining) + ) + stale_active_ids = [str(row[0]) for row in stale_active.all()] + + remaining -= len(stale_active_ids) + + # Priority 3: Dead proxies worth rechecking + dead_ids = [] + if remaining > 0: + dead_recheck = await db.execute( + select(Proxy.id) + .where( + Proxy.status == ProxyStatus.DEAD, + (Proxy.last_checked_at.is_(None)) + | (Proxy.last_checked_at < dead_cutoff), + ) + .order_by(Proxy.last_checked_at.asc().nulls_first()) + .limit(remaining) + ) + dead_ids = [str(row[0]) for row in dead_recheck.all()] + + all_ids = unchecked_ids + stale_active_ids + dead_ids + + results = [] + for proxy_id in all_ids: + result = await validate_proxy(ctx, proxy_id) + results.append(result) + + logger.info( + "Revalidation sweep: %d unchecked, %d stale active, %d dead recheck", + len(unchecked_ids), + len(stale_active_ids), + len(dead_ids), + ) + + return { + "total": len(results), + "unchecked": len(unchecked_ids), + "stale_active": len(stale_active_ids), + "dead_recheck": len(dead_ids), + }