import jwt from werkzeug.security import generate_password_hash, check_password_hash from time import time from datetime import datetime from uuid import uuid4 from mongoengine.fields import StringField from app.database import db, DocumentBase, EmbeddedDocumentBase from app.database.fields import QuarterDateField from app.util.datatypes import enum from app_common.const import USER_ID_REGEX class UserIDField(StringField): def validate(self, value): value = value.lower() if not USER_ID_REGEX.fullmatch(value): self.error(f'The given user ID "{value}" is invalid, can only contain alphanumeric characters and "_, -".') if value in Roles: self.error(f'A user ID cannot be a user role (invalid values: {Roles})') class Roles(enum.BaseEnum): APPLICANT = 'applicant' OCCUPANT = 'occupant' INVESTOR = 'investor' IMPACT_INVESTOR = 'impact_investor' NETWORK_ADMIN = 'network_admin' ADMIN = 'admin' class BlackListedJWT(DocumentBase): """A mongo document to represent expired JWT access tokens.""" jti = db.StringField(required=True) created_at = db.DateTimeField(required=True) @staticmethod def blacklist_jwt(jti): to_blacklist = BlackListedJWT(jti=jti, created_at=datetime.utcnow()) to_blacklist.save() @staticmethod def is_blacklisted(jti): return BlackListedJWT.objects(jti=jti).first() is not None class UserPrefs(EmbeddedDocumentBase): data = db.DictField(required=False) class User(DocumentBase): """A mongo document to represent users.""" user_id = UserIDField(required=True, unique=True) """The user's ID (unique, required)""" email = db.EmailField(required=True, unique=True) """The user's email (unique, required).""" password_hash = db.StringField(required=False) """The hashed password for the user.""" first_name = db.StringField(required=True) """The user's first name (required)""" middle_init_or_name = db.StringField(required=False, default='') """The middle initial or middle name for the user.""" last_name = db.StringField(required=True) """The user's last name (required).""" last_login = db.DateTimeField(required=False) """The last time the user logged in.""" address1 = db.StringField(required=False) address2 = db.StringField(required=False) city = db.StringField(required=False) state = db.StringField(required=False) zip = db.StringField(required=False) """The user's address.""" join_date = QuarterDateField(required=True) """THe date the user joined (required).""" date_of_birth = QuarterDateField(required=False) """The user's date of birth.""" phone_number = db.StringField(requied=False, sparse=True) """The user's phone number (unique but not required).""" roles = db.ListField(db.EnumField(Roles), default=[Roles.APPLICANT]) """The roles the user has.""" active = db.BooleanField(required=True, default=False) """Weather or not the user account is active""" prefs = db.EmbeddedDocumentField('UserPrefs') misc_data = db.DictField(required=False) @classmethod def ignore_to_json(cls): result = super().ignore_to_json() result.append('password_hash') result.append('active') return result @classmethod def from_request_args(cls, **kwargs): password = kwargs.pop('password') kwargs['join_date'] = datetime.utcnow().date() if isinstance(kwargs.get('roles'), str): kwargs['roles'] = kwargs['roles'].split(',') user = User(**kwargs) user.set_password(password) user.save() def set_password(self, password): """Hashes and stores the password for the user. Arguments: password (`str`): The password to store the hash of for the user. """ self.password_hash = generate_password_hash(password, salt_length=15) def check_password(self, password): """Checks if the password is correct for the user. Arguments: password (`str`): The password to check against the user's stored password hash. Returns: `bool`: `True` if the password is correct, `False` otherwise. """ return check_password_hash(self.password_hash, password) def get_password_reset_token(self, expire_secs=600, secret_key=None): """Generates a JWT token allowing the user to reset their password that will expire after the given time. Arguments: expire_secs (`int`): The number of seconds the password reset token should be valid for. secret_key (`str`): The secret key to use when generating the JWT token. Returns: `str`: A JWT token which allows the user to reset their password. """ payload = { 'exp': time() + expire_secs, 'reset_password': self.user_id, 'jti': str(uuid4()) } return jwt.encode(payload, secret_key, algorithm='HS256') def get_activation_token(self, expire_secs=1200, secret_key=None, redirect_url=''): payload = { 'exp': time() + expire_secs, 'activate': self.user_id, 'redirect_url': redirect_url } return jwt.encode(payload, secret_key, algorithm='HS256') @staticmethod def from_activation_token(token, secret_key=None): try: jwt_data = jwt.decode(token, secret_key, algorithms='HS256') user_id = jwt_data['activate'] redirect_url = jwt_data['redirect_url'] except jwt.ExpiredSignatureError: return None return User.objects(user_id=user_id).first(), redirect_url @staticmethod def from_password_reset_token(token, secret_key=None): """Gets a `User` from the given JWT token that allows them to reset their password. Arguments: token (`str`): The encoded JWT token to load the user from. secret_key (`str`): The secret key used to generate the JWT token. Returns: `User|None`: The user object encoded if the token is still valid, `None` if the token has expired. """ try: jwt_data = jwt.decode(token, secret_key, algorithms=['HS256']) # Has this token been blacklisted? if BlackListedJWT.is_blacklisted(jwt_data['jti']): return None # get the user id user_id = jwt_data['reset_password'] # black list this token after it's been used BlackListedJWT.blacklist_jwt(jwt_data['jti']) except jwt.ExpiredSignatureError: return None # get the user return User.objects(user_id=user_id).first()