181 lines
6.6 KiB
Python
181 lines
6.6 KiB
Python
import jwt
|
|
from werkzeug.security import generate_password_hash, check_password_hash
|
|
from time import time
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
from mongoengine.fields import StringField
|
|
from app.database import db, DocumentBase, EmbeddedDocumentBase
|
|
from app.database.fields import QuarterDateField
|
|
from app.util.datatypes import enum
|
|
from app_common.const import USER_ID_REGEX
|
|
|
|
|
|
class UserIDField(StringField):
|
|
def validate(self, value):
|
|
value = value.lower()
|
|
if not USER_ID_REGEX.fullmatch(value):
|
|
self.error(f'The given user ID "{value}" is invalid, can only contain alphanumeric characters and "_, -".')
|
|
if value in Roles:
|
|
self.error(f'A user ID cannot be a user role (invalid values: {Roles})')
|
|
|
|
|
|
class Roles(enum.BaseEnum):
|
|
APPLICANT = 'applicant'
|
|
OCCUPANT = 'occupant'
|
|
INVESTOR = 'investor'
|
|
IMPACT_INVESTOR = 'impact_investor'
|
|
NETWORK_ADMIN = 'network_admin'
|
|
ADMIN = 'admin'
|
|
|
|
|
|
class BlackListedJWT(DocumentBase):
|
|
"""A mongo document to represent expired JWT access tokens."""
|
|
jti = db.StringField(required=True)
|
|
created_at = db.DateTimeField(required=True)
|
|
|
|
@staticmethod
|
|
def blacklist_jwt(jti):
|
|
to_blacklist = BlackListedJWT(jti=jti, created_at=datetime.utcnow())
|
|
to_blacklist.save()
|
|
|
|
@staticmethod
|
|
def is_blacklisted(jti):
|
|
return BlackListedJWT.objects(jti=jti).first() is not None
|
|
|
|
|
|
class UserPrefs(EmbeddedDocumentBase):
|
|
data = db.DictField(required=False)
|
|
|
|
|
|
class User(DocumentBase):
|
|
"""A mongo document to represent users."""
|
|
user_id = UserIDField(required=True, unique=True)
|
|
"""The user's ID (unique, required)"""
|
|
email = db.EmailField(required=True, unique=True)
|
|
"""The user's email (unique, required)."""
|
|
password_hash = db.StringField(required=False)
|
|
"""The hashed password for the user."""
|
|
first_name = db.StringField(required=True)
|
|
"""The user's first name (required)"""
|
|
middle_init_or_name = db.StringField(required=False, default='')
|
|
"""The middle initial or middle name for the user."""
|
|
last_name = db.StringField(required=True)
|
|
"""The user's last name (required)."""
|
|
last_login = db.DateTimeField(required=False)
|
|
"""The last time the user logged in."""
|
|
address1 = db.StringField(required=False)
|
|
address2 = db.StringField(required=False)
|
|
city = db.StringField(required=False)
|
|
state = db.StringField(required=False)
|
|
zip = db.StringField(required=False)
|
|
"""The user's address."""
|
|
join_date = QuarterDateField(required=True)
|
|
"""THe date the user joined (required)."""
|
|
date_of_birth = QuarterDateField(required=False)
|
|
"""The user's date of birth."""
|
|
phone_number = db.StringField(requied=False, sparse=True)
|
|
"""The user's phone number (unique but not required)."""
|
|
roles = db.ListField(db.EnumField(Roles), default=[Roles.APPLICANT])
|
|
"""The roles the user has."""
|
|
active = db.BooleanField(required=True, default=False)
|
|
"""Weather or not the user account is active"""
|
|
prefs = db.EmbeddedDocumentField('UserPrefs')
|
|
misc_data = db.DictField(required=False)
|
|
|
|
@classmethod
|
|
def ignore_to_json(cls):
|
|
result = super().ignore_to_json()
|
|
result.append('password_hash')
|
|
result.append('active')
|
|
return result
|
|
|
|
@classmethod
|
|
def from_request_args(cls, **kwargs):
|
|
password = kwargs.pop('password')
|
|
kwargs['join_date'] = datetime.utcnow().date()
|
|
if isinstance(kwargs.get('roles'), str):
|
|
kwargs['roles'] = kwargs['roles'].split(',')
|
|
user = User(**kwargs)
|
|
user.set_password(password)
|
|
user.save()
|
|
|
|
def set_password(self, password):
|
|
"""Hashes and stores the password for the user.
|
|
|
|
Arguments:
|
|
password (`str`): The password to store the hash of for the user.
|
|
"""
|
|
self.password_hash = generate_password_hash(password, salt_length=15)
|
|
|
|
def check_password(self, password):
|
|
"""Checks if the password is correct for the user.
|
|
|
|
Arguments:
|
|
password (`str`): The password to check against the user's stored password hash.
|
|
|
|
Returns:
|
|
`bool`: `True` if the password is correct, `False` otherwise.
|
|
"""
|
|
return check_password_hash(self.password_hash, password)
|
|
|
|
def get_password_reset_token(self, expire_secs=600, secret_key=None):
|
|
"""Generates a JWT token allowing the user to reset their password that will expire after the given time.
|
|
|
|
Arguments:
|
|
expire_secs (`int`): The number of seconds the password reset token should be valid for.
|
|
secret_key (`str`): The secret key to use when generating the JWT token.
|
|
|
|
Returns:
|
|
`str`: A JWT token which allows the user to reset their password.
|
|
"""
|
|
payload = {
|
|
'exp': time() + expire_secs,
|
|
'reset_password': self.user_id,
|
|
'jti': str(uuid4())
|
|
}
|
|
return jwt.encode(payload, secret_key, algorithm='HS256')
|
|
|
|
def get_activation_token(self, expire_secs=1200, secret_key=None, redirect_url=''):
|
|
payload = {
|
|
'exp': time() + expire_secs,
|
|
'activate': self.user_id,
|
|
'redirect_url': redirect_url
|
|
}
|
|
return jwt.encode(payload, secret_key, algorithm='HS256')
|
|
|
|
@staticmethod
|
|
def from_activation_token(token, secret_key=None):
|
|
try:
|
|
jwt_data = jwt.decode(token, secret_key, algorithms='HS256')
|
|
user_id = jwt_data['activate']
|
|
redirect_url = jwt_data['redirect_url']
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
return User.objects(user_id=user_id).first(), redirect_url
|
|
|
|
@staticmethod
|
|
def from_password_reset_token(token, secret_key=None):
|
|
"""Gets a `User` from the given JWT token that allows them to reset their password.
|
|
|
|
Arguments:
|
|
token (`str`): The encoded JWT token to load the user from.
|
|
secret_key (`str`): The secret key used to generate the JWT token.
|
|
|
|
Returns:
|
|
`User|None`: The user object encoded if the token is still valid, `None` if the token has expired.
|
|
"""
|
|
try:
|
|
jwt_data = jwt.decode(token, secret_key, algorithms=['HS256'])
|
|
# Has this token been blacklisted?
|
|
if BlackListedJWT.is_blacklisted(jwt_data['jti']):
|
|
return None
|
|
# get the user id
|
|
user_id = jwt_data['reset_password']
|
|
# black list this token after it's been used
|
|
BlackListedJWT.blacklist_jwt(jwt_data['jti'])
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
# get the user
|
|
return User.objects(user_id=user_id).first()
|
|
|