diff --git a/backend/routers/auth.py b/backend/routers/auth.py index 29ee23d..8bb11f9 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -1,4 +1,4 @@ -from datetime import timedelta, datetime +from datetime import timedelta, datetime, timezone from typing import Annotated from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel @@ -31,7 +31,22 @@ def get_db(): db.close() +async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get('sub') + user_id: int = payload.get('id') + role: str = payload.get('role') + if username is None or user_id is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="1Could not validate credentials") + return {'username': username, 'user_id': user_id, 'role': role} + except JWTError as e: + print(str(e)) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="2Could not validate credentials") + + db_dependency = Annotated[Session, Depends(get_db)] +user_dependency = Annotated[dict, Depends(get_current_user)] def authenticate_user(username: str, password: str, db): @@ -43,9 +58,9 @@ def authenticate_user(username: str, password: str, db): return user -def create_access_token(username: str, user_id: int, expires_delta: timedelta): - encode = {'sub': username, 'id': user_id} - expire = datetime.now() + expires_delta +def create_access_token(username: str, user_id: int, role: str, expires_delta: timedelta): + encode = {'sub': username, 'id': user_id, 'role': role} + expire = datetime.now(timezone.utc) + expires_delta encode.update({'exp': expire}) return jwt.encode(encode, SECRET_KEY, ALGORITHM) @@ -62,21 +77,14 @@ class Token(BaseModel): token_type: str -async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]): - try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - username: str = payload.get('sub') - user_id: int = payload.get('id') - if username is None or user_id is None: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") - - return {'username': username, 'user_id': user_id} - except JWTError: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") - - @router.post('/user/create', status_code=status.HTTP_201_CREATED) -async def create_user(db: db_dependency, data: CreateUser): +async def create_user(user: user_dependency, db: db_dependency, data: CreateUser): + if user is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Authentication failed') + + if user['role'] != 'admin': + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail='Not authorized') + create_user_model = User( username=data.username, email=data.email, @@ -97,5 +105,5 @@ async def get_token( if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") - token = create_access_token(user.username, user.id, timedelta(minutes=20)) + token = create_access_token(user.username, user.id, user.role, timedelta(minutes=20)) return {'access_token': token, 'token_type': 'bearer'}