Fix token whitelist and new token storage (#220)

* Fix token whitelist and new token storage

* Fix logout interface logic
This commit is contained in:
Wu Clan
2023-09-26 18:09:55 +08:00
committed by GitHub
parent 10f5fa6198
commit de9b10a867
8 changed files with 50 additions and 23 deletions

View File

@ -43,9 +43,19 @@ async def user_login(request: Request, obj: AuthLogin, background_tasks: Backgro
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth]) @router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
async def create_new_token(refresh_token: Annotated[str, Query(...)]): async def create_new_token(request: Request, refresh_token: Annotated[str, Query(...)]):
access_token, access_expire = await AuthService.new_token(refresh_token=refresh_token) (
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire) new_access_token,
new_refresh_token,
new_access_token_expire_time,
new_refresh_token_expire_time,
) = await AuthService.new_token(request=request, refresh_token=refresh_token)
data = GetNewToken(
access_token=new_access_token,
access_token_expire_time=new_access_token_expire_time,
refresh_token=new_refresh_token,
refresh_token_expire_time=new_refresh_token_expire_time,
)
return await response_base.success(data=data) return await response_base.success(data=data)

View File

@ -96,19 +96,25 @@ async def create_refresh_token(sub: str, expire_time: datetime | None = None, **
return refresh_token, expire return refresh_token, expire
async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str, datetime]: async def create_new_token(sub: str, token: str, refresh_token: str, **kwargs) -> tuple[str, str, datetime, datetime]:
""" """
Generate new token Generate new token
:param sub: :param sub:
:param token
:param refresh_token: :param refresh_token:
:return: :return:
""" """
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}') redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}')
if not redis_refresh_token or redis_refresh_token != refresh_token: if not redis_refresh_token or redis_refresh_token != refresh_token:
raise TokenError(msg='refresh_token 已过期') raise TokenError(msg='refresh_token 已过期')
new_token, expire = await create_access_token(sub, **kwargs) new_access_token, new_access_token_expire_time = await create_access_token(sub, **kwargs)
return new_token, expire new_refresh_token, new_refresh_token_expire_time = await create_refresh_token(sub, **kwargs)
token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
refresh_token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{refresh_token}'
await redis_client.delete(token_key)
await redis_client.delete(refresh_token_key)
return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
@sync_to_async @sync_to_async

View File

@ -34,6 +34,9 @@ class RBAC:
:param _: :param _:
:return: :return:
""" """
path = request.url.path
if path in settings.TOKEN_EXCLUDE:
return
# 强制校验 JWT 授权状态 # 强制校验 JWT 授权状态
if not request.auth.scopes: if not request.auth.scopes:
raise TokenError raise TokenError
@ -71,7 +74,6 @@ class RBAC:
else: else:
# casbin 权限校验 # casbin 权限校验
method = request.method method = request.method
path = request.url.path
forbid_menu_path = [ forbid_menu_path = [
menu.path for role in user_roles for menu in role.menus if menu.status == StatusType.disable menu.path for role in user_roles for menu in role.menus if menu.status == StatusType.disable
] ]

View File

@ -99,7 +99,7 @@ class Settings(BaseSettings):
TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login' TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login'
TOKEN_REDIS_PREFIX: str = 'fba_token' TOKEN_REDIS_PREFIX: str = 'fba_token'
TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token' TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token'
TOKEN_WHITELIST: list[str] = [ # 白名单 TOKEN_EXCLUDE: list[str] = [ # 白名单
f'{API_V1_STR}/auth/login', f'{API_V1_STR}/auth/login',
] ]

View File

@ -36,7 +36,7 @@ class JwtAuthMiddleware(AuthenticationBackend):
if not auth: if not auth:
return return
if request.url.path in settings.TOKEN_WHITELIST: if request.url.path in settings.TOKEN_EXCLUDE:
return return
scheme, token = auth.split() scheme, token = auth.split()

View File

@ -26,4 +26,6 @@ class GetLoginToken(AccessTokenBase):
class GetNewToken(AccessTokenBase): class GetNewToken(AccessTokenBase):
pass refresh_token: str
refresh_token_type: str = 'Bearer'
refresh_token_expire_time: datetime

View File

@ -93,18 +93,26 @@ class AuthService:
return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user
@staticmethod @staticmethod
async def new_token(*, refresh_token: str) -> tuple[str, datetime]: async def new_token(*, request: Request, refresh_token: str) -> tuple[str, str, datetime, datetime]:
user_id = await jwt.jwt_decode(refresh_token) user_id = await jwt.jwt_decode(refresh_token)
if str(request.user.id) != user_id:
raise errors.TokenError(msg='刷新 token 无效')
async with async_db_session() as db: async with async_db_session() as db:
current_user = await UserDao.get(db, user_id) current_user = await UserDao.get(db, user_id)
if not current_user: if not current_user:
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
elif not current_user.status: elif not current_user.status:
raise errors.AuthorizationError(msg='用户已锁定, 获取失败') raise errors.AuthorizationError(msg='用户已锁定,操作失败')
access_new_token, access_new_token_expire_time = await jwt.create_new_token( current_token = await get_token(request)
str(current_user.id), refresh_token, multi_login=current_user.is_multi_login (
new_access_token,
new_refresh_token,
new_access_token_expire_time,
new_refresh_token_expire_time,
) = await jwt.create_new_token(
str(current_user.id), current_token, refresh_token, multi_login=current_user.is_multi_login
) )
return access_new_token, access_new_token_expire_time return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
@staticmethod @staticmethod
async def logout(*, request: Request) -> NoReturn: async def logout(*, request: Request) -> NoReturn:

View File

@ -5,9 +5,8 @@ import random
from fastapi import Request from fastapi import Request
from sqlalchemy import Select from sqlalchemy import Select
from backend.app.common import jwt
from backend.app.common.exception import errors from backend.app.common.exception import errors
from backend.app.common.jwt import get_token, password_verify from backend.app.common.jwt import get_token, password_verify, superuser_verify
from backend.app.common.redis import redis_client from backend.app.common.redis import redis_client
from backend.app.core.conf import settings from backend.app.core.conf import settings
from backend.app.crud.crud_dept import DeptDao from backend.app.crud.crud_dept import DeptDao
@ -86,7 +85,7 @@ class UserService:
@staticmethod @staticmethod
async def update(*, request: Request, username: str, obj: UpdateUser) -> int: async def update(*, request: Request, username: str, obj: UpdateUser) -> int:
async with async_db_session.begin() as db: async with async_db_session.begin() as db:
await jwt.superuser_verify(request) await superuser_verify(request)
input_user = await UserDao.get_with_relation(db, username=username) input_user = await UserDao.get_with_relation(db, username=username)
if not input_user: if not input_user:
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
@ -139,7 +138,7 @@ class UserService:
@staticmethod @staticmethod
async def update_permission(*, request: Request, pk: int) -> int: async def update_permission(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db: async with async_db_session.begin() as db:
await jwt.superuser_verify(request) await superuser_verify(request)
if not await UserDao.get(db, pk): if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
else: else:
@ -151,7 +150,7 @@ class UserService:
@staticmethod @staticmethod
async def update_staff(*, request: Request, pk: int) -> int: async def update_staff(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db: async with async_db_session.begin() as db:
await jwt.superuser_verify(request) await superuser_verify(request)
if not await UserDao.get(db, pk): if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
else: else:
@ -163,7 +162,7 @@ class UserService:
@staticmethod @staticmethod
async def update_status(*, request: Request, pk: int) -> int: async def update_status(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db: async with async_db_session.begin() as db:
await jwt.superuser_verify(request) await superuser_verify(request)
if not await UserDao.get(db, pk): if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
else: else:
@ -175,7 +174,7 @@ class UserService:
@staticmethod @staticmethod
async def update_multi_login(*, request: Request, pk: int) -> int: async def update_multi_login(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db: async with async_db_session.begin() as db:
await jwt.superuser_verify(request) await superuser_verify(request)
if not await UserDao.get(db, pk): if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
else: else:
@ -199,7 +198,7 @@ class UserService:
@staticmethod @staticmethod
async def delete(*, request: Request, username: str) -> int: async def delete(*, request: Request, username: str) -> int:
async with async_db_session.begin() as db: async with async_db_session.begin() as db:
await jwt.superuser_verify(request) await superuser_verify(request)
input_user = await UserDao.get_by_username(db, username) input_user = await UserDao.get_by_username(db, username)
if not input_user: if not input_user:
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')