mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-03 13:42:16 +08:00
Refactor OAuth logic into manager
This commit is contained in:
@ -126,7 +126,6 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
|
||||
oauth_client: BaseOAuth2,
|
||||
state_secret: SecretType,
|
||||
redirect_url: str = None,
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Return an OAuth router for a given OAuth client.
|
||||
@ -135,17 +134,13 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]):
|
||||
:param state_secret: Secret used to encode the state JWT.
|
||||
:param redirect_url: Optional arbitrary redirect URL for the OAuth2 flow.
|
||||
If not given, the URL to the callback endpoint will be generated.
|
||||
:param after_register: Optional function called
|
||||
after a successful registration.
|
||||
"""
|
||||
return get_oauth_router(
|
||||
oauth_client,
|
||||
self.get_user_manager,
|
||||
self._user_db_model,
|
||||
self.authenticator,
|
||||
state_secret,
|
||||
redirect_url,
|
||||
after_register,
|
||||
)
|
||||
|
||||
def get_users_router(
|
||||
|
||||
@ -8,7 +8,7 @@ from pydantic import UUID4
|
||||
from fastapi_users import models, password
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.password import get_password_hash
|
||||
from fastapi_users.password import generate_password, get_password_hash
|
||||
|
||||
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
|
||||
|
||||
@ -104,6 +104,42 @@ class BaseUserManager(Generic[models.UC, models.UD]):
|
||||
|
||||
return created_user
|
||||
|
||||
async def oauth_callback(
|
||||
self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None
|
||||
) -> models.UD:
|
||||
try:
|
||||
user = await self.get_by_oauth_account(
|
||||
oauth_account.oauth_name, oauth_account.account_id
|
||||
)
|
||||
except UserNotExists:
|
||||
try:
|
||||
# Link account
|
||||
user = await self.get_by_email(oauth_account.account_email)
|
||||
user.oauth_accounts.append(oauth_account) # type: ignore
|
||||
await self.user_db.update(user)
|
||||
except UserNotExists:
|
||||
# Create account
|
||||
password = generate_password()
|
||||
user = self.user_db_model(
|
||||
email=oauth_account.account_email,
|
||||
hashed_password=get_password_hash(password),
|
||||
oauth_accounts=[oauth_account],
|
||||
)
|
||||
await self.user_db.create(user)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
# Update oauth
|
||||
updated_oauth_accounts = []
|
||||
for existing_oauth_account in user.oauth_accounts: # type: ignore
|
||||
if existing_oauth_account.account_id == oauth_account.account_id:
|
||||
updated_oauth_accounts.append(oauth_account)
|
||||
else:
|
||||
updated_oauth_accounts.append(existing_oauth_account)
|
||||
user.oauth_accounts = updated_oauth_accounts # type: ignore
|
||||
await self.user_db.update(user)
|
||||
|
||||
return user
|
||||
|
||||
async def forgot_password(
|
||||
self, user: models.UD, request: Optional[Request] = None
|
||||
) -> None:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List, Optional, Type
|
||||
from typing import Dict, List
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
|
||||
@ -8,9 +8,8 @@ from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import BaseUserManager, UserManagerDependency, UserNotExists
|
||||
from fastapi_users.password import generate_password, get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, run_handler
|
||||
from fastapi_users.manager import BaseUserManager, UserManagerDependency
|
||||
from fastapi_users.router.common import ErrorCode
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
|
||||
@ -25,11 +24,9 @@ def generate_state_token(
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
get_user_manager: UserManagerDependency[models.UC, models.UD],
|
||||
user_db_model: Type[models.UD],
|
||||
authenticator: Authenticator,
|
||||
state_secret: SecretType,
|
||||
redirect_url: str = None,
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@ -104,37 +101,7 @@ def get_oauth_router(
|
||||
account_email=account_email,
|
||||
)
|
||||
|
||||
try:
|
||||
user = await user_manager.get_by_oauth_account(
|
||||
oauth_client.name, account_id
|
||||
)
|
||||
except UserNotExists:
|
||||
try:
|
||||
# Link account
|
||||
user = await user_manager.get_by_email(account_email)
|
||||
user.oauth_accounts.append(new_oauth_account) # type: ignore
|
||||
await user_manager.user_db.update(user)
|
||||
except UserNotExists:
|
||||
# Create account
|
||||
password = generate_password()
|
||||
user = user_db_model(
|
||||
email=account_email,
|
||||
hashed_password=get_password_hash(password),
|
||||
oauth_accounts=[new_oauth_account],
|
||||
)
|
||||
await user_manager.user_db.create(user)
|
||||
if after_register:
|
||||
await run_handler(after_register, user, request)
|
||||
else:
|
||||
# Update oauth
|
||||
updated_oauth_accounts = []
|
||||
for oauth_account in user.oauth_accounts: # type: ignore
|
||||
if oauth_account.account_id == account_id:
|
||||
updated_oauth_accounts.append(new_oauth_account)
|
||||
else:
|
||||
updated_oauth_accounts.append(oauth_account)
|
||||
user.oauth_accounts = updated_oauth_accounts # type: ignore
|
||||
await user_manager.user_db.update(user)
|
||||
user = await user_manager.oauth_callback(new_oauth_account, request)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Callable, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
@ -65,16 +65,6 @@ class UserManager(BaseUserManager[UserCreate, UserDB]):
|
||||
reason="Password should be at least 3 characters"
|
||||
)
|
||||
|
||||
def mock_method(self, name: str):
|
||||
mock = MagicMock()
|
||||
|
||||
future: asyncio.Future = asyncio.Future()
|
||||
future.set_result(None)
|
||||
mock.return_value = future
|
||||
mock.side_effect = None
|
||||
|
||||
setattr(self, name, mock)
|
||||
|
||||
|
||||
class UserManagerMock(UserManager):
|
||||
get_by_email: MagicMock
|
||||
@ -94,6 +84,30 @@ def event_loop():
|
||||
yield loop
|
||||
|
||||
|
||||
AsyncMethodMocker = Callable[..., MagicMock]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_method_mocker(mocker: MockerFixture) -> AsyncMethodMocker:
|
||||
def _async_method_mocker(
|
||||
object: Any,
|
||||
method: str,
|
||||
return_value: Any = None,
|
||||
) -> MagicMock:
|
||||
mock: MagicMock = mocker.MagicMock()
|
||||
|
||||
future: asyncio.Future = asyncio.Future()
|
||||
future.set_result(return_value)
|
||||
mock.return_value = future
|
||||
mock.side_effect = None
|
||||
|
||||
setattr(object, method, mock)
|
||||
|
||||
return mock
|
||||
|
||||
return _async_method_mocker
|
||||
|
||||
|
||||
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
|
||||
def secret(request) -> SecretType:
|
||||
return request.param
|
||||
@ -361,33 +375,30 @@ def mock_user_db_oauth(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_mock_user_db(mock_user_db):
|
||||
def _get_mock_user_db():
|
||||
yield mock_user_db
|
||||
def make_user_manager(mocker: MockerFixture):
|
||||
def _make_user_manager(user_db_model, mock_user_db):
|
||||
user_manager = UserManager(user_db_model, mock_user_db)
|
||||
mocker.spy(user_manager, "get_by_email")
|
||||
mocker.spy(user_manager, "forgot_password")
|
||||
mocker.spy(user_manager, "reset_password")
|
||||
mocker.spy(user_manager, "on_after_register")
|
||||
mocker.spy(user_manager, "on_after_update")
|
||||
mocker.spy(user_manager, "on_after_forgot_password")
|
||||
mocker.spy(user_manager, "on_after_reset_password")
|
||||
mocker.spy(user_manager, "_update")
|
||||
return user_manager
|
||||
|
||||
return _get_mock_user_db
|
||||
return _make_user_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_mock_user_db_oauth(mock_user_db_oauth):
|
||||
def _get_mock_user_db_oauth():
|
||||
yield mock_user_db_oauth
|
||||
|
||||
return _get_mock_user_db_oauth
|
||||
def user_manager(make_user_manager, mock_user_db):
|
||||
return make_user_manager(UserDB, mock_user_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(mocker: MockerFixture, mock_user_db):
|
||||
user_manager = UserManager(UserDB, mock_user_db)
|
||||
mocker.spy(user_manager, "get_by_email")
|
||||
mocker.spy(user_manager, "forgot_password")
|
||||
mocker.spy(user_manager, "reset_password")
|
||||
mocker.spy(user_manager, "on_after_register")
|
||||
mocker.spy(user_manager, "on_after_update")
|
||||
mocker.spy(user_manager, "on_after_forgot_password")
|
||||
mocker.spy(user_manager, "on_after_reset_password")
|
||||
mocker.spy(user_manager, "_update")
|
||||
return user_manager
|
||||
def user_manager_oauth(make_user_manager, mock_user_db_oauth):
|
||||
return make_user_manager(UserDBOAuth, mock_user_db_oauth)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -399,9 +410,9 @@ def get_user_manager(user_manager):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def get_user_manager_oauth(get_mock_user_db_oauth):
|
||||
def _get_user_manager_oauth(user_db=Depends(get_mock_user_db_oauth)):
|
||||
return UserManager(UserDBOAuth, user_db)
|
||||
def get_user_manager_oauth(user_manager_oauth):
|
||||
def _get_user_manager_oauth():
|
||||
return user_manager_oauth
|
||||
|
||||
return _get_user_manager_oauth
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Callable
|
||||
from typing import Callable, cast
|
||||
|
||||
import pytest
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.jwt import decode_jwt, generate_jwt
|
||||
from fastapi_users.manager import (
|
||||
InvalidPasswordException,
|
||||
@ -13,7 +14,7 @@ from fastapi_users.manager import (
|
||||
UserInactive,
|
||||
UserNotExists,
|
||||
)
|
||||
from tests.conftest import UserCreate, UserDB, UserManagerMock, UserUpdate
|
||||
from tests.conftest import UserCreate, UserDB, UserDBOAuth, UserManagerMock, UserUpdate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -85,6 +86,61 @@ class TestCreateUser:
|
||||
assert user_manager.on_after_register.called is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestOAuthCallback:
|
||||
async def test_existing_user_with_oauth(
|
||||
self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth
|
||||
):
|
||||
oauth_account = models.BaseOAuthAccount(
|
||||
**user_oauth.oauth_accounts[0].dict(exclude={"id", "access_token"}),
|
||||
access_token="UPDATED_TOKEN"
|
||||
)
|
||||
user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account))
|
||||
|
||||
assert user.id == user_oauth.id
|
||||
assert len(user.oauth_accounts) == 2
|
||||
assert user.oauth_accounts[0].oauth_name == "service1"
|
||||
assert user.oauth_accounts[0].access_token == "UPDATED_TOKEN"
|
||||
assert user.oauth_accounts[1].access_token == "TOKEN"
|
||||
assert user.oauth_accounts[1].oauth_name == "service2"
|
||||
|
||||
assert user_manager_oauth.on_after_register.called is False
|
||||
|
||||
async def test_existing_user_without_oauth(
|
||||
self, user_manager_oauth: UserManagerMock, superuser_oauth: UserDBOAuth
|
||||
):
|
||||
oauth_account = models.BaseOAuthAccount(
|
||||
oauth_name="service1",
|
||||
access_token="TOKEN",
|
||||
expires_at=1579000751,
|
||||
account_id="superuser_oauth1",
|
||||
account_email=superuser_oauth.email,
|
||||
)
|
||||
user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account))
|
||||
|
||||
assert user.id == superuser_oauth.id
|
||||
assert len(user.oauth_accounts) == 1
|
||||
assert user.oauth_accounts[0].id == oauth_account.id
|
||||
|
||||
assert user_manager_oauth.on_after_register.called is False
|
||||
|
||||
async def test_new_user(self, user_manager_oauth: UserManagerMock):
|
||||
oauth_account = models.BaseOAuthAccount(
|
||||
oauth_name="service1",
|
||||
access_token="TOKEN",
|
||||
expires_at=1579000751,
|
||||
account_id="new_user_oauth1",
|
||||
account_email="galahad@camelot.bt",
|
||||
)
|
||||
user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account))
|
||||
|
||||
assert user.email == "galahad@camelot.bt"
|
||||
assert len(user.oauth_accounts) == 1
|
||||
assert user.oauth_accounts[0].id == oauth_account.id
|
||||
|
||||
assert user_manager_oauth.on_after_register.called is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestUpdateUser:
|
||||
async def test_safe_update(self, user: UserDB, user_manager: UserManagerMock):
|
||||
|
||||
@ -1,28 +1,18 @@
|
||||
from typing import Any, AsyncGenerator, Dict, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import asynctest
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi import FastAPI, status
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
|
||||
from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.router.common import ErrorCode
|
||||
from fastapi_users.router.oauth import generate_state_token, get_oauth_router
|
||||
from tests.conftest import MockAuthentication, UserDB
|
||||
|
||||
|
||||
def after_register_sync():
|
||||
return MagicMock(return_value=None)
|
||||
|
||||
|
||||
def after_register_async():
|
||||
return asynctest.CoroutineMock(return_value=None)
|
||||
|
||||
|
||||
@pytest.fixture(params=[after_register_sync, after_register_async])
|
||||
def after_register(request):
|
||||
return request.param()
|
||||
from tests.conftest import (
|
||||
AsyncMethodMocker,
|
||||
MockAuthentication,
|
||||
UserDB,
|
||||
UserManagerMock,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -31,7 +21,6 @@ def get_test_app_client(
|
||||
get_user_manager_oauth,
|
||||
mock_authentication,
|
||||
oauth_client,
|
||||
after_register,
|
||||
get_test_client,
|
||||
):
|
||||
async def _get_test_app_client(
|
||||
@ -45,11 +34,9 @@ def get_test_app_client(
|
||||
oauth_router = get_oauth_router(
|
||||
oauth_client,
|
||||
get_user_manager_oauth,
|
||||
UserDB,
|
||||
authenticator,
|
||||
secret,
|
||||
redirect_url,
|
||||
after_register,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
@ -80,64 +67,86 @@ async def test_app_client_redirect_url(get_test_app_client):
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthorize:
|
||||
async def test_missing_authentication_backend(
|
||||
self, test_app_client: httpx.AsyncClient, oauth_client
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
oauth_client: BaseOAuth2,
|
||||
):
|
||||
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
|
||||
mock.return_value = "AUTHORIZATION_URL"
|
||||
response = await test_app_client.get(
|
||||
"/authorize",
|
||||
params={"scopes": ["scope1", "scope2"]},
|
||||
)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
|
||||
)
|
||||
|
||||
response = await test_app_client.get(
|
||||
"/authorize",
|
||||
params={"scopes": ["scope1", "scope2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_wrong_authentication_backend(
|
||||
self, test_app_client: httpx.AsyncClient, oauth_client
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
oauth_client: BaseOAuth2,
|
||||
):
|
||||
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
|
||||
mock.return_value = "AUTHORIZATION_URL"
|
||||
response = await test_app_client.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"authentication_backend": "foo",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
},
|
||||
)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
|
||||
)
|
||||
|
||||
response = await test_app_client.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"authentication_backend": "foo",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
async def test_success(self, test_app_client: httpx.AsyncClient, oauth_client):
|
||||
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
|
||||
mock.return_value = "AUTHORIZATION_URL"
|
||||
response = await test_app_client.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"authentication_backend": "mock",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
},
|
||||
)
|
||||
async def test_success(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
oauth_client: BaseOAuth2,
|
||||
):
|
||||
get_authorization_url_mock = async_method_mocker(
|
||||
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
|
||||
)
|
||||
|
||||
response = await test_app_client.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"authentication_backend": "mock",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
mock.assert_awaited_once()
|
||||
get_authorization_url_mock.assert_called_once()
|
||||
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
|
||||
async def test_with_redirect_url(
|
||||
self, test_app_client_redirect_url: httpx.AsyncClient, oauth_client
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client_redirect_url: httpx.AsyncClient,
|
||||
oauth_client: BaseOAuth2,
|
||||
):
|
||||
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
|
||||
mock.return_value = "AUTHORIZATION_URL"
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"authentication_backend": "mock",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
},
|
||||
)
|
||||
get_authorization_url_mock = async_method_mocker(
|
||||
oauth_client, "get_authorization_url", return_value="AUTHORIZATION_URL"
|
||||
)
|
||||
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/authorize",
|
||||
params={
|
||||
"authentication_backend": "mock",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
mock.assert_awaited_once()
|
||||
get_authorization_url_mock.assert_called_once()
|
||||
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
@ -156,196 +165,110 @@ class TestAuthorize:
|
||||
class TestCallback:
|
||||
async def test_invalid_state(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
access_token,
|
||||
oauth_client,
|
||||
user_oauth,
|
||||
after_register,
|
||||
oauth_client: BaseOAuth2,
|
||||
user_oauth: UserDB,
|
||||
access_token: str,
|
||||
):
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_access_token"
|
||||
) as get_access_token_mock:
|
||||
get_access_token_mock.return_value = access_token
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_id_email"
|
||||
) as get_id_email_mock:
|
||||
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": "STATE"},
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
get_id_email_mock = async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": "STATE"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
assert after_register.called is False
|
||||
get_id_email_mock.assert_called_once_with("TOKEN")
|
||||
|
||||
async def test_existing_user_with_oauth(
|
||||
async def test_active_user(
|
||||
self,
|
||||
mock_user_db_oauth,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
access_token,
|
||||
oauth_client,
|
||||
user_oauth,
|
||||
after_register,
|
||||
oauth_client: BaseOAuth2,
|
||||
user_oauth: UserDB,
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_access_token"
|
||||
) as get_access_token_mock:
|
||||
get_access_token_mock.return_value = access_token
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_id_email"
|
||||
) as get_id_email_mock:
|
||||
with asynctest.patch.object(
|
||||
mock_user_db_oauth, "update"
|
||||
) as user_update_mock:
|
||||
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
async_method_mocker(
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
user_update_mock.assert_awaited_once()
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
|
||||
assert data["token"] == str(user_oauth.id)
|
||||
|
||||
assert after_register.called is False
|
||||
|
||||
async def test_existing_user_without_oauth(
|
||||
self,
|
||||
mock_user_db_oauth,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
access_token,
|
||||
oauth_client,
|
||||
superuser_oauth,
|
||||
after_register,
|
||||
):
|
||||
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_access_token"
|
||||
) as get_access_token_mock:
|
||||
get_access_token_mock.return_value = access_token
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_id_email"
|
||||
) as get_id_email_mock:
|
||||
with asynctest.patch.object(
|
||||
mock_user_db_oauth, "update"
|
||||
) as user_update_mock:
|
||||
get_id_email_mock.return_value = (
|
||||
"superuser_oauth1",
|
||||
superuser_oauth.email,
|
||||
)
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
user_update_mock.assert_awaited_once()
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
|
||||
assert data["token"] == str(superuser_oauth.id)
|
||||
|
||||
assert after_register.called is False
|
||||
|
||||
async def test_unknown_user(
|
||||
self,
|
||||
mock_user_db_oauth,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
access_token,
|
||||
oauth_client,
|
||||
after_register,
|
||||
):
|
||||
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_access_token"
|
||||
) as get_access_token_mock:
|
||||
get_access_token_mock.return_value = access_token
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_id_email"
|
||||
) as get_id_email_mock:
|
||||
with asynctest.patch.object(
|
||||
mock_user_db_oauth, "create"
|
||||
) as user_create_mock:
|
||||
get_id_email_mock.return_value = (
|
||||
"unknown_user_oauth1",
|
||||
"galahad@camelot.bt",
|
||||
)
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
user_create_mock.assert_awaited_once()
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
|
||||
assert "token" in data
|
||||
|
||||
assert after_register.called is True
|
||||
actual_user = after_register.call_args[0][0]
|
||||
assert str(actual_user.id) == data["token"]
|
||||
request = after_register.call_args[0][1]
|
||||
assert isinstance(request, Request)
|
||||
|
||||
async def test_inactive_user(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
access_token,
|
||||
oauth_client,
|
||||
inactive_user_oauth,
|
||||
after_register,
|
||||
oauth_client: BaseOAuth2,
|
||||
inactive_user_oauth: UserDB,
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_access_token"
|
||||
) as get_access_token_mock:
|
||||
get_access_token_mock.return_value = access_token
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_id_email"
|
||||
) as get_id_email_mock:
|
||||
get_id_email_mock.return_value = (
|
||||
"inactive_user_oauth1",
|
||||
inactive_user_oauth.email,
|
||||
)
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client,
|
||||
"get_id_email",
|
||||
return_value=("user_oauth1", inactive_user_oauth.email),
|
||||
)
|
||||
async_method_mocker(
|
||||
user_manager_oauth, "oauth_callback", return_value=inactive_user_oauth
|
||||
)
|
||||
|
||||
response = await test_app_client.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS
|
||||
|
||||
assert after_register.called is False
|
||||
|
||||
async def test_redirect_url_router(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client_redirect_url: httpx.AsyncClient,
|
||||
access_token,
|
||||
oauth_client,
|
||||
user_oauth,
|
||||
oauth_client: BaseOAuth2,
|
||||
user_oauth: UserDB,
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"authentication_backend": "mock"}, "SECRET")
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_access_token"
|
||||
) as get_access_token_mock:
|
||||
get_access_token_mock.return_value = access_token
|
||||
with asynctest.patch.object(
|
||||
oauth_client, "get_id_email"
|
||||
) as get_id_email_mock:
|
||||
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
get_access_token_mock = async_method_mocker(
|
||||
oauth_client, "get_access_token", return_value=access_token
|
||||
)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
async_method_mocker(
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
get_access_token_mock.assert_awaited_once_with(
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
get_access_token_mock.assert_called_once_with(
|
||||
"CODE", "http://www.tintagel.bt/callback"
|
||||
)
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
|
||||
data = cast(Dict[str, Any], response.json())
|
||||
assert data["token"] == str(user_oauth.id)
|
||||
|
||||
@ -11,7 +11,7 @@ from fastapi_users.manager import (
|
||||
UserNotExists,
|
||||
)
|
||||
from fastapi_users.router import ErrorCode, get_reset_password_router
|
||||
from tests.conftest import UserManagerMock
|
||||
from tests.conftest import AsyncMethodMocker, UserManagerMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -57,10 +57,11 @@ class TestForgotPassword:
|
||||
|
||||
async def test_existing_user(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
user_manager.mock_method("forgot_password")
|
||||
async_method_mocker(user_manager, "forgot_password", return_value=None)
|
||||
json = {"email": "king.arthur@camelot.bt"}
|
||||
response = await test_app_client.post("/forgot-password", json=json)
|
||||
assert response.status_code == status.HTTP_202_ACCEPTED
|
||||
@ -139,10 +140,11 @@ class TestResetPassword:
|
||||
|
||||
async def test_valid_user_password(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
user_manager: UserManagerMock,
|
||||
):
|
||||
user_manager.mock_method("reset_password")
|
||||
async_method_mocker(user_manager, "reset_password", return_value=None)
|
||||
json = {"token": "foo", "password": "guinevere"}
|
||||
response = await test_app_client.post("/reset-password", json=json)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
Reference in New Issue
Block a user