from flask import Blueprint, jsonify, current_app, make_response from flask_restx import abort from flask_jwt_extended import JWTManager, verify_jwt_in_request, get_jwt from functools import wraps from app_common.const import USER_NOT_ACTIVE_MESSAGE, USER_NOT_ACTIVE_STATUS_CODE from .model import Roles, BlackListedJWT jwt = JWTManager() blueprint = Blueprint('auth', __name__) from app.auth import routes from app.auth.model import User, BlackListedJWT NO_USER_FOUND_MESSAGE = 'No user found' @jwt.user_identity_loader def _user_identity_lookup(user): return user.user_id @jwt.additional_claims_loader def add_claims_to_token(identity): return {'roles': ','.join([str(r) for r in identity.roles])} @jwt.user_lookup_loader def _user_lookup_callback(_jwt_header, jwt_data): user_id = jwt_data['sub'] return User.objects(user_id=user_id, active=True).first() @jwt.user_lookup_error_loader def _user_lookup_error_callback(_jwt_header, jwt_data): user_id = jwt_data['sub'] user = User.objects(user_id=user_id).first() if not user: return abort(404, NO_USER_FOUND_MESSAGE) elif not user.active: return abort(USER_NOT_ACTIVE_STATUS_CODE, USER_NOT_ACTIVE_MESSAGE) else: return make_response(jsonify(message='An error occurred trying to load the given user.', user_id=user_id), 400) @jwt.token_in_blocklist_loader def _is_token_revoked(_jwt_header, jwt_payload): return BlackListedJWT.is_blacklisted(jwt_payload['jti']) @jwt.expired_token_loader def _expired_token_callback(jwt_header, jwt_payload): return jsonify(message='The given JWT token has expired'), 405 def role_required(roles): """A wrapper for an endpoint that will require the calling user has any of the given roles (implies `jwt_required()`). Arguments: roles (`list(Roles)`): The allowed role for this endpoint. """ def wrapper(func): @wraps(func) def decorator(*args, **kwargs): for role in roles: if role not in Roles: msg = f'The given role is not a valid (role="{role}", valid_roles="{Roles.values()}").' current_app.logger.warn(msg) return abort(401, msg) verify_jwt_in_request() claims = get_jwt() user_has_role = any([role in roles for role in claims['roles'].split(',')]) if user_has_role: return func(*args, **kwargs) return abort(401, f'The endpoint requires a user with the following roles "{roles}".') return decorator return wrapper