diff --git a/Pipfile b/Pipfile index 79c1b45e..1ea22fae 100644 --- a/Pipfile +++ b/Pipfile @@ -17,6 +17,7 @@ passlib = {extras = ["bcrypt"],version = "*"} email-validator = "*" sqlalchemy = "*" databases = "*" +python-multipart = "*" [requires] python_version = "3.7" diff --git a/Pipfile.lock b/Pipfile.lock index b606c97a..3985cdb6 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "a2653ab0f39cfc4780097259fdeef1f152ec4472b5b3b5cc80cf08997d7a1e81" + "sha256": "f950ae2475c73ff3553c19edb24c30a2ea7e0ffa630beb02b13e8588e31b9cb2" }, "pipfile-spec": 6, "requires": { @@ -134,6 +134,13 @@ ], "version": "==0.32.2" }, + "python-multipart": { + "hashes": [ + "sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43" + ], + "index": "pypi", + "version": "==0.0.5" + }, "six": { "hashes": [ "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 75fb8568..e7964ad1 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -1,6 +1,8 @@ from typing import List -from fastapi_users.models import UserDB, UserLogin +from fastapi.security import OAuth2PasswordRequestForm + +from fastapi_users.models import UserDB from fastapi_users.password import get_password_hash, verify_password @@ -19,14 +21,14 @@ class UserDBInterface: async def create(self, user: UserDB) -> UserDB: raise NotImplementedError() - async def authenticate(self, user_login) -> UserLogin: - user = await self.get_by_email(user_login.email) + async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB: + user = await self.get_by_email(credentials.username) # Always run the hasher to mitigate timing attack # Inspired from Django: https://code.djangoproject.com/ticket/20760 - get_password_hash(user_login.password) + get_password_hash(credentials.password) - if user is None or not verify_password(user_login.password, user.hashed_password): + if user is None or not verify_password(credentials.password, user.hashed_password): return None return user diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 46f23d9f..4b2b90ea 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -6,11 +6,6 @@ from pydantic import BaseModel from pydantic.types import EmailStr -class EmailPasswordMixin(BaseModel): - email: EmailStr - password: str - - class UserBase(BaseModel): id: str = None email: Optional[EmailStr] = None @@ -22,8 +17,9 @@ class UserBase(BaseModel): return v or str(uuid.uuid4()) -class UserCreate(EmailPasswordMixin, UserBase): - pass +class UserCreate(UserBase): + email: EmailStr + password: str class UserUpdate(UserBase): @@ -36,7 +32,3 @@ class UserDB(UserBase): class User(UserBase): pass - - -class UserLogin(EmailPasswordMixin): - pass diff --git a/fastapi_users/router.py b/fastapi_users/router.py index 0fc332df..3a8876c1 100644 --- a/fastapi_users/router.py +++ b/fastapi_users/router.py @@ -1,8 +1,9 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import OAuth2PasswordRequestForm from starlette import status from fastapi_users.db import UserDBInterface -from fastapi_users.models import UserCreate, UserDB, UserLogin +from fastapi_users.models import UserCreate, UserDB from fastapi_users.password import get_password_hash @@ -19,8 +20,8 @@ class UserRouter: return created_user @router.post('/login') - async def login(user_login: UserLogin): - user = await userDB.authenticate(user_login) + async def login(credentials: OAuth2PasswordRequestForm = Depends()): + user = await userDB.authenticate(credentials) if user is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/tests/test_db_interface.py b/tests/test_db_interface.py index 7830bb95..158861ed 100644 --- a/tests/test_db_interface.py +++ b/tests/test_db_interface.py @@ -1,31 +1,37 @@ import pytest -from fastapi_users.models import UserLogin +from fastapi.security import OAuth2PasswordRequestForm + + +@pytest.fixture +def create_oauth2_password_request_form(): + def _create_oauth2_password_request_form(username, password): + return OAuth2PasswordRequestForm( + username=username, + password=password, + scope='', + ) + + return _create_oauth2_password_request_form class TestAuthenticate: @pytest.mark.asyncio - async def test_unknown_user(self, mock_db_interface): - user = await mock_db_interface.authenticate(UserLogin( - email='lancelot@camelot.bt', - password='guinevere', - )) + async def test_unknown_user(self, create_oauth2_password_request_form, mock_db_interface): + form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere') + user = await mock_db_interface.authenticate(form) assert user is None @pytest.mark.asyncio - async def test_wrong_password(self, mock_db_interface): - user = await mock_db_interface.authenticate(UserLogin( - email='king.arthur@camelot.bt', - password='percival', - )) + async def test_wrong_password(self, create_oauth2_password_request_form, mock_db_interface): + form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival') + user = await mock_db_interface.authenticate(form) assert user is None @pytest.mark.asyncio - async def test_valid_credentials(self, mock_db_interface): - user = await mock_db_interface.authenticate(UserLogin( - email='king.arthur@camelot.bt', - password='guinevere', - )) + async def test_valid_credentials(self, create_oauth2_password_request_form, mock_db_interface): + form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere') + user = await mock_db_interface.authenticate(form) assert user is not None assert user.email == 'king.arthur@camelot.bt' diff --git a/tests/test_router.py b/tests/test_router.py index 19d938b5..edaf9efb 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -49,51 +49,51 @@ class TestRegister: class TestLogin: def test_empty_body(self, test_app_client: TestClient): - response = test_app_client.post('/login', json={}) + response = test_app_client.post('/login', data={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - def test_missing_email(self, test_app_client: TestClient): - json = { + def test_missing_username(self, test_app_client: TestClient): + data = { 'password': 'guinevere', } - response = test_app_client.post('/login', json=json) + response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_missing_password(self, test_app_client: TestClient): - json = { - 'email': 'king.arthur@camelot.bt', + data = { + 'username': 'king.arthur@camelot.bt', } - response = test_app_client.post('/login', json=json) + response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY def test_not_existing_user(self, test_app_client: TestClient): - json = { - 'email': 'lancelot@camelot.bt', + data = { + 'username': 'lancelot@camelot.bt', 'password': 'guinevere', } - response = test_app_client.post('/login', json=json) + response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST def test_wrong_password(self, test_app_client: TestClient): - json = { - 'email': 'king.arthur@camelot.bt', + data = { + 'username': 'king.arthur@camelot.bt', 'password': 'percival', } - response = test_app_client.post('/login', json=json) + response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST def test_valid_credentials(self, test_app_client: TestClient): - json = { - 'email': 'king.arthur@camelot.bt', + data = { + 'username': 'king.arthur@camelot.bt', 'password': 'guinevere', } - response = test_app_client.post('/login', json=json) + response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_200_OK def test_inactive_user(self, test_app_client: TestClient): - json = { - 'email': 'percival@camelot.bt', + data = { + 'username': 'percival@camelot.bt', 'password': 'angharad', } - response = test_app_client.post('/login', json=json) + response = test_app_client.post('/login', data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST