81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
from flask import Blueprint, jsonify, current_app, make_response
|
|
from flask_restx import abort
|
|
from flask_jwt_extended import JWTManager, verify_jwt_in_request, get_jwt
|
|
from functools import wraps
|
|
from app_common.const import USER_NOT_ACTIVE_MESSAGE, USER_NOT_ACTIVE_STATUS_CODE
|
|
from .model import Roles, BlackListedJWT
|
|
|
|
|
|
jwt = JWTManager()
|
|
blueprint = Blueprint('auth', __name__)
|
|
|
|
from app.auth import routes
|
|
from app.auth.model import User, BlackListedJWT
|
|
|
|
|
|
NO_USER_FOUND_MESSAGE = 'No user found'
|
|
|
|
|
|
@jwt.user_identity_loader
|
|
def _user_identity_lookup(user):
|
|
return user.user_id
|
|
|
|
|
|
@jwt.additional_claims_loader
|
|
def add_claims_to_token(identity):
|
|
return {'roles': ','.join([str(r) for r in identity.roles])}
|
|
|
|
|
|
@jwt.user_lookup_loader
|
|
def _user_lookup_callback(_jwt_header, jwt_data):
|
|
user_id = jwt_data['sub']
|
|
return User.objects(user_id=user_id, active=True).first()
|
|
|
|
|
|
@jwt.user_lookup_error_loader
|
|
def _user_lookup_error_callback(_jwt_header, jwt_data):
|
|
user_id = jwt_data['sub']
|
|
user = User.objects(user_id=user_id).first()
|
|
if not user:
|
|
return abort(404, NO_USER_FOUND_MESSAGE)
|
|
elif not user.active:
|
|
return abort(USER_NOT_ACTIVE_STATUS_CODE, USER_NOT_ACTIVE_MESSAGE)
|
|
else:
|
|
return make_response(jsonify(message='An error occurred trying to load the given user.', user_id=user_id), 400)
|
|
|
|
|
|
@jwt.token_in_blocklist_loader
|
|
def _is_token_revoked(_jwt_header, jwt_payload):
|
|
return BlackListedJWT.is_blacklisted(jwt_payload['jti'])
|
|
|
|
|
|
@jwt.expired_token_loader
|
|
def _expired_token_callback(jwt_header, jwt_payload):
|
|
return jsonify(message='The given JWT token has expired'), 405
|
|
|
|
|
|
def role_required(roles):
|
|
"""A wrapper for an endpoint that will require the calling user has any of the given roles (implies `jwt_required()`).
|
|
|
|
Arguments:
|
|
roles (`list(Roles)`): The allowed role for this endpoint.
|
|
"""
|
|
def wrapper(func):
|
|
@wraps(func)
|
|
def decorator(*args, **kwargs):
|
|
for role in roles:
|
|
if role not in Roles:
|
|
msg = f'The given role is not a valid (role="{role}", valid_roles="{Roles.values()}").'
|
|
current_app.logger.warn(msg)
|
|
return abort(401, msg)
|
|
verify_jwt_in_request()
|
|
claims = get_jwt()
|
|
user_has_role = any([role in roles for role in claims['roles'].split(',')])
|
|
if user_has_role:
|
|
return func(*args, **kwargs)
|
|
return abort(401, f'The endpoint requires a user with the following roles "{roles}".')
|
|
|
|
return decorator
|
|
|
|
return wrapper
|