Files

42 lines
1.0 KiB
Python

from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union
import jwt
from pydantic import SecretStr
SecretType = Union[str, SecretStr]
JWT_ALGORITHM = "HS256"
def _get_secret_value(secret: SecretType) -> str:
if isinstance(secret, SecretStr):
return secret.get_secret_value()
return secret
def generate_jwt(
data: dict,
secret: SecretType,
lifetime_seconds: Optional[int] = None,
algorithm: str = JWT_ALGORITHM,
) -> str:
payload = data.copy()
if lifetime_seconds:
expire = datetime.utcnow() + timedelta(seconds=lifetime_seconds)
payload["exp"] = expire
return jwt.encode(payload, _get_secret_value(secret), algorithm=algorithm)
def decode_jwt(
encoded_jwt: str,
secret: SecretType,
audience: List[str],
algorithms: List[str] = [JWT_ALGORITHM],
) -> Dict[str, Any]:
return jwt.decode(
encoded_jwt,
_get_secret_value(secret),
audience=audience,
algorithms=algorithms,
)