Fix the merge issues (#87)

* Fix the login log status value.

* Fix config information interface constants

* Add fuzzy paging query for login logs

* Fix fuzzy paging query for query user interface

* Fix jwt middleware internal exception not caught
This commit is contained in:
Wu Clan
2023-06-01 16:04:59 +08:00
committed by GitHub
parent 61147d4636
commit e6640e7936
12 changed files with 80 additions and 36 deletions

View File

@ -44,7 +44,8 @@ async def get_sys_config():
'token_algorithm': settings.TOKEN_ALGORITHM, 'token_algorithm': settings.TOKEN_ALGORITHM,
'token_expire_seconds': settings.TOKEN_EXPIRE_SECONDS, 'token_expire_seconds': settings.TOKEN_EXPIRE_SECONDS,
'token_swagger_url': settings.TOKEN_URL_SWAGGER, 'token_swagger_url': settings.TOKEN_URL_SWAGGER,
'log_file_name': settings.LOG_FILE_NAME, 'access_log_filename': settings.LOG_STDOUT_FILENAME,
'error_log_filename': settings.LOG_STDERR_FILENAME,
'middleware_cors': settings.MIDDLEWARE_CORS, 'middleware_cors': settings.MIDDLEWARE_CORS,
'middleware_gzip': settings.MIDDLEWARE_GZIP, 'middleware_gzip': settings.MIDDLEWARE_GZIP,
'middleware_access': settings.MIDDLEWARE_ACCESS, 'middleware_access': settings.MIDDLEWARE_ACCESS,

View File

@ -15,9 +15,14 @@ from backend.app.services.login_log_service import LoginLogService
router = APIRouter() router = APIRouter()
@router.get('', summary='获取所有登录日志', dependencies=[DependsJwtAuth, PageDepends]) @router.get('', summary='(模糊条件)分页获取登录日志', dependencies=[DependsJwtAuth, PageDepends])
async def get_all_login_logs(db: CurrentSession): async def get_all_login_logs(
log_select = await LoginLogService.get_select() db: CurrentSession,
username: Annotated[str | None, Query()] = None,
status: Annotated[bool | None, Query()] = None,
ipaddr: Annotated[str | None, Query()] = None,
):
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
page_data = await paging_data(db, log_select, GetAllLoginLog) page_data = await paging_data(db, log_select, GetAllLoginLog)
return response_base.success(data=page_data) return response_base.success(data=page_data)

View File

@ -57,7 +57,7 @@ async def get_all_users(
db: CurrentSession, db: CurrentSession,
username: Annotated[str | None, Query()] = None, username: Annotated[str | None, Query()] = None,
phone: Annotated[str | None, Query()] = None, phone: Annotated[str | None, Query()] = None,
status: Annotated[int | None, Query()] = None, status: Annotated[bool | None, Query()] = None,
): ):
user_select = await UserService.get_select(username=username, phone=phone, status=status) user_select = await UserService.get_select(username=username, phone=phone, status=status)
page_data = await paging_data(db, user_select, GetAllUserInfo) page_data = await paging_data(db, user_select, GetAllUserInfo)

View File

@ -91,7 +91,6 @@ class Settings(BaseSettings):
# Middleware # Middleware
MIDDLEWARE_CORS: bool = True MIDDLEWARE_CORS: bool = True
MIDDLEWARE_GZIP: bool = True MIDDLEWARE_GZIP: bool = True
MIDDLEWARE_JWT_AUTH: bool = True
MIDDLEWARE_ACCESS: bool = False MIDDLEWARE_ACCESS: bool = False
# Casbin # Casbin

View File

@ -13,6 +13,7 @@ from backend.app.common.redis import redis_client
from backend.app.common.task import scheduler from backend.app.common.task import scheduler
from backend.app.core.conf import settings from backend.app.core.conf import settings
from backend.app.database.db_mysql import create_table from backend.app.database.db_mysql import create_table
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
from backend.app.utils.health_check import ensure_unique_route_names from backend.app.utils.health_check import ensure_unique_route_names
from backend.app.utils.openapi import simplify_operation_ids from backend.app.utils.openapi import simplify_operation_ids
@ -90,7 +91,21 @@ def register_static_file(app: FastAPI):
def register_middleware(app: FastAPI): def register_middleware(app: FastAPI):
# CORS # Gzip
if settings.MIDDLEWARE_GZIP:
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware)
# Api access logs
if settings.MIDDLEWARE_ACCESS:
from backend.app.middleware.access_middleware import AccessMiddleware
app.add_middleware(AccessMiddleware)
# JWT auth: Always open
app.add_middleware(
AuthenticationMiddleware, backend=JwtAuthMiddleware(), on_error=JwtAuthMiddleware.auth_exception_handler
)
# CORS: Always at the end
if settings.MIDDLEWARE_CORS: if settings.MIDDLEWARE_CORS:
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -101,21 +116,6 @@ def register_middleware(app: FastAPI):
allow_methods=['*'], allow_methods=['*'],
allow_headers=['*'], allow_headers=['*'],
) )
# Gzip
if settings.MIDDLEWARE_GZIP:
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware)
# JWT auth
if settings.MIDDLEWARE_JWT_AUTH:
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
app.add_middleware(AuthenticationMiddleware, backend=JwtAuthMiddleware())
# Api access logs
if settings.MIDDLEWARE_ACCESS:
from backend.app.middleware.access_middleware import AccessMiddleware
app.add_middleware(AccessMiddleware)
def register_router(app: FastAPI): def register_router(app: FastAPI):

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import NoReturn from typing import NoReturn
from sqlalchemy import Select, select, desc, delete from sqlalchemy import Select, select, desc, delete, and_
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.crud.base import CRUDBase from backend.app.crud.base import CRUDBase
@ -11,8 +11,18 @@ from backend.app.schemas.login_log import CreateLoginLog, UpdateLoginLog
class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]): class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]):
async def get_all(self) -> Select: async def get_all(self, username: str = None, status: bool = None, ipaddr: str = None) -> Select:
return select(self.model).order_by(desc(self.model.create_time)) se = select(self.model).order_by(desc(self.model.create_time))
where_list = []
if username:
where_list.append(self.model.username.like(f'%{username}%'))
if status is not None:
where_list.append(self.model.status == status)
if ipaddr:
where_list.append(self.model.ipaddr.like(f'%{ipaddr}%'))
if where_list:
se = se.where(and_(*where_list))
return se
async def create(self, db: AsyncSession, obj_in: CreateLoginLog) -> NoReturn: async def create(self, db: AsyncSession, obj_in: CreateLoginLog) -> NoReturn:
await self.create_(db, obj_in) await self.create_(db, obj_in)

View File

@ -70,6 +70,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
async def get_all(self, username: str = None, phone: str = None, status: int = None) -> Select: async def get_all(self, username: str = None, phone: str = None, status: int = None) -> Select:
se = ( se = (
select(self.model) select(self.model)
.options(selectinload(self.model.dept))
.options(selectinload(self.model.roles).selectinload(Role.menus)) .options(selectinload(self.model.roles).selectinload(Role.menus))
.order_by(desc(self.model.time_joined)) .order_by(desc(self.model.time_joined))
) )

View File

@ -1,15 +1,37 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from starlette.authentication import AuthenticationBackend from typing import Any
from fastapi import Request
from fastapi import Request, Response
from starlette.authentication import AuthenticationBackend, AuthenticationError
from starlette.requests import HTTPConnection
from starlette.responses import JSONResponse
from backend.app.common import jwt from backend.app.common import jwt
from backend.app.common.exception.errors import TokenError
from backend.app.core.conf import settings
from backend.app.database.db_mysql import async_db_session from backend.app.database.db_mysql import async_db_session
class _AuthenticationError(AuthenticationError):
"""重写内部认证错误类"""
def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None):
self.code = code
self.msg = msg
self.headers = headers
class JwtAuthMiddleware(AuthenticationBackend): class JwtAuthMiddleware(AuthenticationBackend):
"""JWT 认证中间件""" """JWT 认证中间件"""
@staticmethod
def auth_exception_handler(conn: HTTPConnection, exc: Exception) -> Response:
"""覆盖内部认证错误处理"""
code = getattr(exc, 'code', 500)
msg = getattr(exc, 'msg', 'Internal Server Error')
return JSONResponse(content={'code': code, 'msg': msg, 'data': None}, status_code=code)
async def authenticate(self, request: Request): async def authenticate(self, request: Request):
auth = request.headers.get('Authorization') auth = request.headers.get('Authorization')
if not auth: if not auth:
@ -19,9 +41,15 @@ class JwtAuthMiddleware(AuthenticationBackend):
if scheme.lower() != 'bearer': if scheme.lower() != 'bearer':
return return
try:
sub = await jwt.jwt_authentication(token) sub = await jwt.jwt_authentication(token)
async with async_db_session() as db: async with async_db_session() as db:
user = await jwt.get_current_user(db, data=sub) user = await jwt.get_current_user(db, data=sub)
except TokenError as exc:
raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
except Exception:
import traceback
raise _AuthenticationError(msg=traceback.format_exc() if settings.ENVIRONMENT == 'dev' else None)
return auth, user return auth, user

View File

@ -16,7 +16,7 @@ class LoginLog(DataClassBase):
id: Mapped[id_key] = mapped_column(init=False) id: Mapped[id_key] = mapped_column(init=False)
user_uuid: Mapped[str] = mapped_column(String(50), nullable=False, comment='用户UUID') user_uuid: Mapped[str] = mapped_column(String(50), nullable=False, comment='用户UUID')
username: Mapped[str] = mapped_column(String(20), nullable=False, comment='用户名') username: Mapped[str] = mapped_column(String(20), nullable=False, comment='用户名')
status: Mapped[int] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)') status: Mapped[bool] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
ipaddr: Mapped[str] = mapped_column(String(50), nullable=False, comment='登录IP地址') ipaddr: Mapped[str] = mapped_column(String(50), nullable=False, comment='登录IP地址')
location: Mapped[str] = mapped_column(String(255), nullable=False, comment='归属地') location: Mapped[str] = mapped_column(String(255), nullable=False, comment='归属地')
browser: Mapped[str] = mapped_column(String(255), nullable=False, comment='浏览器') browser: Mapped[str] = mapped_column(String(255), nullable=False, comment='浏览器')

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel
class LoginLogBase(BaseModel): class LoginLogBase(BaseModel):
user_uuid: str user_uuid: str
username: str username: str
status: int status: bool
ipaddr: str ipaddr: str
location: str location: str
browser: str browser: str

View File

@ -71,7 +71,7 @@ class AuthService:
request=request, request=request,
user=current_user, user=current_user,
login_time=self.login_time, login_time=self.login_time,
status=LoginLogStatus.fail, status=LoginLogStatus.fail.value,
msg=e.msg, msg=e.msg,
) )
task = BackgroundTask(LoginLogService.create, **err_log_info) task = BackgroundTask(LoginLogService.create, **err_log_info)
@ -84,7 +84,7 @@ class AuthService:
request=request, request=request,
user=user, user=user,
login_time=self.login_time, login_time=self.login_time,
status=LoginLogStatus.success, status=LoginLogStatus.success.value,
msg='登录成功', msg='登录成功',
) )
background_tasks.add_task(LoginLogService.create, **log_info) background_tasks.add_task(LoginLogService.create, **log_info)

View File

@ -19,8 +19,8 @@ from backend.app.utils import request_parse
class LoginLogService: class LoginLogService:
@staticmethod @staticmethod
async def get_select() -> Select: async def get_select(*, username: str, status: bool, ipaddr: str) -> Select:
return await LoginLogDao.get_all() return await LoginLogDao.get_all(username=username, status=status, ipaddr=ipaddr)
@staticmethod @staticmethod
async def create( async def create(