mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-26 12:59:44 +08:00
Fix token whitelist and new token storage (#220)
* Fix token whitelist and new token storage * Fix logout interface logic
This commit is contained in:
@ -43,9 +43,19 @@ async def user_login(request: Request, obj: AuthLogin, background_tasks: Backgro
|
|||||||
|
|
||||||
|
|
||||||
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
|
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
|
||||||
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
|
async def create_new_token(request: Request, refresh_token: Annotated[str, Query(...)]):
|
||||||
access_token, access_expire = await AuthService.new_token(refresh_token=refresh_token)
|
(
|
||||||
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire)
|
new_access_token,
|
||||||
|
new_refresh_token,
|
||||||
|
new_access_token_expire_time,
|
||||||
|
new_refresh_token_expire_time,
|
||||||
|
) = await AuthService.new_token(request=request, refresh_token=refresh_token)
|
||||||
|
data = GetNewToken(
|
||||||
|
access_token=new_access_token,
|
||||||
|
access_token_expire_time=new_access_token_expire_time,
|
||||||
|
refresh_token=new_refresh_token,
|
||||||
|
refresh_token_expire_time=new_refresh_token_expire_time,
|
||||||
|
)
|
||||||
return await response_base.success(data=data)
|
return await response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,19 +96,25 @@ async def create_refresh_token(sub: str, expire_time: datetime | None = None, **
|
|||||||
return refresh_token, expire
|
return refresh_token, expire
|
||||||
|
|
||||||
|
|
||||||
async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str, datetime]:
|
async def create_new_token(sub: str, token: str, refresh_token: str, **kwargs) -> tuple[str, str, datetime, datetime]:
|
||||||
"""
|
"""
|
||||||
Generate new token
|
Generate new token
|
||||||
|
|
||||||
:param sub:
|
:param sub:
|
||||||
|
:param token
|
||||||
:param refresh_token:
|
:param refresh_token:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}')
|
redis_refresh_token = await redis_client.get(f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{sub}:{refresh_token}')
|
||||||
if not redis_refresh_token or redis_refresh_token != refresh_token:
|
if not redis_refresh_token or redis_refresh_token != refresh_token:
|
||||||
raise TokenError(msg='refresh_token 已过期')
|
raise TokenError(msg='refresh_token 已过期')
|
||||||
new_token, expire = await create_access_token(sub, **kwargs)
|
new_access_token, new_access_token_expire_time = await create_access_token(sub, **kwargs)
|
||||||
return new_token, expire
|
new_refresh_token, new_refresh_token_expire_time = await create_refresh_token(sub, **kwargs)
|
||||||
|
token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{token}'
|
||||||
|
refresh_token_key = f'{settings.TOKEN_REDIS_PREFIX}:{sub}:{refresh_token}'
|
||||||
|
await redis_client.delete(token_key)
|
||||||
|
await redis_client.delete(refresh_token_key)
|
||||||
|
return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
|
||||||
|
|
||||||
|
|
||||||
@sync_to_async
|
@sync_to_async
|
||||||
|
@ -34,6 +34,9 @@ class RBAC:
|
|||||||
:param _:
|
:param _:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
path = request.url.path
|
||||||
|
if path in settings.TOKEN_EXCLUDE:
|
||||||
|
return
|
||||||
# 强制校验 JWT 授权状态
|
# 强制校验 JWT 授权状态
|
||||||
if not request.auth.scopes:
|
if not request.auth.scopes:
|
||||||
raise TokenError
|
raise TokenError
|
||||||
@ -71,7 +74,6 @@ class RBAC:
|
|||||||
else:
|
else:
|
||||||
# casbin 权限校验
|
# casbin 权限校验
|
||||||
method = request.method
|
method = request.method
|
||||||
path = request.url.path
|
|
||||||
forbid_menu_path = [
|
forbid_menu_path = [
|
||||||
menu.path for role in user_roles for menu in role.menus if menu.status == StatusType.disable
|
menu.path for role in user_roles for menu in role.menus if menu.status == StatusType.disable
|
||||||
]
|
]
|
||||||
|
@ -99,7 +99,7 @@ class Settings(BaseSettings):
|
|||||||
TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login'
|
TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login'
|
||||||
TOKEN_REDIS_PREFIX: str = 'fba_token'
|
TOKEN_REDIS_PREFIX: str = 'fba_token'
|
||||||
TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token'
|
TOKEN_REFRESH_REDIS_PREFIX: str = 'fba_refresh_token'
|
||||||
TOKEN_WHITELIST: list[str] = [ # 白名单
|
TOKEN_EXCLUDE: list[str] = [ # 白名单
|
||||||
f'{API_V1_STR}/auth/login',
|
f'{API_V1_STR}/auth/login',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class JwtAuthMiddleware(AuthenticationBackend):
|
|||||||
if not auth:
|
if not auth:
|
||||||
return
|
return
|
||||||
|
|
||||||
if request.url.path in settings.TOKEN_WHITELIST:
|
if request.url.path in settings.TOKEN_EXCLUDE:
|
||||||
return
|
return
|
||||||
|
|
||||||
scheme, token = auth.split()
|
scheme, token = auth.split()
|
||||||
|
@ -26,4 +26,6 @@ class GetLoginToken(AccessTokenBase):
|
|||||||
|
|
||||||
|
|
||||||
class GetNewToken(AccessTokenBase):
|
class GetNewToken(AccessTokenBase):
|
||||||
pass
|
refresh_token: str
|
||||||
|
refresh_token_type: str = 'Bearer'
|
||||||
|
refresh_token_expire_time: datetime
|
||||||
|
@ -93,18 +93,26 @@ class AuthService:
|
|||||||
return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user
|
return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def new_token(*, refresh_token: str) -> tuple[str, datetime]:
|
async def new_token(*, request: Request, refresh_token: str) -> tuple[str, str, datetime, datetime]:
|
||||||
user_id = await jwt.jwt_decode(refresh_token)
|
user_id = await jwt.jwt_decode(refresh_token)
|
||||||
|
if str(request.user.id) != user_id:
|
||||||
|
raise errors.TokenError(msg='刷新 token 无效')
|
||||||
async with async_db_session() as db:
|
async with async_db_session() as db:
|
||||||
current_user = await UserDao.get(db, user_id)
|
current_user = await UserDao.get(db, user_id)
|
||||||
if not current_user:
|
if not current_user:
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
elif not current_user.status:
|
elif not current_user.status:
|
||||||
raise errors.AuthorizationError(msg='用户已锁定, 获取失败')
|
raise errors.AuthorizationError(msg='用户已锁定,操作失败')
|
||||||
access_new_token, access_new_token_expire_time = await jwt.create_new_token(
|
current_token = await get_token(request)
|
||||||
str(current_user.id), refresh_token, multi_login=current_user.is_multi_login
|
(
|
||||||
|
new_access_token,
|
||||||
|
new_refresh_token,
|
||||||
|
new_access_token_expire_time,
|
||||||
|
new_refresh_token_expire_time,
|
||||||
|
) = await jwt.create_new_token(
|
||||||
|
str(current_user.id), current_token, refresh_token, multi_login=current_user.is_multi_login
|
||||||
)
|
)
|
||||||
return access_new_token, access_new_token_expire_time
|
return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def logout(*, request: Request) -> NoReturn:
|
async def logout(*, request: Request) -> NoReturn:
|
||||||
|
@ -5,9 +5,8 @@ import random
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy import Select
|
from sqlalchemy import Select
|
||||||
|
|
||||||
from backend.app.common import jwt
|
|
||||||
from backend.app.common.exception import errors
|
from backend.app.common.exception import errors
|
||||||
from backend.app.common.jwt import get_token, password_verify
|
from backend.app.common.jwt import get_token, password_verify, superuser_verify
|
||||||
from backend.app.common.redis import redis_client
|
from backend.app.common.redis import redis_client
|
||||||
from backend.app.core.conf import settings
|
from backend.app.core.conf import settings
|
||||||
from backend.app.crud.crud_dept import DeptDao
|
from backend.app.crud.crud_dept import DeptDao
|
||||||
@ -86,7 +85,7 @@ class UserService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update(*, request: Request, username: str, obj: UpdateUser) -> int:
|
async def update(*, request: Request, username: str, obj: UpdateUser) -> int:
|
||||||
async with async_db_session.begin() as db:
|
async with async_db_session.begin() as db:
|
||||||
await jwt.superuser_verify(request)
|
await superuser_verify(request)
|
||||||
input_user = await UserDao.get_with_relation(db, username=username)
|
input_user = await UserDao.get_with_relation(db, username=username)
|
||||||
if not input_user:
|
if not input_user:
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
@ -139,7 +138,7 @@ class UserService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_permission(*, request: Request, pk: int) -> int:
|
async def update_permission(*, request: Request, pk: int) -> int:
|
||||||
async with async_db_session.begin() as db:
|
async with async_db_session.begin() as db:
|
||||||
await jwt.superuser_verify(request)
|
await superuser_verify(request)
|
||||||
if not await UserDao.get(db, pk):
|
if not await UserDao.get(db, pk):
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
else:
|
else:
|
||||||
@ -151,7 +150,7 @@ class UserService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_staff(*, request: Request, pk: int) -> int:
|
async def update_staff(*, request: Request, pk: int) -> int:
|
||||||
async with async_db_session.begin() as db:
|
async with async_db_session.begin() as db:
|
||||||
await jwt.superuser_verify(request)
|
await superuser_verify(request)
|
||||||
if not await UserDao.get(db, pk):
|
if not await UserDao.get(db, pk):
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
else:
|
else:
|
||||||
@ -163,7 +162,7 @@ class UserService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_status(*, request: Request, pk: int) -> int:
|
async def update_status(*, request: Request, pk: int) -> int:
|
||||||
async with async_db_session.begin() as db:
|
async with async_db_session.begin() as db:
|
||||||
await jwt.superuser_verify(request)
|
await superuser_verify(request)
|
||||||
if not await UserDao.get(db, pk):
|
if not await UserDao.get(db, pk):
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
else:
|
else:
|
||||||
@ -175,7 +174,7 @@ class UserService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_multi_login(*, request: Request, pk: int) -> int:
|
async def update_multi_login(*, request: Request, pk: int) -> int:
|
||||||
async with async_db_session.begin() as db:
|
async with async_db_session.begin() as db:
|
||||||
await jwt.superuser_verify(request)
|
await superuser_verify(request)
|
||||||
if not await UserDao.get(db, pk):
|
if not await UserDao.get(db, pk):
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
else:
|
else:
|
||||||
@ -199,7 +198,7 @@ class UserService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def delete(*, request: Request, username: str) -> int:
|
async def delete(*, request: Request, username: str) -> int:
|
||||||
async with async_db_session.begin() as db:
|
async with async_db_session.begin() as db:
|
||||||
await jwt.superuser_verify(request)
|
await superuser_verify(request)
|
||||||
input_user = await UserDao.get_by_username(db, username)
|
input_user = await UserDao.get_by_username(db, username)
|
||||||
if not input_user:
|
if not input_user:
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
|
Reference in New Issue
Block a user