mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-01 10:25:45 +08:00
Implement variant of dep injections to get active/super user
This commit is contained in:
@ -1,9 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
@ -24,6 +23,13 @@ inactive_user_data = BaseUserDB(
|
||||
is_active=False,
|
||||
)
|
||||
|
||||
superuser_data = BaseUserDB(
|
||||
id="ccc",
|
||||
email="merlin@camelot.bt",
|
||||
hashed_password=get_password_hash("viviane"),
|
||||
is_superuser=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user() -> BaseUserDB:
|
||||
@ -35,12 +41,19 @@ def inactive_user() -> BaseUserDB:
|
||||
return inactive_user_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def superuser() -> BaseUserDB:
|
||||
return superuser_data
|
||||
|
||||
|
||||
class MockUserDatabase(BaseUserDatabase):
|
||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||
if id == active_user_data.id:
|
||||
return active_user_data
|
||||
elif id == inactive_user_data.id:
|
||||
return inactive_user_data
|
||||
elif id == superuser_data.id:
|
||||
return superuser_data
|
||||
return None
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||
@ -48,6 +61,8 @@ class MockUserDatabase(BaseUserDatabase):
|
||||
return active_user_data
|
||||
elif email == inactive_user_data.email:
|
||||
return inactive_user_data
|
||||
elif email == superuser_data.email:
|
||||
return superuser_data
|
||||
return None
|
||||
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
@ -60,20 +75,35 @@ def mock_user_db() -> MockUserDatabase:
|
||||
|
||||
|
||||
class MockAuthentication(BaseAuthentication):
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
return {"token": user.id}
|
||||
|
||||
def get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
def get_current_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_user(token: str = Depends(self.oauth2_scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_user_base(user)
|
||||
|
||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
user = await user_db.get(token)
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
return user
|
||||
return _get_current_user
|
||||
|
||||
def get_current_active_user(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_active_user(token: str = Depends(self.oauth2_scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_active_user_base(user)
|
||||
|
||||
return _get_current_active_user
|
||||
|
||||
def get_current_superuser(self, user_db: BaseUserDatabase):
|
||||
async def _get_current_superuser(token: str = Depends(self.oauth2_scheme)):
|
||||
user = await self._get_authentication_method(user_db)(token)
|
||||
return self._get_current_superuser_base(user)
|
||||
|
||||
return _get_current_superuser
|
||||
|
||||
def _get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
async def authentication_method(token: str = Depends(self.oauth2_scheme)):
|
||||
return await user_db.get(token)
|
||||
|
||||
return authentication_method
|
||||
|
||||
|
||||
130
tests/test_authentication_base.py
Normal file
130
tests/test_authentication_base.py
Normal file
@ -0,0 +1,130 @@
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from starlette import status
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_auth_client(mock_authentication, mock_user_db):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test-current-user")
|
||||
def test_current_user(
|
||||
user: BaseUserDB = Depends(mock_authentication.get_current_user(mock_user_db))
|
||||
):
|
||||
return user
|
||||
|
||||
@app.get("/test-current-active-user")
|
||||
def test_current_active_user(
|
||||
user: BaseUserDB = Depends(
|
||||
mock_authentication.get_current_active_user(mock_user_db)
|
||||
)
|
||||
):
|
||||
return user
|
||||
|
||||
@app.get("/test-current-superuser")
|
||||
def test_current_superuser(
|
||||
user: BaseUserDB = Depends(
|
||||
mock_authentication.get_current_superuser(mock_user_db)
|
||||
)
|
||||
):
|
||||
return user
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-current-user")
|
||||
print(response.json())
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user",
|
||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == inactive_user.id
|
||||
|
||||
def test_valid_token(self, test_auth_client, user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == user.id
|
||||
|
||||
|
||||
class TestGetCurrentActiveUser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-current-active-user")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-active-user", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-active-user",
|
||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token(self, test_auth_client, user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-active-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == user.id
|
||||
|
||||
|
||||
class TestGetCurrentSuperuser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-current-superuser")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_auth_client):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(self, test_auth_client, inactive_user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser",
|
||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_regular_user(self, test_auth_client, user):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_valid_token_superuser(self, test_auth_client, superuser):
|
||||
response = test_auth_client.get(
|
||||
"/test-current-superuser",
|
||||
headers={"Authorization": f"Bearer {superuser.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == superuser.id
|
||||
@ -33,9 +33,7 @@ def test_auth_client(jwt_authentication, mock_user_db):
|
||||
|
||||
@app.get("/test-auth")
|
||||
def test_auth(
|
||||
user: BaseUserDB = Depends(
|
||||
jwt_authentication.get_authentication_method(mock_user_db)
|
||||
)
|
||||
user: BaseUserDB = Depends(jwt_authentication.get_current_user(mock_user_db))
|
||||
):
|
||||
return user
|
||||
|
||||
@ -53,7 +51,7 @@ async def test_get_login_response(jwt_authentication, user):
|
||||
assert decoded["user_id"] == user.id
|
||||
|
||||
|
||||
class TestGetAuthenticationMethod:
|
||||
class TestGetCurrentUser:
|
||||
def test_missing_token(self, test_auth_client):
|
||||
response = test_auth_client.get("/test-auth")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
@ -68,7 +66,10 @@ class TestGetAuthenticationMethod:
|
||||
response = test_auth_client.get(
|
||||
"/test-auth", headers={"Authorization": f"Bearer {token(inactive_user)}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == inactive_user.id
|
||||
|
||||
def test_valid_token(self, test_auth_client, token, user):
|
||||
response = test_auth_client.get(
|
||||
|
||||
@ -7,6 +7,8 @@ from databases import Database
|
||||
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
|
||||
|
||||
from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.password import get_password_hash
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -32,7 +34,13 @@ async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queries(user, sqlalchemy_user_db):
|
||||
async def test_queries(sqlalchemy_user_db):
|
||||
user = BaseUserDB(
|
||||
id="111",
|
||||
email="lancelot@camelot.bt",
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
|
||||
# Create
|
||||
user_db = await sqlalchemy_user_db.create(user)
|
||||
assert user_db.id is not None
|
||||
@ -64,5 +72,5 @@ async def test_queries(user, sqlalchemy_user_db):
|
||||
await sqlalchemy_user_db.create(user)
|
||||
|
||||
# Unknown user
|
||||
unknown_user = await sqlalchemy_user_db.get_by_email("lancelot@camelot.bt")
|
||||
unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt")
|
||||
assert unknown_user is None
|
||||
|
||||
@ -20,8 +20,16 @@ def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
|
||||
app = FastAPI()
|
||||
app.include_router(fastapi_users.router)
|
||||
|
||||
@app.get("/authenticated")
|
||||
def authenticated(user=Depends(fastapi_users.get_current_user)):
|
||||
@app.get("/current-user")
|
||||
def current_user(user=Depends(fastapi_users.get_current_user)):
|
||||
return user
|
||||
|
||||
@app.get("/current-active-user")
|
||||
def current_active_user(user=Depends(fastapi_users.get_current_active_user)):
|
||||
return user
|
||||
|
||||
@app.get("/current-superuser")
|
||||
def current_superuser(user=Depends(fastapi_users.get_current_superuser)):
|
||||
return user
|
||||
|
||||
return TestClient(app)
|
||||
@ -38,17 +46,72 @@ class TestRouter:
|
||||
|
||||
class TestGetCurrentUser:
|
||||
def test_missing_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get("/authenticated")
|
||||
response = test_app_client.get("/current-user")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get(
|
||||
"/authenticated", headers={"Authorization": "Bearer foo"}
|
||||
"/current-user", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
|
||||
response = test_app_client.get(
|
||||
"/authenticated", headers={"Authorization": f"Bearer {user.id}"}
|
||||
"/current-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
class TestGetCurrentActiveUser:
|
||||
def test_missing_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get("/current-active-user")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get(
|
||||
"/current-active-user", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_inactive_user(
|
||||
self, test_app_client: TestClient, inactive_user: BaseUserDB
|
||||
):
|
||||
response = test_app_client.get(
|
||||
"/current-active-user",
|
||||
headers={"Authorization": f"Bearer {inactive_user.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
|
||||
response = test_app_client.get(
|
||||
"/current-active-user", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
class TestGetCurrentSuperuser:
|
||||
def test_missing_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get("/current-superuser")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get(
|
||||
"/current-superuser", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token_regular_user(
|
||||
self, test_app_client: TestClient, user: BaseUserDB
|
||||
):
|
||||
response = test_app_client.get(
|
||||
"/current-superuser", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
|
||||
def test_valid_token_superuser(
|
||||
self, test_app_client: TestClient, superuser: BaseUserDB
|
||||
):
|
||||
response = test_app_client.get(
|
||||
"/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
Reference in New Issue
Block a user