diff --git a/backend/app/api/v1/auth/auth.py b/backend/app/api/v1/auth/auth.py index 85ca76d..2c69dc4 100644 --- a/backend/app/api/v1/auth/auth.py +++ b/backend/app/api/v1/auth/auth.py @@ -43,9 +43,19 @@ async def user_login(request: Request, obj: AuthLogin, background_tasks: Backgro @router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth]) -async def create_new_token(refresh_token: Annotated[str, Query(...)]): - access_token, access_expire = await AuthService.new_token(refresh_token=refresh_token) - data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire) +async def create_new_token(request: Request, refresh_token: Annotated[str, Query(...)]): + ( + new_access_token, + new_refresh_token, + new_access_token_expire_time, + new_refresh_token_expire_time, + ) = await AuthService.new_token(request=request, refresh_token=refresh_token) + data = GetNewToken( + access_token=new_access_token, + access_token_expire_time=new_access_token_expire_time, + refresh_token=new_refresh_token, + refresh_token_expire_time=new_refresh_token_expire_time, + ) return await response_base.success(data=data) diff --git a/backend/app/common/jwt.py b/backend/app/common/jwt.py index c9f6e47..4d32b23 100644 --- a/backend/app/common/jwt.py +++ b/backend/app/common/jwt.py @@ -96,19 +96,25 @@ async def create_refresh_token(sub: str, expire_time: datetime | None = None, ** return refresh_token, expire -async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str, datetime]: +async def create_new_token(sub: str, token: str, refresh_token: str, **kwargs) -> tuple[str, str, datetime, datetime]: """ Generate new token :param sub: + :param token :param refresh_token: :return: """ redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}') if not redis_refresh_token or redis_refresh_token != refresh_token: raise TokenError(msg='refresh_token 已过期') - new_token, expire = await create_access_token(sub, **kwargs) - return new_token, expire + new_access_token, new_access_token_expire_time = await create_access_token(sub, **kwargs) + new_refresh_token, new_refresh_token_expire_time = await create_refresh_token(sub, **kwargs) + token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}' + refresh_token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{refresh_token}' + await redis_client.delete(token_key) + await redis_client.delete(refresh_token_key) + return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time @sync_to_async diff --git a/backend/app/common/rbac.py b/backend/app/common/rbac.py index a158e70..bce5e34 100644 --- a/backend/app/common/rbac.py +++ b/backend/app/common/rbac.py @@ -34,6 +34,9 @@ class RBAC: :param _: :return: """ + path = request.url.path + if path in settings.TOKEN_EXCLUDE: + return # 强制校验 JWT 授权状态 if not request.auth.scopes: raise TokenError @@ -71,7 +74,6 @@ class RBAC: else: # casbin 权限校验 method = request.method - path = request.url.path forbid_menu_path = [ menu.path for role in user_roles for menu in role.menus if menu.status == StatusType.disable ] diff --git a/backend/app/core/conf.py b/backend/app/core/conf.py index b387de4..5f3fa6f 100644 --- a/backend/app/core/conf.py +++ b/backend/app/core/conf.py @@ -99,7 +99,7 @@ class Settings(BaseSettings): TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login' TOKEN_REDIS_PREFIX: str = 'fba_token' TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token' - TOKEN_WHITELIST: list[str] = [ # 白名单 + TOKEN_EXCLUDE: list[str] = [ # 白名单 f'{API_V1_STR}/auth/login', ] diff --git a/backend/app/middleware/jwt_auth_middleware.py b/backend/app/middleware/jwt_auth_middleware.py index 50a9499..091a0ec 100644 --- a/backend/app/middleware/jwt_auth_middleware.py +++ b/backend/app/middleware/jwt_auth_middleware.py @@ -36,7 +36,7 @@ class JwtAuthMiddleware(AuthenticationBackend): if not auth: return - if request.url.path in settings.TOKEN_WHITELIST: + if request.url.path in settings.TOKEN_EXCLUDE: return scheme, token = auth.split() diff --git a/backend/app/schemas/token.py b/backend/app/schemas/token.py index 242880a..e0c0846 100644 --- a/backend/app/schemas/token.py +++ b/backend/app/schemas/token.py @@ -26,4 +26,6 @@ class GetLoginToken(AccessTokenBase): class GetNewToken(AccessTokenBase): - pass + refresh_token: str + refresh_token_type: str = 'Bearer' + refresh_token_expire_time: datetime diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 7f0af3d..540d058 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -93,18 +93,26 @@ class AuthService: return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user @staticmethod - async def new_token(*, refresh_token: str) -> tuple[str, datetime]: + async def new_token(*, request: Request, refresh_token: str) -> tuple[str, str, datetime, datetime]: user_id = await jwt.jwt_decode(refresh_token) + if str(request.user.id) != user_id: + raise errors.TokenError(msg='刷新 token 无效') async with async_db_session() as db: current_user = await UserDao.get(db, user_id) if not current_user: raise errors.NotFoundError(msg='用户不存在') elif not current_user.status: - raise errors.AuthorizationError(msg='用户已锁定, 获取失败') - access_new_token, access_new_token_expire_time = await jwt.create_new_token( - str(current_user.id), refresh_token, multi_login=current_user.is_multi_login + raise errors.AuthorizationError(msg='用户已锁定,操作失败') + current_token = await get_token(request) + ( + new_access_token, + new_refresh_token, + new_access_token_expire_time, + new_refresh_token_expire_time, + ) = await jwt.create_new_token( + str(current_user.id), current_token, refresh_token, multi_login=current_user.is_multi_login ) - return access_new_token, access_new_token_expire_time + return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time @staticmethod async def logout(*, request: Request) -> NoReturn: diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py index 8e71f83..f3e6eca 100644 --- a/backend/app/services/user_service.py +++ b/backend/app/services/user_service.py @@ -5,9 +5,8 @@ import random from fastapi import Request from sqlalchemy import Select -from backend.app.common import jwt from backend.app.common.exception import errors -from backend.app.common.jwt import get_token, password_verify +from backend.app.common.jwt import get_token, password_verify, superuser_verify from backend.app.common.redis import redis_client from backend.app.core.conf import settings from backend.app.crud.crud_dept import DeptDao @@ -86,7 +85,7 @@ class UserService: @staticmethod async def update(*, request: Request, username: str, obj: UpdateUser) -> int: async with async_db_session.begin() as db: - await jwt.superuser_verify(request) + await superuser_verify(request) input_user = await UserDao.get_with_relation(db, username=username) if not input_user: raise errors.NotFoundError(msg='用户不存在') @@ -139,7 +138,7 @@ class UserService: @staticmethod async def update_permission(*, request: Request, pk: int) -> int: async with async_db_session.begin() as db: - await jwt.superuser_verify(request) + await superuser_verify(request) if not await UserDao.get(db, pk): raise errors.NotFoundError(msg='用户不存在') else: @@ -151,7 +150,7 @@ class UserService: @staticmethod async def update_staff(*, request: Request, pk: int) -> int: async with async_db_session.begin() as db: - await jwt.superuser_verify(request) + await superuser_verify(request) if not await UserDao.get(db, pk): raise errors.NotFoundError(msg='用户不存在') else: @@ -163,7 +162,7 @@ class UserService: @staticmethod async def update_status(*, request: Request, pk: int) -> int: async with async_db_session.begin() as db: - await jwt.superuser_verify(request) + await superuser_verify(request) if not await UserDao.get(db, pk): raise errors.NotFoundError(msg='用户不存在') else: @@ -175,7 +174,7 @@ class UserService: @staticmethod async def update_multi_login(*, request: Request, pk: int) -> int: async with async_db_session.begin() as db: - await jwt.superuser_verify(request) + await superuser_verify(request) if not await UserDao.get(db, pk): raise errors.NotFoundError(msg='用户不存在') else: @@ -199,7 +198,7 @@ class UserService: @staticmethod async def delete(*, request: Request, username: str) -> int: async with async_db_session.begin() as db: - await jwt.superuser_verify(request) + await superuser_verify(request) input_user = await UserDao.get_by_username(db, username) if not input_user: raise errors.NotFoundError(msg='用户不存在')