diff --git a/backend/main.py b/backend/main.py index bc04e52..a8de0cb 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,108 +1,12 @@ -from typing import Annotated -from fastapi import FastAPI, Depends, HTTPException, Path -from sqlalchemy.orm import Session -from pydantic import BaseModel, Field -from starlette import status +from fastapi import FastAPI -from database import engine, SessionLocal -from models import Base, Post, Thread +import models +from database import engine +from routers import auth, forum app = FastAPI() -Base.metadata.create_all(bind=engine) - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -db_dependency = Annotated[Session, Depends(get_db)] - - -class PostCreate(BaseModel): - author: str = Field('anon') - title: str = Field('') - content: str = Field('') - - -@app.get('/', status_code=status.HTTP_200_OK) -async def get_posts(db: db_dependency): - return db.query(Post).all() - - -@app.post('/', status_code=status.HTTP_201_CREATED) -async def create_thread(db: db_dependency, data: PostCreate): - try: - # Create the post - post = Post( - author=data.author, - title=data.title, - content=data.content - ) - db.add(post) - db.flush() - - # Create the thread - thread = Thread( - author=post.author, - title=post.title, - content=post.content - ) - db.add(thread) - db.flush() - - # Update the Post with thread_id - post.thread_id = thread.id - db.commit() - - return { - 'id': post.id, - 'author': post.author, - 'title': post.title, - 'content': post.content - } - except Exception as e: - db.rollback() - raise HTTPException(status_code=400, detail=str(e)) - - -@app.get('/{thread_id}', status_code=status.HTTP_200_OK) -async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)): - posts = db.query(Post).filter(Post.thread_id == thread_id).all() - if posts: - return posts - - raise HTTPException(404, f'Could not find thread') - - -@app.post('/{thread_id}', status_code=status.HTTP_201_CREATED) -async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)): - thread = db.query(Thread).filter(Thread.id == thread_id).first() - if thread: - post = Post( - thread_id=thread_id, - author=data.author, - title=data.title, - content=data.content - ) - db.add(post) - db.commit() - - return { - 'id': post.id, - 'author': post.author, - 'title': post.title, - 'content': post.content - } - - raise HTTPException(status_code=404, detail='Could not find thread') - - -@app.get('/catalog', status_code=status.HTTP_200_OK) -async def get_catalog(db: db_dependency): - return db.query(Thread).all() +models.Base.metadata.create_all(bind=engine) +app.include_router(auth.router) +app.include_router(forum.router) diff --git a/backend/routers/__init__.py b/backend/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/routers/auth.py b/backend/routers/auth.py new file mode 100644 index 0000000..7980643 --- /dev/null +++ b/backend/routers/auth.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +router = APIRouter() + + +@router.get('/auth/') +async def get_user(): + return {'user': 'authenticated'} diff --git a/backend/routers/forum.py b/backend/routers/forum.py new file mode 100644 index 0000000..a7fda60 --- /dev/null +++ b/backend/routers/forum.py @@ -0,0 +1,106 @@ +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException, Path +from sqlalchemy.orm import Session +from pydantic import BaseModel, Field +from starlette import status + +from database import SessionLocal +from models import Post, Thread + +router = APIRouter() + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +db_dependency = Annotated[Session, Depends(get_db)] + + +class PostCreate(BaseModel): + author: str = Field('anon') + title: str = Field('') + content: str = Field('') + + +@router.get('/', status_code=status.HTTP_200_OK) +async def get_posts(db: db_dependency): + return db.query(Post).all() + + +@router.post('/', status_code=status.HTTP_201_CREATED) +async def create_thread(db: db_dependency, data: PostCreate): + try: + # Create the post + post = Post( + author=data.author, + title=data.title, + content=data.content + ) + db.add(post) + db.flush() + + # Create the thread + thread = Thread( + author=post.author, + title=post.title, + content=post.content + ) + db.add(thread) + db.flush() + + # Update the Post with thread_id + post.thread_id = thread.id + db.commit() + + return { + 'id': post.id, + 'author': post.author, + 'title': post.title, + 'content': post.content + } + except Exception as e: + db.rollback() + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get('/{thread_id}', status_code=status.HTTP_200_OK) +async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)): + posts = db.query(Post).filter(Post.thread_id == thread_id).all() + if posts: + return posts + + raise HTTPException(404, f'Could not find thread') + + +@router.post('/{thread_id}', status_code=status.HTTP_201_CREATED) +async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)): + thread = db.query(Thread).filter(Thread.id == thread_id).first() + if thread: + post = Post( + thread_id=thread_id, + author=data.author, + title=data.title, + content=data.content + ) + db.add(post) + db.commit() + + return { + 'id': post.id, + 'author': post.author, + 'title': post.title, + 'content': post.content + } + + raise HTTPException(status_code=404, detail='Could not find thread') + + +@router.get('/catalog', status_code=status.HTTP_200_OK) +async def get_catalog(db: db_dependency): + return db.query(Thread).all() +