mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-15 11:11:16 +08:00

* Replace unused `for` index with underscore * Use `items()` to directly unpack dictionary values * Merge duplicate blocks in conditional * Use `any()` instead of for loop * Format __init__.py
142 lines
4.2 KiB
Python
142 lines
4.2 KiB
Python
from typing import AsyncGenerator, Callable, Generic, List, Optional, Sequence
|
|
|
|
import httpx
|
|
import pytest
|
|
from fastapi import Depends, FastAPI, Request, status
|
|
from fastapi.security.base import SecurityBase
|
|
|
|
from fastapi_users import models
|
|
from fastapi_users.authentication import (
|
|
Authenticator,
|
|
BaseAuthentication,
|
|
DuplicateBackendNamesError,
|
|
)
|
|
from fastapi_users.manager import BaseUserManager
|
|
from tests.conftest import UserDB
|
|
|
|
|
|
class MockSecurityScheme(SecurityBase):
|
|
def __call__(self, request: Request) -> Optional[str]:
|
|
return "mock"
|
|
|
|
|
|
class BackendNone(
|
|
Generic[models.UC, models.UD], BaseAuthentication[str, models.UC, models.UD]
|
|
):
|
|
def __init__(self, name="none"):
|
|
super().__init__(name, logout=False)
|
|
self.scheme = MockSecurityScheme()
|
|
|
|
async def __call__(
|
|
self,
|
|
credentials: Optional[str],
|
|
user_manager: BaseUserManager[models.UC, models.UD],
|
|
) -> Optional[models.UD]:
|
|
return None
|
|
|
|
|
|
class BackendUser(
|
|
Generic[models.UC, models.UD], BaseAuthentication[str, models.UC, models.UD]
|
|
):
|
|
def __init__(self, user: models.UD, name="user"):
|
|
super().__init__(name, logout=False)
|
|
self.scheme = MockSecurityScheme()
|
|
self.user = user
|
|
|
|
async def __call__(
|
|
self,
|
|
credentials: Optional[str],
|
|
user_manager: BaseUserManager[models.UC, models.UD],
|
|
) -> Optional[models.UD]:
|
|
return self.user
|
|
|
|
|
|
@pytest.fixture
|
|
@pytest.mark.asyncio
|
|
def get_test_auth_client(get_user_manager, get_test_client):
|
|
async def _get_test_auth_client(
|
|
backends: List[BaseAuthentication],
|
|
get_enabled_backends: Optional[
|
|
Callable[..., Sequence[BaseAuthentication]]
|
|
] = None,
|
|
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
|
app = FastAPI()
|
|
authenticator = Authenticator(backends, get_user_manager)
|
|
|
|
@app.get("/test-current-user")
|
|
def test_current_user(
|
|
user: UserDB = Depends(
|
|
authenticator.current_user(get_enabled_backends=get_enabled_backends)
|
|
),
|
|
):
|
|
return user
|
|
|
|
@app.get("/test-current-active-user")
|
|
def test_current_active_user(
|
|
user: UserDB = Depends(
|
|
authenticator.current_user(
|
|
active=True, get_enabled_backends=get_enabled_backends
|
|
)
|
|
),
|
|
):
|
|
return user
|
|
|
|
@app.get("/test-current-superuser")
|
|
def test_current_superuser(
|
|
user: UserDB = Depends(
|
|
authenticator.current_user(
|
|
active=True,
|
|
superuser=True,
|
|
get_enabled_backends=get_enabled_backends,
|
|
)
|
|
),
|
|
):
|
|
return user
|
|
|
|
async for client in get_test_client(app):
|
|
yield client
|
|
|
|
return _get_test_auth_client
|
|
|
|
|
|
@pytest.mark.authentication
|
|
@pytest.mark.asyncio
|
|
async def test_authenticator(get_test_auth_client, user):
|
|
async for client in get_test_auth_client([BackendNone(), BackendUser(user)]):
|
|
response = await client.get("/test-current-user")
|
|
assert response.status_code == status.HTTP_200_OK
|
|
|
|
|
|
@pytest.mark.authentication
|
|
@pytest.mark.asyncio
|
|
async def test_authenticator_none(get_test_auth_client):
|
|
async for client in get_test_auth_client(
|
|
[BackendNone(), BackendNone(name="none-bis")]
|
|
):
|
|
response = await client.get("/test-current-user")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
@pytest.mark.authentication
|
|
@pytest.mark.asyncio
|
|
async def test_authenticator_none_enabled(get_test_auth_client, user):
|
|
backend_none = BackendNone()
|
|
backend_user = BackendUser(user)
|
|
|
|
async def get_enabled_backends():
|
|
return [backend_none]
|
|
|
|
async for client in get_test_auth_client(
|
|
[backend_none, backend_user], get_enabled_backends
|
|
):
|
|
response = await client.get("/test-current-user")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
@pytest.mark.authentication
|
|
@pytest.mark.asyncio
|
|
async def test_authenticators_with_same_name(get_test_auth_client):
|
|
with pytest.raises(DuplicateBackendNamesError):
|
|
async for _ in get_test_auth_client([BackendNone(), BackendNone()]):
|
|
pass
|