diff --git a/Pipfile b/Pipfile index 1ea22fae..8d08e0bd 100644 --- a/Pipfile +++ b/Pipfile @@ -18,6 +18,7 @@ email-validator = "*" sqlalchemy = "*" databases = "*" python-multipart = "*" +pyjwt = "*" [requires] python_version = "3.7" diff --git a/Pipfile.lock b/Pipfile.lock index 3985cdb6..e925284a 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "f950ae2475c73ff3553c19edb24c30a2ea7e0ffa630beb02b13e8588e31b9cb2" + "sha256": "3f2120d9354ee0e8e94a16b4dd1c948e2eb13d91fd71798c4005effc786b93f0" }, "pipfile-spec": 6, "requires": { @@ -93,11 +93,11 @@ }, "fastapi": { "hashes": [ - "sha256:345a8b0a8761bf830c99af2c38881f2215703844e4d5e52fff1ff711824bae0b", - "sha256:7a81762807d8ed43b2dddac8bb189d174e270f9c64493542d279673a6d768b26" + "sha256:96d2ba720f154a7f876fc419a464faac5c4bf3e660dd72f7134f53b90a8488ff", + "sha256:c4a6cb1b0b068efdb2b05ed27b8a88c29d7df5870c7267b02352a0a4eed5f4e9" ], "index": "pypi", - "version": "==0.40.0" + "version": "==0.41.0" }, "idna": { "hashes": [ @@ -134,6 +134,14 @@ ], "version": "==0.32.2" }, + "pyjwt": { + "hashes": [ + "sha256:5c6eca3c2940464d106b99ba83b00c6add741c9becaec087fb7ccdefea71350e", + "sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96" + ], + "index": "pypi", + "version": "==1.7.1" + }, "python-multipart": { "hashes": [ "sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43" @@ -157,9 +165,9 @@ }, "starlette": { "hashes": [ - "sha256:f600bf9d0beeeeebcb143e6d0c4f8858c2b05067d5a4feb446ba7400ba5e5dc5" + "sha256:c2ac9a42e0e0328ad20fe444115ac5e3760c1ee2ac1ff8cdb5ec915c4a453411" ], - "version": "==0.12.8" + "version": "==0.12.9" } }, "develop": { @@ -300,11 +308,11 @@ }, "pytest": { "hashes": [ - "sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5", - "sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5" + "sha256:7e4800063ccfc306a53c461442526c5571e1462f61583506ce97e4da6a1d88c8", + "sha256:ca563435f4941d0cb34767301c27bc65c510cb82e90b9ecf9cb52dc2c63caaa0" ], "index": "pypi", - "version": "==5.2.0" + "version": "==5.2.1" }, "pytest-asyncio": { "hashes": [ diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py new file mode 100644 index 00000000..bed698e0 --- /dev/null +++ b/fastapi_users/authentication/__init__.py @@ -0,0 +1,20 @@ +from typing import Callable + +from starlette.responses import Response + +from fastapi_users.db import BaseUserDatabase +from fastapi_users.models import UserDB + + +class BaseAuthentication: + + userDB: BaseUserDatabase + + def __init__(self, userDB: BaseUserDatabase): + self.userDB = userDB + + async def get_login_response(self, user: UserDB, response: Response): + raise NotImplementedError() + + def get_authentication_method(self) -> Callable[..., UserDB]: + raise NotImplementedError() diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py new file mode 100644 index 00000000..f7858de2 --- /dev/null +++ b/fastapi_users/authentication/jwt.py @@ -0,0 +1,58 @@ + +from datetime import datetime, timedelta + +import jwt +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from starlette import status +from starlette.responses import Response + +from fastapi_users.authentication import BaseAuthentication +from fastapi_users.models import UserDB + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") + + +def generate_jwt(data: dict, lifetime_seconds: int, secret: str, algorithm: str) -> str: + payload = data.copy() + expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds) + payload['exp'] = expire + return jwt.encode(payload, secret, algorithm=algorithm).decode('utf-8') + + +class JWTAuthentication(BaseAuthentication): + + algorithm: str = 'HS256' + secret: str + lifetime_seconds: int + + def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.secret = secret + self.lifetime_seconds = lifetime_seconds + + async def get_login_response(self, user: UserDB, response: Response): + data = {'user_id': user.id} + token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm) + + return {'token': token} + + def get_authentication_method(self): + async def authentication_method(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + + try: + data = jwt.decode(token, self.secret, algorithms=[self.algorithm]) + user_id: str = data.get('user_id') + if user_id is None: + raise credentials_exception + except jwt.PyJWTError: + raise credentials_exception + + user = await self.userDB.get(user_id) + if user is None or not user.is_active: + raise credentials_exception + + return user + + return authentication_method diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index ed80002a..93d21a05 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -3,7 +3,9 @@ from typing import List from fastapi.security import OAuth2PasswordRequestForm from fastapi_users.models import UserDB -from fastapi_users.password import get_password_hash, verify_and_update_password +from fastapi_users.password import ( + get_password_hash, verify_and_update_password +) class BaseUserDatabase: diff --git a/fastapi_users/router.py b/fastapi_users/router.py index d3d5ae40..e7ba7f18 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -1,7 +1,9 @@ from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm from starlette import status +from starlette.responses import Response +from fastapi_users.authentication import BaseAuthentication from fastapi_users.db import BaseUserDatabase from fastapi_users.models import UserCreate, UserDB from fastapi_users.password import get_password_hash @@ -9,7 +11,7 @@ from fastapi_users.password import get_password_hash class UserRouter: - def __new__(cls, userDB: BaseUserDatabase) -> APIRouter: + def __new__(cls, userDB: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter: router = APIRouter() @router.post('/register') @@ -20,7 +22,7 @@ class UserRouter: return created_user @router.post('/login') - async def login(credentials: OAuth2PasswordRequestForm = Depends()): + async def login(response: Response, credentials: OAuth2PasswordRequestForm = Depends()): user = await userDB.authenticate(credentials) if user is None: @@ -28,6 +30,6 @@ class UserRouter: elif not user.is_active: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - return user + return await auth.get_login_response(user, response) return router diff --git a/tests/conftest.py b/tests/conftest.py index 49413c08..e7f9cdcc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,21 @@ import pytest +from fastapi import HTTPException +from starlette import status +from starlette.responses import Response +from fastapi_users.authentication import BaseAuthentication from fastapi_users.db import BaseUserDatabase from fastapi_users.models import UserDB from fastapi_users.password import get_password_hash -active_user = UserDB( +active_user_data = UserDB( + id='aaa', email='king.arthur@camelot.bt', hashed_password=get_password_hash('guinevere'), ) -inactive_user = UserDB( +inactive_user_data = UserDB( + id='bbb', email='percival@camelot.bt', hashed_password=get_password_hash('angharad'), is_active=False @@ -18,16 +24,28 @@ inactive_user = UserDB( @pytest.fixture def user() -> UserDB: - return active_user + return active_user_data -class MockUserDBInterface(BaseUserDatabase): +@pytest.fixture +def inactive_user() -> UserDB: + return inactive_user_data + + +class MockUserDatabase(BaseUserDatabase): + + async def get(self, id: str) -> UserDB: + if id == active_user_data.id: + return active_user_data + elif id == inactive_user_data.id: + return inactive_user_data + return None async def get_by_email(self, email: str) -> UserDB: - if email == active_user.email: - return active_user - elif email == inactive_user.email: - return inactive_user + if email == active_user_data.email: + return active_user_data + elif email == inactive_user_data.email: + return inactive_user_data return None async def create(self, user: UserDB) -> UserDB: @@ -35,5 +53,22 @@ class MockUserDBInterface(BaseUserDatabase): @pytest.fixture -def mock_db_interface() -> MockUserDBInterface: - return MockUserDBInterface() +def mock_user_db() -> MockUserDatabase: + return MockUserDatabase() + + +class MockAuthentication(BaseAuthentication): + + async def get_login_response(self, user: UserDB, response: Response): + return {'token': user.id} + + async def authenticate(self, token: str) -> UserDB: + user = await self.userDB.get(token) + if user is None or not user.is_active: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + return user + + +@pytest.fixture +def mock_authentication(mock_user_db) -> MockAuthentication: + return MockAuthentication(mock_user_db) diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py new file mode 100644 index 00000000..4a83cc7b --- /dev/null +++ b/tests/test_authentication_jwt.py @@ -0,0 +1,70 @@ +import jwt +import pytest +from fastapi import Depends, FastAPI +from starlette import status +from starlette.responses import Response +from starlette.testclient import TestClient + +from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt +from fastapi_users.models import UserDB + +SECRET = 'SECRET' +ALGORITHM = 'HS256' +LIFETIME = 3600 + + +@pytest.fixture +def jwt_authentication(mock_user_db): + return JWTAuthentication(SECRET, LIFETIME, mock_user_db) + + +@pytest.fixture +def token(): + def _token(user, lifetime=LIFETIME): + data = {'user_id': user.id} + return generate_jwt(data, lifetime, SECRET, ALGORITHM) + return _token + + +@pytest.fixture +def test_auth_client(jwt_authentication): + app = FastAPI() + + @app.get('/test-auth') + def test_auth(user: UserDB = Depends(jwt_authentication.get_authentication_method())): + return user + + return TestClient(app) + + +@pytest.mark.asyncio +async def test_get_login_response(jwt_authentication, user): + login_response = await jwt_authentication.get_login_response(user, Response()) + + assert 'token' in login_response + + token = login_response['token'] + decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM]) + assert decoded['user_id'] == user.id + + +class TestGetAuthenticationMethod: + + def test_missing_token(self, test_auth_client): + response = test_auth_client.get('/test-auth') + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_invalid_token(self, test_auth_client): + response = test_auth_client.get('/test-auth', headers={'Authorization': 'Bearer foo'}) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user): + response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(inactive_user)}'}) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_valid_token(self, test_auth_client, token, user): + response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(user)}'}) + assert response.status_code == status.HTTP_200_OK + + json = response.json() + assert json['id'] == user.id diff --git a/tests/test_db_base.py b/tests/test_db_base.py index 158861ed..19e3c8da 100644 --- a/tests/test_db_base.py +++ b/tests/test_db_base.py @@ -1,5 +1,4 @@ import pytest - from fastapi.security import OAuth2PasswordRequestForm @@ -11,27 +10,26 @@ def create_oauth2_password_request_form(): password=password, scope='', ) - return _create_oauth2_password_request_form class TestAuthenticate: @pytest.mark.asyncio - async def test_unknown_user(self, create_oauth2_password_request_form, mock_db_interface): + async def test_unknown_user(self, create_oauth2_password_request_form, mock_user_db): form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere') - user = await mock_db_interface.authenticate(form) + user = await mock_user_db.authenticate(form) assert user is None @pytest.mark.asyncio - async def test_wrong_password(self, create_oauth2_password_request_form, mock_db_interface): + async def test_wrong_password(self, create_oauth2_password_request_form, mock_user_db): form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival') - user = await mock_db_interface.authenticate(form) + user = await mock_user_db.authenticate(form) assert user is None @pytest.mark.asyncio - async def test_valid_credentials(self, create_oauth2_password_request_form, mock_db_interface): + async def test_valid_credentials(self, create_oauth2_password_request_form, mock_user_db): form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere') - user = await mock_db_interface.authenticate(form) + user = await mock_user_db.authenticate(form) assert user is not None assert user.email == 'king.arthur@camelot.bt' diff --git a/tests/test_router.py b/tests/test_router.py index edaf9efb..ca2efe8a 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,14 +1,15 @@ import pytest +from fastapi import FastAPI from starlette import status from starlette.testclient import TestClient +from fastapi_users.models import UserDB +from fastapi_users.router import UserRouter + @pytest.fixture -def test_app_client(mock_db_interface) -> TestClient: - from fastapi import FastAPI - from fastapi_users.router import UserRouter - - userRouter = UserRouter(mock_db_interface) +def test_app_client(mock_user_db, mock_authentication) -> TestClient: + userRouter = UserRouter(mock_user_db, mock_authentication) app = FastAPI() app.include_router(userRouter) @@ -82,13 +83,14 @@ class TestLogin: response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST - def test_valid_credentials(self, test_app_client: TestClient): + def test_valid_credentials(self, test_app_client: TestClient, user: UserDB): data = { 'username': 'king.arthur@camelot.bt', 'password': 'guinevere', } response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_200_OK + assert response.json() == {'token': user.id} def test_inactive_user(self, test_app_client: TestClient): data = {