feat: add authentication for administration (#2)
Reviewed-on: #2 Co-authored-by: agatha <agatha@juggalol.com> Co-committed-by: agatha <agatha@juggalol.com>
This commit is contained in:
parent
0525dec9b9
commit
2db6e18d5b
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
|
.idea/
|
||||||
forum.db
|
forum.db
|
||||||
|
110
backend/main.py
110
backend/main.py
@ -1,108 +1,12 @@
|
|||||||
from typing import Annotated
|
from fastapi import FastAPI
|
||||||
from fastapi import FastAPI, Depends, HTTPException, Path
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from starlette import status
|
|
||||||
|
|
||||||
from database import engine, SessionLocal
|
import models
|
||||||
from models import Base, Post, Thread
|
from database import engine
|
||||||
|
from routers import auth, forum
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
Base.metadata.create_all(bind=engine)
|
models.Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
db_dependency = Annotated[Session, Depends(get_db)]
|
|
||||||
|
|
||||||
|
|
||||||
class PostCreate(BaseModel):
|
|
||||||
author: str = Field('anon')
|
|
||||||
title: str = Field('')
|
|
||||||
content: str = Field('')
|
|
||||||
|
|
||||||
|
|
||||||
@app.get('/', status_code=status.HTTP_200_OK)
|
|
||||||
async def get_posts(db: db_dependency):
|
|
||||||
return db.query(Post).all()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post('/', status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_thread(db: db_dependency, data: PostCreate):
|
|
||||||
try:
|
|
||||||
# Create the post
|
|
||||||
post = Post(
|
|
||||||
author=data.author,
|
|
||||||
title=data.title,
|
|
||||||
content=data.content
|
|
||||||
)
|
|
||||||
db.add(post)
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
# Create the thread
|
|
||||||
thread = Thread(
|
|
||||||
author=post.author,
|
|
||||||
title=post.title,
|
|
||||||
content=post.content
|
|
||||||
)
|
|
||||||
db.add(thread)
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
# Update the Post with thread_id
|
|
||||||
post.thread_id = thread.id
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
'id': post.id,
|
|
||||||
'author': post.author,
|
|
||||||
'title': post.title,
|
|
||||||
'content': post.content
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@app.get('/{thread_id}', status_code=status.HTTP_200_OK)
|
|
||||||
async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)):
|
|
||||||
posts = db.query(Post).filter(Post.thread_id == thread_id).all()
|
|
||||||
if posts:
|
|
||||||
return posts
|
|
||||||
|
|
||||||
raise HTTPException(404, f'Could not find thread')
|
|
||||||
|
|
||||||
|
|
||||||
@app.post('/{thread_id}', status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)):
|
|
||||||
thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
|
||||||
if thread:
|
|
||||||
post = Post(
|
|
||||||
thread_id=thread_id,
|
|
||||||
author=data.author,
|
|
||||||
title=data.title,
|
|
||||||
content=data.content
|
|
||||||
)
|
|
||||||
db.add(post)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
'id': post.id,
|
|
||||||
'author': post.author,
|
|
||||||
'title': post.title,
|
|
||||||
'content': post.content
|
|
||||||
}
|
|
||||||
|
|
||||||
raise HTTPException(status_code=404, detail='Could not find thread')
|
|
||||||
|
|
||||||
|
|
||||||
@app.get('/catalog', status_code=status.HTTP_200_OK)
|
|
||||||
async def get_catalog(db: db_dependency):
|
|
||||||
return db.query(Thread).all()
|
|
||||||
|
|
||||||
|
app.include_router(forum.router)
|
||||||
|
app.include_router(auth.router)
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from sqlalchemy.orm import relationship
|
|
||||||
|
|
||||||
from database import Base
|
from database import Base
|
||||||
from sqlalchemy import Column, Integer, String, ForeignKey
|
from sqlalchemy import Column, Integer, String, ForeignKey
|
||||||
|
|
||||||
@ -7,7 +5,7 @@ from sqlalchemy import Column, Integer, String, ForeignKey
|
|||||||
class Post(Base):
|
class Post(Base):
|
||||||
__tablename__ = 'posts'
|
__tablename__ = 'posts'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
thread_id = Column(Integer, ForeignKey("threads.id"))
|
thread_id = Column(Integer, ForeignKey("threads.id"))
|
||||||
author = Column(String)
|
author = Column(String)
|
||||||
title = Column(String)
|
title = Column(String)
|
||||||
@ -17,7 +15,17 @@ class Post(Base):
|
|||||||
class Thread(Base):
|
class Thread(Base):
|
||||||
__tablename__ = 'threads'
|
__tablename__ = 'threads'
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
author = Column(String)
|
author = Column(String)
|
||||||
title = Column(String)
|
title = Column(String)
|
||||||
content = Column(String)
|
content = Column(String)
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = 'users'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
username = Column(String, unique=True)
|
||||||
|
email = Column(String, unique=True)
|
||||||
|
password = Column(String)
|
||||||
|
role = Column(String)
|
||||||
|
@ -1,3 +1,8 @@
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn[standard]
|
uvicorn[standard]
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
|
passlib[bcrypt]
|
||||||
|
pydantic
|
||||||
|
starlette
|
||||||
|
python-multipart
|
||||||
|
python-jose[cryptography]
|
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
109
backend/routers/auth.py
Normal file
109
backend/routers/auth.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
from datetime import timedelta, datetime, timezone
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from starlette import status
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
|
||||||
|
from jose import jwt, JWTError
|
||||||
|
|
||||||
|
from models import User
|
||||||
|
from database import SessionLocal
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix='/auth',
|
||||||
|
tags=['auth']
|
||||||
|
)
|
||||||
|
|
||||||
|
SECRET_KEY = '3b004eeae34b43bd05226f210d9bdc2ad99abdd3c52bf32802906085b762ff55'
|
||||||
|
ALGORITHM = 'HS256'
|
||||||
|
|
||||||
|
bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
||||||
|
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='auth/token')
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get('sub')
|
||||||
|
user_id: int = payload.get('id')
|
||||||
|
role: str = payload.get('role')
|
||||||
|
if username is None or user_id is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="1Could not validate credentials")
|
||||||
|
return {'username': username, 'user_id': user_id, 'role': role}
|
||||||
|
except JWTError as e:
|
||||||
|
print(str(e))
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2Could not validate credentials")
|
||||||
|
|
||||||
|
|
||||||
|
db_dependency = Annotated[Session, Depends(get_db)]
|
||||||
|
user_dependency = Annotated[dict, Depends(get_current_user)]
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(username: str, password: str, db):
|
||||||
|
user = db.query(User).filter(User.username == username).first()
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
if not bcrypt_context.verify(password, user.password):
|
||||||
|
return False
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta):
|
||||||
|
encode = {'sub': username, 'id': user_id, 'role': role}
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
encode.update({'exp': expire})
|
||||||
|
|
||||||
|
return jwt.encode(encode, SECRET_KEY, ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUser(BaseModel):
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/user/create', status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_user(user: user_dependency, db: db_dependency, data: CreateUser):
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Authentication failed')
|
||||||
|
|
||||||
|
if user['role'] != 'admin':
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Not authorized')
|
||||||
|
|
||||||
|
create_user_model = User(
|
||||||
|
username=data.username,
|
||||||
|
email=data.email,
|
||||||
|
password=bcrypt_context.hash(data.password),
|
||||||
|
role='admin'
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(create_user_model)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/token', status_code=status.HTTP_200_OK, response_model=Token)
|
||||||
|
async def get_token(
|
||||||
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
db: db_dependency
|
||||||
|
):
|
||||||
|
user = authenticate_user(form_data.username, form_data.password, db)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
|
||||||
|
|
||||||
|
token = create_access_token(user.username, user.id, user.role, timedelta(minutes=20))
|
||||||
|
return {'access_token': token, 'token_type': 'bearer'}
|
106
backend/routers/forum.py
Normal file
106
backend/routers/forum.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from starlette import status
|
||||||
|
|
||||||
|
from database import SessionLocal
|
||||||
|
from models import Post, Thread
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
tags=['forum']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
db_dependency = Annotated[Session, Depends(get_db)]
|
||||||
|
|
||||||
|
|
||||||
|
class PostCreate(BaseModel):
|
||||||
|
author: str = Field('anon')
|
||||||
|
title: str = Field('')
|
||||||
|
content: str = Field('')
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/catalog', status_code=status.HTTP_200_OK)
|
||||||
|
async def get_catalog(db: db_dependency):
|
||||||
|
return db.query(Thread).all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/', status_code=status.HTTP_200_OK)
|
||||||
|
async def get_posts(db: db_dependency):
|
||||||
|
return db.query(Post).all()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/{thread_id}', status_code=status.HTTP_200_OK)
|
||||||
|
async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)):
|
||||||
|
posts = db.query(Post).filter(Post.thread_id == thread_id).all()
|
||||||
|
if posts:
|
||||||
|
return posts
|
||||||
|
|
||||||
|
raise HTTPException(404, f'Could not find thread')
|
||||||
|
|
||||||
|
@router.post('/', status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_thread(db: db_dependency, data: PostCreate):
|
||||||
|
try:
|
||||||
|
# Create the post
|
||||||
|
post = Post(
|
||||||
|
author=data.author,
|
||||||
|
title=data.title,
|
||||||
|
content=data.content
|
||||||
|
)
|
||||||
|
db.add(post)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
# Create the thread
|
||||||
|
thread = Thread(
|
||||||
|
author=post.author,
|
||||||
|
title=post.title,
|
||||||
|
content=post.content
|
||||||
|
)
|
||||||
|
db.add(thread)
|
||||||
|
db.flush()
|
||||||
|
|
||||||
|
# Update the Post with thread_id
|
||||||
|
post.thread_id = thread.id
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'id': post.id,
|
||||||
|
'author': post.author,
|
||||||
|
'title': post.title,
|
||||||
|
'content': post.content
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/{thread_id}', status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)):
|
||||||
|
thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
||||||
|
if thread:
|
||||||
|
post = Post(
|
||||||
|
thread_id=thread_id,
|
||||||
|
author=data.author,
|
||||||
|
title=data.title,
|
||||||
|
content=data.content
|
||||||
|
)
|
||||||
|
db.add(post)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'id': post.id,
|
||||||
|
'author': post.author,
|
||||||
|
'title': post.title,
|
||||||
|
'content': post.content
|
||||||
|
}
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail='Could not find thread')
|
Loading…
Reference in New Issue
Block a user