From ef4a54c2045be2d17fd3731da73256a5e0326c07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 19 Mar 2021 18:19:58 +0100 Subject: [PATCH] Allow lifetime_seconds to be None to get session cookies --- fastapi_users/authentication/cookie.py | 8 ++++---- fastapi_users/authentication/jwt.py | 2 +- fastapi_users/router/oauth.py | 2 +- fastapi_users/router/reset.py | 2 +- fastapi_users/router/verify.py | 2 +- fastapi_users/utils.py | 11 ++++++++--- tests/test_authentication_cookie.py | 2 +- tests/test_authentication_jwt.py | 2 +- tests/test_router_reset.py | 2 +- tests/test_router_verify.py | 2 +- 10 files changed, 20 insertions(+), 15 deletions(-) diff --git a/fastapi_users/authentication/cookie.py b/fastapi_users/authentication/cookie.py index 90c3488a..5b2d8ae7 100644 --- a/fastapi_users/authentication/cookie.py +++ b/fastapi_users/authentication/cookie.py @@ -30,7 +30,7 @@ class CookieAuthentication(BaseAuthentication[str]): scheme: APIKeyCookie token_audience: str = "fastapi-users:auth" secret: str - lifetime_seconds: int + lifetime_seconds: Optional[int] cookie_name: str cookie_path: str cookie_domain: Optional[str] @@ -41,10 +41,10 @@ class CookieAuthentication(BaseAuthentication[str]): def __init__( self, secret: str, - lifetime_seconds: int, + lifetime_seconds: Optional[int] = None, cookie_name: str = "fastapiusersauth", cookie_path: str = "/", - cookie_domain: str = None, + cookie_domain: Optional[str] = None, cookie_secure: bool = True, cookie_httponly: bool = True, cookie_samesite: str = "lax", @@ -112,4 +112,4 @@ class CookieAuthentication(BaseAuthentication[str]): async def _generate_token(self, user: BaseUserDB) -> str: data = {"user_id": str(user.id), "aud": self.token_audience} - return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) + return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM) diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index 755d7d49..24cb5d74 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -71,4 +71,4 @@ class JWTAuthentication(BaseAuthentication[str]): async def _generate_token(self, user: BaseUserDB) -> str: data = {"user_id": str(user.id), "aud": self.token_audience} - return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) + return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM) diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index e6aae2b7..394dd980 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -19,7 +19,7 @@ def generate_state_token( data: Dict[str, str], secret: str, lifetime_seconds: int = 3600 ) -> str: data["aud"] = STATE_TOKEN_AUDIENCE - return generate_jwt(data, lifetime_seconds, secret, JWT_ALGORITHM) + return generate_jwt(data, secret, lifetime_seconds, JWT_ALGORITHM) def decode_state_token(token: str, secret: str) -> Dict[str, str]: diff --git a/fastapi_users/router/reset.py b/fastapi_users/router/reset.py index b0b740d7..3a7c9c62 100644 --- a/fastapi_users/router/reset.py +++ b/fastapi_users/router/reset.py @@ -33,8 +33,8 @@ def get_reset_password_router( token_data = {"user_id": str(user.id), "aud": RESET_PASSWORD_TOKEN_AUDIENCE} token = generate_jwt( token_data, - reset_password_token_lifetime_seconds, reset_password_token_secret, + reset_password_token_lifetime_seconds, ) if after_forgot_password: await run_handler(after_forgot_password, user, token, request) diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py index 7da8cedb..e12fa301 100644 --- a/fastapi_users/router/verify.py +++ b/fastapi_users/router/verify.py @@ -49,8 +49,8 @@ def get_verify_router( } token = generate_jwt( token_data, - verification_token_lifetime_seconds, verification_token_secret, + verification_token_lifetime_seconds, ) if after_verification_request: diff --git a/fastapi_users/utils.py b/fastapi_users/utils.py index 1134a482..279c3cf7 100644 --- a/fastapi_users/utils.py +++ b/fastapi_users/utils.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Optional import jwt @@ -6,9 +7,13 @@ JWT_ALGORITHM = "HS256" def generate_jwt( - data: dict, lifetime_seconds: int, secret: str, algorithm: str = JWT_ALGORITHM + data: dict, + secret: str, + lifetime_seconds: Optional[int] = None, + algorithm: str = JWT_ALGORITHM, ) -> str: payload = data.copy() - expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds) - payload["exp"] = expire + if lifetime_seconds: + expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds) + payload["exp"] = expire return jwt.encode(payload, secret, algorithm=algorithm) diff --git a/tests/test_authentication_cookie.py b/tests/test_authentication_cookie.py index 9a7c179e..9d846a66 100644 --- a/tests/test_authentication_cookie.py +++ b/tests/test_authentication_cookie.py @@ -32,7 +32,7 @@ def token(): data = {"aud": "fastapi-users:auth"} if user_id is not None: data["user_id"] = str(user_id) - return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) + return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM) return _token diff --git a/tests/test_authentication_jwt.py b/tests/test_authentication_jwt.py index aebe83c9..5600de24 100644 --- a/tests/test_authentication_jwt.py +++ b/tests/test_authentication_jwt.py @@ -21,7 +21,7 @@ def token(): data = {"aud": "fastapi-users:auth"} if user_id is not None: data["user_id"] = str(user_id) - return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) + return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM) return _token diff --git a/tests/test_router_reset.py b/tests/test_router_reset.py index a44d00dd..e348a3fd 100644 --- a/tests/test_router_reset.py +++ b/tests/test_router_reset.py @@ -21,7 +21,7 @@ def forgot_password_token(): data = {"aud": "fastapi-users:reset"} if user_id is not None: data["user_id"] = str(user_id) - return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) + return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM) return _forgot_password_token diff --git a/tests/test_router_verify.py b/tests/test_router_verify.py index c54f9d8c..83c9e896 100644 --- a/tests/test_router_verify.py +++ b/tests/test_router_verify.py @@ -25,7 +25,7 @@ def verify_token(): data["user_id"] = str(user_id) if email is not None: data["email"] = email - return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) + return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM) return _verify_token