Skip to content
Snippets Groups Projects

starlette: Allow to provide access token in authorization header

2 files
+ 29
14
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -8,6 +8,7 @@ import hashlib
from typing import Any, Dict, Optional, Tuple
from aiocache.base import BaseCache
from jose.exceptions import JWTError
from starlette.authentication import (
AuthCredentials,
AuthenticationBackend,
@@ -85,24 +86,36 @@ class BearerTokenAuthBackend(AuthenticationBackend):
async def authenticate(
self, conn: HTTPConnection
) -> Optional[Tuple[AuthCredentials, SimpleUser]]:
auth_header = conn.headers.get("Authorization")
if auth_header is None:
# anonymous user
return None
refresh_token = self._get_token_from_header(auth_header)
# get the cache key
cache_key = self._get_token_cache_key(refresh_token)
# read access token from the cache
access_token = await self.cache.get(cache_key)
decoded_token = self._decode_token(access_token)
if not access_token or not decoded_token:
access_token = self._get_new_access_token(refresh_token)["access_token"]
decoded_token = self._decode_token(access_token)
token = self._get_token_from_header(auth_header)
try:
# check if access token was provided in authorization header
decoded_token = self._decode_token(token)
if not decoded_token:
raise AuthenticationError("Access token failed to be decoded")
exp = datetime.fromtimestamp(decoded_token["exp"])
ttl = int(exp.timestamp() - datetime.now().timestamp())
await self.cache.set(cache_key, access_token, ttl=ttl)
except JWTError:
# token is a refresh one so backend handles access token renewal
# get the cache key
cache_key = self._get_token_cache_key(token)
# read access token from the cache
access_token = await self.cache.get(cache_key)
decoded_token = self._decode_token(access_token)
if not access_token or not decoded_token:
access_token = self._get_new_access_token(token)["access_token"]
decoded_token = self._decode_token(access_token)
if not decoded_token:
raise AuthenticationError("Access token failed to be decoded")
exp = datetime.fromtimestamp(decoded_token["exp"])
ttl = int(exp.timestamp() - datetime.now().timestamp())
await self.cache.set(cache_key, access_token, ttl=ttl)
# set user scopes
realm_access = decoded_token.get("realm_access", {})
user_scopes = realm_access.get("roles", [])
Loading