Source code for yadawia.helpers

"""
Helpers
-------
Contains helper functions (decorators, others) used by other parts of the app.

"""
from yadawia import db, app
from yadawia.classes import User, LoginException, DBException, MessageThread, Product, Variety, ProductCategory, Upload
from flask import session, url_for, redirect, request, flash
from urllib.parse import urlparse, urljoin
from functools import wraps
import re
import string
import random
from sqlalchemy import exc
import uuid
import boto3
import os


[docs]def login_user(username, password): """Function to login user through their username. Sets: - session['logged_in'] to True. - session['username'] to the username. - session['userId'] to the user ID. Raises LoginException (represented as e here) if: - User with that username does not exist (e.args[0]['code'] = 'username') - Password is incorrect (e.args[0]['code'] = 'password') """ user = User.query.filter_by(username=username.lower()).first() if user is not None: if not user.isPassword(password): raise LoginException( {'message': 'Password is incorrect.', 'code': 'password'}) if user.suspended: raise LoginException( {'message': 'Your account has been suspended.', 'code': 'suspend'}) elif user.disabled: enable_user(username) flash('Welcome back!') session['logged_in'] = True session['username'] = username.lower() session['userId'] = user.id generate_csrf_token(force=True) else: raise LoginException( {'message': 'Username does not exist.', 'code': 'username'})
[docs]def suspend_user(username): """Suspend a user.""" # TODO cancel orders, refund people etc disable_user(username, suspended=True)
[docs]def unsuspend_user(username): """Unsuspend a user.""" enable_user(username, was_suspended=True)
[docs]def disable_user(username, suspended=False): """Disable a user.""" user = User.query.filter_by(username=username.lower()).first() if user is not None: user.suspended = suspended user.disabled = True prods = user.products.filter_by(available=True).all() for prod in prods: prod.available = False prod.force_unavailable = True db.session.commit()
[docs]def enable_user(username, was_suspended=False): """Enable a user.""" user = User.query.filter_by(username=username.lower()).first() if user is not None: if was_suspended: user.suspended = False user.disabled = False prods = user.products.filter_by(force_unavailable=True).all() for prod in prods: prod.available = True prod.force_unavailable = False db.session.commit()
[docs]def create_edit_product(create=True, productID=None): """Function to create or edit a product (used in views).""" error = None name = request.form['pname'] seller_id = session['userId'] description = request.form['description'] currency = request.form['currency'] price = float(request.form['price'] ) if request.form['price'] is not None else None categories = request.form.getlist('categories') variety_titles = request.form.getlist('variety_title') variety_prices = [float(x) if x != 'Default' and x != '' else None for x in request.form.getlist('variety_price')] pictures = request.form.getlist('photo_url') var_indexes = list(range(1, len(variety_titles))) try: if create: product = Product(name, seller_id, description, price, currency) db.session.add(product) db.session.flush() # to access product ID else: product = Product.query.filter_by(id=productID).first() product.description = description product.name = name product.price = price product.currency_id = currency ProductCategory.query.filter_by(product_id=productID).delete() prod_vars = product.varieties.all() for pv in prod_vars: if pv.name in variety_titles: var_index = variety_titles.index(pv.name) var_indexes.remove(var_index) var_price = variety_prices[var_index] if pv.price != var_price: pv.price = var_price continue pv.available = False for category_id in categories: prodCat = ProductCategory(product.id, category_id) db.session.add(prodCat) for i in var_indexes: vtitle = variety_titles[i] vprice = variety_prices[i] variety = Variety(vtitle, product.id, vprice) db.session.add(variety) for pic in pictures: upload = Upload(pic, product.id) db.session.add(upload) db.session.commit() db.session.flush() except DBException as dbe: error = dbe.args[0]['message'] except (exc.IntegrityError, exc.SQLAlchemyError) as e: error = e.message if error: flash(error) if create: return redirect(url_for('create_product')) return redirect(url_for('product', productID=product.id))
[docs]def valid_photo(photo_type, photo_size): """Given a photo_type and a photo_size, make sure it's an image under MAX_PHOTO_SIZE in app.config""" return photo_size <= app.config['MAX_PHOTO_SIZE'] and photo_type[:6] == 'image/'
[docs]def get_presigned_post(filename, filetype): """Use boto3 the AWS Python SDK to generate a presigned post for S3.""" S3_BUCKET = os.environ.get('S3_BUCKET') s3 = boto3.client('s3') return s3.generate_presigned_post( Bucket=S3_BUCKET, Key=filename, Fields={"acl": "public-read", "Content-Type": filetype}, Conditions=[ {"acl": "public-read"}, {"Content-Type": filetype} ], ExpiresIn=3600 )
[docs]def get_random_string(length=32): """Generate a random string of length 32, used in ``generate_csrf_token()``""" return ''.join(random.choice(string.ascii_letters + string.digits) for i in range(length))
[docs]def generate_csrf_token(force=False): """Create a CSRF-protection token if one doesn't already exist in the user's session (or force it, as done per login) and put it there.""" if force or '_csrf_token' not in session: session['_csrf_token'] = get_random_string() return session['_csrf_token']
[docs]def is_allowed_in_thread(threadID): """Given a threadID, is the signed in user allowed in the thread?""" thread = MessageThread.query.filter_by(id=threadID).first() return thread is not None and thread.isParticipant(session['userId'])
[docs]def is_safe(url): """Is the URL safe to redirect to?""" ref_url = urlparse(request.host_url) test_url = urlparse(urljoin(request.host_url, url)) return test_url.scheme in ('http', 'https') and \ ref_url.netloc == test_url.netloc
[docs]def redirect_back(endpoint, **values): """Helper function to redirect to 'next' URL if it exists. Otherwise, redirect to an endpoint.""" target = request.form['next'] if request.method == 'POST' else request.args.get( 'next', 0, type=str) if not target or not is_safe(target): target = url_for(endpoint, **values) return redirect(target)
[docs]def logout_user(): """Log user out.""" session.pop('logged_in', None) session.pop('username', None) session.pop('userId', None) generate_csrf_token(force=True)
[docs]def no_special_chars(string, allowNumbers=False, optional=True, allowComma=False): """Function to check if a string has no special characters.""" nums = '0-9' if not allowNumbers else '' postfix = '*' if optional else '+' comma = '\,' if not allowComma else '' pattern = re.compile( '^([^' + nums + '\_\+' + comma + '\@\!\#\$\%\^\&\*\(\)\;\\\/\|\<\>\"\'\:\?\=\+])' + postfix + '$') return pattern.match(string)
[docs]def validate_name_pattern(name_input, allowNumbers=False, optional=True): """Validate name pattern, given that generally names do not have special chars.""" if not no_special_chars( name_input, allowNumbers=allowNumbers, optional=optional): raise DBException({'message': 'Name cannot contain numbers or special characters.', 'code': 'name'})
[docs]def public(obj, keys): """Pass a db class object and a list of keys you don't want returned (e.g. password hash, etc) and get a filtered dict. """ d = dict((col, getattr(obj, col)) for col in obj.__table__.columns.keys()) return {x: d[x] for x in d if x not in keys}
[docs]def curr_user(username): """True if this username is that of the logged in user, false otherwise.""" return 'username' in session and session['username'] == username
[docs]def get_upload_url(filename): """Return url to uploaded file.""" return url_for('static', filename='uploads/' + filename) if filename else None
[docs]def is_logged_in(): """Is the user logged in?""" return 'logged_in' in session and session['logged_in'] == True and 'userId' in session
[docs]def user_exists(): """Given user is logged in, do they exist in db?""" return User.query.filter_by(id=session['userId']).first() is not None
def splitall(path): allparts = [] while True: parts = os.path.split(path) if parts[0] == path: # sentinel for absolute paths allparts.insert(0, parts[0]) break elif parts[1] == path: # sentinel for relative paths allparts.insert(0, parts[1]) break else: path = parts[0] allparts.insert(0, parts[1]) return allparts
[docs]def assetsList(app, folder='js', extension='js', exclusions=[]): """Get list of files of a specific extension in a folder inside the static directory.""" files_list = [] for root, dirs, files in os.walk(os.path.join(app.static_folder, folder)): for file in files: if file.endswith("." + extension) and file not in exclusions: path_parts = splitall(root) static_index = path_parts.index("static") path_parts = path_parts[static_index + 1:] path_parts.append(file) files_list.append('/'.join(path_parts)) return files_list
[docs]def authenticate(f): """Decorator function to ensure user is logged in before a page is visited.""" @wraps(f) def decorated_function(*args, **kwargs): if not is_logged_in() or not user_exists(): return redirect(url_for('login', next=request.url)) return f(*args, **kwargs) return decorated_function
[docs]def anonymous_only(f): """Decorator function to ensure user is NOT logged in before a page is visited.""" @wraps(f) def decorated_function(*args, **kwargs): if is_logged_in(): return redirect(url_for('home')) return f(*args, **kwargs) return decorated_function