From 2db6e18d5b77c16703fe92050d0f5af788059f8c Mon Sep 17 00:00:00 2001 From: agatha Date: Sat, 6 Apr 2024 19:56:48 +0000 Subject: [PATCH] feat: add authentication for administration (#2) Reviewed-on: https://git.juggalol.com/agatha/forum-app/pulls/2 Co-authored-by: agatha Co-committed-by: agatha --- backend/.gitignore | 1 + backend/main.py | 110 +++--------------------------------- backend/models.py | 16 ++++-- backend/requirements.txt | 5 ++ backend/routers/__init__.py | 0 backend/routers/auth.py | 109 +++++++++++++++++++++++++++++++++++ backend/routers/forum.py | 106 ++++++++++++++++++++++++++++++++++ 7 files changed, 240 insertions(+), 107 deletions(-) create mode 100644 backend/routers/__init__.py create mode 100644 backend/routers/auth.py create mode 100644 backend/routers/forum.py diff --git a/backend/.gitignore b/backend/.gitignore index 866cead..b9b8699 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1 +1,2 @@ +.idea/ forum.db diff --git a/backend/main.py b/backend/main.py index bc04e52..e392f81 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,108 +1,12 @@ -from typing import Annotated -from fastapi import FastAPI, Depends, HTTPException, Path -from sqlalchemy.orm import Session -from pydantic import BaseModel, Field -from starlette import status +from fastapi import FastAPI -from database import engine, SessionLocal -from models import Base, Post, Thread +import models +from database import engine +from routers import auth, forum app = FastAPI() -Base.metadata.create_all(bind=engine) - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -db_dependency = Annotated[Session, Depends(get_db)] - - -class PostCreate(BaseModel): - author: str = Field('anon') - title: str = Field('') - content: str = Field('') - - -@app.get('/', status_code=status.HTTP_200_OK) -async def get_posts(db: db_dependency): - return db.query(Post).all() - - -@app.post('/', status_code=status.HTTP_201_CREATED) -async def create_thread(db: db_dependency, data: PostCreate): - try: - # Create the post - post = Post( - author=data.author, - title=data.title, - content=data.content - ) - db.add(post) - db.flush() - - # Create the thread - thread = Thread( - author=post.author, - title=post.title, - content=post.content - ) - db.add(thread) - db.flush() - - # Update the Post with thread_id - post.thread_id = thread.id - db.commit() - - return { - 'id': post.id, - 'author': post.author, - 'title': post.title, - 'content': post.content - } - except Exception as e: - db.rollback() - raise HTTPException(status_code=400, detail=str(e)) - - -@app.get('/{thread_id}', status_code=status.HTTP_200_OK) -async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)): - posts = db.query(Post).filter(Post.thread_id == thread_id).all() - if posts: - return posts - - raise HTTPException(404, f'Could not find thread') - - -@app.post('/{thread_id}', status_code=status.HTTP_201_CREATED) -async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)): - thread = db.query(Thread).filter(Thread.id == thread_id).first() - if thread: - post = Post( - thread_id=thread_id, - author=data.author, - title=data.title, - content=data.content - ) - db.add(post) - db.commit() - - return { - 'id': post.id, - 'author': post.author, - 'title': post.title, - 'content': post.content - } - - raise HTTPException(status_code=404, detail='Could not find thread') - - -@app.get('/catalog', status_code=status.HTTP_200_OK) -async def get_catalog(db: db_dependency): - return db.query(Thread).all() +models.Base.metadata.create_all(bind=engine) +app.include_router(forum.router) +app.include_router(auth.router) diff --git a/backend/models.py b/backend/models.py index d342728..fdd77b2 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,5 +1,3 @@ -from sqlalchemy.orm import relationship - from database import Base from sqlalchemy import Column, Integer, String, ForeignKey @@ -7,7 +5,7 @@ from sqlalchemy import Column, Integer, String, ForeignKey class Post(Base): __tablename__ = 'posts' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, index=True) thread_id = Column(Integer, ForeignKey("threads.id")) author = Column(String) title = Column(String) @@ -17,7 +15,17 @@ class Post(Base): class Thread(Base): __tablename__ = 'threads' - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, index=True) author = Column(String) title = Column(String) content = Column(String) + + +class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True, index=True) + username = Column(String, unique=True) + email = Column(String, unique=True) + password = Column(String) + role = Column(String) diff --git a/backend/requirements.txt b/backend/requirements.txt index 09b688d..7fe8d81 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,3 +1,8 @@ fastapi uvicorn[standard] sqlalchemy +passlib[bcrypt] +pydantic +starlette +python-multipart +python-jose[cryptography] \ No newline at end of file diff --git a/backend/routers/__init__.py b/backend/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/routers/auth.py b/backend/routers/auth.py new file mode 100644 index 0000000..8bb11f9 --- /dev/null +++ b/backend/routers/auth.py @@ -0,0 +1,109 @@ +from datetime import timedelta, datetime, timezone +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from passlib.context import CryptContext +from sqlalchemy.orm import Session +from starlette import status +from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer +from jose import jwt, JWTError + +from models import User +from database import SessionLocal + +router = APIRouter( + prefix='/auth', + tags=['auth'] +) + +SECRET_KEY = '3b004eeae34b43bd05226f210d9bdc2ad99abdd3c52bf32802906085b762ff55' +ALGORITHM = 'HS256' + +bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto') +oauth2_bearer = OAuth2PasswordBearer(tokenUrl='auth/token') + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get('sub') + user_id: int = payload.get('id') + role: str = payload.get('role') + if username is None or user_id is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="1Could not validate credentials") + return {'username': username, 'user_id': user_id, 'role': role} + except JWTError as e: + print(str(e)) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2Could not validate credentials") + + +db_dependency = Annotated[Session, Depends(get_db)] +user_dependency = Annotated[dict, Depends(get_current_user)] + + +def authenticate_user(username: str, password: str, db): + user = db.query(User).filter(User.username == username).first() + if not user: + return False + if not bcrypt_context.verify(password, user.password): + return False + return user + + +def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta): + encode = {'sub': username, 'id': user_id, 'role': role} + expire = datetime.now(timezone.utc) + expires_delta + encode.update({'exp': expire}) + + return jwt.encode(encode, SECRET_KEY, ALGORITHM) + + +class CreateUser(BaseModel): + username: str + email: str + password: str + + +class Token(BaseModel): + access_token: str + token_type: str + + +@router.post('/user/create', status_code=status.HTTP_201_CREATED) +async def create_user(user: user_dependency, db: db_dependency, data: CreateUser): + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Authentication failed') + + if user['role'] != 'admin': + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Not authorized') + + create_user_model = User( + username=data.username, + email=data.email, + password=bcrypt_context.hash(data.password), + role='admin' + ) + + db.add(create_user_model) + db.commit() + + +@router.post('/token', status_code=status.HTTP_200_OK, response_model=Token) +async def get_token( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + db: db_dependency +): + user = authenticate_user(form_data.username, form_data.password, db) + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") + + token = create_access_token(user.username, user.id, user.role, timedelta(minutes=20)) + return {'access_token': token, 'token_type': 'bearer'} diff --git a/backend/routers/forum.py b/backend/routers/forum.py new file mode 100644 index 0000000..5558dbc --- /dev/null +++ b/backend/routers/forum.py @@ -0,0 +1,106 @@ +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException, Path +from sqlalchemy.orm import Session +from pydantic import BaseModel, Field +from starlette import status + +from database import SessionLocal +from models import Post, Thread + +router = APIRouter( + tags=['forum'] +) + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() + + +db_dependency = Annotated[Session, Depends(get_db)] + + +class PostCreate(BaseModel): + author: str = Field('anon') + title: str = Field('') + content: str = Field('') + + +@router.get('/catalog', status_code=status.HTTP_200_OK) +async def get_catalog(db: db_dependency): + return db.query(Thread).all() + + +@router.get('/', status_code=status.HTTP_200_OK) +async def get_posts(db: db_dependency): + return db.query(Post).all() + + +@router.get('/{thread_id}', status_code=status.HTTP_200_OK) +async def get_thread_by_id(db: db_dependency, thread_id: int = Path(gt=0)): + posts = db.query(Post).filter(Post.thread_id == thread_id).all() + if posts: + return posts + + raise HTTPException(404, f'Could not find thread') + +@router.post('/', status_code=status.HTTP_201_CREATED) +async def create_thread(db: db_dependency, data: PostCreate): + try: + # Create the post + post = Post( + author=data.author, + title=data.title, + content=data.content + ) + db.add(post) + db.flush() + + # Create the thread + thread = Thread( + author=post.author, + title=post.title, + content=post.content + ) + db.add(thread) + db.flush() + + # Update the Post with thread_id + post.thread_id = thread.id + db.commit() + + return { + 'id': post.id, + 'author': post.author, + 'title': post.title, + 'content': post.content + } + except Exception as e: + db.rollback() + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post('/{thread_id}', status_code=status.HTTP_201_CREATED) +async def create_reply(db: db_dependency, data: PostCreate, thread_id: int = Path(gt=0)): + thread = db.query(Thread).filter(Thread.id == thread_id).first() + if thread: + post = Post( + thread_id=thread_id, + author=data.author, + title=data.title, + content=data.content + ) + db.add(post) + db.commit() + + return { + 'id': post.id, + 'author': post.author, + 'title': post.title, + 'content': post.content + } + + raise HTTPException(status_code=404, detail='Could not find thread')