mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Implement JWT authentication
This commit is contained in:
1
Pipfile
1
Pipfile
@ -18,6 +18,7 @@ email-validator = "*"
|
||||
sqlalchemy = "*"
|
||||
databases = "*"
|
||||
python-multipart = "*"
|
||||
pyjwt = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.7"
|
||||
|
26
Pipfile.lock
generated
26
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
||||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "f950ae2475c73ff3553c19edb24c30a2ea7e0ffa630beb02b13e8588e31b9cb2"
|
||||
"sha256": "3f2120d9354ee0e8e94a16b4dd1c948e2eb13d91fd71798c4005effc786b93f0"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
@ -93,11 +93,11 @@
|
||||
},
|
||||
"fastapi": {
|
||||
"hashes": [
|
||||
"sha256:345a8b0a8761bf830c99af2c38881f2215703844e4d5e52fff1ff711824bae0b",
|
||||
"sha256:7a81762807d8ed43b2dddac8bb189d174e270f9c64493542d279673a6d768b26"
|
||||
"sha256:96d2ba720f154a7f876fc419a464faac5c4bf3e660dd72f7134f53b90a8488ff",
|
||||
"sha256:c4a6cb1b0b068efdb2b05ed27b8a88c29d7df5870c7267b02352a0a4eed5f4e9"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==0.40.0"
|
||||
"version": "==0.41.0"
|
||||
},
|
||||
"idna": {
|
||||
"hashes": [
|
||||
@ -134,6 +134,14 @@
|
||||
],
|
||||
"version": "==0.32.2"
|
||||
},
|
||||
"pyjwt": {
|
||||
"hashes": [
|
||||
"sha256:5c6eca3c2940464d106b99ba83b00c6add741c9becaec087fb7ccdefea71350e",
|
||||
"sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==1.7.1"
|
||||
},
|
||||
"python-multipart": {
|
||||
"hashes": [
|
||||
"sha256:f7bb5f611fc600d15fa47b3974c8aa16e93724513b49b5f95c81e6624c83fa43"
|
||||
@ -157,9 +165,9 @@
|
||||
},
|
||||
"starlette": {
|
||||
"hashes": [
|
||||
"sha256:f600bf9d0beeeeebcb143e6d0c4f8858c2b05067d5a4feb446ba7400ba5e5dc5"
|
||||
"sha256:c2ac9a42e0e0328ad20fe444115ac5e3760c1ee2ac1ff8cdb5ec915c4a453411"
|
||||
],
|
||||
"version": "==0.12.8"
|
||||
"version": "==0.12.9"
|
||||
}
|
||||
},
|
||||
"develop": {
|
||||
@ -300,11 +308,11 @@
|
||||
},
|
||||
"pytest": {
|
||||
"hashes": [
|
||||
"sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5",
|
||||
"sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5"
|
||||
"sha256:7e4800063ccfc306a53c461442526c5571e1462f61583506ce97e4da6a1d88c8",
|
||||
"sha256:ca563435f4941d0cb34767301c27bc65c510cb82e90b9ecf9cb52dc2c63caaa0"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==5.2.0"
|
||||
"version": "==5.2.1"
|
||||
},
|
||||
"pytest-asyncio": {
|
||||
"hashes": [
|
||||
|
20
fastapi_users/authentication/__init__.py
Normal file
20
fastapi_users/authentication/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Callable
|
||||
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
|
||||
userDB: BaseUserDatabase
|
||||
|
||||
def __init__(self, userDB: BaseUserDatabase):
|
||||
self.userDB = userDB
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_authentication_method(self) -> Callable[..., UserDB]:
|
||||
raise NotImplementedError()
|
58
fastapi_users/authentication/jwt.py
Normal file
58
fastapi_users/authentication/jwt.py
Normal file
@ -0,0 +1,58 @@
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.models import UserDB
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
|
||||
|
||||
def generate_jwt(data: dict, lifetime_seconds: int, secret: str, algorithm: str) -> str:
|
||||
payload = data.copy()
|
||||
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
|
||||
payload['exp'] = expire
|
||||
return jwt.encode(payload, secret, algorithm=algorithm).decode('utf-8')
|
||||
|
||||
|
||||
class JWTAuthentication(BaseAuthentication):
|
||||
|
||||
algorithm: str = 'HS256'
|
||||
secret: str
|
||||
lifetime_seconds: int
|
||||
|
||||
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.secret = secret
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
data = {'user_id': user.id}
|
||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
||||
|
||||
return {'token': token}
|
||||
|
||||
def get_authentication_method(self):
|
||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
try:
|
||||
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||
user_id: str = data.get('user_id')
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
except jwt.PyJWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = await self.userDB.get(user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
return authentication_method
|
@ -3,7 +3,9 @@ from typing import List
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.password import get_password_hash, verify_and_update_password
|
||||
from fastapi_users.password import (
|
||||
get_password_hash, verify_and_update_password
|
||||
)
|
||||
|
||||
|
||||
class BaseUserDatabase:
|
||||
|
@ -1,7 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserCreate, UserDB
|
||||
from fastapi_users.password import get_password_hash
|
||||
@ -9,7 +11,7 @@ from fastapi_users.password import get_password_hash
|
||||
|
||||
class UserRouter:
|
||||
|
||||
def __new__(cls, userDB: BaseUserDatabase) -> APIRouter:
|
||||
def __new__(cls, userDB: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
|
||||
router = APIRouter()
|
||||
|
||||
@router.post('/register')
|
||||
@ -20,7 +22,7 @@ class UserRouter:
|
||||
return created_user
|
||||
|
||||
@router.post('/login')
|
||||
async def login(credentials: OAuth2PasswordRequestForm = Depends()):
|
||||
async def login(response: Response, credentials: OAuth2PasswordRequestForm = Depends()):
|
||||
user = await userDB.authenticate(credentials)
|
||||
|
||||
if user is None:
|
||||
@ -28,6 +30,6 @@ class UserRouter:
|
||||
elif not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
return user
|
||||
return await auth.get_login_response(user, response)
|
||||
|
||||
return router
|
||||
|
@ -1,15 +1,21 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
active_user = UserDB(
|
||||
active_user_data = UserDB(
|
||||
id='aaa',
|
||||
email='king.arthur@camelot.bt',
|
||||
hashed_password=get_password_hash('guinevere'),
|
||||
)
|
||||
|
||||
inactive_user = UserDB(
|
||||
inactive_user_data = UserDB(
|
||||
id='bbb',
|
||||
email='percival@camelot.bt',
|
||||
hashed_password=get_password_hash('angharad'),
|
||||
is_active=False
|
||||
@ -18,16 +24,28 @@ inactive_user = UserDB(
|
||||
|
||||
@pytest.fixture
|
||||
def user() -> UserDB:
|
||||
return active_user
|
||||
return active_user_data
|
||||
|
||||
|
||||
class MockUserDBInterface(BaseUserDatabase):
|
||||
@pytest.fixture
|
||||
def inactive_user() -> UserDB:
|
||||
return inactive_user_data
|
||||
|
||||
|
||||
class MockUserDatabase(BaseUserDatabase):
|
||||
|
||||
async def get(self, id: str) -> UserDB:
|
||||
if id == active_user_data.id:
|
||||
return active_user_data
|
||||
elif id == inactive_user_data.id:
|
||||
return inactive_user_data
|
||||
return None
|
||||
|
||||
async def get_by_email(self, email: str) -> UserDB:
|
||||
if email == active_user.email:
|
||||
return active_user
|
||||
elif email == inactive_user.email:
|
||||
return inactive_user
|
||||
if email == active_user_data.email:
|
||||
return active_user_data
|
||||
elif email == inactive_user_data.email:
|
||||
return inactive_user_data
|
||||
return None
|
||||
|
||||
async def create(self, user: UserDB) -> UserDB:
|
||||
@ -35,5 +53,22 @@ class MockUserDBInterface(BaseUserDatabase):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_interface() -> MockUserDBInterface:
|
||||
return MockUserDBInterface()
|
||||
def mock_user_db() -> MockUserDatabase:
|
||||
return MockUserDatabase()
|
||||
|
||||
|
||||
class MockAuthentication(BaseAuthentication):
|
||||
|
||||
async def get_login_response(self, user: UserDB, response: Response):
|
||||
return {'token': user.id}
|
||||
|
||||
async def authenticate(self, token: str) -> UserDB:
|
||||
user = await self.userDB.get(token)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_authentication(mock_user_db) -> MockAuthentication:
|
||||
return MockAuthentication(mock_user_db)
|
||||
|
70
tests/test_authentication_jwt.py
Normal file
70
tests/test_authentication_jwt.py
Normal file
@ -0,0 +1,70 @@
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
|
||||
from fastapi_users.models import UserDB
|
||||
|
||||
SECRET = 'SECRET'
|
||||
ALGORITHM = 'HS256'
|
||||
LIFETIME = 3600
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_authentication(mock_user_db):
|
||||
return JWTAuthentication(SECRET, LIFETIME, mock_user_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token():
|
||||
def _token(user, lifetime=LIFETIME):
|
||||
data = {'user_id': user.id}
|
||||
return generate_jwt(data, lifetime, SECRET, ALGORITHM)
|
||||
return _token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_auth_client(jwt_authentication):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get('/test-auth')
|
||||
def test_auth(user: UserDB = Depends(jwt_authentication.get_authentication_method())):
|
||||
return user
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_login_response(jwt_authentication, user):
|
||||
login_response = await jwt_authentication.get_login_response(user, Response())
|
||||
|
||||
assert 'token' in login_response
|
||||
|
||||
token = login_response['token']
|
||||
decoded = jwt.decode(token, SECRET, algorithms=[ALGORITHM])
|
||||
assert decoded['user_id'] == user.id
|
||||
|
||||
|
||||
class TestGetAuthenticationMethod:
|
||||
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get('/test-auth')
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get('/test-auth', headers={'Authorization': 'Bearer foo'})
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
|
||||
response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(inactive_user)}'})
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token(self, test_auth_client, token, user):
|
||||
response = test_auth_client.get('/test-auth', headers={'Authorization': f'Bearer {token(user)}'})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
json = response.json()
|
||||
assert json['id'] == user.id
|
@ -1,5 +1,4 @@
|
||||
import pytest
|
||||
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
|
||||
@ -11,27 +10,26 @@ def create_oauth2_password_request_form():
|
||||
password=password,
|
||||
scope='',
|
||||
)
|
||||
|
||||
return _create_oauth2_password_request_form
|
||||
|
||||
|
||||
class TestAuthenticate:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_user(self, create_oauth2_password_request_form, mock_db_interface):
|
||||
async def test_unknown_user(self, create_oauth2_password_request_form, mock_user_db):
|
||||
form = create_oauth2_password_request_form('lancelot@camelot.bt', 'guinevere')
|
||||
user = await mock_db_interface.authenticate(form)
|
||||
user = await mock_user_db.authenticate(form)
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_password(self, create_oauth2_password_request_form, mock_db_interface):
|
||||
async def test_wrong_password(self, create_oauth2_password_request_form, mock_user_db):
|
||||
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'percival')
|
||||
user = await mock_db_interface.authenticate(form)
|
||||
user = await mock_user_db.authenticate(form)
|
||||
assert user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_credentials(self, create_oauth2_password_request_form, mock_db_interface):
|
||||
async def test_valid_credentials(self, create_oauth2_password_request_form, mock_user_db):
|
||||
form = create_oauth2_password_request_form('king.arthur@camelot.bt', 'guinevere')
|
||||
user = await mock_db_interface.authenticate(form)
|
||||
user = await mock_user_db.authenticate(form)
|
||||
assert user is not None
|
||||
assert user.email == 'king.arthur@camelot.bt'
|
||||
|
@ -1,14 +1,15 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from starlette import status
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.models import UserDB
|
||||
from fastapi_users.router import UserRouter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app_client(mock_db_interface) -> TestClient:
|
||||
from fastapi import FastAPI
|
||||
from fastapi_users.router import UserRouter
|
||||
|
||||
userRouter = UserRouter(mock_db_interface)
|
||||
def test_app_client(mock_user_db, mock_authentication) -> TestClient:
|
||||
userRouter = UserRouter(mock_user_db, mock_authentication)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(userRouter)
|
||||
@ -82,13 +83,14 @@ class TestLogin:
|
||||
response = test_app_client.post('/login', data=data)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
def test_valid_credentials(self, test_app_client: TestClient):
|
||||
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
|
||||
data = {
|
||||
'username': 'king.arthur@camelot.bt',
|
||||
'password': 'guinevere',
|
||||
}
|
||||
response = test_app_client.post('/login', data=data)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert response.json() == {'token': user.id}
|
||||
|
||||
def test_inactive_user(self, test_app_client: TestClient):
|
||||
data = {
|
||||
|
Reference in New Issue
Block a user