Fix #42: multiple authentication backends (#47)

* Revamp authentication to allow multiple backends

* Make router generate a login route for each backend

* Apply black

* Remove unused imports

* Complete docstrings

* Update documentation

* WIP add cookie auth

* Complete cookie auth unit tests

* Add documentation for cookie auth

* Fix cookie backend default name

* Don't make cookie return a Response
This commit is contained in:
François Voron
2019-12-04 13:32:49 +01:00
committed by GitHub
parent 5e4c7996de
commit 49deb437a6
22 changed files with 591 additions and 341 deletions

View File

@ -1,12 +1,14 @@
from typing import List, Optional
from typing import Any, List, Mapping, Optional, Tuple
import http.cookies
import pytest
from fastapi import Depends, FastAPI
from fastapi.security import OAuth2PasswordBearer
from starlette.requests import Request
from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.authentication import Authenticator, BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash
@ -81,70 +83,76 @@ def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
return MockUserDatabase()
class MockAuthentication(BaseAuthentication):
def __init__(self, name: str = "mock"):
super().__init__(name)
self.scheme = OAuth2PasswordBearer("/users/login", auto_error=False)
async def __call__(self, request: Request, user_db: BaseUserDatabase):
token = await self.scheme.__call__(request)
if token is not None:
return await user_db.get(token)
return None
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
@pytest.fixture
def mock_authentication():
class MockAuthentication(BaseAuthentication):
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
def get_current_user(self, user_db: BaseUserDatabase):
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_user_base(user)
return _get_current_user
def get_current_active_user(self, user_db: BaseUserDatabase):
async def _get_current_active_user(
token: str = Depends(self.oauth2_scheme),
):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_active_user_base(user)
return _get_current_active_user
def get_current_superuser(self, user_db: BaseUserDatabase):
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
user = await self._get_authentication_method(user_db)(token)
return self._get_current_superuser_base(user)
return _get_current_superuser
def _get_authentication_method(self, user_db: BaseUserDatabase):
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
return await user_db.get(token)
return authentication_method
return MockAuthentication()
@pytest.fixture
def request_builder():
def _request_builder(
headers: Mapping[str, Any] = None, cookies: Mapping[str, str] = None
) -> Request:
encoded_headers: List[Tuple[bytes, bytes]] = []
if headers is not None:
encoded_headers += [
(key.lower().encode("latin-1"), headers[key].encode("latin-1"))
for key in headers
]
if cookies is not None:
for key in cookies:
cookie = http.cookies.SimpleCookie() # type: http.cookies.BaseCookie
cookie[key] = cookies[key]
cookie_val = cookie.output(header="").strip()
encoded_headers.append((b"cookie", cookie_val.encode("latin-1")))
scope = {
"type": "http",
"headers": encoded_headers,
}
return Request(scope)
return _request_builder
@pytest.fixture
def get_test_auth_client(mock_user_db):
def _get_test_auth_client(authentication):
def _get_test_auth_client(backends: List[BaseAuthentication]) -> TestClient:
app = FastAPI()
authenticator = Authenticator(backends, mock_user_db)
@app.get("/test-current-user")
def test_current_user(
user: BaseUserDB = Depends(authentication.get_current_user(mock_user_db)),
user: BaseUserDB = Depends(authenticator.get_current_user),
):
return user
@app.get("/test-current-active-user")
def test_current_active_user(
user: BaseUserDB = Depends(
authentication.get_current_active_user(mock_user_db)
),
user: BaseUserDB = Depends(authenticator.get_current_active_user),
):
return user
@app.get("/test-current-superuser")
def test_current_superuser(
user: BaseUserDB = Depends(
authentication.get_current_superuser(mock_user_db)
),
user: BaseUserDB = Depends(authenticator.get_current_superuser),
):
return user

View File

@ -0,0 +1,45 @@
from typing import Optional
import pytest
from starlette import status
from starlette.requests import Request
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import BaseUserDB
@pytest.fixture()
def auth_backend_none():
class BackendNone(BaseAuthentication):
async def __call__(
self, request: Request, user_db: BaseUserDatabase
) -> Optional[BaseUserDB]:
return None
return BackendNone()
@pytest.fixture()
def auth_backend_user(user):
class BackendUser(BaseAuthentication):
async def __call__(
self, request: Request, user_db: BaseUserDatabase
) -> Optional[BaseUserDB]:
return user
return BackendUser()
@pytest.mark.authentication
def test_authenticator(get_test_auth_client, auth_backend_none, auth_backend_user):
client = get_test_auth_client([auth_backend_none, auth_backend_user])
response = client.get("/test-current-user")
assert response.status_code == status.HTTP_200_OK
@pytest.mark.authentication
def test_authenticator_none(get_test_auth_client, auth_backend_none):
client = get_test_auth_client([auth_backend_none, auth_backend_none])
response = client.get("/test-current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@ -1,30 +1,27 @@
import pytest
from fastapi.security import OAuth2PasswordBearer
from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
@pytest.mark.asyncio
@pytest.fixture
def base_authentication():
return BaseAuthentication()
@pytest.mark.authentication
@pytest.mark.parametrize(
"constructor_kwargs", [{}, {"scheme": OAuth2PasswordBearer(tokenUrl="/foo")}]
)
async def test_not_implemented_methods(constructor_kwargs, user, mock_user_db):
response = Response()
base_authentication = BaseAuthentication(**constructor_kwargs)
class TestAuthenticate:
@pytest.mark.asyncio
async def test_not_implemented(
self, base_authentication, mock_user_db, request_builder
):
request = request_builder({})
with pytest.raises(NotImplementedError):
await base_authentication(request, mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication.get_login_response(user, response)
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(base_authentication, user):
with pytest.raises(NotImplementedError):
await base_authentication.get_current_user(mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication.get_current_active_user(mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication.get_current_superuser(mock_user_db)
with pytest.raises(NotImplementedError):
await base_authentication._get_authentication_method(mock_user_db)
await base_authentication.get_login_response(user, Response())

View File

@ -0,0 +1,105 @@
import re
import jwt
import pytest
from starlette.responses import Response
from fastapi_users.authentication.cookie import CookieAuthentication
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
SECRET = "SECRET"
LIFETIME = 3600
COOKIE_NAME = "COOKIE_NAME"
@pytest.fixture
def cookie_authentication():
return CookieAuthentication(SECRET, LIFETIME, COOKIE_NAME)
@pytest.fixture
def token():
def _token(user=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"}
if user is not None:
data["user_id"] = user.id
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _token
@pytest.mark.authentication
def test_default_name(cookie_authentication):
assert cookie_authentication.name == "cookie"
@pytest.mark.authentication
class TestAuthenticate:
@pytest.mark.asyncio
async def test_missing_token(
self, cookie_authentication, mock_user_db, request_builder
):
request = request_builder()
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self, cookie_authentication, mock_user_db, request_builder
):
cookies = {}
cookies[COOKIE_NAME] = "foo"
request = request_builder(cookies=cookies)
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_missing_user_payload(
self, cookie_authentication, mock_user_db, request_builder, token
):
cookies = {}
cookies[COOKIE_NAME] = token()
request = request_builder(cookies=cookies)
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(
self, cookie_authentication, mock_user_db, request_builder, token, user
):
cookies = {}
cookies[COOKIE_NAME] = token(user)
request = request_builder(cookies=cookies)
authenticated_user = await cookie_authentication(request, mock_user_db)
assert authenticated_user.id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(cookie_authentication, user):
response = Response()
login_response = await cookie_authentication.get_login_response(user, response)
# We shouldn't return directly the response
# so that FastAPI can terminate it properly
assert login_response is None
cookies = [
header for header in response.raw_headers if header[0] == b"set-cookie"
]
assert len(cookies) == 1
cookie = cookies[0][1].decode("latin-1")
assert f"Max-Age={LIFETIME}" in cookie
cookie_name_value = re.match(r"^(\w+)=([^;]+);", cookie)
cookie_name = cookie_name_value[1]
assert cookie_name == COOKIE_NAME
cookie_value = cookie_name_value[2]
decoded = jwt.decode(
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
assert decoded["user_id"] == user.id

View File

@ -1,6 +1,5 @@
import jwt
import pytest
from starlette import status
from starlette.responses import Response
from fastapi_users.authentication.jwt import JWTAuthentication
@ -8,11 +7,12 @@ from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
SECRET = "SECRET"
LIFETIME = 3600
TOKEN_URL = "/login"
@pytest.fixture
def jwt_authentication():
return JWTAuthentication(SECRET, LIFETIME)
return JWTAuthentication(SECRET, LIFETIME, TOKEN_URL)
@pytest.fixture
@ -26,11 +26,47 @@ def token():
return _token
@pytest.fixture
def test_auth_client(get_test_auth_client, jwt_authentication):
return get_test_auth_client(jwt_authentication)
@pytest.mark.authentication
def test_default_name(jwt_authentication):
assert jwt_authentication.name == "jwt"
@pytest.mark.authentication
class TestAuthenticate:
@pytest.mark.asyncio
async def test_missing_token(
self, jwt_authentication, mock_user_db, request_builder
):
request = request_builder(headers={})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_invalid_token(
self, jwt_authentication, mock_user_db, request_builder
):
request = request_builder(headers={"Authorization": "Bearer foo"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_missing_user_payload(
self, jwt_authentication, mock_user_db, request_builder, token
):
request = request_builder(headers={"Authorization": f"Bearer {token()}"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(
self, jwt_authentication, mock_user_db, request_builder, token, user
):
request = request_builder(headers={"Authorization": f"Bearer {token(user)}"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user.id == user.id
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_get_login_response(jwt_authentication, user):
login_response = await jwt_authentication.get_login_response(user, Response())
@ -42,108 +78,3 @@ async def test_get_login_response(jwt_authentication, user):
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
assert decoded["user_id"] == user.id
@pytest.mark.authentication
class TestGetCurrentUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_missing_user_payload(self, test_auth_client, token):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {token()}"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-user",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == inactive_user.id
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-user", headers={"Authorization": f"Bearer {token(user)}"}
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
@pytest.mark.authentication
class TestGetCurrentActiveUser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-active-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-active-user",
headers={"Authorization": f"Bearer {token(user)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == user.id
@pytest.mark.authentication
class TestGetCurrentSuperuser:
def test_missing_token(self, test_auth_client):
response = test_auth_client.get("/test-current-superuser")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_invalid_token(self, test_auth_client):
response = test_auth_client.get(
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_inactive_user(self, test_auth_client, token, inactive_user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(inactive_user)}"},
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_valid_token_regular_user(self, test_auth_client, token, user):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(user)}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_valid_token_superuser(self, test_auth_client, token, superuser):
response = test_auth_client.get(
"/test-current-superuser",
headers={"Authorization": f"Bearer {token(superuser)}"},
)
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert response_json["id"] == superuser.id

View File

@ -21,7 +21,7 @@ def fastapi_users(request, mock_user_db, mock_authentication) -> FastAPIUsers:
class User(BaseUser):
pass
fastapi_users = FastAPIUsers(mock_user_db, mock_authentication, User, "SECRET")
fastapi_users = FastAPIUsers(mock_user_db, [mock_authentication], User, "SECRET")
@fastapi_users.on_after_register()
def on_after_register():

View File

@ -7,9 +7,11 @@ from fastapi import FastAPI
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.authentication import Authenticator
from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.router import ErrorCode, Event, get_user_router
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
from tests.conftest import MockAuthentication
SECRET = "SECRET"
LIFETIME = 3600
@ -44,10 +46,13 @@ def test_app_client(mock_user_db, mock_authentication, event_handler) -> TestCli
class User(BaseUser):
pass
userRouter = get_user_router(
mock_user_db, User, mock_authentication, SECRET, LIFETIME
mock_authentication_bis = MockAuthentication(name="mock-bis")
authenticator = Authenticator(
[mock_authentication, mock_authentication_bis], mock_user_db
)
userRouter = get_user_router(mock_user_db, User, authenticator, SECRET, LIFETIME)
userRouter.add_event_handler(Event.ON_AFTER_REGISTER, event_handler)
userRouter.add_event_handler(Event.ON_AFTER_FORGOT_PASSWORD, event_handler)
@ -125,42 +130,45 @@ class TestRegister:
@pytest.mark.router
@pytest.mark.parametrize("path", ["/login/mock", "/login/mock-bis"])
class TestLogin:
def test_empty_body(self, test_app_client: TestClient):
response = test_app_client.post("/login", data={})
def test_empty_body(self, path, test_app_client: TestClient):
response = test_app_client.post(path, data={})
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_missing_username(self, test_app_client: TestClient):
def test_missing_username(self, path, test_app_client: TestClient):
data = {"password": "guinevere"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_missing_password(self, test_app_client: TestClient):
def test_missing_password(self, path, test_app_client: TestClient):
data = {"username": "king.arthur@camelot.bt"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
def test_not_existing_user(self, test_app_client: TestClient):
def test_not_existing_user(self, path, test_app_client: TestClient):
data = {"username": "lancelot@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
def test_wrong_password(self, test_app_client: TestClient):
def test_wrong_password(self, path, test_app_client: TestClient):
data = {"username": "king.arthur@camelot.bt", "password": "percival"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
def test_valid_credentials(
self, path, test_app_client: TestClient, user: BaseUserDB
):
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"token": user.id}
def test_inactive_user(self, test_app_client: TestClient):
def test_inactive_user(self, path, test_app_client: TestClient):
data = {"username": "percival@camelot.bt", "password": "angharad"}
response = test_app_client.post("/login", data=data)
response = test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json()["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS