from typing import Annotated from fastapi import FastAPI, Depends, HTTPException, Path from sqlalchemy.orm import Session from pydantic import BaseModel, Field from starlette import status from database import engine, SessionLocal from models import Base, Post, Thread app = FastAPI() Base.metadata.create_all(bind=engine) def get_db(): db = SessionLocal() try: yield db finally: db.close() db_dependency = Annotated[Session, Depends(get_db)] class PostCreate(BaseModel): author: str = Field('anon') title: str = Field('') content: str = Field('') @app.get('/', status_code=status.HTTP_200_OK) async def get_posts(db: db_dependency): return db.query(Post).all() @app.post('/', status_code=status.HTTP_201_CREATED) async def create_thread(db: db_dependency, data: PostCreate): # Create the post post = Post( author=data.author, title=data.title, content=data.content ) db.add(post) db.commit() # Create the thread thread = Thread( author=post.author, title=post.title, content=post.content ) db.add(thread) db.commit() # Update the Post with thread_id post.thread_id = thread.id db.commit() return { 'id': post.id, 'author': post.author, 'title': post.title, 'content': post.content } @app.get('/{thread_id}', status_code=status.HTTP_200_OK) async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)): posts = db.query(Post).filter(Post.thread_id == thread_id).all() if posts: return posts raise HTTPException(404, f'Could not find thread') @app.post('/{thread_id}', status_code=status.HTTP_201_CREATED) async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)): thread = db.query(Thread).filter(Thread.id == thread_id).first() if thread: post = Post( thread_id=thread_id, author=data.author, title=data.title, content=data.content ) db.add(post) db.commit() return { 'id': post.id, 'author': post.author, 'title': post.title, 'content': post.content } raise HTTPException(404, f'Could not find thread') @app.get('/catalog', status_code=status.HTTP_200_OK) async def get_catalog(db: db_dependency): return db.query(Thread).all()