Compare commits

...

9 Commits

Author SHA1 Message Date
03a95c3872 admins can create new users 2024-04-06 15:55:18 -04:00
24887274f0 add router tags 2024-04-06 15:36:48 -04:00
ea4ae48b23 add jwt validation 2024-04-06 15:29:08 -04:00
353218aca7 implement jwt 2024-04-06 15:08:41 -04:00
8aa077441f feat: add login and user creation endpoints 2024-04-06 14:51:28 -04:00
594c5c6de8 chore: refactor 2024-04-06 14:50:59 -04:00
d0e593a268 refactor app structure 2024-04-06 13:58:20 -04:00
4a31a1b671 update backend .gitignore 2024-04-06 13:55:38 -04:00
42ee9326ae chore: remove unused import 2024-04-06 13:33:44 -04:00
7 changed files with 240 additions and 107 deletions

1
backend/.gitignore vendored
View File

@ -1 +1,2 @@
.idea/
forum.db forum.db

View File

@ -1,108 +1,12 @@
from typing import Annotated from fastapi import FastAPI
from fastapi import FastAPI, Depends, HTTPException, Path
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from starlette import status
from database import engine, SessionLocal import models
from models import Base, Post, Thread from database import engine
from routers import auth, forum
app = FastAPI() app = FastAPI()
Base.metadata.create_all(bind=engine) models.Base.metadata.create_all(bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
db_dependency = Annotated[Session, Depends(get_db)]
class PostCreate(BaseModel):
author: str = Field('anon')
title: str = Field('')
content: str = Field('')
@app.get('/', status_code=status.HTTP_200_OK)
async def get_posts(db: db_dependency):
return db.query(Post).all()
@app.post('/', status_code=status.HTTP_201_CREATED)
async def create_thread(db: db_dependency, data: PostCreate):
try:
# Create the post
post = Post(
author=data.author,
title=data.title,
content=data.content
)
db.add(post)
db.flush()
# Create the thread
thread = Thread(
author=post.author,
title=post.title,
content=post.content
)
db.add(thread)
db.flush()
# Update the Post with thread_id
post.thread_id = thread.id
db.commit()
return {
'id': post.id,
'author': post.author,
'title': post.title,
'content': post.content
}
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e))
@app.get('/{thread_id}', status_code=status.HTTP_200_OK)
async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)):
posts = db.query(Post).filter(Post.thread_id == thread_id).all()
if posts:
return posts
raise HTTPException(404, f'Could not find thread')
@app.post('/{thread_id}', status_code=status.HTTP_201_CREATED)
async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)):
thread = db.query(Thread).filter(Thread.id == thread_id).first()
if thread:
post = Post(
thread_id=thread_id,
author=data.author,
title=data.title,
content=data.content
)
db.add(post)
db.commit()
return {
'id': post.id,
'author': post.author,
'title': post.title,
'content': post.content
}
raise HTTPException(status_code=404, detail='Could not find thread')
@app.get('/catalog', status_code=status.HTTP_200_OK)
async def get_catalog(db: db_dependency):
return db.query(Thread).all()
app.include_router(forum.router)
app.include_router(auth.router)

View File

@ -1,5 +1,3 @@
from sqlalchemy.orm import relationship
from database import Base from database import Base
from sqlalchemy import Column, Integer, String, ForeignKey from sqlalchemy import Column, Integer, String, ForeignKey
@ -7,7 +5,7 @@ from sqlalchemy import Column, Integer, String, ForeignKey
class Post(Base): class Post(Base):
__tablename__ = 'posts' __tablename__ = 'posts'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True, index=True)
thread_id = Column(Integer, ForeignKey("threads.id")) thread_id = Column(Integer, ForeignKey("threads.id"))
author = Column(String) author = Column(String)
title = Column(String) title = Column(String)
@ -17,7 +15,17 @@ class Post(Base):
class Thread(Base): class Thread(Base):
__tablename__ = 'threads' __tablename__ = 'threads'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True, index=True)
author = Column(String) author = Column(String)
title = Column(String) title = Column(String)
content = Column(String) content = Column(String)
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True)
email = Column(String, unique=True)
password = Column(String)
role = Column(String)

View File

@ -1,3 +1,8 @@
fastapi fastapi
uvicorn[standard] uvicorn[standard]
sqlalchemy sqlalchemy
passlib[bcrypt]
pydantic
starlette
python-multipart
python-jose[cryptography]

View File

109
backend/routers/auth.py Normal file
View File

@ -0,0 +1,109 @@
from datetime import timedelta, datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from starlette import status
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
from jose import jwt, JWTError
from models import User
from database import SessionLocal
router = APIRouter(
prefix='/auth',
tags=['auth']
)
SECRET_KEY = '3b004eeae34b43bd05226f210d9bdc2ad99abdd3c52bf32802906085b762ff55'
ALGORITHM = 'HS256'
bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='auth/token')
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get('sub')
user_id: int = payload.get('id')
role: str = payload.get('role')
if username is None or user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="1Could not validate credentials")
return {'username': username, 'user_id': user_id, 'role': role}
except JWTError as e:
print(str(e))
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2Could not validate credentials")
db_dependency = Annotated[Session, Depends(get_db)]
user_dependency = Annotated[dict, Depends(get_current_user)]
def authenticate_user(username: str, password: str, db):
user = db.query(User).filter(User.username == username).first()
if not user:
return False
if not bcrypt_context.verify(password, user.password):
return False
return user
def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta):
encode = {'sub': username, 'id': user_id, 'role': role}
expire = datetime.now(timezone.utc) + expires_delta
encode.update({'exp': expire})
return jwt.encode(encode, SECRET_KEY, ALGORITHM)
class CreateUser(BaseModel):
username: str
email: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
@router.post('/user/create', status_code=status.HTTP_201_CREATED)
async def create_user(user: user_dependency, db: db_dependency, data: CreateUser):
if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Authentication failed')
if user['role'] != 'admin':
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Not authorized')
create_user_model = User(
username=data.username,
email=data.email,
password=bcrypt_context.hash(data.password),
role='admin'
)
db.add(create_user_model)
db.commit()
@router.post('/token', status_code=status.HTTP_200_OK, response_model=Token)
async def get_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: db_dependency
):
user = authenticate_user(form_data.username, form_data.password, db)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
token = create_access_token(user.username, user.id, user.role, timedelta(minutes=20))
return {'access_token': token, 'token_type': 'bearer'}

106
backend/routers/forum.py Normal file
View File

@ -0,0 +1,106 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from starlette import status
from database import SessionLocal
from models import Post, Thread
router = APIRouter(
tags=['forum']
)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
db_dependency = Annotated[Session, Depends(get_db)]
class PostCreate(BaseModel):
author: str = Field('anon')
title: str = Field('')
content: str = Field('')
@router.get('/catalog', status_code=status.HTTP_200_OK)
async def get_catalog(db: db_dependency):
return db.query(Thread).all()
@router.get('/', status_code=status.HTTP_200_OK)
async def get_posts(db: db_dependency):
return db.query(Post).all()
@router.get('/{thread_id}', status_code=status.HTTP_200_OK)
async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)):
posts = db.query(Post).filter(Post.thread_id == thread_id).all()
if posts:
return posts
raise HTTPException(404, f'Could not find thread')
@router.post('/', status_code=status.HTTP_201_CREATED)
async def create_thread(db: db_dependency, data: PostCreate):
try:
# Create the post
post = Post(
author=data.author,
title=data.title,
content=data.content
)
db.add(post)
db.flush()
# Create the thread
thread = Thread(
author=post.author,
title=post.title,
content=post.content
)
db.add(thread)
db.flush()
# Update the Post with thread_id
post.thread_id = thread.id
db.commit()
return {
'id': post.id,
'author': post.author,
'title': post.title,
'content': post.content
}
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e))
@router.post('/{thread_id}', status_code=status.HTTP_201_CREATED)
async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)):
thread = db.query(Thread).filter(Thread.id == thread_id).first()
if thread:
post = Post(
thread_id=thread_id,
author=data.author,
title=data.title,
content=data.content
)
db.add(post)
db.commit()
return {
'id': post.id,
'author': post.author,
'title': post.title,
'content': post.content
}
raise HTTPException(status_code=404, detail='Could not find thread')