Files
amap/backend/src/auth/auth.py

91 lines
2.8 KiB
Python

from fastapi import APIRouter, Security, HTTPException, Depends
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlmodel import Session
from src.settings import AUTH_URL, TOKEN_URL, JWKS_URL, ISSUER, settings
import src.users.service as service
from src.database import get_session
from src.models import UserCreate
import secrets
import jwt
from jwt import PyJWKClient
import requests
router = APIRouter(prefix="/auth")
jwk_client = PyJWKClient(JWKS_URL)
security = HTTPBearer()
@router.get('/login')
def login():
state = secrets.token_urlsafe(16)
params = {
"client_id": settings.keycloak_client_id,
"response_type": "code",
"scope": "openid",
"redirect_uri": settings.keycloak_redirect_uri,
"state": state,
}
request_url = requests.Request('GET', AUTH_URL, params=params).prepare().url
return RedirectResponse(request_url)
@router.get("/callback")
def callback(code: str, session: Session = Depends(get_session)):
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": settings.keycloak_redirect_uri,
"client_id": settings.keycloak_client_id,
"client_secret": settings.keycloak_client_secret,
}
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
response = requests.post(TOKEN_URL, data=data, headers=headers)
if response.status_code != 200:
return JSONResponse(
{"error": "Failed to get token"},
status_code=400
)
token_data = response.json()
id_token = token_data["id_token"]
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
user_create = UserCreate(
email=decoded_token.get("email"),
name=decoded_token.get("preferred_username")
)
print(user_create)
user = service.get_or_create_user(session, user_create)
return {
"access_token": token_data["access_token"],
"id_token": token_data["id_token"],
"refresh_token": token_data["refresh_token"],
}
def verify_token(token: str):
try:
signing_key = jwk_client.get_signing_key_from_jwt(token)
decoded = jwt.decode(token, options={"verify_signature": False})
payload = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience=settings.keycloak_client_id,
issuer=ISSUER,
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def get_current_user(
credentials: HTTPAuthorizationCredentials = Security(security)
):
return verify_token(credentials.credentials)