mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Fix #701: factorize JWT handling and support secrets as SecretStr
This commit is contained in:
@ -18,7 +18,7 @@ auth_backends.append(cookie_authentication)
|
||||
|
||||
As you can see, instantiation is quite simple. It accepts the following arguments:
|
||||
|
||||
* `secret` (`str`): A constant secret which is used to encode the cookie. **Use a strong passphrase and keep it secure.**
|
||||
* `secret` (`Union[str, pydantic.SecretStr]`): A constant secret which is used to encode the cookie. **Use a strong passphrase and keep it secure.**
|
||||
* `lifetime_seconds` (`int`): The lifetime of the cookie in seconds.
|
||||
* `cookie_name` (`fastapiusersauth`): Name of the cookie.
|
||||
* `cookie_path` (`/`): Cookie path.
|
||||
|
@ -18,7 +18,7 @@ auth_backends.append(jwt_authentication)
|
||||
|
||||
As you can see, instantiation is quite simple. It accepts the following arguments:
|
||||
|
||||
* `secret` (`str`): A constant secret which is used to encode the token. **Use a strong passphrase and keep it secure.**
|
||||
* `secret` (`Union[str, pydantic.SecretStr]`): A constant secret which is used to encode the token. **Use a strong passphrase and keep it secure.**
|
||||
* `lifetime_seconds` (`int`): The lifetime of the token in seconds.
|
||||
* `tokenUrl` (`Optional[str]`): The exact path of your login endpoint. It'll allow the interactive documentation to automatically discover it and get a working *Authorize* button. In most cases, you'll probably need a **relative** path, not absolute. You can read more details about this in the [FastAPI documentation](https://fastapi.tiangolo.com/tutorial/security/first-steps/#fastapis-oauth2passwordbearer). Defaults to `auth/jwt/login`.
|
||||
* `name` (`Optional[str]`): Name of the backend. It's useful in the case you wish to have several backends of the same class. Each backend should have a unique name. Defaults to `jwt`.
|
||||
|
@ -29,8 +29,8 @@ app.include_router(
|
||||
|
||||
Parameters:
|
||||
|
||||
* `reset_password_token_secret`: Secret to encode reset password token.
|
||||
* `reset_password_token_lifetime_seconds`: Lifetime of reset password token. **Defaults to 3600**.
|
||||
* `reset_password_token_secret` (`Union[str, pydantic.SecretStr]`): Secret to encode reset password token.
|
||||
* `reset_password_token_lifetime_seconds` (`int`): Lifetime of reset password token. **Defaults to 3600**.
|
||||
* `after_forgot_password`: Optional function called after a successful forgot password request. See below.
|
||||
|
||||
## After forgot password
|
||||
|
@ -30,8 +30,8 @@ app.include_router(
|
||||
|
||||
Parameters:
|
||||
|
||||
* `verification_token_secret`: Secret to encode verify token.
|
||||
* `verification_token_lifetime_seconds`: Lifetime of verify token. **Defaults to 3600**.
|
||||
* `verification_token_secret` (`Union[str, pydantic.SecretStr]`): Secret to encode verify token.
|
||||
* `verification_token_lifetime_seconds` (`int`): Lifetime of verify token. **Defaults to 3600**.
|
||||
* `after_verification_request`: Optional function called after a successful verify request. See below.
|
||||
* `after_verification`: Optional function called after a successful verification. See below.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import Response
|
||||
@ -7,8 +7,8 @@ from pydantic import UUID4
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db.base import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
|
||||
class CookieAuthentication(BaseAuthentication[str]):
|
||||
@ -25,11 +25,12 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
:param cookie_secure: Whether to only send the cookie to the server via SSL request.
|
||||
:param cookie_httponly: Whether to prevent access to the cookie via JavaScript.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
:param token_audience: List of valid audiences for the JWT.
|
||||
"""
|
||||
|
||||
scheme: APIKeyCookie
|
||||
token_audience: str = "fastapi-users:auth"
|
||||
secret: str
|
||||
token_audience: List[str]
|
||||
secret: SecretType
|
||||
lifetime_seconds: Optional[int]
|
||||
cookie_name: str
|
||||
cookie_path: str
|
||||
@ -40,7 +41,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret: str,
|
||||
secret: SecretType,
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
cookie_name: str = "fastapiusersauth",
|
||||
cookie_path: str = "/",
|
||||
@ -49,6 +50,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
cookie_httponly: bool = True,
|
||||
cookie_samesite: str = "lax",
|
||||
name: str = "cookie",
|
||||
token_audience: List[str] = ["fastapi-users:auth"],
|
||||
):
|
||||
super().__init__(name, logout=True)
|
||||
self.secret = secret
|
||||
@ -59,6 +61,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
self.cookie_secure = cookie_secure
|
||||
self.cookie_httponly = cookie_httponly
|
||||
self.cookie_samesite = cookie_samesite
|
||||
self.token_audience = token_audience
|
||||
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
|
||||
|
||||
async def __call__(
|
||||
@ -70,12 +73,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = jwt.decode(
|
||||
credentials,
|
||||
self.secret,
|
||||
audience=self.token_audience,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
data = decode_jwt(credentials, self.secret, self.token_audience)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
return None
|
||||
@ -112,4 +110,4 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
|
||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||
data = {"user_id": str(user.id), "aud": self.token_audience}
|
||||
return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM)
|
||||
return generate_jwt(data, self.secret, self.lifetime_seconds)
|
||||
|
@ -7,8 +7,8 @@ from pydantic import UUID4
|
||||
|
||||
from fastapi_users.authentication.base import BaseAuthentication
|
||||
from fastapi_users.db.base import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.models import BaseUserDB
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
|
||||
class JWTAuthentication(BaseAuthentication[str]):
|
||||
@ -19,11 +19,12 @@ class JWTAuthentication(BaseAuthentication[str]):
|
||||
:param lifetime_seconds: Lifetime duration of the JWT in seconds.
|
||||
:param tokenUrl: Path where to get a token.
|
||||
:param name: Name of the backend. It will be used to name the login route.
|
||||
:param token_audience: List of valid audiences for the JWT.
|
||||
"""
|
||||
|
||||
scheme: OAuth2PasswordBearer
|
||||
token_audience: List[str] = ["fastapi-users:auth"]
|
||||
secret: str
|
||||
token_audience: List[str]
|
||||
secret: SecretType
|
||||
lifetime_seconds: int
|
||||
|
||||
def __init__(
|
||||
@ -49,12 +50,7 @@ class JWTAuthentication(BaseAuthentication[str]):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = jwt.decode(
|
||||
credentials,
|
||||
self.secret,
|
||||
audience=self.token_audience,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
data = decode_jwt(credentials, self.secret, self.token_audience)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
return None
|
||||
@ -73,4 +69,4 @@ class JWTAuthentication(BaseAuthentication[str]):
|
||||
|
||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||
data = {"user_id": str(user.id), "aud": self.token_audience}
|
||||
return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM)
|
||||
return generate_jwt(data, self.secret, self.lifetime_seconds)
|
||||
|
@ -5,6 +5,7 @@ from fastapi import APIRouter, Request
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType
|
||||
from fastapi_users.router import (
|
||||
get_auth_router,
|
||||
get_register_router,
|
||||
@ -132,7 +133,7 @@ class FastAPIUsers:
|
||||
|
||||
def get_verify_router(
|
||||
self,
|
||||
verification_token_secret: str,
|
||||
verification_token_secret: SecretType,
|
||||
verification_token_lifetime_seconds: int = 3600,
|
||||
after_verification_request: Optional[
|
||||
Callable[[models.UD, str, Request], None]
|
||||
@ -161,7 +162,7 @@ class FastAPIUsers:
|
||||
|
||||
def get_reset_password_router(
|
||||
self,
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_secret: SecretType,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
after_forgot_password: Optional[
|
||||
Callable[[models.UD, str, Request], None]
|
||||
@ -207,7 +208,7 @@ class FastAPIUsers:
|
||||
def get_oauth_router(
|
||||
self,
|
||||
oauth_client: BaseOAuth2,
|
||||
state_secret: str,
|
||||
state_secret: SecretType,
|
||||
redirect_url: str = None,
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
|
41
fastapi_users/jwt.py
Normal file
41
fastapi_users/jwt.py
Normal file
@ -0,0 +1,41 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import jwt
|
||||
from pydantic import SecretStr
|
||||
|
||||
SecretType = Union[str, SecretStr]
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def _get_secret_value(secret: SecretType) -> str:
|
||||
if isinstance(secret, SecretStr):
|
||||
return secret.get_secret_value()
|
||||
return secret
|
||||
|
||||
|
||||
def generate_jwt(
|
||||
data: dict,
|
||||
secret: SecretType,
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
algorithm: str = JWT_ALGORITHM,
|
||||
) -> str:
|
||||
payload = data.copy()
|
||||
if lifetime_seconds:
|
||||
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
|
||||
payload["exp"] = expire
|
||||
return jwt.encode(payload, _get_secret_value(secret), algorithm=algorithm)
|
||||
|
||||
|
||||
def decode_jwt(
|
||||
encoded_jwt: str,
|
||||
secret: SecretType,
|
||||
audience: List[str],
|
||||
algorithms: List[str] = [JWT_ALGORITHM],
|
||||
) -> Dict[str, Any]:
|
||||
return jwt.decode(
|
||||
encoded_jwt,
|
||||
_get_secret_value(secret),
|
||||
audience=audience,
|
||||
algorithms=algorithms,
|
||||
)
|
@ -8,27 +8,18 @@ from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.password import generate_password, get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, run_handler
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: str, lifetime_seconds: int = 3600
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
return generate_jwt(data, secret, lifetime_seconds, JWT_ALGORITHM)
|
||||
|
||||
|
||||
def decode_state_token(token: str, secret: str) -> Dict[str, str]:
|
||||
return jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
audience=STATE_TOKEN_AUDIENCE,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
def get_oauth_router(
|
||||
@ -36,7 +27,7 @@ def get_oauth_router(
|
||||
user_db: BaseUserDatabase[models.BaseUserDB],
|
||||
user_db_model: Type[models.BaseUserDB],
|
||||
authenticator: Authenticator,
|
||||
state_secret: str,
|
||||
state_secret: SecretType,
|
||||
redirect_url: str = None,
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
@ -99,7 +90,7 @@ def get_oauth_router(
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_state_token(state, state_secret)
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
|
@ -6,17 +6,17 @@ from pydantic import UUID4, EmailStr
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.password import get_password_hash
|
||||
from fastapi_users.router.common import ErrorCode, run_handler
|
||||
from fastapi_users.user import InvalidPasswordException, ValidatePasswordProtocol
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
RESET_PASSWORD_TOKEN_AUDIENCE = "fastapi-users:reset"
|
||||
|
||||
|
||||
def get_reset_password_router(
|
||||
user_db: BaseUserDatabase[models.BaseUserDB],
|
||||
reset_password_token_secret: str,
|
||||
reset_password_token_secret: SecretType,
|
||||
reset_password_token_lifetime_seconds: int = 3600,
|
||||
after_forgot_password: Optional[Callable[[models.UD, str, Request], None]] = None,
|
||||
after_reset_password: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
@ -48,11 +48,8 @@ def get_reset_password_router(
|
||||
request: Request, token: str = Body(...), password: str = Body(...)
|
||||
):
|
||||
try:
|
||||
data = jwt.decode(
|
||||
token,
|
||||
reset_password_token_secret,
|
||||
audience=RESET_PASSWORD_TOKEN_AUDIENCE,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
data = decode_jwt(
|
||||
token, reset_password_token_secret, [RESET_PASSWORD_TOKEN_AUDIENCE]
|
||||
)
|
||||
user_id = data.get("user_id")
|
||||
if user_id is None:
|
||||
|
@ -5,6 +5,7 @@ from fastapi import APIRouter, Body, HTTPException, Request, status
|
||||
from pydantic import UUID4, EmailStr
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
from fastapi_users.router.common import ErrorCode, run_handler
|
||||
from fastapi_users.user import (
|
||||
GetUserProtocol,
|
||||
@ -12,7 +13,6 @@ from fastapi_users.user import (
|
||||
UserNotExists,
|
||||
VerifyUserProtocol,
|
||||
)
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
|
||||
VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify"
|
||||
|
||||
@ -21,7 +21,7 @@ def get_verify_router(
|
||||
verify_user: VerifyUserProtocol,
|
||||
get_user: GetUserProtocol,
|
||||
user_model: Type[models.BaseUser],
|
||||
verification_token_secret: str,
|
||||
verification_token_secret: SecretType,
|
||||
verification_token_lifetime_seconds: int = 3600,
|
||||
after_verification_request: Optional[
|
||||
Callable[[models.UD, str, Request], None]
|
||||
@ -58,11 +58,8 @@ def get_verify_router(
|
||||
@router.post("/verify", response_model=user_model)
|
||||
async def verify(request: Request, token: str = Body(..., embed=True)):
|
||||
try:
|
||||
data = jwt.decode(
|
||||
token,
|
||||
verification_token_secret,
|
||||
audience=VERIFY_USER_TOKEN_AUDIENCE,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
data = decode_jwt(
|
||||
token, verification_token_secret, [VERIFY_USER_TOKEN_AUDIENCE]
|
||||
)
|
||||
except jwt.exceptions.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
|
@ -1,19 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import jwt
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def generate_jwt(
|
||||
data: dict,
|
||||
secret: str,
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
algorithm: str = JWT_ALGORITHM,
|
||||
) -> str:
|
||||
payload = data.copy()
|
||||
if lifetime_seconds:
|
||||
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
|
||||
payload["exp"] = expire
|
||||
return jwt.encode(payload, secret, algorithm=algorithm)
|
@ -8,12 +8,13 @@ from asgi_lifespan import LifespanManager
|
||||
from fastapi import Depends, FastAPI, Response
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from httpx_oauth.oauth2 import OAuth2
|
||||
from pydantic import UUID4
|
||||
from pydantic import UUID4, SecretStr
|
||||
from starlette.applications import ASGIApp
|
||||
|
||||
from fastapi_users import models
|
||||
from fastapi_users.authentication import Authenticator, BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.jwt import SecretType
|
||||
from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin, BaseUserDB
|
||||
from fastapi_users.password import get_password_hash
|
||||
from fastapi_users.user import InvalidPasswordException, ValidatePasswordProtocol
|
||||
@ -56,6 +57,11 @@ def event_loop():
|
||||
yield loop
|
||||
|
||||
|
||||
@pytest.fixture(params=["SECRET", SecretStr("SECRET")])
|
||||
def secret(request) -> SecretType:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user() -> UserDB:
|
||||
return UserDB(
|
||||
|
@ -1,90 +1,101 @@
|
||||
import re
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import Response
|
||||
|
||||
from fastapi_users.authentication.cookie import CookieAuthentication
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
COOKIE_NAME = "COOKIE_NAME"
|
||||
|
||||
cookie_authentication = CookieAuthentication(SECRET, LIFETIME, COOKIE_NAME)
|
||||
cookie_authentication_path = CookieAuthentication(
|
||||
SECRET, LIFETIME, COOKIE_NAME, cookie_path="/arthur"
|
||||
)
|
||||
cookie_authentication_domain = CookieAuthentication(
|
||||
SECRET, LIFETIME, COOKIE_NAME, cookie_domain="camelot.bt"
|
||||
)
|
||||
cookie_authentication_secure = CookieAuthentication(
|
||||
SECRET, LIFETIME, COOKIE_NAME, cookie_secure=False
|
||||
)
|
||||
cookie_authentication_httponly = CookieAuthentication(
|
||||
SECRET, LIFETIME, COOKIE_NAME, cookie_httponly=False
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("/", None, True, True),
|
||||
("/arthur", None, True, True),
|
||||
("/", "camelot.bt", True, True),
|
||||
("/", None, False, True),
|
||||
("/", None, True, False),
|
||||
]
|
||||
)
|
||||
def cookie_authentication(secret: SecretType, request):
|
||||
path, domain, secure, httponly = request.param
|
||||
return CookieAuthentication(
|
||||
secret,
|
||||
lifetime_seconds=LIFETIME,
|
||||
cookie_name=COOKIE_NAME,
|
||||
cookie_path=path,
|
||||
cookie_domain=domain,
|
||||
cookie_secure=secure,
|
||||
cookie_httponly=httponly,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token():
|
||||
def token(secret):
|
||||
def _token(user_id=None, lifetime=LIFETIME):
|
||||
data = {"aud": "fastapi-users:auth"}
|
||||
if user_id is not None:
|
||||
data["user_id"] = str(user_id)
|
||||
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
|
||||
return generate_jwt(data, secret, lifetime)
|
||||
|
||||
return _token
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
def test_default_name():
|
||||
def test_default_name(cookie_authentication: CookieAuthentication):
|
||||
assert cookie_authentication.name == "cookie"
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
class TestAuthenticate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_token(self, mock_user_db):
|
||||
async def test_missing_token(
|
||||
self, mock_user_db, cookie_authentication: CookieAuthentication
|
||||
):
|
||||
authenticated_user = await cookie_authentication(None, mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token(self, mock_user_db):
|
||||
async def test_invalid_token(
|
||||
self, mock_user_db, cookie_authentication: CookieAuthentication
|
||||
):
|
||||
authenticated_user = await cookie_authentication("foo", mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token_missing_user_payload(self, mock_user_db, token):
|
||||
async def test_valid_token_missing_user_payload(
|
||||
self, mock_user_db, token, cookie_authentication: CookieAuthentication
|
||||
):
|
||||
authenticated_user = await cookie_authentication(token(), mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token_invalid_uuid(self, mock_user_db, token):
|
||||
async def test_valid_token_invalid_uuid(
|
||||
self, mock_user_db, token, cookie_authentication: CookieAuthentication
|
||||
):
|
||||
authenticated_user = await cookie_authentication(token("foo"), mock_user_db)
|
||||
assert authenticated_user is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token(self, mock_user_db, token, user):
|
||||
async def test_valid_token(
|
||||
self, mock_user_db, token, user, cookie_authentication: CookieAuthentication
|
||||
):
|
||||
authenticated_user = await cookie_authentication(token(user.id), mock_user_db)
|
||||
assert authenticated_user is not None
|
||||
assert authenticated_user.id == user.id
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"cookie_authentication,path,domain,secure,httponly",
|
||||
[
|
||||
(cookie_authentication, "/", None, True, True),
|
||||
(cookie_authentication_path, "/arthur", None, True, True),
|
||||
(cookie_authentication_domain, "/", "camelot.bt", True, True),
|
||||
(cookie_authentication_secure, "/", None, False, True),
|
||||
(cookie_authentication_httponly, "/", None, True, False),
|
||||
],
|
||||
)
|
||||
async def test_get_login_response(
|
||||
user, cookie_authentication, path, domain, secure, httponly
|
||||
):
|
||||
async def test_get_login_response(user, cookie_authentication: CookieAuthentication):
|
||||
secret = cookie_authentication.secret
|
||||
path = cookie_authentication.cookie_path
|
||||
domain = cookie_authentication.cookie_domain
|
||||
secure = cookie_authentication.cookie_secure
|
||||
httponly = cookie_authentication.cookie_httponly
|
||||
|
||||
response = Response()
|
||||
login_response = await cookie_authentication.get_login_response(user, response)
|
||||
|
||||
@ -116,20 +127,19 @@ async def test_get_login_response(
|
||||
assert "HttpOnly" not in cookie
|
||||
|
||||
cookie_name_value = re.match(r"^(\w+)=([^;]+);", cookie)
|
||||
assert cookie_name_value is not None
|
||||
|
||||
cookie_name = cookie_name_value[1]
|
||||
assert cookie_name == COOKIE_NAME
|
||||
|
||||
cookie_value = cookie_name_value[2]
|
||||
decoded = jwt.decode(
|
||||
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
||||
)
|
||||
decoded = decode_jwt(cookie_value, secret, audience=["fastapi-users:auth"])
|
||||
assert decoded["user_id"] == str(user.id)
|
||||
|
||||
|
||||
@pytest.mark.authentication
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_logout_response(user):
|
||||
async def test_get_logout_response(user, cookie_authentication: CookieAuthentication):
|
||||
response = Response()
|
||||
logout_response = await cookie_authentication.get_logout_response(user, response)
|
||||
|
||||
|
@ -1,27 +1,25 @@
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import Response
|
||||
|
||||
from fastapi_users.authentication.jwt import JWTAuthentication
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
from fastapi_users.jwt import decode_jwt, generate_jwt
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
TOKEN_URL = "/login"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_authentication():
|
||||
return JWTAuthentication(SECRET, LIFETIME, TOKEN_URL)
|
||||
def jwt_authentication(secret):
|
||||
return JWTAuthentication(secret, LIFETIME, TOKEN_URL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def token():
|
||||
def token(secret):
|
||||
def _token(user_id=None, lifetime=LIFETIME):
|
||||
data = {"aud": "fastapi-users:auth"}
|
||||
if user_id is not None:
|
||||
data["user_id"] = str(user_id)
|
||||
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
|
||||
return generate_jwt(data, secret, lifetime)
|
||||
|
||||
return _token
|
||||
|
||||
@ -72,8 +70,8 @@ async def test_get_login_response(jwt_authentication, user):
|
||||
assert login_response["token_type"] == "bearer"
|
||||
|
||||
token = login_response["access_token"]
|
||||
decoded = jwt.decode(
|
||||
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
||||
decoded = decode_jwt(
|
||||
token, jwt_authentication.secret, audience=["fastapi-users:auth"]
|
||||
)
|
||||
assert decoded["user_id"] == str(user.id)
|
||||
|
||||
|
@ -11,7 +11,12 @@ from tests.conftest import User, UserCreate, UserDB, UserUpdate
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_client(
|
||||
mock_user_db, mock_authentication, oauth_client, get_test_client, validate_password
|
||||
secret,
|
||||
mock_user_db,
|
||||
mock_authentication,
|
||||
oauth_client,
|
||||
get_test_client,
|
||||
validate_password,
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
fastapi_users = FastAPIUsers(
|
||||
mock_user_db,
|
||||
@ -25,11 +30,11 @@ async def test_app_client(
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(fastapi_users.get_register_router())
|
||||
app.include_router(fastapi_users.get_reset_password_router("SECRET"))
|
||||
app.include_router(fastapi_users.get_reset_password_router(secret))
|
||||
app.include_router(fastapi_users.get_auth_router(mock_authentication))
|
||||
app.include_router(fastapi_users.get_oauth_router(oauth_client, "SECRET"))
|
||||
app.include_router(fastapi_users.get_oauth_router(oauth_client, secret))
|
||||
app.include_router(fastapi_users.get_users_router(), prefix="/users")
|
||||
app.include_router(fastapi_users.get_verify_router("SECRET"))
|
||||
app.include_router(fastapi_users.get_verify_router(secret))
|
||||
|
||||
@app.delete("/users/me")
|
||||
def custom_users_route():
|
||||
|
@ -11,8 +11,6 @@ from fastapi_users.router.common import ErrorCode
|
||||
from fastapi_users.router.oauth import generate_state_token, get_oauth_router
|
||||
from tests.conftest import MockAuthentication, UserDB
|
||||
|
||||
SECRET = "SECRET"
|
||||
|
||||
|
||||
def after_register_sync():
|
||||
return MagicMock(return_value=None)
|
||||
@ -29,6 +27,7 @@ def after_register(request):
|
||||
|
||||
@pytest.fixture
|
||||
def get_test_app_client(
|
||||
secret,
|
||||
mock_user_db_oauth,
|
||||
mock_authentication,
|
||||
oauth_client,
|
||||
@ -48,7 +47,7 @@ def get_test_app_client(
|
||||
mock_user_db_oauth,
|
||||
UserDB,
|
||||
authenticator,
|
||||
SECRET,
|
||||
secret,
|
||||
redirect_url,
|
||||
after_register,
|
||||
)
|
||||
|
@ -10,9 +10,6 @@ from fastapi_users.router import ErrorCode, get_register_router
|
||||
from fastapi_users.user import get_create_user
|
||||
from tests.conftest import User, UserCreate, UserDB
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
|
||||
|
||||
def after_register_sync():
|
||||
return MagicMock(return_value=None)
|
||||
|
@ -3,25 +3,23 @@ from unittest.mock import MagicMock
|
||||
|
||||
import asynctest
|
||||
import httpx
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request, status
|
||||
|
||||
from fastapi_users.jwt import decode_jwt, generate_jwt
|
||||
from fastapi_users.router import ErrorCode, get_reset_password_router
|
||||
from fastapi_users.utils import JWT_ALGORITHM, generate_jwt
|
||||
from tests.conftest import UserDB
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def forgot_password_token():
|
||||
def forgot_password_token(secret):
|
||||
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
|
||||
data = {"aud": "fastapi-users:reset"}
|
||||
if user_id is not None:
|
||||
data["user_id"] = str(user_id)
|
||||
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
|
||||
return generate_jwt(data, secret, lifetime)
|
||||
|
||||
return _forgot_password_token
|
||||
|
||||
@ -55,6 +53,7 @@ def after_reset_password(request):
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_client(
|
||||
secret,
|
||||
mock_user_db,
|
||||
after_forgot_password,
|
||||
after_reset_password,
|
||||
@ -63,7 +62,7 @@ async def test_app_client(
|
||||
) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
reset_router = get_reset_password_router(
|
||||
mock_user_db,
|
||||
SECRET,
|
||||
secret,
|
||||
LIFETIME,
|
||||
after_forgot_password,
|
||||
after_reset_password,
|
||||
@ -107,7 +106,12 @@ class TestForgotPassword:
|
||||
"email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"]
|
||||
)
|
||||
async def test_existing_user(
|
||||
self, email, test_app_client: httpx.AsyncClient, after_forgot_password, user
|
||||
self,
|
||||
secret,
|
||||
email,
|
||||
test_app_client: httpx.AsyncClient,
|
||||
after_forgot_password,
|
||||
user,
|
||||
):
|
||||
json = {"email": email}
|
||||
response = await test_app_client.post("/forgot-password", json=json)
|
||||
@ -117,11 +121,11 @@ class TestForgotPassword:
|
||||
actual_user = after_forgot_password.call_args[0][0]
|
||||
assert actual_user.id == user.id
|
||||
actual_token = after_forgot_password.call_args[0][1]
|
||||
decoded_token = jwt.decode(
|
||||
|
||||
decoded_token = decode_jwt(
|
||||
actual_token,
|
||||
SECRET,
|
||||
audience="fastapi-users:reset",
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
secret,
|
||||
audience=["fastapi-users:reset"],
|
||||
)
|
||||
assert decoded_token["user_id"] == str(user.id)
|
||||
request = after_forgot_password.call_args[0][2]
|
||||
|
@ -10,9 +10,6 @@ from fastapi_users.authentication import Authenticator
|
||||
from fastapi_users.router import ErrorCode, get_users_router
|
||||
from tests.conftest import MockAuthentication, User, UserDB, UserUpdate
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
|
||||
|
||||
def after_update_sync():
|
||||
return MagicMock(return_value=None)
|
||||
|
@ -6,26 +6,24 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from fastapi_users.router import ErrorCode, get_verify_router
|
||||
from fastapi_users.user import get_get_user, get_verify_user
|
||||
from fastapi_users.utils import generate_jwt
|
||||
from tests.conftest import User, UserDB
|
||||
|
||||
SECRET = "SECRET"
|
||||
LIFETIME = 3600
|
||||
VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify"
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def verify_token():
|
||||
def verify_token(secret):
|
||||
def _verify_token(user_id=None, email=None, lifetime=LIFETIME):
|
||||
data = {"aud": VERIFY_USER_TOKEN_AUDIENCE}
|
||||
if user_id is not None:
|
||||
data["user_id"] = str(user_id)
|
||||
if email is not None:
|
||||
data["email"] = email
|
||||
return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
|
||||
return generate_jwt(data, secret, lifetime)
|
||||
|
||||
return _verify_token
|
||||
|
||||
@ -61,6 +59,7 @@ def after_verification_request(request):
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_client(
|
||||
secret,
|
||||
mock_user_db,
|
||||
after_verification_request,
|
||||
after_verification,
|
||||
@ -72,7 +71,7 @@ async def test_app_client(
|
||||
verify_user,
|
||||
get_user,
|
||||
User,
|
||||
SECRET,
|
||||
secret,
|
||||
LIFETIME,
|
||||
after_verification_request,
|
||||
after_verification,
|
||||
|
Reference in New Issue
Block a user