Use real UUID for User id. and OAuthAccount id. (#198)

* Use UUID for user id and oauth account id

* Update documentation for UUID

* Tweak GUID definition of SQLAlchemy to match Tortoise ORM one

* Write migration doc
This commit is contained in:
François Voron
2020-05-21 16:40:33 +02:00
committed by GitHub
parent df479a9003
commit 0a0dcadfdc
24 changed files with 260 additions and 98 deletions

View File

@ -8,6 +8,7 @@ from asgi_lifespan import LifespanManager
from fastapi import Depends, FastAPI
from fastapi.security import OAuth2PasswordBearer
from httpx_oauth.oauth2 import OAuth2
from pydantic import UUID4
from starlette.applications import ASGIApp
from starlette.requests import Request
from starlette.responses import Response
@ -58,16 +59,13 @@ def event_loop():
@pytest.fixture
def user() -> UserDB:
return UserDB(
id="aaa",
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash,
)
@pytest.fixture
def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth:
return UserDBOAuth(
id="aaa",
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
oauth_accounts=[oauth_account1, oauth_account2],
@ -77,7 +75,6 @@ def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth:
@pytest.fixture
def inactive_user() -> UserDB:
return UserDB(
id="bbb",
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
@ -87,7 +84,6 @@ def inactive_user() -> UserDB:
@pytest.fixture
def inactive_user_oauth(oauth_account3) -> UserDBOAuth:
return UserDBOAuth(
id="bbb",
email="percival@camelot.bt",
hashed_password=angharad_password_hash,
is_active=False,
@ -98,7 +94,6 @@ def inactive_user_oauth(oauth_account3) -> UserDBOAuth:
@pytest.fixture
def superuser() -> UserDB:
return UserDB(
id="ccc",
email="merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
@ -108,7 +103,6 @@ def superuser() -> UserDB:
@pytest.fixture
def superuser_oauth() -> UserDBOAuth:
return UserDBOAuth(
id="ccc",
email="merlin@camelot.bt",
hashed_password=viviane_password_hash,
is_superuser=True,
@ -119,7 +113,6 @@ def superuser_oauth() -> UserDBOAuth:
@pytest.fixture
def oauth_account1() -> BaseOAuthAccount:
return BaseOAuthAccount(
id="aaa",
oauth_name="service1",
access_token="TOKEN",
expires_at=1579000751,
@ -131,7 +124,6 @@ def oauth_account1() -> BaseOAuthAccount:
@pytest.fixture
def oauth_account2() -> BaseOAuthAccount:
return BaseOAuthAccount(
id="bbb",
oauth_name="service2",
access_token="TOKEN",
expires_at=1579000751,
@ -143,7 +135,6 @@ def oauth_account2() -> BaseOAuthAccount:
@pytest.fixture
def oauth_account3() -> BaseOAuthAccount:
return BaseOAuthAccount(
id="ccc",
oauth_name="service3",
access_token="TOKEN",
expires_at=1579000751,
@ -155,7 +146,7 @@ def oauth_account3() -> BaseOAuthAccount:
@pytest.fixture
def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
class MockUserDatabase(BaseUserDatabase[UserDB]):
async def get(self, id: str) -> Optional[UserDB]:
async def get(self, id: UUID4) -> Optional[UserDB]:
if id == user.id:
return user
if id == inactive_user.id:
@ -190,7 +181,7 @@ def mock_user_db_oauth(
user_oauth, inactive_user_oauth, superuser_oauth
) -> BaseUserDatabase:
class MockUserDatabase(BaseUserDatabase[UserDBOAuth]):
async def get(self, id: str) -> Optional[UserDBOAuth]:
async def get(self, id: UUID4) -> Optional[UserDBOAuth]:
if id == user_oauth.id:
return user_oauth
if id == inactive_user_oauth.id:
@ -246,7 +237,11 @@ class MockAuthentication(BaseAuthentication):
async def __call__(self, request: Request, user_db: BaseUserDatabase):
token = await self.scheme.__call__(request)
if token is not None:
return await user_db.get(token)
try:
token_uuid = UUID4(token)
return await user_db.get(token_uuid)
except ValueError:
return None
return None
async def get_login_response(self, user: BaseUserDB, response: Response):

View File

@ -31,7 +31,7 @@ def token():
def _token(user=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"}
if user is not None:
data["user_id"] = user.id
data["user_id"] = str(user.id)
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _token
@ -131,7 +131,7 @@ async def test_get_login_response(
decoded = jwt.decode(
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
assert decoded["user_id"] == user.id
assert decoded["user_id"] == str(user.id)
@pytest.mark.authentication
@ -149,4 +149,4 @@ async def test_get_logout_response(user):
cookie = cookies[0][1].decode("latin-1")
assert f"Max-Age=0" in cookie
assert "Max-Age=0" in cookie

View File

@ -17,10 +17,10 @@ def jwt_authentication():
@pytest.fixture
def token():
def _token(user=None, lifetime=LIFETIME):
def _token(user_id=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"}
if user is not None:
data["user_id"] = user.id
if user_id is not None:
data["user_id"] = str(user_id)
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _token
@ -57,11 +57,19 @@ class TestAuthenticate:
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token_invalid_uuid(
self, jwt_authentication, mock_user_db, request_builder, token
):
request = request_builder(headers={"Authorization": f"Bearer {token('foo')}"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user is None
@pytest.mark.asyncio
async def test_valid_token(
self, jwt_authentication, mock_user_db, request_builder, token, user
):
request = request_builder(headers={"Authorization": f"Bearer {token(user)}"})
request = request_builder(headers={"Authorization": f"Bearer {token(user.id)}"})
authenticated_user = await jwt_authentication(request, mock_user_db)
assert authenticated_user.id == user.id
@ -77,7 +85,7 @@ async def test_get_login_response(jwt_authentication, user):
decoded = jwt.decode(
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
)
assert decoded["user_id"] == user.id
assert decoded["user_id"] == str(user.id)
@pytest.mark.authentication

View File

@ -15,7 +15,9 @@ def get_mongodb_user_db():
user_model,
) -> AsyncGenerator[MongoDBUserDatabase, None]:
client = motor.motor_asyncio.AsyncIOMotorClient(
"mongodb://localhost:27017", serverSelectionTimeoutMS=100
"mongodb://localhost:27017",
serverSelectionTimeoutMS=100,
uuidRepresentation="standard",
)
try:
@ -50,9 +52,7 @@ async def mongodb_user_db_oauth(get_mongodb_user_db):
@pytest.mark.db
async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
user = UserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
)
# Create
@ -96,7 +96,6 @@ async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
async def test_queries_custom_fields(mongodb_user_db: MongoDBUserDatabase[UserDB]):
"""It should output custom fields in query result."""
user = UserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
first_name="Lancelot",
@ -117,7 +116,6 @@ async def test_queries_oauth(
oauth_account2,
):
user = UserDBOAuth(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
oauth_accounts=[oauth_account1, oauth_account2],

View File

@ -70,9 +70,7 @@ async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, N
@pytest.mark.db
async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
user = UserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
)
# Create
@ -103,7 +101,7 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
# Exception when inserting non-nullable fields
with pytest.raises(sqlite3.IntegrityError):
wrong_user = UserDB(id="222", hashed_password="aaa")
wrong_user = UserDB(hashed_password="aaa")
await sqlalchemy_user_db.create(wrong_user)
# Unknown user
@ -117,9 +115,7 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
# Exception when creating/updating a OAuth user
user_oauth = UserDBOAuth(
id="222",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
)
with pytest.raises(NotSetOAuthAccountTableError):
await sqlalchemy_user_db.create(user_oauth)
@ -138,7 +134,6 @@ async def test_queries_custom_fields(
):
"""It should output custom fields in query result."""
user = UserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
first_name="Lancelot",
@ -159,7 +154,6 @@ async def test_queries_oauth(
oauth_account2,
):
user = UserDBOAuth(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
oauth_accounts=[oauth_account1, oauth_account2],

View File

@ -55,9 +55,7 @@ async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None]
@pytest.mark.db
async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
user = UserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
)
# Create
@ -88,7 +86,7 @@ async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
# Exception when inserting non-nullable fields
with pytest.raises(ValueError):
wrong_user = UserDB(id="222", hashed_password="aaa")
wrong_user = UserDB(hashed_password="aaa")
await tortoise_user_db.create(wrong_user)
# Unknown user
@ -106,7 +104,6 @@ async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
async def test_queries_custom_fields(tortoise_user_db: TortoiseUserDatabase[UserDB]):
"""It should output custom fields in query result."""
user = UserDB(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
first_name="Lancelot",
@ -127,7 +124,6 @@ async def test_queries_oauth(
oauth_account2,
):
user = UserDBOAuth(
id="111",
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
oauth_accounts=[oauth_account1, oauth_account2],

View File

@ -204,7 +204,7 @@ class TestCallback:
user_update_mock.assert_awaited_once()
data = cast(Dict[str, Any], response.json())
assert data["token"] == user_oauth.id
assert data["token"] == str(user_oauth.id)
assert event_handler.called is False
@ -242,7 +242,7 @@ class TestCallback:
user_update_mock.assert_awaited_once()
data = cast(Dict[str, Any], response.json())
assert data["token"] == superuser_oauth.id
assert data["token"] == str(superuser_oauth.id)
assert event_handler.called is False
@ -283,7 +283,7 @@ class TestCallback:
assert event_handler.called is True
actual_user = event_handler.call_args[0][0]
assert actual_user.id == data["token"]
assert str(actual_user.id) == data["token"]
request = event_handler.call_args[0][1]
assert isinstance(request, Request)
@ -348,4 +348,4 @@ class TestCallback:
)
data = cast(Dict[str, Any], response.json())
assert data["token"] == user_oauth.id
assert data["token"] == str(user_oauth.id)

View File

@ -23,7 +23,7 @@ def forgot_password_token():
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:reset"}
if user_id is not None:
data["user_id"] = user_id
data["user_id"] = str(user_id)
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
return _forgot_password_token
@ -117,7 +117,7 @@ class TestRegister:
assert data["id"] is not None
actual_user = event_handler.call_args[0][0]
assert actual_user.id == data["id"]
assert str(actual_user.id) == data["id"]
request = event_handler.call_args[0][1]
assert isinstance(request, Request)
@ -190,7 +190,7 @@ class TestLogin:
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
response = await test_app_client.post(path, data=data)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {"token": user.id}
assert response.json() == {"token": str(user.id)}
async def test_inactive_user(self, path, test_app_client: httpx.AsyncClient):
data = {"username": "percival@camelot.bt", "password": "angharad"}
@ -261,7 +261,7 @@ class TestForgotPassword:
audience="fastapi-users:reset",
algorithms=[JWT_ALGORITHM],
)
assert decoded_token["user_id"] == user.id
assert decoded_token["user_id"] == str(user.id)
request = event_handler.call_args[0][2]
assert isinstance(request, Request)
@ -306,6 +306,22 @@ class TestResetPassword:
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
assert mock_user_db.update.called is False
async def test_valid_token_invalid_uuid(
self,
mocker,
mock_user_db,
test_app_client: httpx.AsyncClient,
forgot_password_token,
):
mocker.spy(mock_user_db, "update")
json = {"token": forgot_password_token("foo"), "password": "holygrail"}
response = await test_app_client.post("/reset-password", json=json)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = cast(Dict[str, Any], response.json())
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
assert mock_user_db.update.called is False
async def test_inactive_user(
self,
mocker,
@ -368,7 +384,7 @@ class TestMe:
assert response.status_code == status.HTTP_200_OK
data = cast(Dict[str, Any], response.json())
assert data["id"] == user.id
assert data["id"] == str(user.id)
assert data["email"] == user.email
@ -504,12 +520,13 @@ class TestUpdateMe:
@pytest.mark.asyncio
class TestGetUser:
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
response = await test_app_client.get("/000")
response = await test_app_client.get("/d35d213e-f3d8-4f08-954a-7e0d1bea286f")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
response = await test_app_client.get(
"/000", headers={"Authorization": f"Bearer {user.id}"}
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
headers={"Authorization": f"Bearer {user.id}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@ -517,7 +534,8 @@ class TestGetUser:
self, test_app_client: httpx.AsyncClient, superuser: UserDB
):
response = await test_app_client.get(
"/000", headers={"Authorization": f"Bearer {superuser.id}"}
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
headers={"Authorization": f"Bearer {superuser.id}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@ -530,7 +548,7 @@ class TestGetUser:
assert response.status_code == status.HTTP_200_OK
data = cast(Dict[str, Any], response.json())
assert data["id"] == user.id
assert data["id"] == str(user.id)
assert "hashed_password" not in data
@ -538,12 +556,13 @@ class TestGetUser:
@pytest.mark.asyncio
class TestUpdateUser:
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
response = await test_app_client.patch("/000")
response = await test_app_client.patch("/d35d213e-f3d8-4f08-954a-7e0d1bea286f")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
response = await test_app_client.patch(
"/000", headers={"Authorization": f"Bearer {user.id}"}
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
headers={"Authorization": f"Bearer {user.id}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@ -551,7 +570,9 @@ class TestUpdateUser:
self, test_app_client: httpx.AsyncClient, superuser: UserDB
):
response = await test_app_client.patch(
"/000", json={}, headers={"Authorization": f"Bearer {superuser.id}"}
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
json={},
headers={"Authorization": f"Bearer {superuser.id}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@ -636,12 +657,13 @@ class TestUpdateUser:
@pytest.mark.asyncio
class TestDeleteUser:
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
response = await test_app_client.delete("/000")
response = await test_app_client.delete("/d35d213e-f3d8-4f08-954a-7e0d1bea286f")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
response = await test_app_client.delete(
"/000", headers={"Authorization": f"Bearer {user.id}"}
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
headers={"Authorization": f"Bearer {user.id}"},
)
assert response.status_code == status.HTTP_403_FORBIDDEN
@ -649,7 +671,8 @@ class TestDeleteUser:
self, test_app_client: httpx.AsyncClient, superuser: UserDB
):
response = await test_app_client.delete(
"/000", headers={"Authorization": f"Bearer {superuser.id}"}
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
headers={"Authorization": f"Bearer {superuser.id}"},
)
assert response.status_code == status.HTTP_404_NOT_FOUND