refactor app structure
This commit is contained in:
parent
4a31a1b671
commit
d0e593a268
110
backend/main.py
110
backend/main.py
@ -1,108 +1,12 @@
|
||||
from typing import Annotated
|
||||
from fastapi import FastAPI, Depends, HTTPException, Path
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette import status
|
||||
from fastapi import FastAPI
|
||||
|
||||
from database import engine, SessionLocal
|
||||
from models import Base, Post, Thread
|
||||
import models
|
||||
from database import engine
|
||||
from routers import auth, forum
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
db_dependency = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
class PostCreate(BaseModel):
|
||||
author: str = Field('anon')
|
||||
title: str = Field('')
|
||||
content: str = Field('')
|
||||
|
||||
|
||||
@app.get('/', status_code=status.HTTP_200_OK)
|
||||
async def get_posts(db: db_dependency):
|
||||
return db.query(Post).all()
|
||||
|
||||
|
||||
@app.post('/', status_code=status.HTTP_201_CREATED)
|
||||
async def create_thread(db: db_dependency, data: PostCreate):
|
||||
try:
|
||||
# Create the post
|
||||
post = Post(
|
||||
author=data.author,
|
||||
title=data.title,
|
||||
content=data.content
|
||||
)
|
||||
db.add(post)
|
||||
db.flush()
|
||||
|
||||
# Create the thread
|
||||
thread = Thread(
|
||||
author=post.author,
|
||||
title=post.title,
|
||||
content=post.content
|
||||
)
|
||||
db.add(thread)
|
||||
db.flush()
|
||||
|
||||
# Update the Post with thread_id
|
||||
post.thread_id = thread.id
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
'id': post.id,
|
||||
'author': post.author,
|
||||
'title': post.title,
|
||||
'content': post.content
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.get('/{thread_id}', status_code=status.HTTP_200_OK)
|
||||
async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)):
|
||||
posts = db.query(Post).filter(Post.thread_id == thread_id).all()
|
||||
if posts:
|
||||
return posts
|
||||
|
||||
raise HTTPException(404, f'Could not find thread')
|
||||
|
||||
|
||||
@app.post('/{thread_id}', status_code=status.HTTP_201_CREATED)
|
||||
async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)):
|
||||
thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
||||
if thread:
|
||||
post = Post(
|
||||
thread_id=thread_id,
|
||||
author=data.author,
|
||||
title=data.title,
|
||||
content=data.content
|
||||
)
|
||||
db.add(post)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
'id': post.id,
|
||||
'author': post.author,
|
||||
'title': post.title,
|
||||
'content': post.content
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail='Could not find thread')
|
||||
|
||||
|
||||
@app.get('/catalog', status_code=status.HTTP_200_OK)
|
||||
async def get_catalog(db: db_dependency):
|
||||
return db.query(Thread).all()
|
||||
models.Base.metadata.create_all(bind=engine)
|
||||
|
||||
app.include_router(auth.router)
|
||||
app.include_router(forum.router)
|
||||
|
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
8
backend/routers/auth.py
Normal file
8
backend/routers/auth.py
Normal file
@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('/auth/')
|
||||
async def get_user():
|
||||
return {'user': 'authenticated'}
|
106
backend/routers/forum.py
Normal file
106
backend/routers/forum.py
Normal file
@ -0,0 +1,106 @@
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette import status
|
||||
|
||||
from database import SessionLocal
|
||||
from models import Post, Thread
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
db_dependency = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
class PostCreate(BaseModel):
|
||||
author: str = Field('anon')
|
||||
title: str = Field('')
|
||||
content: str = Field('')
|
||||
|
||||
|
||||
@router.get('/', status_code=status.HTTP_200_OK)
|
||||
async def get_posts(db: db_dependency):
|
||||
return db.query(Post).all()
|
||||
|
||||
|
||||
@router.post('/', status_code=status.HTTP_201_CREATED)
|
||||
async def create_thread(db: db_dependency, data: PostCreate):
|
||||
try:
|
||||
# Create the post
|
||||
post = Post(
|
||||
author=data.author,
|
||||
title=data.title,
|
||||
content=data.content
|
||||
)
|
||||
db.add(post)
|
||||
db.flush()
|
||||
|
||||
# Create the thread
|
||||
thread = Thread(
|
||||
author=post.author,
|
||||
title=post.title,
|
||||
content=post.content
|
||||
)
|
||||
db.add(thread)
|
||||
db.flush()
|
||||
|
||||
# Update the Post with thread_id
|
||||
post.thread_id = thread.id
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
'id': post.id,
|
||||
'author': post.author,
|
||||
'title': post.title,
|
||||
'content': post.content
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get('/{thread_id}', status_code=status.HTTP_200_OK)
|
||||
async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)):
|
||||
posts = db.query(Post).filter(Post.thread_id == thread_id).all()
|
||||
if posts:
|
||||
return posts
|
||||
|
||||
raise HTTPException(404, f'Could not find thread')
|
||||
|
||||
|
||||
@router.post('/{thread_id}', status_code=status.HTTP_201_CREATED)
|
||||
async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)):
|
||||
thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
||||
if thread:
|
||||
post = Post(
|
||||
thread_id=thread_id,
|
||||
author=data.author,
|
||||
title=data.title,
|
||||
content=data.content
|
||||
)
|
||||
db.add(post)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
'id': post.id,
|
||||
'author': post.author,
|
||||
'title': post.title,
|
||||
'content': post.content
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail='Could not find thread')
|
||||
|
||||
|
||||
@router.get('/catalog', status_code=status.HTTP_200_OK)
|
||||
async def get_catalog(db: db_dependency):
|
||||
return db.query(Thread).all()
|
||||
|
Loading…
Reference in New Issue
Block a user