from fastapi import APIRouter, Security, HTTPException, Depends from fastapi.responses import RedirectResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlmodel import Session from src.settings import AUTH_URL, TOKEN_URL, JWKS_URL, ISSUER, settings import src.users.service as service from src.database import get_session from src.models import UserCreate import secrets import jwt from jwt import PyJWKClient import requests router = APIRouter(prefix="/auth") jwk_client = PyJWKClient(JWKS_URL) security = HTTPBearer() @router.get('/login') def login(): state = secrets.token_urlsafe(16) params = { "client_id": settings.keycloak_client_id, "response_type": "code", "scope": "openid", "redirect_uri": settings.keycloak_redirect_uri, "state": state, } request_url = requests.Request('GET', AUTH_URL, params=params).prepare().url return RedirectResponse(request_url) @router.get("/callback") def callback(code: str, session: Session = Depends(get_session)): data = { "grant_type": "authorization_code", "code": code, "redirect_uri": settings.keycloak_redirect_uri, "client_id": settings.keycloak_client_id, "client_secret": settings.keycloak_client_secret, } headers = { "Content-Type": "application/x-www-form-urlencoded" } response = requests.post(TOKEN_URL, data=data, headers=headers) if response.status_code != 200: return JSONResponse( {"error": "Failed to get token"}, status_code=400 ) token_data = response.json() id_token = token_data["id_token"] decoded_token = jwt.decode(id_token, options={"verify_signature": False}) user_create = UserCreate( email=decoded_token.get("email"), name=decoded_token.get("preferred_username") ) print(user_create) user = service.get_or_create_user(session, user_create) return { "access_token": token_data["access_token"], "id_token": token_data["id_token"], "refresh_token": token_data["refresh_token"], } def verify_token(token: str): try: signing_key = jwk_client.get_signing_key_from_jwt(token) decoded = jwt.decode(token, options={"verify_signature": False}) payload = jwt.decode( token, signing_key.key, algorithms=["RS256"], audience=settings.keycloak_client_id, issuer=ISSUER, ) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Invalid token") def get_current_user( credentials: HTTPAuthorizationCredentials = Security(security) ): return verify_token(credentials.credentials)