mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2026-03-13 07:49:55 +08:00
Add a double-submit cookie in the OAuth flow
Prevents CSRF attacks by ensuring that the state parameter is tied to a cookie. Fix https://github.com/fastapi-users/fastapi-users/security/advisories/GHSA-5j53-63w8-8625
This commit is contained in:
@@ -41,6 +41,7 @@ Notice that we also manually added a `relationship` on `User` so that SQLAlchemy
|
||||
Besides, when instantiating the database adapter, we need pass this SQLAlchemy model as third argument.
|
||||
|
||||
!!! tip "Primary key is defined as UUID"
|
||||
|
||||
By default, we use UUID as a primary key ID for your user. If you want to use another type, like an auto-incremented integer, you can use `SQLAlchemyBaseOAuthAccountTable` as base class and define your own `id` and `user_id` column.
|
||||
|
||||
```py
|
||||
@@ -78,8 +79,25 @@ app.include_router(
|
||||
```
|
||||
|
||||
!!! tip
|
||||
|
||||
If you have several OAuth clients and/or several authentication backends, you'll need to create a router for each pair you want to support.
|
||||
|
||||
#### CSRF Cookie configuration
|
||||
|
||||
For security purposes, OAuth routers set a CSRF cookie when the authentication flow is initiated. By default, the cookie is configured with the following parameters:
|
||||
|
||||
- `csrf_token_cookie_name` (`fastapiusersoauthcsrf`): Name of the cookie.
|
||||
- `csrf_token_cookie_max_age` (`Optional[int]`): The lifetime of the cookie in seconds. `None` by default, which means it's a session cookie.
|
||||
- `csrf_token_cookie_path` (`/`): Cookie path.
|
||||
- `csrf_token_cookie_domain` (`None`): Cookie domain.
|
||||
- `csrf_token_cookie_secure` (`True`): Whether to only send the cookie to the server via SSL request.
|
||||
- `csrf_token_cookie_httponly` (`True`): Whether to prevent access to the cookie via JavaScript.
|
||||
- `csrf_token_cookie_samesite` (`lax`): A string that specifies the samesite strategy for the cookie. Valid values are `lax`, `strict` and `none`. Defaults to `lax`.
|
||||
|
||||
!!! tip
|
||||
|
||||
In local development, if you're not using HTTPS, you may want to set `csrf_token_cookie_secure` to `False` so that the cookie is sent by the browser.
|
||||
|
||||
#### Existing account association
|
||||
|
||||
If a user with the same e-mail address already exists, an HTTP 400 error will be raised by default.
|
||||
@@ -101,11 +119,11 @@ app.include_router(
|
||||
|
||||
Bear in mind though that it can lead to security breaches if the OAuth provider does not validate e-mail addresses. How?
|
||||
|
||||
* Let's say your app support an OAuth provider, *Merlinbook*, which does not validate e-mail addresses.
|
||||
* Imagine a user registers to your app with the e-mail address `lancelot@camelot.bt`.
|
||||
* Now, a malicious user creates an account on *Merlinbook* with the same e-mail address. Without e-mail validation, the malicious user can use this account without limitation.
|
||||
* The malicious user authenticates using *Merlinbook* OAuth on your app, which automatically associates to the existing `lancelot@camelot.bt`.
|
||||
* Now, the malicious user has full access to the user account on your app 😞
|
||||
- Let's say your app support an OAuth provider, _Merlinbook_, which does not validate e-mail addresses.
|
||||
- Imagine a user registers to your app with the e-mail address `lancelot@camelot.bt`.
|
||||
- Now, a malicious user creates an account on _Merlinbook_ with the same e-mail address. Without e-mail validation, the malicious user can use this account without limitation.
|
||||
- The malicious user authenticates using _Merlinbook_ OAuth on your app, which automatically associates to the existing `lancelot@camelot.bt`.
|
||||
- Now, the malicious user has full access to the user account on your app 😞
|
||||
|
||||
#### Association router for authenticated users
|
||||
|
||||
@@ -124,6 +142,7 @@ Notice that, just like for the [Users router](./routers/users.md), you have to p
|
||||
#### Set `is_verified` to `True` by default
|
||||
|
||||
!!! tip "This section is only useful if you set up email verification"
|
||||
|
||||
You can read more about this feature [here](./routers/verify.md).
|
||||
|
||||
When a new user registers with an OAuth provider, the `is_verified` flag is set to `False`, which requires the user to verify its email address.
|
||||
@@ -144,11 +163,13 @@ app.include_router(
|
||||
```
|
||||
|
||||
!!! danger "Make sure you can trust the OAuth provider"
|
||||
|
||||
Make sure the OAuth provider you're using **does verify** the email address before enabling this flag.
|
||||
|
||||
### Full example
|
||||
|
||||
!!! warning
|
||||
|
||||
Notice that **SECRET** should be changed to a strong passphrase.
|
||||
Insecure passwords may give attackers full access to your database.
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class ErrorCode(str, Enum):
|
||||
REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS"
|
||||
OAUTH_NOT_AVAILABLE_EMAIL = "OAUTH_NOT_AVAILABLE_EMAIL"
|
||||
OAUTH_USER_ALREADY_EXISTS = "OAUTH_USER_ALREADY_EXISTS"
|
||||
OAUTH_INVALID_STATE = "OAUTH_INVALID_STATE"
|
||||
LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS"
|
||||
LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED"
|
||||
RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN"
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import secrets
|
||||
from typing import Literal
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
@@ -12,6 +15,8 @@ from fastapi_users.manager import BaseUserManager, UserManagerDependency
|
||||
from fastapi_users.router.common import ErrorCode, ErrorModel
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -25,6 +30,10 @@ def generate_state_token(
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend[models.UP, models.ID],
|
||||
@@ -33,6 +42,13 @@ def get_oauth_router(
|
||||
redirect_url: str | None = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
*,
|
||||
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
|
||||
csrf_token_cookie_path: str = "/",
|
||||
csrf_token_cookie_domain: str | None = None,
|
||||
csrf_token_cookie_secure: bool = True,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Literal["lax", "strict", "none"] = "lax",
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -55,14 +71,15 @@ def get_oauth_router(
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request, scopes: list[str] = Query(None)
|
||||
request: Request, response: Response, scopes: list[str] = Query(None)
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
state_data: dict[str, str] = {}
|
||||
csrf_token = generate_csrf_token()
|
||||
state_data: dict[str, str] = {CSRF_TOKEN_KEY: csrf_token}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
@@ -70,6 +87,17 @@ def get_oauth_router(
|
||||
scopes,
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
csrf_token_cookie_name,
|
||||
csrf_token,
|
||||
max_age=3600,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
@@ -117,18 +145,9 @@ def get_oauth_router(
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
):
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
try:
|
||||
decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -140,6 +159,28 @@ def get_oauth_router(
|
||||
detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED,
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_INVALID_STATE,
|
||||
)
|
||||
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
@@ -180,6 +221,13 @@ def get_oauth_associate_router(
|
||||
state_secret: SecretType,
|
||||
redirect_url: str | None = None,
|
||||
requires_verification: bool = False,
|
||||
*,
|
||||
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
|
||||
csrf_token_cookie_path: str = "/",
|
||||
csrf_token_cookie_domain: str | None = None,
|
||||
csrf_token_cookie_secure: bool = True,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Literal["lax", "strict", "none"] = "lax",
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes to associate an authenticated user."""
|
||||
router = APIRouter()
|
||||
@@ -208,6 +256,7 @@ def get_oauth_associate_router(
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response: Response,
|
||||
scopes: list[str] = Query(None),
|
||||
user: models.UP = Depends(get_current_active_user),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
@@ -216,7 +265,8 @@ def get_oauth_associate_router(
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
state_data: dict[str, str] = {"sub": str(user.id)}
|
||||
csrf_token = generate_csrf_token()
|
||||
state_data: dict[str, str] = {"sub": str(user.id), CSRF_TOKEN_KEY: csrf_token}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
@@ -224,6 +274,17 @@ def get_oauth_associate_router(
|
||||
scopes,
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
csrf_token_cookie_name,
|
||||
csrf_token,
|
||||
max_age=3600,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
@@ -268,15 +329,6 @@ def get_oauth_associate_router(
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
):
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
@@ -291,9 +343,31 @@ def get_oauth_associate_router(
|
||||
detail=ErrorCode.ACCESS_TOKEN_ALREADY_EXPIRED,
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_INVALID_STATE,
|
||||
)
|
||||
|
||||
if state_data["sub"] != str(user.id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
user = await user_manager.oauth_associate_callback(
|
||||
user,
|
||||
oauth_client.name,
|
||||
|
||||
@@ -106,6 +106,8 @@ class TestAuthorize:
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
|
||||
assert response.cookies.get("fastapiusersoauthcsrf") is not None
|
||||
|
||||
async def test_with_redirect_url(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
@@ -126,6 +128,8 @@ class TestAuthorize:
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
|
||||
assert response.cookies.get("fastapiusersoauthcsrf") is not None
|
||||
|
||||
|
||||
@pytest.mark.router
|
||||
@pytest.mark.oauth
|
||||
@@ -157,7 +161,33 @@ class TestCallback:
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
get_id_email_mock.assert_called_once_with("TOKEN")
|
||||
get_id_email_mock.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("csrf_token", [None, "invalid_csrf_token"])
|
||||
async def test_invalid_csrf_state(
|
||||
self,
|
||||
csrf_token: str | None,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
oauth_client: BaseOAuth2,
|
||||
user_oauth: UserOAuthModel,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
get_id_email_mock = async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
|
||||
if csrf_token is not None:
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", csrf_token)
|
||||
response = await test_app_client.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
get_id_email_mock.assert_not_called()
|
||||
|
||||
async def test_already_exists_error(
|
||||
self,
|
||||
@@ -168,7 +198,7 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "SECRET")
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
@@ -177,6 +207,7 @@ class TestCallback:
|
||||
user_manager_oauth, "oauth_callback"
|
||||
).side_effect = exceptions.UserAlreadyExists
|
||||
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -198,7 +229,7 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "SECRET")
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
@@ -207,6 +238,7 @@ class TestCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -228,7 +260,7 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "SECRET")
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client,
|
||||
@@ -239,6 +271,7 @@ class TestCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=inactive_user_oauth
|
||||
)
|
||||
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -256,7 +289,7 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "SECRET")
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
get_access_token_mock = async_method_mocker(
|
||||
oauth_client, "get_access_token", return_value=access_token
|
||||
)
|
||||
@@ -267,6 +300,7 @@ class TestCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
test_app_client_redirect_url.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -291,7 +325,7 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "SECRET")
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", None)
|
||||
@@ -300,6 +334,7 @@ class TestCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
test_app_client_redirect_url.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -318,11 +353,14 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "SECRET", lifetime_seconds=-1)
|
||||
state_jwt = generate_state_token(
|
||||
{"csrftoken": "CSRFTOKEN"}, "SECRET", lifetime_seconds=-1
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -344,11 +382,12 @@ class TestCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({}, "RANDOM")
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "RANDOM")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -406,6 +445,8 @@ class TestAssociateAuthorize:
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
|
||||
assert response.cookies.get("fastapiusersoauthcsrf") is not None
|
||||
|
||||
async def test_with_redirect_url(
|
||||
self,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
@@ -429,6 +470,8 @@ class TestAssociateAuthorize:
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
|
||||
assert response.cookies.get("fastapiusersoauthcsrf") is not None
|
||||
|
||||
|
||||
@pytest.mark.router
|
||||
@pytest.mark.oauth
|
||||
@@ -485,7 +528,34 @@ class TestAssociateCallback:
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
get_id_email_mock.assert_called_once_with("TOKEN")
|
||||
get_id_email_mock.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("csrf_token", [None, "invalid_csrf_token"])
|
||||
async def test_invalid_csrf_state(
|
||||
self,
|
||||
csrf_token: str | None,
|
||||
async_method_mocker: AsyncMethodMocker,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
oauth_client: BaseOAuth2,
|
||||
user_oauth: UserOAuthModel,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"csrftoken": "CSRFTOKEN"}, "SECRET")
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
get_id_email_mock = async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
|
||||
if csrf_token is not None:
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", csrf_token)
|
||||
response = await test_app_client.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
headers={"Authorization": f"Bearer {user_oauth.id}"},
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
get_id_email_mock.assert_not_called()
|
||||
|
||||
async def test_state_with_different_user_id(
|
||||
self,
|
||||
@@ -496,12 +566,16 @@ class TestAssociateCallback:
|
||||
user: UserModel,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"sub": str(user.id)}, "SECRET")
|
||||
state_jwt = generate_state_token(
|
||||
{"sub": str(user.id), "csrftoken": "CSRFTOKEN"}, "SECRET"
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
get_id_email_mock = async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
)
|
||||
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -509,7 +583,7 @@ class TestAssociateCallback:
|
||||
)
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
|
||||
get_id_email_mock.assert_called_once_with("TOKEN")
|
||||
get_id_email_mock.assert_not_called()
|
||||
|
||||
async def test_active_user(
|
||||
self,
|
||||
@@ -520,7 +594,9 @@ class TestAssociateCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"sub": str(user_oauth.id)}, "SECRET")
|
||||
state_jwt = generate_state_token(
|
||||
{"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "SECRET"
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
@@ -529,6 +605,7 @@ class TestAssociateCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
test_app_client.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -549,7 +626,9 @@ class TestAssociateCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"sub": str(user_oauth.id)}, "SECRET")
|
||||
state_jwt = generate_state_token(
|
||||
{"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "SECRET"
|
||||
)
|
||||
get_access_token_mock = async_method_mocker(
|
||||
oauth_client, "get_access_token", return_value=access_token
|
||||
)
|
||||
@@ -560,6 +639,7 @@ class TestAssociateCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
test_app_client_redirect_url.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -584,7 +664,9 @@ class TestAssociateCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"sub": str(user_oauth.id)}, "SECRET")
|
||||
state_jwt = generate_state_token(
|
||||
{"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "SECRET"
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", None)
|
||||
@@ -593,6 +675,7 @@ class TestAssociateCallback:
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
|
||||
test_app_client_redirect_url.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -613,7 +696,9 @@ class TestAssociateCallback:
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token(
|
||||
{"sub": str(user_oauth.id)}, "SECRET", lifetime_seconds=-1
|
||||
{"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"},
|
||||
"SECRET",
|
||||
lifetime_seconds=-1,
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
@@ -622,6 +707,7 @@ class TestAssociateCallback:
|
||||
async_method_mocker(
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
test_app_client_redirect_url.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
@@ -642,7 +728,9 @@ class TestAssociateCallback:
|
||||
user_manager_oauth: UserManagerMock,
|
||||
access_token: str,
|
||||
):
|
||||
state_jwt = generate_state_token({"sub": str(user_oauth.id)}, "RANDOM")
|
||||
state_jwt = generate_state_token(
|
||||
{"sub": str(user_oauth.id), "csrftoken": "CSRFTOKEN"}, "RANDOM"
|
||||
)
|
||||
async_method_mocker(oauth_client, "get_access_token", return_value=access_token)
|
||||
async_method_mocker(
|
||||
oauth_client, "get_id_email", return_value=("user_oauth1", user_oauth.email)
|
||||
@@ -650,6 +738,7 @@ class TestAssociateCallback:
|
||||
async_method_mocker(
|
||||
user_manager_oauth, "oauth_callback", return_value=user_oauth
|
||||
)
|
||||
test_app_client_redirect_url.cookies.set("fastapiusersoauthcsrf", "CSRFTOKEN")
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/oauth-associate/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
|
||||
Reference in New Issue
Block a user