"""
saraki.auth
~~~~~~~~~~~
"""
import jwt
from functools import wraps
from datetime import datetime
from cerberus import Validator
from flask import request, current_app, jsonify, abort, _request_ctx_stack
from werkzeug.routing import BaseConverter
from werkzeug.local import LocalProxy
from .utility import generate_schema, get_key_path
from .model import (
User,
Org,
Membership,
_persist_actions,
_persist_resources,
_persist_abilities,
get_member_privileges,
)
from .exc import (
NotFoundCredentialError,
InvalidUserError,
InvalidOrgError,
InvalidMemberError,
InvalidPasswordError,
JWTError,
TokenNotFoundError,
AuthorizationError,
ProgrammingError,
)
AUTH_SCHEMA = generate_schema(
User, include=["username", "password"], exclude_rules=["unique"]
)
HTTP_VERBS_CRUD = {
"get": "read",
"post": "write",
"patch": "write",
"put": "write",
"delete": "delete",
}
#: A local proxy object that points to the user accessing an endpoint in the
#: current request. The value of this object is an instance of the model class
#: :class:`~saraki.model.User` or None if there is not a user.
current_user = LocalProxy(lambda: getattr(_request_ctx_stack.top, "current_user", None))
#: A local proxy object that points to the organization being accessed in the
#: current request. The value of this object is an instance of the model class
#: :class:`~saraki.model.Org` or None if the endpoint is not a tenant endpoint.
current_org = LocalProxy(lambda: getattr(_request_ctx_stack.top, "current_org", None))
class Claim(str):
def __new__(cls, value, type):
return str.__new__(cls, value)
def __init__(self, value, type):
self._claim_type = type
@property
def type(self):
return self._claim_type
def __repr__(self):
return f'Claim(value="{self[:]}", type="{self.type}")'
class SubClaimConverter(BaseConverter):
def to_python(self, value):
return Claim(value=value, type="sub")
def to_url(self, value):
return value
class AudClaimConverter(BaseConverter):
def to_python(self, value):
return Claim(value=value, type="aud")
def to_url(self, value):
return value
def _verify_username(username):
filters = {"canonical_username": username.lower()}
identity = User.query.filter_by(**filters).one_or_none()
if identity is None:
raise InvalidUserError(f'Username "{username}" is not registered')
return identity
def _verify_orgname(orgname):
org = Org.query.filter_by(orgname=orgname).one_or_none()
if org is None:
raise InvalidOrgError(f'Orgname "{orgname}" is not registered')
return org
def _verify_member(user, org):
member = Membership.query.filter_by(user_id=user.id, org_id=org.id).one_or_none()
if member is None:
raise InvalidMemberError(f"{user.username} is not a member of {org.orgname}")
return member
def _get_request_jwt():
"""Return ``Authorization`` header token if present, otherwise None.
If the Authorization header is present, raises a JWTError if the token
is malformed.
"""
token_string = request.headers.get("Authorization", None)
token_prefix = current_app.config["JWT_AUTH_HEADER_PREFIX"]
if token_string is None:
return None
parts = token_string.split()
error = (
"Missing or malformed token"
if len(parts) < 2
else "The token contains spaces"
if len(parts) > 2
else "Unsupported authorization type"
if parts[0] != token_prefix
else None
)
if error:
raise JWTError(error)
return parts[1]
def _generate_jwt_payload(user, org=None):
required_claim_list = current_app.config["JWT_REQUIRED_CLAIMS"]
iat = datetime.utcnow()
exp = iat + current_app.config["JWT_EXPIRATION_DELTA"]
iss = current_app.config["JWT_ISSUER"] or current_app.config["SERVER_NAME"]
payload = {}
if "iss" in required_claim_list:
if not iss:
raise RuntimeError(
"The token payload could not be generated. The claim iss is "
"required, but neither JWT_ISSUER nor SERVER_NAME are provided"
)
payload["iss"] = iss
if org:
payload["aud"] = org.orgname
payload.update({"scp": get_member_privileges(org, user)})
payload.update({"iat": iat, "exp": exp, "sub": user.username})
return payload
def _encode_jwt(payload):
secret = current_app.config["SECRET_KEY"]
if secret is None:
raise RuntimeError("SECRET_KEY is not set. Can not generate Token")
algorithm = current_app.config["JWT_ALGORITHM"]
required_claim_list = current_app.config["JWT_REQUIRED_CLAIMS"]
missing_claims = list(set(required_claim_list) - set(payload.keys()))
if missing_claims:
raise ValueError(
f'Payload is missing required claims: {", ".join(missing_claims)}'
)
return jwt.encode(payload, secret, algorithm=algorithm)
def _decode_jwt(token):
if not isinstance(token, (str, bytes)):
raise ValueError(f"{type(token)} is not a valid JWT string")
required_claim_list = current_app.config["JWT_REQUIRED_CLAIMS"]
options = {"require_" + claim: True for claim in required_claim_list}
options.update({"verify_" + claim: True for claim in required_claim_list})
options["verify_aud"] = False
parameters = {
"jwt": token,
"key": current_app.config["SECRET_KEY"],
"leeway": current_app.config["JWT_LEEWAY"],
"options": options,
"algorithms": [current_app.config["JWT_ALGORITHM"]],
}
if "iss" in required_claim_list:
parameters["issuer"] = current_app.config["JWT_ISSUER"]
try:
payload = jwt.decode(**parameters)
except jwt.exceptions.MissingRequiredClaimError as e:
raise JWTError(str(e))
except jwt.exceptions.ExpiredSignatureError as e:
raise JWTError("Token has expired")
except jwt.exceptions.InvalidIssuerError as e:
raise JWTError(str(e))
except jwt.exceptions.DecodeError as e:
raise JWTError("Invalid or malformed token")
return payload
def _get_parent_resource(required_resource, scopes):
resource_map = current_app.auth.resources
path = get_key_path(required_resource, resource_map) or []
for resource in scopes:
if resource in path:
return resource
return None
def _is_authorized(payload, resource=None, action=None):
criteria = []
for c in request.view_args.values():
if isinstance(c, Claim):
criteria.append(c.type in payload and c == payload[c.type])
if not all(criteria):
return False
if not resource:
return True
scopes = payload.get("scp")
if not scopes:
return False
if action is None:
method = request.method.lower()
if method not in HTTP_VERBS_CRUD:
return False
action = HTTP_VERBS_CRUD[method]
scope_resource = scopes.get(resource)
if not scope_resource:
parent_resource = _get_parent_resource(resource, scopes)
if not parent_resource:
return False
scope_resource = scopes[parent_resource]
return action in scope_resource or "manage" in scope_resource
def _validate_request(resource=None, action=None):
token = _get_request_jwt()
if token is None:
raise TokenNotFoundError
payload = _decode_jwt(token)
if _is_authorized(payload, resource, action) is False:
raise AuthorizationError
org = None
try:
user = _verify_username(payload["sub"])
org = _verify_orgname(payload["aud"]) if "aud" in payload else None
except (InvalidUserError, InvalidOrgError) as e:
raise AuthorizationError
_request_ctx_stack.top.current_user = user
_request_ctx_stack.top.current_org = org
# ~~~~~~~~~~~~~~~~~~~~~
#
# AUTHENTICATION
#
# ~~~~~~~~~~~~~~~~~~~~~
def _authenticate_with_token(token):
"""Given a valid access token, authenticate a user and return a new access
token.
"""
payload = _decode_jwt(token)
username = payload["sub"]
return _verify_username(username)
def _authenticate_with_password(username, password):
user = _verify_username(username)
if user.verify_password(password) is False:
raise InvalidPasswordError("Invalid password")
return user
def _authenticate(orgname=None):
"""Handles an authentication request and returns an access token."""
org = None
if orgname:
try:
org = _verify_orgname(orgname)
except InvalidOrgError:
raise abort(404)
token = _get_request_jwt()
if token:
user = _authenticate_with_token(token)
else:
username_password = request.get_json()
if username_password is None:
raise NotFoundCredentialError("Missing token and username/password")
v = Validator(AUTH_SCHEMA)
if v.validate(username_password) is False:
abort(400, v.errors)
user = _authenticate_with_password(**username_password)
if org:
_verify_member(user, org)
payload = _generate_jwt_payload(user, org)
access_token = _encode_jwt(payload)
return jsonify({"access_token": access_token.decode("utf-8")})
[docs]def require_auth(resource=None, action=None, parent_resource=None):
""" Decorator to restrict view function access only to requests with enough
authorization.
A valid request must meet the following conditions:
1. The request header must have the ``Authorization`` header with a valid JSON
Web Token.
2. The token ``sub`` claim must contain a username registered in the
application. If ``aud`` claim is present the value must be an
orgname also registered in the application.
3. The token scope must have enough privileges to access the view function
being accessed.
If the parameter **resource** is not provided, the token scope won't be
verified.
The **resource** parameter locks an endpoint to access tokens that contain that
resource or any other parent resource in their ``scp`` claim. Let's look to
at an example to illustrate how this work:
.. code-block:: python
@require_auth("cartoon")
def view_cartoons():
pass
@require_auth("movie", parent_resource="catalog")
def view_movies():
pass
@require_auth("comic")
def view_comics():
pass
And a hyipothetical access token ``scp`` claim:
.. code-block:: json
{
"catalog": ["read"],
"cartoon": ["read"]
}
The above access token would be authorized to access to ``view_cartoons``
and ``view_movies`` but not to ``view_comics``. In the case of
``view_cartoons``, the resource ``cartoon`` is present in the token scope.
The resource ``movie`` is not present but ``catalog`` which is a parent of
it is present, so that’s why ``view_movies`` can be accessed.
``view_comics`` is not accessible because neither ``comic`` nor a parent of
it is present.
The **action** parameter locks the endpoint to a specific action, for
instance, read, create, update, delete, etc. If this parameter is omitted,
the HTTP method of the route endpoint definition will be used:
.. code-block:: python
@app.route('/friends')
@require_auth('private', 'follow')
def endpoint_handler():
pass
@app.route('/friends', methods=['DELETE'])
@require_auth('private')
def endpoint_handler():
pass
The first example above, requires the resource `private` with `follow`
action like the example below:
.. code-block:: json
{"private": ["follow"]}
The second example:
.. code-block:: json
{"resource": ["delete"]}
The last argument ``parent_resource`` is optional. It defines the parent
resource of the endpoint. That means that if an access token has a resource
matching the parent resource, but not the required resource, it still pass the
validation. For instance, ``@require_auth('resource', 'action',
parent='parent')`` will pass with the next access token:
.. code-block:: json
{"parent": ["action"]}
Whenever a request with an unauthorized access token reaches a locked view
function an :class:`~saraki.exc.AuthorizationError` exception is raised.
:param resource: The name of the resource
:param action: The action that can be performed on the resource.
:param parent_resource: The parent resource.
"""
if action and not resource:
raise ProgrammingError(f"You passed an action '{action}' without a resource")
def decorator(func):
func._auth_metadata = dict(
resource=resource, action=action, parent_resource=parent_resource
)
@wraps(func)
def wrapper(*arg, **karg):
_validate_request(resource, action)
return func(*arg, **karg)
return wrapper
return decorator
class Auth:
def __init__(self, app=None):
self._resources = {}
self._actions = ["manage"]
self._persist_actions_func = _persist_actions
self._persist_resources_func = _persist_resources
self._persist_abilities_func = _persist_abilities
if app:
self.init_app(app)
def init_app(self, app):
self.app = app
app.url_map.converters["sub"] = SubClaimConverter
app.url_map.converters["aud"] = AudClaimConverter
methods = ["POST"]
app.add_url_rule(
rule="/auth",
view_func=_authenticate,
methods=methods,
defaults={"orgname": None},
)
app.add_url_rule(
rule="/auth/<orgname>", view_func=_authenticate, methods=methods
)
@property
def resources(self):
return self._resources
@property
def actions(self):
return self._actions
def _add_resource(self, resource, parent=None):
if type(resource) != str:
raise TypeError(
f"resource argument must be an string, got {type(resource)}"
)
resources = self.resources
path = get_key_path(resource, resources)
if not path:
parent_path = get_key_path(parent, resources) if parent else []
if parent_path is None:
self._resources[parent] = None
parent_path = [parent]
parent_resource = self._resources
for branch in parent_path:
if parent_resource[branch] is None:
parent_resource[branch] = {}
parent_resource = parent_resource[branch]
parent_resource[resource] = None
def _add_action(self, action):
if type(action) != str:
raise TypeError(f"action argument must be an string, got {type(action)}")
if action is not None and action not in self.actions:
self._actions.append(action)
def _collect_metadata(self):
for view_func in self.app.view_functions.values():
if hasattr(view_func, "_auth_metadata"):
_auth = view_func._auth_metadata
resource = _auth["resource"]
parent_resource = _auth["parent_resource"]
action = _auth["action"]
if resource:
self._add_resource(resource, parent_resource)
if action:
self._add_action(action)
def persist_actions(self, f):
"""Registers a function called to persist actions in a database."""
self._persist_actions_func = f
return f
def persist_resources(self, f):
"""Registers a function called to persist resources in a database."""
self._persist_resources_func = f
return f
def persist_abilities(self, f):
"""Registers a function called to persist abilities in a database."""
self._persist_resources_func = f
return f
def persist_data(self):
self._collect_metadata()
self._persist_actions_func(self.actions)
self._persist_resources_func(self.resources)
self._persist_abilities_func()