Fix the merge issues (#87)

* Fix the login log status value.

* Fix config information interface constants

* Add fuzzy paging query for login logs

* Fix fuzzy paging query for query user interface

* Fix jwt middleware internal exception not caught
This commit is contained in:
Wu Clan
2023-06-01 16:04:59 +08:00
committed by GitHub
parent 61147d4636
commit e6640e7936
12 changed files with 80 additions and 36 deletions

View File

@ -44,7 +44,8 @@ async def get_sys_config():
'token_algorithm': settings.TOKEN_ALGORITHM,
'token_expire_seconds': settings.TOKEN_EXPIRE_SECONDS,
'token_swagger_url': settings.TOKEN_URL_SWAGGER,
'log_file_name': settings.LOG_FILE_NAME,
'access_log_filename': settings.LOG_STDOUT_FILENAME,
'error_log_filename': settings.LOG_STDERR_FILENAME,
'middleware_cors': settings.MIDDLEWARE_CORS,
'middleware_gzip': settings.MIDDLEWARE_GZIP,
'middleware_access': settings.MIDDLEWARE_ACCESS,

View File

@ -15,9 +15,14 @@ from backend.app.services.login_log_service import LoginLogService
router = APIRouter()
@router.get('', summary='获取所有登录日志', dependencies=[DependsJwtAuth, PageDepends])
async def get_all_login_logs(db: CurrentSession):
log_select = await LoginLogService.get_select()
@router.get('', summary='(模糊条件)分页获取登录日志', dependencies=[DependsJwtAuth, PageDepends])
async def get_all_login_logs(
db: CurrentSession,
username: Annotated[str | None, Query()] = None,
status: Annotated[bool | None, Query()] = None,
ipaddr: Annotated[str | None, Query()] = None,
):
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
page_data = await paging_data(db, log_select, GetAllLoginLog)
return response_base.success(data=page_data)

View File

@ -57,7 +57,7 @@ async def get_all_users(
db: CurrentSession,
username: Annotated[str | None, Query()] = None,
phone: Annotated[str | None, Query()] = None,
status: Annotated[int | None, Query()] = None,
status: Annotated[bool | None, Query()] = None,
):
user_select = await UserService.get_select(username=username, phone=phone, status=status)
page_data = await paging_data(db, user_select, GetAllUserInfo)

View File

@ -91,7 +91,6 @@ class Settings(BaseSettings):
# Middleware
MIDDLEWARE_CORS: bool = True
MIDDLEWARE_GZIP: bool = True
MIDDLEWARE_JWT_AUTH: bool = True
MIDDLEWARE_ACCESS: bool = False
# Casbin

View File

@ -13,6 +13,7 @@ from backend.app.common.redis import redis_client
from backend.app.common.task import scheduler
from backend.app.core.conf import settings
from backend.app.database.db_mysql import create_table
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
from backend.app.utils.health_check import ensure_unique_route_names
from backend.app.utils.openapi import simplify_operation_ids
@ -90,7 +91,21 @@ def register_static_file(app: FastAPI):
def register_middleware(app: FastAPI):
# CORS
# Gzip
if settings.MIDDLEWARE_GZIP:
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware)
# Api access logs
if settings.MIDDLEWARE_ACCESS:
from backend.app.middleware.access_middleware import AccessMiddleware
app.add_middleware(AccessMiddleware)
# JWT auth: Always open
app.add_middleware(
AuthenticationMiddleware, backend=JwtAuthMiddleware(), on_error=JwtAuthMiddleware.auth_exception_handler
)
# CORS: Always at the end
if settings.MIDDLEWARE_CORS:
from fastapi.middleware.cors import CORSMiddleware
@ -101,21 +116,6 @@ def register_middleware(app: FastAPI):
allow_methods=['*'],
allow_headers=['*'],
)
# Gzip
if settings.MIDDLEWARE_GZIP:
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware)
# JWT auth
if settings.MIDDLEWARE_JWT_AUTH:
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
app.add_middleware(AuthenticationMiddleware, backend=JwtAuthMiddleware())
# Api access logs
if settings.MIDDLEWARE_ACCESS:
from backend.app.middleware.access_middleware import AccessMiddleware
app.add_middleware(AccessMiddleware)
def register_router(app: FastAPI):

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
from typing import NoReturn
from sqlalchemy import Select, select, desc, delete
from sqlalchemy import Select, select, desc, delete, and_
from sqlalchemy.ext.asyncio import AsyncSession
from backend.app.crud.base import CRUDBase
@ -11,8 +11,18 @@ from backend.app.schemas.login_log import CreateLoginLog, UpdateLoginLog
class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]):
async def get_all(self) -> Select:
return select(self.model).order_by(desc(self.model.create_time))
async def get_all(self, username: str = None, status: bool = None, ipaddr: str = None) -> Select:
se = select(self.model).order_by(desc(self.model.create_time))
where_list = []
if username:
where_list.append(self.model.username.like(f'%{username}%'))
if status is not None:
where_list.append(self.model.status == status)
if ipaddr:
where_list.append(self.model.ipaddr.like(f'%{ipaddr}%'))
if where_list:
se = se.where(and_(*where_list))
return se
async def create(self, db: AsyncSession, obj_in: CreateLoginLog) -> NoReturn:
await self.create_(db, obj_in)

View File

@ -70,6 +70,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
async def get_all(self, username: str = None, phone: str = None, status: int = None) -> Select:
se = (
select(self.model)
.options(selectinload(self.model.dept))
.options(selectinload(self.model.roles).selectinload(Role.menus))
.order_by(desc(self.model.time_joined))
)

View File

@ -1,15 +1,37 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from starlette.authentication import AuthenticationBackend
from fastapi import Request
from typing import Any
from fastapi import Request, Response
from starlette.authentication import AuthenticationBackend, AuthenticationError
from starlette.requests import HTTPConnection
from starlette.responses import JSONResponse
from backend.app.common import jwt
from backend.app.common.exception.errors import TokenError
from backend.app.core.conf import settings
from backend.app.database.db_mysql import async_db_session
class _AuthenticationError(AuthenticationError):
"""重写内部认证错误类"""
def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None):
self.code = code
self.msg = msg
self.headers = headers
class JwtAuthMiddleware(AuthenticationBackend):
"""JWT 认证中间件"""
@staticmethod
def auth_exception_handler(conn: HTTPConnection, exc: Exception) -> Response:
"""覆盖内部认证错误处理"""
code = getattr(exc, 'code', 500)
msg = getattr(exc, 'msg', 'Internal Server Error')
return JSONResponse(content={'code': code, 'msg': msg, 'data': None}, status_code=code)
async def authenticate(self, request: Request):
auth = request.headers.get('Authorization')
if not auth:
@ -19,9 +41,15 @@ class JwtAuthMiddleware(AuthenticationBackend):
if scheme.lower() != 'bearer':
return
try:
sub = await jwt.jwt_authentication(token)
async with async_db_session() as db:
user = await jwt.get_current_user(db, data=sub)
except TokenError as exc:
raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
except Exception:
import traceback
raise _AuthenticationError(msg=traceback.format_exc() if settings.ENVIRONMENT == 'dev' else None)
return auth, user

View File

@ -16,7 +16,7 @@ class LoginLog(DataClassBase):
id: Mapped[id_key] = mapped_column(init=False)
user_uuid: Mapped[str] = mapped_column(String(50), nullable=False, comment='用户UUID')
username: Mapped[str] = mapped_column(String(20), nullable=False, comment='用户名')
status: Mapped[int] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
status: Mapped[bool] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
ipaddr: Mapped[str] = mapped_column(String(50), nullable=False, comment='登录IP地址')
location: Mapped[str] = mapped_column(String(255), nullable=False, comment='归属地')
browser: Mapped[str] = mapped_column(String(255), nullable=False, comment='浏览器')

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel
class LoginLogBase(BaseModel):
user_uuid: str
username: str
status: int
status: bool
ipaddr: str
location: str
browser: str

View File

@ -71,7 +71,7 @@ class AuthService:
request=request,
user=current_user,
login_time=self.login_time,
status=LoginLogStatus.fail,
status=LoginLogStatus.fail.value,
msg=e.msg,
)
task = BackgroundTask(LoginLogService.create, **err_log_info)
@ -84,7 +84,7 @@ class AuthService:
request=request,
user=user,
login_time=self.login_time,
status=LoginLogStatus.success,
status=LoginLogStatus.success.value,
msg='登录成功',
)
background_tasks.add_task(LoginLogService.create, **log_info)

View File

@ -19,8 +19,8 @@ from backend.app.utils import request_parse
class LoginLogService:
@staticmethod
async def get_select() -> Select:
return await LoginLogDao.get_all()
async def get_select(*, username: str, status: bool, ipaddr: str) -> Select:
return await LoginLogDao.get_all(username=username, status=status, ipaddr=ipaddr)
@staticmethod
async def create(