ApplicantPortal/app/auth/__init__.py
2025-03-12 20:43:26 -06:00

81 lines
2.5 KiB
Python

from flask import Blueprint, jsonify, current_app, make_response
from flask_restx import abort
from flask_jwt_extended import JWTManager, verify_jwt_in_request, get_jwt
from functools import wraps
from app_common.const import USER_NOT_ACTIVE_MESSAGE, USER_NOT_ACTIVE_STATUS_CODE
from .model import Roles, BlackListedJWT
jwt = JWTManager()
blueprint = Blueprint('auth', __name__)
from app.auth import routes
from app.auth.model import User, BlackListedJWT
NO_USER_FOUND_MESSAGE = 'No user found'
@jwt.user_identity_loader
def _user_identity_lookup(user):
return user.user_id
@jwt.additional_claims_loader
def add_claims_to_token(identity):
return {'roles': ','.join([str(r) for r in identity.roles])}
@jwt.user_lookup_loader
def _user_lookup_callback(_jwt_header, jwt_data):
user_id = jwt_data['sub']
return User.objects(user_id=user_id, active=True).first()
@jwt.user_lookup_error_loader
def _user_lookup_error_callback(_jwt_header, jwt_data):
user_id = jwt_data['sub']
user = User.objects(user_id=user_id).first()
if not user:
return abort(404, NO_USER_FOUND_MESSAGE)
elif not user.active:
return abort(USER_NOT_ACTIVE_STATUS_CODE, USER_NOT_ACTIVE_MESSAGE)
else:
return make_response(jsonify(message='An error occurred trying to load the given user.', user_id=user_id), 400)
@jwt.token_in_blocklist_loader
def _is_token_revoked(_jwt_header, jwt_payload):
return BlackListedJWT.is_blacklisted(jwt_payload['jti'])
@jwt.expired_token_loader
def _expired_token_callback(jwt_header, jwt_payload):
return jsonify(message='The given JWT token has expired'), 405
def role_required(roles):
"""A wrapper for an endpoint that will require the calling user has any of the given roles (implies `jwt_required()`).
Arguments:
roles (`list(Roles)`): The allowed role for this endpoint.
"""
def wrapper(func):
@wraps(func)
def decorator(*args, **kwargs):
for role in roles:
if role not in Roles:
msg = f'The given role is not a valid (role="{role}", valid_roles="{Roles.values()}").'
current_app.logger.warn(msg)
return abort(401, msg)
verify_jwt_in_request()
claims = get_jwt()
user_has_role = any([role in roles for role in claims['roles'].split(',')])
if user_has_role:
return func(*args, **kwargs)
return abort(401, f'The endpoint requires a user with the following roles "{roles}".')
return decorator
return wrapper