import uuid from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models import Image, ImageTag, Tag class ImageRepository: def __init__(self, session: AsyncSession) -> None: self._session = session async def get_by_hash(self, hash_hex: str) -> Optional[Image]: result = await self._session.execute( select(Image).where(Image.hash == hash_hex).options(selectinload(Image.image_tags).selectinload(ImageTag.tag)) ) return result.scalar_one_or_none() async def get_by_id(self, image_id: uuid.UUID) -> Optional[Image]: result = await self._session.execute( select(Image).where(Image.id == image_id).options(selectinload(Image.image_tags).selectinload(ImageTag.tag)) ) return result.scalar_one_or_none() async def create( self, *, hash_hex: str, filename: str, mime_type: str, size_bytes: int, width: int, height: int, storage_key: str, thumbnail_key: str | None = None, ) -> Image: image = Image( hash=hash_hex, filename=filename, mime_type=mime_type, size_bytes=size_bytes, width=width, height=height, storage_key=storage_key, thumbnail_key=thumbnail_key, ) self._session.add(image) await self._session.flush() await self._session.refresh(image, ["image_tags"]) return image async def list_images( self, tag_names: list[str] | None = None, limit: int = 50, offset: int = 0, ) -> tuple[list[Image], int]: from sqlalchemy import func, and_ base_query = select(Image).options( selectinload(Image.image_tags).selectinload(ImageTag.tag) ) if tag_names: for tag_name in tag_names: subq = ( select(ImageTag.image_id) .join(Tag, ImageTag.tag_id == Tag.id) .where(Tag.name == tag_name) .scalar_subquery() ) base_query = base_query.where(Image.id.in_(subq)) count_query = select(func.count()).select_from(base_query.subquery()) total_result = await self._session.execute(count_query) total = total_result.scalar_one() paginated = base_query.order_by(Image.created_at.desc()).limit(limit).offset(offset) result = await self._session.execute(paginated) return result.scalars().all(), total async def reload_with_tags(self, image_id: uuid.UUID) -> Image: result = await self._session.execute( select(Image) .where(Image.id == image_id) .options(selectinload(Image.image_tags).selectinload(ImageTag.tag)) .execution_options(populate_existing=True) ) return result.scalar_one() async def delete(self, image: Image) -> None: await self._session.delete(image) await self._session.flush()