Fix token whitelist and new token storage (#220)

* Fix token whitelist and new token storage

* Fix logout interface logic
This commit is contained in:
Wu Clan
2023-09-26 18:09:55 +08:00
committed by GitHub
parent 10f5fa6198
commit de9b10a867
8 changed files with 50 additions and 23 deletions

View File

@ -43,9 +43,19 @@ async def user_login(request: Request, obj: AuthLogin, background_tasks: Backgro
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
access_token, access_expire = await AuthService.new_token(refresh_token=refresh_token)
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire)
async def create_new_token(request: Request, refresh_token: Annotated[str, Query(...)]):
(
new_access_token,
new_refresh_token,
new_access_token_expire_time,
new_refresh_token_expire_time,
) = await AuthService.new_token(request=request, refresh_token=refresh_token)
data = GetNewToken(
access_token=new_access_token,
access_token_expire_time=new_access_token_expire_time,
refresh_token=new_refresh_token,
refresh_token_expire_time=new_refresh_token_expire_time,
)
return await response_base.success(data=data)

View File

@ -96,19 +96,25 @@ async def create_refresh_token(sub: str, expire_time: datetime | None = None, **
return refresh_token, expire
async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str, datetime]:
async def create_new_token(sub: str, token: str, refresh_token: str, **kwargs) -> tuple[str, str, datetime, datetime]:
"""
Generate new token
:param sub:
:param token
:param refresh_token:
:return:
"""
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}')
if not redis_refresh_token or redis_refresh_token != refresh_token:
raise TokenError(msg='refresh_token 已过期')
new_token, expire = await create_access_token(sub, **kwargs)
return new_token, expire
new_access_token, new_access_token_expire_time = await create_access_token(sub, **kwargs)
new_refresh_token, new_refresh_token_expire_time = await create_refresh_token(sub, **kwargs)
token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
refresh_token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{refresh_token}'
await redis_client.delete(token_key)
await redis_client.delete(refresh_token_key)
return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
@sync_to_async

View File

@ -34,6 +34,9 @@ class RBAC:
:param _:
:return:
"""
path = request.url.path
if path in settings.TOKEN_EXCLUDE:
return
# 强制校验 JWT 授权状态
if not request.auth.scopes:
raise TokenError
@ -71,7 +74,6 @@ class RBAC:
else:
# casbin 权限校验
method = request.method
path = request.url.path
forbid_menu_path = [
menu.path for role in user_roles for menu in role.menus if menu.status == StatusType.disable
]

View File

@ -99,7 +99,7 @@ class Settings(BaseSettings):
TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login'
TOKEN_REDIS_PREFIX: str = 'fba_token'
TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token'
TOKEN_WHITELIST: list[str] = [ # 白名单
TOKEN_EXCLUDE: list[str] = [ # 白名单
f'{API_V1_STR}/auth/login',
]

View File

@ -36,7 +36,7 @@ class JwtAuthMiddleware(AuthenticationBackend):
if not auth:
return
if request.url.path in settings.TOKEN_WHITELIST:
if request.url.path in settings.TOKEN_EXCLUDE:
return
scheme, token = auth.split()

View File

@ -26,4 +26,6 @@ class GetLoginToken(AccessTokenBase):
class GetNewToken(AccessTokenBase):
pass
refresh_token: str
refresh_token_type: str = 'Bearer'
refresh_token_expire_time: datetime

View File

@ -93,18 +93,26 @@ class AuthService:
return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user
@staticmethod
async def new_token(*, refresh_token: str) -> tuple[str, datetime]:
async def new_token(*, request: Request, refresh_token: str) -> tuple[str, str, datetime, datetime]:
user_id = await jwt.jwt_decode(refresh_token)
if str(request.user.id) != user_id:
raise errors.TokenError(msg='刷新 token 无效')
async with async_db_session() as db:
current_user = await UserDao.get(db, user_id)
if not current_user:
raise errors.NotFoundError(msg='用户不存在')
elif not current_user.status:
raise errors.AuthorizationError(msg='用户已锁定, 获取失败')
access_new_token, access_new_token_expire_time = await jwt.create_new_token(
str(current_user.id), refresh_token, multi_login=current_user.is_multi_login
raise errors.AuthorizationError(msg='用户已锁定,操作失败')
current_token = await get_token(request)
(
new_access_token,
new_refresh_token,
new_access_token_expire_time,
new_refresh_token_expire_time,
) = await jwt.create_new_token(
str(current_user.id), current_token, refresh_token, multi_login=current_user.is_multi_login
)
return access_new_token, access_new_token_expire_time
return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
@staticmethod
async def logout(*, request: Request) -> NoReturn:

View File

@ -5,9 +5,8 @@ import random
from fastapi import Request
from sqlalchemy import Select
from backend.app.common import jwt
from backend.app.common.exception import errors
from backend.app.common.jwt import get_token, password_verify
from backend.app.common.jwt import get_token, password_verify, superuser_verify
from backend.app.common.redis import redis_client
from backend.app.core.conf import settings
from backend.app.crud.crud_dept import DeptDao
@ -86,7 +85,7 @@ class UserService:
@staticmethod
async def update(*, request: Request, username: str, obj: UpdateUser) -> int:
async with async_db_session.begin() as db:
await jwt.superuser_verify(request)
await superuser_verify(request)
input_user = await UserDao.get_with_relation(db, username=username)
if not input_user:
raise errors.NotFoundError(msg='用户不存在')
@ -139,7 +138,7 @@ class UserService:
@staticmethod
async def update_permission(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db:
await jwt.superuser_verify(request)
await superuser_verify(request)
if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在')
else:
@ -151,7 +150,7 @@ class UserService:
@staticmethod
async def update_staff(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db:
await jwt.superuser_verify(request)
await superuser_verify(request)
if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在')
else:
@ -163,7 +162,7 @@ class UserService:
@staticmethod
async def update_status(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db:
await jwt.superuser_verify(request)
await superuser_verify(request)
if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在')
else:
@ -175,7 +174,7 @@ class UserService:
@staticmethod
async def update_multi_login(*, request: Request, pk: int) -> int:
async with async_db_session.begin() as db:
await jwt.superuser_verify(request)
await superuser_verify(request)
if not await UserDao.get(db, pk):
raise errors.NotFoundError(msg='用户不存在')
else:
@ -199,7 +198,7 @@ class UserService:
@staticmethod
async def delete(*, request: Request, username: str) -> int:
async with async_db_session.begin() as db:
await jwt.superuser_verify(request)
await superuser_verify(request)
input_user = await UserDao.get_by_username(db, username)
if not input_user:
raise errors.NotFoundError(msg='用户不存在')