add jwt validation

This commit is contained in:
agatha 2024-04-06 15:29:08 -04:00
parent 353218aca7
commit ea4ae48b23

View File

@ -1,12 +1,12 @@
from datetime import timedelta, datetime from datetime import timedelta, datetime
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from passlib.context import CryptContext from passlib.context import CryptContext
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette import status from starlette import status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
from jose import jwt from jose import jwt, JWTError
from models import User from models import User
from database import SessionLocal from database import SessionLocal
@ -17,6 +17,7 @@ SECRET_KEY = '3b004eeae34b43bd05226f210d9bdc2ad99abdd3c52bf32802906085b762ff55'
ALGORITHM = 'HS256' ALGORITHM = 'HS256'
bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto') bcrypt_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
oauth2_bearer = OAuth2PasswordBearer(tokenUrl='validate')
def get_db(): def get_db():
@ -58,6 +59,19 @@ class Token(BaseModel):
token_type: str token_type: str
async def get_current_user(token: Annotated[str, Depends(oauth2_bearer)]):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get('sub')
user_id: int = payload.get('id')
if username is None or user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
return {'username': username, 'user_id': user_id}
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")
@router.post('/auth/create', status_code=status.HTTP_201_CREATED) @router.post('/auth/create', status_code=status.HTTP_201_CREATED)
async def create_user(db: db_dependency, data: CreateUser): async def create_user(db: db_dependency, data: CreateUser):
create_user_model = User( create_user_model = User(