Files
fastapi-users/tests/test_authentication_authenticator.py
2024-11-03 12:51:32 +00:00

160 lines
4.9 KiB
Python

from collections.abc import AsyncGenerator, Sequence
from typing import Generic, Optional
import httpx
import pytest
import pytest_asyncio
from fastapi import Depends, FastAPI, Request, status
from fastapi.security.base import SecurityBase
from fastapi_users import models
from fastapi_users.authentication import AuthenticationBackend, Authenticator
from fastapi_users.authentication.authenticator import DuplicateBackendNamesError
from fastapi_users.authentication.strategy import Strategy
from fastapi_users.authentication.transport import Transport
from fastapi_users.manager import BaseUserManager
from fastapi_users.types import DependencyCallable
from tests.conftest import User, UserModel
class MockSecurityScheme(SecurityBase):
def __call__(self, request: Request) -> Optional[str]:
return "mock"
class MockTransport(Transport):
scheme: MockSecurityScheme
def __init__(self):
self.scheme = MockSecurityScheme()
class NoneStrategy(Strategy):
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
return None
class UserStrategy(Strategy, Generic[models.UP]):
def __init__(self, user: models.UP):
self.user = user
async def read_token(
self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID]
) -> Optional[models.UP]:
return self.user
@pytest.fixture
def get_backend_none():
def _get_backend_none(name: str = "none"):
return AuthenticationBackend(
name=name, transport=MockTransport(), get_strategy=lambda: NoneStrategy()
)
return _get_backend_none
@pytest.fixture
def get_backend_user(user: UserModel):
def _get_backend_user(name: str = "user"):
return AuthenticationBackend(
name=name,
transport=MockTransport(),
get_strategy=lambda: UserStrategy(user),
)
return _get_backend_user
@pytest_asyncio.fixture
def get_test_auth_client(get_user_manager, get_test_client):
async def _get_test_auth_client(
backends: list[AuthenticationBackend],
get_enabled_backends: Optional[
DependencyCallable[Sequence[AuthenticationBackend]]
] = None,
) -> AsyncGenerator[httpx.AsyncClient, None]:
app = FastAPI()
authenticator = Authenticator(backends, get_user_manager)
@app.get("/test-current-user", response_model=User)
def test_current_user(
user: UserModel = Depends(
authenticator.current_user(get_enabled_backends=get_enabled_backends)
),
):
return user
@app.get("/test-current-active-user", response_model=User)
def test_current_active_user(
user: UserModel = Depends(
authenticator.current_user(
active=True, get_enabled_backends=get_enabled_backends
)
),
):
return user
@app.get("/test-current-superuser", response_model=User)
def test_current_superuser(
user: UserModel = Depends(
authenticator.current_user(
active=True,
superuser=True,
get_enabled_backends=get_enabled_backends,
)
),
):
return user
async for client in get_test_client(app):
yield client
return _get_test_auth_client
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_authenticator(get_test_auth_client, get_backend_none, get_backend_user):
async for client in get_test_auth_client([get_backend_none(), get_backend_user()]):
response = await client.get("/test-current-user")
assert response.status_code == status.HTTP_200_OK
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_authenticator_none(get_test_auth_client, get_backend_none):
async for client in get_test_auth_client(
[get_backend_none(), get_backend_none(name="none-bis")]
):
response = await client.get("/test-current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_authenticator_none_enabled(
get_test_auth_client, get_backend_none, get_backend_user
):
backend_none = get_backend_none()
backend_user = get_backend_user()
async def get_enabled_backends():
return [backend_none]
async for client in get_test_auth_client(
[backend_none, backend_user], get_enabled_backends
):
response = await client.get("/test-current-user")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.authentication
@pytest.mark.asyncio
async def test_authenticators_with_same_name(get_test_auth_client, get_backend_none):
with pytest.raises(DuplicateBackendNamesError):
async for _ in get_test_auth_client([get_backend_none(), get_backend_none()]):
pass