Allow lifetime_seconds to be None to get session cookies

This commit is contained in:
François Voron
2021-03-19 18:19:58 +01:00
parent 902bcdb8d2
commit ef4a54c204
10 changed files with 20 additions and 15 deletions

View File

@ -30,7 +30,7 @@ class CookieAuthentication(BaseAuthentication[str]):
scheme: APIKeyCookie scheme: APIKeyCookie
token_audience: str = "fastapi-users:auth" token_audience: str = "fastapi-users:auth"
secret: str secret: str
lifetime_seconds: int lifetime_seconds: Optional[int]
cookie_name: str cookie_name: str
cookie_path: str cookie_path: str
cookie_domain: Optional[str] cookie_domain: Optional[str]
@ -41,10 +41,10 @@ class CookieAuthentication(BaseAuthentication[str]):
def __init__( def __init__(
self, self,
secret: str, secret: str,
lifetime_seconds: int, lifetime_seconds: Optional[int] = None,
cookie_name: str = "fastapiusersauth", cookie_name: str = "fastapiusersauth",
cookie_path: str = "/", cookie_path: str = "/",
cookie_domain: str = None, cookie_domain: Optional[str] = None,
cookie_secure: bool = True, cookie_secure: bool = True,
cookie_httponly: bool = True, cookie_httponly: bool = True,
cookie_samesite: str = "lax", cookie_samesite: str = "lax",
@ -112,4 +112,4 @@ class CookieAuthentication(BaseAuthentication[str]):
async def _generate_token(self, user: BaseUserDB) -> str: async def _generate_token(self, user: BaseUserDB) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience} data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM)

View File

@ -71,4 +71,4 @@ class JWTAuthentication(BaseAuthentication[str]):
async def _generate_token(self, user: BaseUserDB) -> str: async def _generate_token(self, user: BaseUserDB) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience} data = {"user_id": str(user.id), "aud": self.token_audience}
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM) return generate_jwt(data, self.secret, self.lifetime_seconds, JWT_ALGORITHM)

View File

@ -19,7 +19,7 @@ def generate_state_token(
data: Dict[str, str], secret: str, lifetime_seconds: int = 3600 data: Dict[str, str], secret: str, lifetime_seconds: int = 3600
) -> str: ) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, lifetime_seconds, secret, JWT_ALGORITHM) return generate_jwt(data, secret, lifetime_seconds, JWT_ALGORITHM)
def decode_state_token(token: str, secret: str) -> Dict[str, str]: def decode_state_token(token: str, secret: str) -> Dict[str, str]:

View File

@ -33,8 +33,8 @@ def get_reset_password_router(
token_data = {"user_id": str(user.id), "aud": RESET_PASSWORD_TOKEN_AUDIENCE} token_data = {"user_id": str(user.id), "aud": RESET_PASSWORD_TOKEN_AUDIENCE}
token = generate_jwt( token = generate_jwt(
token_data, token_data,
reset_password_token_lifetime_seconds,
reset_password_token_secret, reset_password_token_secret,
reset_password_token_lifetime_seconds,
) )
if after_forgot_password: if after_forgot_password:
await run_handler(after_forgot_password, user, token, request) await run_handler(after_forgot_password, user, token, request)

View File

@ -49,8 +49,8 @@ def get_verify_router(
} }
token = generate_jwt( token = generate_jwt(
token_data, token_data,
verification_token_lifetime_seconds,
verification_token_secret, verification_token_secret,
verification_token_lifetime_seconds,
) )
if after_verification_request: if after_verification_request:

View File

@ -1,4 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
import jwt import jwt
@ -6,9 +7,13 @@ JWT_ALGORITHM = "HS256"
def generate_jwt( def generate_jwt(
data: dict, lifetime_seconds: int, secret: str, algorithm: str = JWT_ALGORITHM data: dict,
secret: str,
lifetime_seconds: Optional[int] = None,
algorithm: str = JWT_ALGORITHM,
) -> str: ) -> str:
payload = data.copy() payload = data.copy()
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds) if lifetime_seconds:
payload["exp"] = expire expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
payload["exp"] = expire
return jwt.encode(payload, secret, algorithm=algorithm) return jwt.encode(payload, secret, algorithm=algorithm)

View File

@ -32,7 +32,7 @@ def token():
data = {"aud": "fastapi-users:auth"} data = {"aud": "fastapi-users:auth"}
if user_id is not None: if user_id is not None:
data["user_id"] = str(user_id) data["user_id"] = str(user_id)
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return _token return _token

View File

@ -21,7 +21,7 @@ def token():
data = {"aud": "fastapi-users:auth"} data = {"aud": "fastapi-users:auth"}
if user_id is not None: if user_id is not None:
data["user_id"] = str(user_id) data["user_id"] = str(user_id)
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return _token return _token

View File

@ -21,7 +21,7 @@ def forgot_password_token():
data = {"aud": "fastapi-users:reset"} data = {"aud": "fastapi-users:reset"}
if user_id is not None: if user_id is not None:
data["user_id"] = str(user_id) data["user_id"] = str(user_id)
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return _forgot_password_token return _forgot_password_token

View File

@ -25,7 +25,7 @@ def verify_token():
data["user_id"] = str(user_id) data["user_id"] = str(user_id)
if email is not None: if email is not None:
data["email"] = email data["email"] = email
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) return generate_jwt(data, SECRET, lifetime, JWT_ALGORITHM)
return _verify_token return _verify_token