ApplicantPortal/app/auth/model.py
2025-03-12 20:43:26 -06:00

181 lines
6.6 KiB
Python

import jwt
from werkzeug.security import generate_password_hash, check_password_hash
from time import time
from datetime import datetime
from uuid import uuid4
from mongoengine.fields import StringField
from app.database import db, DocumentBase, EmbeddedDocumentBase
from app.database.fields import QuarterDateField
from app.util.datatypes import enum
from app_common.const import USER_ID_REGEX
class UserIDField(StringField):
def validate(self, value):
value = value.lower()
if not USER_ID_REGEX.fullmatch(value):
self.error(f'The given user ID "{value}" is invalid, can only contain alphanumeric characters and "_, -".')
if value in Roles:
self.error(f'A user ID cannot be a user role (invalid values: {Roles})')
class Roles(enum.BaseEnum):
APPLICANT = 'applicant'
OCCUPANT = 'occupant'
INVESTOR = 'investor'
IMPACT_INVESTOR = 'impact_investor'
NETWORK_ADMIN = 'network_admin'
ADMIN = 'admin'
class BlackListedJWT(DocumentBase):
"""A mongo document to represent expired JWT access tokens."""
jti = db.StringField(required=True)
created_at = db.DateTimeField(required=True)
@staticmethod
def blacklist_jwt(jti):
to_blacklist = BlackListedJWT(jti=jti, created_at=datetime.utcnow())
to_blacklist.save()
@staticmethod
def is_blacklisted(jti):
return BlackListedJWT.objects(jti=jti).first() is not None
class UserPrefs(EmbeddedDocumentBase):
data = db.DictField(required=False)
class User(DocumentBase):
"""A mongo document to represent users."""
user_id = UserIDField(required=True, unique=True)
"""The user's ID (unique, required)"""
email = db.EmailField(required=True, unique=True)
"""The user's email (unique, required)."""
password_hash = db.StringField(required=False)
"""The hashed password for the user."""
first_name = db.StringField(required=True)
"""The user's first name (required)"""
middle_init_or_name = db.StringField(required=False, default='')
"""The middle initial or middle name for the user."""
last_name = db.StringField(required=True)
"""The user's last name (required)."""
last_login = db.DateTimeField(required=False)
"""The last time the user logged in."""
address1 = db.StringField(required=False)
address2 = db.StringField(required=False)
city = db.StringField(required=False)
state = db.StringField(required=False)
zip = db.StringField(required=False)
"""The user's address."""
join_date = QuarterDateField(required=True)
"""THe date the user joined (required)."""
date_of_birth = QuarterDateField(required=False)
"""The user's date of birth."""
phone_number = db.StringField(requied=False, sparse=True)
"""The user's phone number (unique but not required)."""
roles = db.ListField(db.EnumField(Roles), default=[Roles.APPLICANT])
"""The roles the user has."""
active = db.BooleanField(required=True, default=False)
"""Weather or not the user account is active"""
prefs = db.EmbeddedDocumentField('UserPrefs')
misc_data = db.DictField(required=False)
@classmethod
def ignore_to_json(cls):
result = super().ignore_to_json()
result.append('password_hash')
result.append('active')
return result
@classmethod
def from_request_args(cls, **kwargs):
password = kwargs.pop('password')
kwargs['join_date'] = datetime.utcnow().date()
if isinstance(kwargs.get('roles'), str):
kwargs['roles'] = kwargs['roles'].split(',')
user = User(**kwargs)
user.set_password(password)
user.save()
def set_password(self, password):
"""Hashes and stores the password for the user.
Arguments:
password (`str`): The password to store the hash of for the user.
"""
self.password_hash = generate_password_hash(password, salt_length=15)
def check_password(self, password):
"""Checks if the password is correct for the user.
Arguments:
password (`str`): The password to check against the user's stored password hash.
Returns:
`bool`: `True` if the password is correct, `False` otherwise.
"""
return check_password_hash(self.password_hash, password)
def get_password_reset_token(self, expire_secs=600, secret_key=None):
"""Generates a JWT token allowing the user to reset their password that will expire after the given time.
Arguments:
expire_secs (`int`): The number of seconds the password reset token should be valid for.
secret_key (`str`): The secret key to use when generating the JWT token.
Returns:
`str`: A JWT token which allows the user to reset their password.
"""
payload = {
'exp': time() + expire_secs,
'reset_password': self.user_id,
'jti': str(uuid4())
}
return jwt.encode(payload, secret_key, algorithm='HS256')
def get_activation_token(self, expire_secs=1200, secret_key=None, redirect_url=''):
payload = {
'exp': time() + expire_secs,
'activate': self.user_id,
'redirect_url': redirect_url
}
return jwt.encode(payload, secret_key, algorithm='HS256')
@staticmethod
def from_activation_token(token, secret_key=None):
try:
jwt_data = jwt.decode(token, secret_key, algorithms='HS256')
user_id = jwt_data['activate']
redirect_url = jwt_data['redirect_url']
except jwt.ExpiredSignatureError:
return None
return User.objects(user_id=user_id).first(), redirect_url
@staticmethod
def from_password_reset_token(token, secret_key=None):
"""Gets a `User` from the given JWT token that allows them to reset their password.
Arguments:
token (`str`): The encoded JWT token to load the user from.
secret_key (`str`): The secret key used to generate the JWT token.
Returns:
`User|None`: The user object encoded if the token is still valid, `None` if the token has expired.
"""
try:
jwt_data = jwt.decode(token, secret_key, algorithms=['HS256'])
# Has this token been blacklisted?
if BlackListedJWT.is_blacklisted(jwt_data['jti']):
return None
# get the user id
user_id = jwt_data['reset_password']
# black list this token after it's been used
BlackListedJWT.blacklist_jwt(jwt_data['jti'])
except jwt.ExpiredSignatureError:
return None
# get the user
return User.objects(user_id=user_id).first()