mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Use built-in OAuth2PasswordRequestForm for login
This commit is contained in:
1
Pipfile
1
Pipfile
@ -17,6 +17,7 @@ passlib = {extras = ["bcrypt"],version = "*"}
|
|||||||
email-validator = "*"
|
email-validator = "*"
|
||||||
sqlalchemy = "*"
|
sqlalchemy = "*"
|
||||||
databases = "*"
|
databases = "*"
|
||||||
|
python-multipart = "*"
|
||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.7"
|
python_version = "3.7"
|
||||||
|
9
Pipfile.lock
generated
9
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "a2653ab0f39cfc4780097259fdeef1f152ec4472b5b3b5cc80cf08997d7a1e81"
|
"sha256": "f950ae2475c73ff3553c19edb24c30a2ea7e0ffa630beb02b13e8588e31b9cb2"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -134,6 +134,13 @@
|
|||||||
],
|
],
|
||||||
"version": "==0.32.2"
|
"version": "==0.32.2"
|
||||||
},
|
},
|
||||||
|
"python-multipart": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==0.0.5"
|
||||||
|
},
|
||||||
"six": {
|
"six": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
|
"sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi_users.models import UserDB, UserLogin
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from fastapi_users.models import UserDB
|
||||||
from fastapi_users.password import get_password_hash, verify_password
|
from fastapi_users.password import get_password_hash, verify_password
|
||||||
|
|
||||||
|
|
||||||
@ -19,14 +21,14 @@ class UserDBInterface:
|
|||||||
async def create(self, user: UserDB) -> UserDB:
|
async def create(self, user: UserDB) -> UserDB:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def authenticate(self, user_login) -> UserLogin:
|
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
|
||||||
user = await self.get_by_email(user_login.email)
|
user = await self.get_by_email(credentials.username)
|
||||||
|
|
||||||
# Always run the hasher to mitigate timing attack
|
# Always run the hasher to mitigate timing attack
|
||||||
# Inspired from Django: https://code.djangoproject.com/ticket/20760
|
# Inspired from Django: https://code.djangoproject.com/ticket/20760
|
||||||
get_password_hash(user_login.password)
|
get_password_hash(credentials.password)
|
||||||
|
|
||||||
if user is None or not verify_password(user_login.password, user.hashed_password):
|
if user is None or not verify_password(credentials.password, user.hashed_password):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
@ -6,11 +6,6 @@ from pydantic import BaseModel
|
|||||||
from pydantic.types import EmailStr
|
from pydantic.types import EmailStr
|
||||||
|
|
||||||
|
|
||||||
class EmailPasswordMixin(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
id: str = None
|
id: str = None
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
@ -22,8 +17,9 @@ class UserBase(BaseModel):
|
|||||||
return v or str(uuid.uuid4())
|
return v or str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(EmailPasswordMixin, UserBase):
|
class UserCreate(UserBase):
|
||||||
pass
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(UserBase):
|
class UserUpdate(UserBase):
|
||||||
@ -36,7 +32,3 @@ class UserDB(UserBase):
|
|||||||
|
|
||||||
class User(UserBase):
|
class User(UserBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UserLogin(EmailPasswordMixin):
|
|
||||||
pass
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
|
||||||
from fastapi_users.db import UserDBInterface
|
from fastapi_users.db import UserDBInterface
|
||||||
from fastapi_users.models import UserCreate, UserDB, UserLogin
|
from fastapi_users.models import UserCreate, UserDB
|
||||||
from fastapi_users.password import get_password_hash
|
from fastapi_users.password import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
@ -19,8 +20,8 @@ class UserRouter:
|
|||||||
return created_user
|
return created_user
|
||||||
|
|
||||||
@router.post('/login')
|
@router.post('/login')
|
||||||
async def login(user_login: UserLogin):
|
async def login(credentials: OAuth2PasswordRequestForm = Depends()):
|
||||||
user = await userDB.authenticate(user_login)
|
user = await userDB.authenticate(credentials)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
@ -1,31 +1,37 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from fastapi_users.models import UserLogin
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_oauth2_password_request_form():
|
||||||
|
def _create_oauth2_password_request_form(username, password):
|
||||||
|
return OAuth2PasswordRequestForm(
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
scope='',
|
||||||
|
)
|
||||||
|
|
||||||
|
return _create_oauth2_password_request_form
|
||||||
|
|
||||||
|
|
||||||
class TestAuthenticate:
|
class TestAuthenticate:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_user(self, mock_db_interface):
|
async def test_unknown_user(self, create_oauth2_password_request_form, mock_db_interface):
|
||||||
user = await mock_db_interface.authenticate(UserLogin(
|
form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere')
|
||||||
email='lancelot@camelot.bt',
|
user = await mock_db_interface.authenticate(form)
|
||||||
password='guinevere',
|
|
||||||
))
|
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_wrong_password(self, mock_db_interface):
|
async def test_wrong_password(self, create_oauth2_password_request_form, mock_db_interface):
|
||||||
user = await mock_db_interface.authenticate(UserLogin(
|
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival')
|
||||||
email='king.arthur@camelot.bt',
|
user = await mock_db_interface.authenticate(form)
|
||||||
password='percival',
|
|
||||||
))
|
|
||||||
assert user is None
|
assert user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_credentials(self, mock_db_interface):
|
async def test_valid_credentials(self, create_oauth2_password_request_form, mock_db_interface):
|
||||||
user = await mock_db_interface.authenticate(UserLogin(
|
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere')
|
||||||
email='king.arthur@camelot.bt',
|
user = await mock_db_interface.authenticate(form)
|
||||||
password='guinevere',
|
|
||||||
))
|
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.email == 'king.arthur@camelot.bt'
|
assert user.email == 'king.arthur@camelot.bt'
|
||||||
|
@ -49,51 +49,51 @@ class TestRegister:
|
|||||||
class TestLogin:
|
class TestLogin:
|
||||||
|
|
||||||
def test_empty_body(self, test_app_client: TestClient):
|
def test_empty_body(self, test_app_client: TestClient):
|
||||||
response = test_app_client.post('/login', json={})
|
response = test_app_client.post('/login', data={})
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_missing_email(self, test_app_client: TestClient):
|
def test_missing_username(self, test_app_client: TestClient):
|
||||||
json = {
|
data = {
|
||||||
'password': 'guinevere',
|
'password': 'guinevere',
|
||||||
}
|
}
|
||||||
response = test_app_client.post('/login', json=json)
|
response = test_app_client.post('/login', data=data)
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_missing_password(self, test_app_client: TestClient):
|
def test_missing_password(self, test_app_client: TestClient):
|
||||||
json = {
|
data = {
|
||||||
'email': 'king.arthur@camelot.bt',
|
'username': 'king.arthur@camelot.bt',
|
||||||
}
|
}
|
||||||
response = test_app_client.post('/login', json=json)
|
response = test_app_client.post('/login', data=data)
|
||||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||||
|
|
||||||
def test_not_existing_user(self, test_app_client: TestClient):
|
def test_not_existing_user(self, test_app_client: TestClient):
|
||||||
json = {
|
data = {
|
||||||
'email': 'lancelot@camelot.bt',
|
'username': 'lancelot@camelot.bt',
|
||||||
'password': 'guinevere',
|
'password': 'guinevere',
|
||||||
}
|
}
|
||||||
response = test_app_client.post('/login', json=json)
|
response = test_app_client.post('/login', data=data)
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
def test_wrong_password(self, test_app_client: TestClient):
|
def test_wrong_password(self, test_app_client: TestClient):
|
||||||
json = {
|
data = {
|
||||||
'email': 'king.arthur@camelot.bt',
|
'username': 'king.arthur@camelot.bt',
|
||||||
'password': 'percival',
|
'password': 'percival',
|
||||||
}
|
}
|
||||||
response = test_app_client.post('/login', json=json)
|
response = test_app_client.post('/login', data=data)
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
def test_valid_credentials(self, test_app_client: TestClient):
|
def test_valid_credentials(self, test_app_client: TestClient):
|
||||||
json = {
|
data = {
|
||||||
'email': 'king.arthur@camelot.bt',
|
'username': 'king.arthur@camelot.bt',
|
||||||
'password': 'guinevere',
|
'password': 'guinevere',
|
||||||
}
|
}
|
||||||
response = test_app_client.post('/login', json=json)
|
response = test_app_client.post('/login', data=data)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
def test_inactive_user(self, test_app_client: TestClient):
|
def test_inactive_user(self, test_app_client: TestClient):
|
||||||
json = {
|
data = {
|
||||||
'email': 'percival@camelot.bt',
|
'username': 'percival@camelot.bt',
|
||||||
'password': 'angharad',
|
'password': 'angharad',
|
||||||
}
|
}
|
||||||
response = test_app_client.post('/login', json=json)
|
response = test_app_client.post('/login', data=data)
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
Reference in New Issue
Block a user