Implement JWT authentication

This commit is contained in:
François Voron
2019-10-08 17:18:38 +02:00
parent 20aa806375
commit 06dd8ad22e
10 changed files with 233 additions and 37 deletions

View File

@ -18,6 +18,7 @@ email-validator = "*"
sqlalchemy = "*" sqlalchemy = "*"
databases = "*" databases = "*"
python-multipart = "*" python-multipart = "*"
pyjwt = "*"
[requires] [requires]
python_version = "3.7" python_version = "3.7"

26
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "f950ae2475c73ff3553c19edb24c30a2ea7e0ffa630beb02b13e8588e31b9cb2" "sha256": "3f2120d9354ee0e8e94a16b4dd1c948e2eb13d91fd71798c4005effc786b93f0"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -93,11 +93,11 @@
}, },
"fastapi": { "fastapi": {
"hashes": [ "hashes": [
"sha256:345a8b0a8761bf830c99af2c38881f2215703844e4d5e52fff1ff711824bae0b", "sha256:96d2ba720f154a7f876fc419a464faac5c4bf3e660dd72f7134f53b90a8488ff",
"sha256:7a81762807d8ed43b2dddac8bb189d174e270f9c64493542d279673a6d768b26" "sha256:c4a6cb1b0b068efdb2b05ed27b8a88c29d7df5870c7267b02352a0a4eed5f4e9"
], ],
"index": "pypi", "index": "pypi",
"version": "==0.40.0" "version": "==0.41.0"
}, },
"idna": { "idna": {
"hashes": [ "hashes": [
@ -134,6 +134,14 @@
], ],
"version": "==0.32.2" "version": "==0.32.2"
}, },
"pyjwt": {
"hashes": [
"sha256:5c6eca3c2940464d106b99ba83b00c6add741c9becaec087fb7ccdefea71350e",
"sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96"
],
"index": "pypi",
"version": "==1.7.1"
},
"python-multipart": { "python-multipart": {
"hashes": [ "hashes": [
"sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43" "sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43"
@ -157,9 +165,9 @@
}, },
"starlette": { "starlette": {
"hashes": [ "hashes": [
"sha256:f600bf9d0beeeeebcb143e6d0c4f8858c2b05067d5a4feb446ba7400ba5e5dc5" "sha256:c2ac9a42e0e0328ad20fe444115ac5e3760c1ee2ac1ff8cdb5ec915c4a453411"
], ],
"version": "==0.12.8" "version": "==0.12.9"
} }
}, },
"develop": { "develop": {
@ -300,11 +308,11 @@
}, },
"pytest": { "pytest": {
"hashes": [ "hashes": [
"sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5", "sha256:7e4800063ccfc306a53c461442526c5571e1462f61583506ce97e4da6a1d88c8",
"sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5" "sha256:ca563435f4941d0cb34767301c27bc65c510cb82e90b9ecf9cb52dc2c63caaa0"
], ],
"index": "pypi", "index": "pypi",
"version": "==5.2.0" "version": "==5.2.1"
}, },
"pytest-asyncio": { "pytest-asyncio": {
"hashes": [ "hashes": [

View File

@ -0,0 +1,20 @@
from typing import Callable
from starlette.responses import Response
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
class BaseAuthentication:
userDB: BaseUserDatabase
def __init__(self, userDB: BaseUserDatabase):
self.userDB = userDB
async def get_login_response(self, user: UserDB, response: Response):
raise NotImplementedError()
def get_authentication_method(self) -> Callable[..., UserDB]:
raise NotImplementedError()

View File

@ -0,0 +1,58 @@
from datetime import datetime, timedelta
import jwt
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.models import UserDB
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
def generate_jwt(data: dict, lifetime_seconds: int, secret: str, algorithm: str) -> str:
payload = data.copy()
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
payload['exp'] = expire
return jwt.encode(payload, secret, algorithm=algorithm).decode('utf-8')
class JWTAuthentication(BaseAuthentication):
algorithm: str = 'HS256'
secret: str
lifetime_seconds: int
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.secret = secret
self.lifetime_seconds = lifetime_seconds
async def get_login_response(self, user: UserDB, response: Response):
data = {'user_id': user.id}
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
return {'token': token}
def get_authentication_method(self):
async def authentication_method(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
try:
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
user_id: str = data.get('user_id')
if user_id is None:
raise credentials_exception
except jwt.PyJWTError:
raise credentials_exception
user = await self.userDB.get(user_id)
if user is None or not user.is_active:
raise credentials_exception
return user
return authentication_method

View File

@ -3,7 +3,9 @@ from typing import List
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users.models import UserDB from fastapi_users.models import UserDB
from fastapi_users.password import get_password_hash, verify_and_update_password from fastapi_users.password import (
get_password_hash, verify_and_update_password
)
class BaseUserDatabase: class BaseUserDatabase:

View File

@ -1,7 +1,9 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from starlette import status from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserCreate, UserDB from fastapi_users.models import UserCreate, UserDB
from fastapi_users.password import get_password_hash from fastapi_users.password import get_password_hash
@ -9,7 +11,7 @@ from fastapi_users.password import get_password_hash
class UserRouter: class UserRouter:
def __new__(cls, userDB: BaseUserDatabase) -> APIRouter: def __new__(cls, userDB: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
router = APIRouter() router = APIRouter()
@router.post('/register') @router.post('/register')
@ -20,7 +22,7 @@ class UserRouter:
return created_user return created_user
@router.post('/login') @router.post('/login')
async def login(credentials: OAuth2PasswordRequestForm = Depends()): async def login(response: Response, credentials: OAuth2PasswordRequestForm = Depends()):
user = await userDB.authenticate(credentials) user = await userDB.authenticate(credentials)
if user is None: if user is None:
@ -28,6 +30,6 @@ class UserRouter:
elif not user.is_active: elif not user.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return user return await auth.get_login_response(user, response)
return router return router

View File

@ -1,15 +1,21 @@
import pytest import pytest
from fastapi import HTTPException
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB from fastapi_users.models import UserDB
from fastapi_users.password import get_password_hash from fastapi_users.password import get_password_hash
active_user = UserDB( active_user_data = UserDB(
id='aaa',
email='king.arthur@camelot.bt', email='king.arthur@camelot.bt',
hashed_password=get_password_hash('guinevere'), hashed_password=get_password_hash('guinevere'),
) )
inactive_user = UserDB( inactive_user_data = UserDB(
id='bbb',
email='percival@camelot.bt', email='percival@camelot.bt',
hashed_password=get_password_hash('angharad'), hashed_password=get_password_hash('angharad'),
is_active=False is_active=False
@ -18,16 +24,28 @@ inactive_user = UserDB(
@pytest.fixture @pytest.fixture
def user() -> UserDB: def user() -> UserDB:
return active_user return active_user_data
class MockUserDBInterface(BaseUserDatabase): @pytest.fixture
def inactive_user() -> UserDB:
return inactive_user_data
class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> UserDB:
if id == active_user_data.id:
return active_user_data
elif id == inactive_user_data.id:
return inactive_user_data
return None
async def get_by_email(self, email: str) -> UserDB: async def get_by_email(self, email: str) -> UserDB:
if email == active_user.email: if email == active_user_data.email:
return active_user return active_user_data
elif email == inactive_user.email: elif email == inactive_user_data.email:
return inactive_user return inactive_user_data
return None return None
async def create(self, user: UserDB) -> UserDB: async def create(self, user: UserDB) -> UserDB:
@ -35,5 +53,22 @@ class MockUserDBInterface(BaseUserDatabase):
@pytest.fixture @pytest.fixture
def mock_db_interface() -> MockUserDBInterface: def mock_user_db() -> MockUserDatabase:
return MockUserDBInterface() return MockUserDatabase()
class MockAuthentication(BaseAuthentication):
async def get_login_response(self, user: UserDB, response: Response):
return {'token': user.id}
async def authenticate(self, token: str) -> UserDB:
user = await self.userDB.get(token)
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return user
@pytest.fixture
def mock_authentication(mock_user_db) -> MockAuthentication:
return MockAuthentication(mock_user_db)

View File

@ -0,0 +1,70 @@
import jwt
import pytest
from fastapi import Depends, FastAPI
from starlette import status
from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
from fastapi_users.models import UserDB
SECRET = 'SECRET'
ALGORITHM = 'HS256'
LIFETIME = 3600
@pytest.fixture
def jwt_authentication(mock_user_db):
return JWTAuthentication(SECRET, LIFETIME, mock_user_db)
@pytest.fixture
def token():
def _token(user, lifetime=LIFETIME):
data = {'user_id': user.id}
return generate_jwt(data, lifetime, SECRET, ALGORITHM)
return _token
@pytest.fixture
def test_auth_client(jwt_authentication):
app = FastAPI()
@app.get('/test-auth')
def test_auth(user: UserDB = Depends(jwt_authentication.get_authentication_method())):
return user
return TestClient(app)
@pytest.mark.asyncio
async def test_get_login_response(jwt_authentication, user):
login_response = await jwt_authentication.get_login_response(user, Response())
assert 'token' in login_response
token = login_response['token']
decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM])
assert decoded['user_id'] == user.id
class TestGetAuthenticationMethod:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get('/test-auth')
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get('/test-auth', headers={'Authorization': 'Bearer foo'})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(inactive_user)}'})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(user)}'})
assert response.status_code == status.HTTP_200_OK
json = response.json()
assert json['id'] == user.id

View File

@ -1,5 +1,4 @@
import pytest import pytest
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
@ -11,27 +10,26 @@ def create_oauth2_password_request_form():
password=password, password=password,
scope='', scope='',
) )
return _create_oauth2_password_request_form return _create_oauth2_password_request_form
class TestAuthenticate: class TestAuthenticate:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_user(self, create_oauth2_password_request_form, mock_db_interface): async def test_unknown_user(self, create_oauth2_password_request_form, mock_user_db):
form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere') form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere')
user = await mock_db_interface.authenticate(form) user = await mock_user_db.authenticate(form)
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_password(self, create_oauth2_password_request_form, mock_db_interface): async def test_wrong_password(self, create_oauth2_password_request_form, mock_user_db):
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival') form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival')
user = await mock_db_interface.authenticate(form) user = await mock_user_db.authenticate(form)
assert user is None assert user is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_valid_credentials(self, create_oauth2_password_request_form, mock_db_interface): async def test_valid_credentials(self, create_oauth2_password_request_form, mock_user_db):
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere') form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere')
user = await mock_db_interface.authenticate(form) user = await mock_user_db.authenticate(form)
assert user is not None assert user is not None
assert user.email == 'king.arthur@camelot.bt' assert user.email == 'king.arthur@camelot.bt'

View File

@ -1,14 +1,15 @@
import pytest import pytest
from fastapi import FastAPI
from starlette import status from starlette import status
from starlette.testclient import TestClient from starlette.testclient import TestClient
from fastapi_users.models import UserDB
from fastapi_users.router import UserRouter
@pytest.fixture @pytest.fixture
def test_app_client(mock_db_interface) -> TestClient: def test_app_client(mock_user_db, mock_authentication) -> TestClient:
from fastapi import FastAPI userRouter = UserRouter(mock_user_db, mock_authentication)
from fastapi_users.router import UserRouter
userRouter = UserRouter(mock_db_interface)
app = FastAPI() app = FastAPI()
app.include_router(userRouter) app.include_router(userRouter)
@ -82,13 +83,14 @@ class TestLogin:
response = test_app_client.post('/login', data=data) response = test_app_client.post('/login', data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_valid_credentials(self, test_app_client: TestClient): def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
data = { data = {
'username': 'king.arthur@camelot.bt', 'username': 'king.arthur@camelot.bt',
'password': 'guinevere', 'password': 'guinevere',
} }
response = test_app_client.post('/login', data=data) response = test_app_client.post('/login', data=data)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.json() == {'token': user.id}
def test_inactive_user(self, test_app_client: TestClient): def test_inactive_user(self, test_app_client: TestClient):
data = { data = {