Skip to content
Snippets Groups Projects
Verified Commit 34e97658 authored by nd's avatar nd
Browse files

fixed csrf protection

parent bd27a38d
No related branches found
No related tags found
No related merge requests found
...@@ -8,23 +8,31 @@ bp = Blueprint("csrf", __name__) ...@@ -8,23 +8,31 @@ bp = Blueprint("csrf", __name__)
csrfEndpoints = [] csrfEndpoints = []
# pylint: enable=invalid-name # pylint: enable=invalid-name
def csrf_protect(func): def csrf_protect(blueprint=None, endpoint=None):
csrfEndpoints.append(func.__name__) def wraper(func):
@wraps(func) if not endpoint:
def decorator(*args, **kwargs): if blueprint:
if '_csrf_token' in request.values: urlendpoint = "{}.{}".format(blueprint.name, func.__name__)
token = request.values['_csrf_token'] else:
elif request.get_json() and ('_csrf_token' in request.get_json()): urlendpoint = func.__name__
token = request.get_json()['_csrf_token'] csrfEndpoints.append(urlendpoint)
else: @wraps(func)
token = None def decorator(*args, **kwargs):
if ('_csrf_token' not in session) or (session['_csrf_token'] != token) or not token: if '_csrf_token' in request.values:
return 'csrf test failed', 403 token = request.values['_csrf_token']
return func(*args, **kwargs) elif request.get_json() and ('_csrf_token' in request.get_json()):
return decorator token = request.get_json()['_csrf_token']
else:
token = None
if ('_csrf_token' not in session) or (session['_csrf_token'] != token) or not token:
return 'csrf test failed', 403
return func(*args, **kwargs)
return decorator
return wraper
@bp.url_defaults @bp.app_url_defaults
def csrf_inject(endpoint, values): def csrf_inject(endpoint, values):
print(endpoint, csrfEndpoints, endpoint not in csrfEndpoints)
if endpoint not in csrfEndpoints or not session.get('_csrf_token'): if endpoint not in csrfEndpoints or not session.get('_csrf_token'):
return return
values['_csrf_token'] = session['_csrf_token'] values['_csrf_token'] = session['_csrf_token']
...@@ -26,7 +26,7 @@ def self_index(): ...@@ -26,7 +26,7 @@ def self_index():
return render_template('self.html', user=get_current_user()) return render_template('self.html', user=get_current_user())
@bp.route("/update", methods=(['POST'])) @bp.route("/update", methods=(['POST']))
@csrf_protect @csrf_protect(blueprint=bp)
def self_update(): def self_update():
pass pass
import datetime import datetime
import random
import string
import functools import functools
from flask import Blueprint, render_template, request, url_for, redirect, flash, current_app, session from flask import Blueprint, render_template, request, url_for, redirect, flash, current_app, session
...@@ -36,6 +38,7 @@ def login(): ...@@ -36,6 +38,7 @@ def login():
return redirect(url_for('.login')) return redirect(url_for('.login'))
session['user_uid'] = user.uid session['user_uid'] = user.uid
session['logintime'] = datetime.datetime.now().timestamp() session['logintime'] = datetime.datetime.now().timestamp()
session['_csrf_token'] = ''.join(random.SystemRandom().choice(string.ascii_letters + string.digits) for _ in range(64))
return redirect(request.values.get('ref', url_for('index'))) return redirect(request.values.get('ref', url_for('index')))
def get_current_user(): def get_current_user():
......
...@@ -44,6 +44,7 @@ def user_show(uid=None): ...@@ -44,6 +44,7 @@ def user_show(uid=None):
@bp_user.route("/<int:uid>/update", methods=['POST']) @bp_user.route("/<int:uid>/update", methods=['POST'])
@bp_user.route("/new", methods=['POST']) @bp_user.route("/new", methods=['POST'])
@csrf_protect(blueprint=bp_user)
def user_update(uid=False): def user_update(uid=False):
conn = get_conn() conn = get_conn()
if uid: if uid:
...@@ -70,7 +71,7 @@ def user_update(uid=False): ...@@ -70,7 +71,7 @@ def user_update(uid=False):
return redirect(url_for('.user_list')) return redirect(url_for('.user_list'))
@bp_user.route("/<int:uid>/del") @bp_user.route("/<int:uid>/del")
@csrf_protect @csrf_protect(blueprint=bp_user)
def user_delete(uid): def user_delete(uid):
conn = get_conn() conn = get_conn()
conn.search(current_app.config["LDAP_BASE_USER"], '(&(objectclass=person)(uidNumber={}))'.format((escape_filter_chars(uid)))) conn.search(current_app.config["LDAP_BASE_USER"], '(&(objectclass=person)(uidNumber={}))'.format((escape_filter_chars(uid))))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment