feat: add validation checks

This commit is contained in:
agatha 2026-03-15 15:24:18 -04:00
parent 67089c570c
commit e0cdf94063

View File

@ -0,0 +1,237 @@
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timedelta
from itertools import groupby
from operator import attrgetter
from uuid import UUID
import httpx
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import async_sessionmaker
from proxy_pool.plugins.protocols import CheckContext, CheckResult, Event
from proxy_pool.plugins.registry import PluginRegistry
from proxy_pool.proxy.models import Proxy, ProxyCheck, ProxyStatus
logger = logging.getLogger(__name__)
async def _run_single_check(checker, proxy, context) -> CheckResult:
"""Run one checker with timeout and error isolation."""
try:
return await asyncio.wait_for(
checker.check(
proxy_ip=str(proxy.ip),
proxy_port=proxy.port,
proxy_protocol=proxy.protocol.value,
context=context,
),
timeout=checker.timeout,
)
except TimeoutError:
return CheckResult(
passed=False,
detail=f"{checker.name} timed out after {checker.timeout}s",
)
except Exception as err:
return CheckResult(
passed=False,
detail=f"{checker.name} raised {type(err).__name__}: {err}",
)
async def validate_proxy(ctx: dict, proxy_id: str) -> dict:
"""Run the full checker pipeline for a single proxy."""
session_factory: async_sessionmaker = ctx["session_factory"]
registry: PluginRegistry = ctx["registry"]
settings = ctx["settings"]
async with session_factory() as db:
proxy = await db.get(Proxy, UUID(proxy_id))
if proxy is None:
logger.warning("Proxy %s not found", proxy_id)
return {"status": "skipped", "reason": "not_found"}
checkers = registry.get_checker_pipeline()
if not checkers:
logger.warning("No checkers registered, skipping validation")
return {"status": "skipped", "reason": "no_checkers"}
# Create shared context
async with httpx.AsyncClient(
timeout=settings.proxy.check_http_timeout,
) as http_client:
context = CheckContext(
started_at=datetime.now(),
http_client=http_client,
)
all_results: list[tuple[object, CheckResult]] = []
final_status = ProxyStatus.ACTIVE
# Group by stage, run stages sequentially
for _stage_num, stage_group in groupby(checkers, key=attrgetter("stage")):
stage_list = [
c for c in stage_group if not c.should_skip(proxy.protocol.value)
]
if not stage_list:
continue
# Run checkers within a stage concurrently
stage_results = await asyncio.gather(
*(_run_single_check(c, proxy, context) for c in stage_list),
)
for checker, result in zip(stage_list, stage_results, strict=False):
all_results.append((checker, result))
# Log check to database
check_record = ProxyCheck(
proxy_id=proxy.id,
checker_name=checker.name,
stage=checker.stage,
passed=result.passed,
latency_ms=result.latency_ms,
detail=result.detail,
exit_ip=context.exit_ip,
)
db.add(check_record)
# If any check in this stage failed, abort pipeline
if any(not r.passed for r in stage_results):
final_status = ProxyStatus.DEAD
break
# Update proxy record
old_status = proxy.status
proxy.status = final_status
proxy.last_checked_at = datetime.now()
if context.exit_ip:
proxy.exit_ip = context.exit_ip
if context.anonymity_level:
proxy.anonymity = context.anonymity_level
if context.country:
proxy.country = context.country
if context.tcp_latency_ms:
proxy.avg_latency_ms = context.tcp_latency_ms
# Simple score: 1.0 if all passed, 0.0 if any failed
# We'll refine this later with historical data
passed_count = sum(1 for _, r in all_results if r.passed)
total_count = len(all_results)
proxy.score = passed_count / total_count if total_count > 0 else 0.0
await db.commit()
# Emit events on status transitions
if old_status == ProxyStatus.ACTIVE and final_status == ProxyStatus.DEAD:
# Check if pool is running low
count_result = await db.execute(
select(func.count())
.select_from(Proxy)
.where(Proxy.status == ProxyStatus.ACTIVE)
)
active_count = count_result.scalar_one()
if active_count < settings.proxy.pool_low_threshold:
await registry.emit(
Event(
type="proxy.pool_low",
payload={
"active_count": active_count,
"threshold": settings.proxy.pool_low_threshold,
},
)
)
logger.info(
"Validated proxy %s: %s (score=%.2f)",
proxy_id,
final_status.value,
proxy.score,
)
return {
"status": final_status.value,
"score": proxy.score,
"checks": total_count,
"passed": passed_count,
}
async def revalidate_sweep(ctx: dict) -> dict:
"""Select proxies due for revalidation and validate them."""
session_factory: async_sessionmaker = ctx["session_factory"]
settings = ctx["settings"]
now = datetime.now()
active_cutoff = now - timedelta(minutes=settings.proxy.revalidate_active_minutes)
dead_cutoff = now - timedelta(hours=settings.proxy.revalidate_dead_hours)
batch_size = settings.proxy.revalidate_batch_size
async with session_factory() as db:
# Priority 1: Never checked
unchecked = await db.execute(
select(Proxy.id)
.where(Proxy.status == ProxyStatus.UNCHECKED)
.limit(batch_size)
)
unchecked_ids = [str(row[0]) for row in unchecked.all()]
remaining = batch_size - len(unchecked_ids)
# Priority 2: Stale active proxies
stale_active_ids = []
if remaining > 0:
stale_active = await db.execute(
select(Proxy.id)
.where(
Proxy.status == ProxyStatus.ACTIVE,
(Proxy.last_checked_at.is_(None))
| (Proxy.last_checked_at < active_cutoff),
)
.order_by(Proxy.last_checked_at.asc().nulls_first())
.limit(remaining)
)
stale_active_ids = [str(row[0]) for row in stale_active.all()]
remaining -= len(stale_active_ids)
# Priority 3: Dead proxies worth rechecking
dead_ids = []
if remaining > 0:
dead_recheck = await db.execute(
select(Proxy.id)
.where(
Proxy.status == ProxyStatus.DEAD,
(Proxy.last_checked_at.is_(None))
| (Proxy.last_checked_at < dead_cutoff),
)
.order_by(Proxy.last_checked_at.asc().nulls_first())
.limit(remaining)
)
dead_ids = [str(row[0]) for row in dead_recheck.all()]
all_ids = unchecked_ids + stale_active_ids + dead_ids
results = []
for proxy_id in all_ids:
result = await validate_proxy(ctx, proxy_id)
results.append(result)
logger.info(
"Revalidation sweep: %d unchecked, %d stale active, %d dead recheck",
len(unchecked_ids),
len(stale_active_ids),
len(dead_ids),
)
return {
"total": len(results),
"unchecked": len(unchecked_ids),
"stale_active": len(stale_active_ids),
"dead_recheck": len(dead_ids),
}