mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-15 11:11:16 +08:00
Implement variant of dep injections to get active/super user
This commit is contained in:
@ -1,18 +1,57 @@
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from starlette import status
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import BaseUserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
|
|
||||||
|
credentials_exception = HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
|
||||||
class BaseAuthentication:
|
class BaseAuthentication:
|
||||||
"""Base adapter for generating and decoding authentication tokens."""
|
"""
|
||||||
|
Base adapter for generating and decoding authentication tokens.
|
||||||
|
|
||||||
|
Provides dependency injectors to get current active/superuser user.
|
||||||
|
"""
|
||||||
|
|
||||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_authentication_method(
|
def get_current_user(self, user_db: BaseUserDatabase):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _get_authentication_method(
|
||||||
self, user_db: BaseUserDatabase
|
self, user_db: BaseUserDatabase
|
||||||
) -> Callable[..., BaseUserDB]:
|
) -> Callable[..., BaseUserDB]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _get_current_user_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
|
if user is None:
|
||||||
|
raise self._get_credentials_exception()
|
||||||
|
return user
|
||||||
|
|
||||||
|
def _get_current_active_user_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
|
user = self._get_current_user_base(user)
|
||||||
|
if not user.is_active:
|
||||||
|
raise self._get_credentials_exception()
|
||||||
|
return user
|
||||||
|
|
||||||
|
def _get_current_superuser_base(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
|
user = self._get_current_active_user_base(user)
|
||||||
|
if not user.is_superuser:
|
||||||
|
raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN)
|
||||||
|
return user
|
||||||
|
|
||||||
|
def _get_credentials_exception(
|
||||||
|
self, status_code: int = status.HTTP_401_UNAUTHORIZED
|
||||||
|
) -> HTTPException:
|
||||||
|
return HTTPException(status_code=status_code)
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from starlette import status
|
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from fastapi_users.authentication.base import BaseAuthentication
|
from fastapi_users.authentication.base import BaseAuthentication
|
||||||
@ -42,24 +41,36 @@ class JWTAuthentication(BaseAuthentication):
|
|||||||
|
|
||||||
return {"token": token}
|
return {"token": token}
|
||||||
|
|
||||||
def get_authentication_method(self, user_db: BaseUserDatabase):
|
def get_current_user(self, user_db: BaseUserDatabase):
|
||||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
async def _get_current_user(token: str = Depends(oauth2_scheme)):
|
||||||
credentials_exception = HTTPException(
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED
|
return self._get_current_user_base(user)
|
||||||
)
|
|
||||||
|
|
||||||
|
return _get_current_user
|
||||||
|
|
||||||
|
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||||
|
async def _get_current_active_user(token: str = Depends(oauth2_scheme)):
|
||||||
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
|
return self._get_current_active_user_base(user)
|
||||||
|
|
||||||
|
return _get_current_active_user
|
||||||
|
|
||||||
|
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||||
|
async def _get_current_superuser(token: str = Depends(oauth2_scheme)):
|
||||||
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
|
return self._get_current_superuser_base(user)
|
||||||
|
|
||||||
|
return _get_current_superuser
|
||||||
|
|
||||||
|
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||||
|
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
||||||
try:
|
try:
|
||||||
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||||
user_id = data.get("user_id")
|
user_id = data.get("user_id")
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise credentials_exception
|
return None
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
raise credentials_exception
|
return None
|
||||||
|
return await user_db.get(user_id)
|
||||||
user = await user_db.get(user_id)
|
|
||||||
if user is None or not user.is_active:
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
return authentication_method
|
return authentication_method
|
||||||
|
@ -34,5 +34,11 @@ class FastAPIUsers:
|
|||||||
self.auth = auth
|
self.auth = auth
|
||||||
self.router = get_user_router(self.db, user_model, self.auth)
|
self.router = get_user_router(self.db, user_model, self.auth)
|
||||||
|
|
||||||
get_current_user = self.auth.get_authentication_method(self.db)
|
get_current_user = self.auth.get_current_user(self.db)
|
||||||
self.get_current_user = get_current_user # type: ignore
|
self.get_current_user = get_current_user # type: ignore
|
||||||
|
|
||||||
|
get_current_active_user = self.auth.get_current_active_user(self.db)
|
||||||
|
self.get_current_active_user = get_current_active_user # type: ignore
|
||||||
|
|
||||||
|
get_current_superuser = self.auth.get_current_superuser(self.db)
|
||||||
|
self.get_current_superuser = get_current_superuser # type: ignore
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from starlette import status
|
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from fastapi_users.authentication import BaseAuthentication
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
@ -24,6 +23,13 @@ inactive_user_data = BaseUserDB(
|
|||||||
is_active=False,
|
is_active=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
superuser_data = BaseUserDB(
|
||||||
|
id="ccc",
|
||||||
|
email="merlin@camelot.bt",
|
||||||
|
hashed_password=get_password_hash("viviane"),
|
||||||
|
is_superuser=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user() -> BaseUserDB:
|
def user() -> BaseUserDB:
|
||||||
@ -35,12 +41,19 @@ def inactive_user() -> BaseUserDB:
|
|||||||
return inactive_user_data
|
return inactive_user_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def superuser() -> BaseUserDB:
|
||||||
|
return superuser_data
|
||||||
|
|
||||||
|
|
||||||
class MockUserDatabase(BaseUserDatabase):
|
class MockUserDatabase(BaseUserDatabase):
|
||||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||||
if id == active_user_data.id:
|
if id == active_user_data.id:
|
||||||
return active_user_data
|
return active_user_data
|
||||||
elif id == inactive_user_data.id:
|
elif id == inactive_user_data.id:
|
||||||
return inactive_user_data
|
return inactive_user_data
|
||||||
|
elif id == superuser_data.id:
|
||||||
|
return superuser_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||||
@ -48,6 +61,8 @@ class MockUserDatabase(BaseUserDatabase):
|
|||||||
return active_user_data
|
return active_user_data
|
||||||
elif email == inactive_user_data.email:
|
elif email == inactive_user_data.email:
|
||||||
return inactive_user_data
|
return inactive_user_data
|
||||||
|
elif email == superuser_data.email:
|
||||||
|
return superuser_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
@ -60,20 +75,35 @@ def mock_user_db() -> MockUserDatabase:
|
|||||||
|
|
||||||
|
|
||||||
class MockAuthentication(BaseAuthentication):
|
class MockAuthentication(BaseAuthentication):
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||||
|
|
||||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
return {"token": user.id}
|
return {"token": user.id}
|
||||||
|
|
||||||
def get_authentication_method(self, user_db: BaseUserDatabase):
|
def get_current_user(self, user_db: BaseUserDatabase):
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
|
||||||
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
|
return self._get_current_user_base(user)
|
||||||
|
|
||||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
return _get_current_user
|
||||||
credentials_exception = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED
|
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||||
)
|
async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)):
|
||||||
user = await user_db.get(token)
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
if user is None or not user.is_active:
|
return self._get_current_active_user_base(user)
|
||||||
raise credentials_exception
|
|
||||||
return user
|
return _get_current_active_user
|
||||||
|
|
||||||
|
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||||
|
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
|
||||||
|
user = await self._get_authentication_method(user_db)(token)
|
||||||
|
return self._get_current_superuser_base(user)
|
||||||
|
|
||||||
|
return _get_current_superuser
|
||||||
|
|
||||||
|
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||||
|
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
|
||||||
|
return await user_db.get(token)
|
||||||
|
|
||||||
return authentication_method
|
return authentication_method
|
||||||
|
|
||||||
|
130
tests/test_authentication_base.py
Normal file
130
tests/test_authentication_base.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import Depends, FastAPI
|
||||||
|
from starlette import status
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from fastapi_users.models import BaseUserDB
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_auth_client(mock_authentication, mock_user_db):
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test-current-user")
|
||||||
|
def test_current_user(
|
||||||
|
user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db))
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
@app.get("/test-current-active-user")
|
||||||
|
def test_current_active_user(
|
||||||
|
user: BaseUserDB = Depends(
|
||||||
|
mock_authentication.get_current_active_user(mock_user_db)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
@app.get("/test-current-superuser")
|
||||||
|
def test_current_superuser(
|
||||||
|
user: BaseUserDB = Depends(
|
||||||
|
mock_authentication.get_current_superuser(mock_user_db)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return user
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentUser:
|
||||||
|
def test_missing_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get("/test-current-user")
|
||||||
|
print(response.json())
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-user", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-user",
|
||||||
|
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == inactive_user.id
|
||||||
|
|
||||||
|
def test_valid_token(self, test_auth_client, user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentActiveUser:
|
||||||
|
def test_missing_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get("/test-current-active-user")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-active-user",
|
||||||
|
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token(self, test_auth_client, user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentSuperuser:
|
||||||
|
def test_missing_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get("/test-current-superuser")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_auth_client):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser",
|
||||||
|
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_regular_user(self, test_auth_client, user):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
def test_valid_token_superuser(self, test_auth_client, superuser):
|
||||||
|
response = test_auth_client.get(
|
||||||
|
"/test-current-superuser",
|
||||||
|
headers={"Authorization": f"Bearer {superuser.id}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == superuser.id
|
@ -33,9 +33,7 @@ def test_auth_client(jwt_authentication, mock_user_db):
|
|||||||
|
|
||||||
@app.get("/test-auth")
|
@app.get("/test-auth")
|
||||||
def test_auth(
|
def test_auth(
|
||||||
user: BaseUserDB = Depends(
|
user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db))
|
||||||
jwt_authentication.get_authentication_method(mock_user_db)
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@ -53,7 +51,7 @@ async def test_get_login_response(jwt_authentication, user):
|
|||||||
assert decoded["user_id"] == user.id
|
assert decoded["user_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
class TestGetAuthenticationMethod:
|
class TestGetCurrentUser:
|
||||||
def test_missing_token(self, test_auth_client):
|
def test_missing_token(self, test_auth_client):
|
||||||
response = test_auth_client.get("/test-auth")
|
response = test_auth_client.get("/test-auth")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
@ -68,7 +66,10 @@ class TestGetAuthenticationMethod:
|
|||||||
response = test_auth_client.get(
|
response = test_auth_client.get(
|
||||||
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
|
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == inactive_user.id
|
||||||
|
|
||||||
def test_valid_token(self, test_auth_client, token, user):
|
def test_valid_token(self, test_auth_client, token, user):
|
||||||
response = test_auth_client.get(
|
response = test_auth_client.get(
|
||||||
|
@ -7,6 +7,8 @@ from databases import Database
|
|||||||
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
|
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
|
||||||
|
|
||||||
from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
|
from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
|
||||||
|
from fastapi_users.models import BaseUserDB
|
||||||
|
from fastapi_users.password import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -32,7 +34,13 @@ async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_queries(user, sqlalchemy_user_db):
|
async def test_queries(sqlalchemy_user_db):
|
||||||
|
user = BaseUserDB(
|
||||||
|
id="111",
|
||||||
|
email="lancelot@camelot.bt",
|
||||||
|
hashed_password=get_password_hash("guinevere"),
|
||||||
|
)
|
||||||
|
|
||||||
# Create
|
# Create
|
||||||
user_db = await sqlalchemy_user_db.create(user)
|
user_db = await sqlalchemy_user_db.create(user)
|
||||||
assert user_db.id is not None
|
assert user_db.id is not None
|
||||||
@ -64,5 +72,5 @@ async def test_queries(user, sqlalchemy_user_db):
|
|||||||
await sqlalchemy_user_db.create(user)
|
await sqlalchemy_user_db.create(user)
|
||||||
|
|
||||||
# Unknown user
|
# Unknown user
|
||||||
unknown_user = await sqlalchemy_user_db.get_by_email("lancelot@camelot.bt")
|
unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt")
|
||||||
assert unknown_user is None
|
assert unknown_user is None
|
||||||
|
@ -20,8 +20,16 @@ def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(fastapi_users.router)
|
app.include_router(fastapi_users.router)
|
||||||
|
|
||||||
@app.get("/authenticated")
|
@app.get("/current-user")
|
||||||
def authenticated(user=Depends(fastapi_users.get_current_user)):
|
def current_user(user=Depends(fastapi_users.get_current_user)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
@app.get("/current-active-user")
|
||||||
|
def current_active_user(user=Depends(fastapi_users.get_current_active_user)):
|
||||||
|
return user
|
||||||
|
|
||||||
|
@app.get("/current-superuser")
|
||||||
|
def current_superuser(user=Depends(fastapi_users.get_current_superuser)):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
@ -38,17 +46,72 @@ class TestRouter:
|
|||||||
|
|
||||||
class TestGetCurrentUser:
|
class TestGetCurrentUser:
|
||||||
def test_missing_token(self, test_app_client: TestClient):
|
def test_missing_token(self, test_app_client: TestClient):
|
||||||
response = test_app_client.get("/authenticated")
|
response = test_app_client.get("/current-user")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_invalid_token(self, test_app_client: TestClient):
|
def test_invalid_token(self, test_app_client: TestClient):
|
||||||
response = test_app_client.get(
|
response = test_app_client.get(
|
||||||
"/authenticated", headers={"Authorization": "Bearer foo"}
|
"/current-user", headers={"Authorization": "Bearer foo"}
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
|
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
|
||||||
response = test_app_client.get(
|
response = test_app_client.get(
|
||||||
"/authenticated", headers={"Authorization": f"Bearer {user.id}"}
|
"/current-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentActiveUser:
|
||||||
|
def test_missing_token(self, test_app_client: TestClient):
|
||||||
|
response = test_app_client.get("/current-active-user")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_app_client: TestClient):
|
||||||
|
response = test_app_client.get(
|
||||||
|
"/current-active-user", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_inactive_user(
|
||||||
|
self, test_app_client: TestClient, inactive_user: BaseUserDB
|
||||||
|
):
|
||||||
|
response = test_app_client.get(
|
||||||
|
"/current-active-user",
|
||||||
|
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
|
||||||
|
response = test_app_client.get(
|
||||||
|
"/current-active-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetCurrentSuperuser:
|
||||||
|
def test_missing_token(self, test_app_client: TestClient):
|
||||||
|
response = test_app_client.get("/current-superuser")
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_invalid_token(self, test_app_client: TestClient):
|
||||||
|
response = test_app_client.get(
|
||||||
|
"/current-superuser", headers={"Authorization": "Bearer foo"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
|
def test_valid_token_regular_user(
|
||||||
|
self, test_app_client: TestClient, user: BaseUserDB
|
||||||
|
):
|
||||||
|
response = test_app_client.get(
|
||||||
|
"/current-superuser", headers={"Authorization": f"Bearer {user.id}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
|
def test_valid_token_superuser(
|
||||||
|
self, test_app_client: TestClient, superuser: BaseUserDB
|
||||||
|
):
|
||||||
|
response = test_app_client.get(
|
||||||
|
"/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"}
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
Reference in New Issue
Block a user