mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-18 15:00:46 +08:00
Fix the merge issues (#87)
* Fix the login log status value. * Fix config information interface constants * Add fuzzy paging query for login logs * Fix fuzzy paging query for query user interface * Fix jwt middleware internal exception not caught
This commit is contained in:
@ -44,7 +44,8 @@ async def get_sys_config():
|
||||
'token_algorithm': settings.TOKEN_ALGORITHM,
|
||||
'token_expire_seconds': settings.TOKEN_EXPIRE_SECONDS,
|
||||
'token_swagger_url': settings.TOKEN_URL_SWAGGER,
|
||||
'log_file_name': settings.LOG_FILE_NAME,
|
||||
'access_log_filename': settings.LOG_STDOUT_FILENAME,
|
||||
'error_log_filename': settings.LOG_STDERR_FILENAME,
|
||||
'middleware_cors': settings.MIDDLEWARE_CORS,
|
||||
'middleware_gzip': settings.MIDDLEWARE_GZIP,
|
||||
'middleware_access': settings.MIDDLEWARE_ACCESS,
|
||||
|
@ -15,9 +15,14 @@ from backend.app.services.login_log_service import LoginLogService
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('', summary='获取所有登录日志', dependencies=[DependsJwtAuth, PageDepends])
|
||||
async def get_all_login_logs(db: CurrentSession):
|
||||
log_select = await LoginLogService.get_select()
|
||||
@router.get('', summary='(模糊条件)分页获取登录日志', dependencies=[DependsJwtAuth, PageDepends])
|
||||
async def get_all_login_logs(
|
||||
db: CurrentSession,
|
||||
username: Annotated[str | None, Query()] = None,
|
||||
status: Annotated[bool | None, Query()] = None,
|
||||
ipaddr: Annotated[str | None, Query()] = None,
|
||||
):
|
||||
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
||||
page_data = await paging_data(db, log_select, GetAllLoginLog)
|
||||
return response_base.success(data=page_data)
|
||||
|
||||
|
@ -57,7 +57,7 @@ async def get_all_users(
|
||||
db: CurrentSession,
|
||||
username: Annotated[str | None, Query()] = None,
|
||||
phone: Annotated[str | None, Query()] = None,
|
||||
status: Annotated[int | None, Query()] = None,
|
||||
status: Annotated[bool | None, Query()] = None,
|
||||
):
|
||||
user_select = await UserService.get_select(username=username, phone=phone, status=status)
|
||||
page_data = await paging_data(db, user_select, GetAllUserInfo)
|
||||
|
@ -91,7 +91,6 @@ class Settings(BaseSettings):
|
||||
# Middleware
|
||||
MIDDLEWARE_CORS: bool = True
|
||||
MIDDLEWARE_GZIP: bool = True
|
||||
MIDDLEWARE_JWT_AUTH: bool = True
|
||||
MIDDLEWARE_ACCESS: bool = False
|
||||
|
||||
# Casbin
|
||||
|
@ -13,6 +13,7 @@ from backend.app.common.redis import redis_client
|
||||
from backend.app.common.task import scheduler
|
||||
from backend.app.core.conf import settings
|
||||
from backend.app.database.db_mysql import create_table
|
||||
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
|
||||
from backend.app.utils.health_check import ensure_unique_route_names
|
||||
from backend.app.utils.openapi import simplify_operation_ids
|
||||
|
||||
@ -90,7 +91,21 @@ def register_static_file(app: FastAPI):
|
||||
|
||||
|
||||
def register_middleware(app: FastAPI):
|
||||
# CORS
|
||||
# Gzip
|
||||
if settings.MIDDLEWARE_GZIP:
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
|
||||
app.add_middleware(GZipMiddleware)
|
||||
# Api access logs
|
||||
if settings.MIDDLEWARE_ACCESS:
|
||||
from backend.app.middleware.access_middleware import AccessMiddleware
|
||||
|
||||
app.add_middleware(AccessMiddleware)
|
||||
# JWT auth: Always open
|
||||
app.add_middleware(
|
||||
AuthenticationMiddleware, backend=JwtAuthMiddleware(), on_error=JwtAuthMiddleware.auth_exception_handler
|
||||
)
|
||||
# CORS: Always at the end
|
||||
if settings.MIDDLEWARE_CORS:
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
@ -101,21 +116,6 @@ def register_middleware(app: FastAPI):
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
# Gzip
|
||||
if settings.MIDDLEWARE_GZIP:
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
|
||||
app.add_middleware(GZipMiddleware)
|
||||
# JWT auth
|
||||
if settings.MIDDLEWARE_JWT_AUTH:
|
||||
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
|
||||
|
||||
app.add_middleware(AuthenticationMiddleware, backend=JwtAuthMiddleware())
|
||||
# Api access logs
|
||||
if settings.MIDDLEWARE_ACCESS:
|
||||
from backend.app.middleware.access_middleware import AccessMiddleware
|
||||
|
||||
app.add_middleware(AccessMiddleware)
|
||||
|
||||
|
||||
def register_router(app: FastAPI):
|
||||
|
@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import NoReturn
|
||||
|
||||
from sqlalchemy import Select, select, desc, delete
|
||||
from sqlalchemy import Select, select, desc, delete, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.app.crud.base import CRUDBase
|
||||
@ -11,8 +11,18 @@ from backend.app.schemas.login_log import CreateLoginLog, UpdateLoginLog
|
||||
|
||||
|
||||
class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]):
|
||||
async def get_all(self) -> Select:
|
||||
return select(self.model).order_by(desc(self.model.create_time))
|
||||
async def get_all(self, username: str = None, status: bool = None, ipaddr: str = None) -> Select:
|
||||
se = select(self.model).order_by(desc(self.model.create_time))
|
||||
where_list = []
|
||||
if username:
|
||||
where_list.append(self.model.username.like(f'%{username}%'))
|
||||
if status is not None:
|
||||
where_list.append(self.model.status == status)
|
||||
if ipaddr:
|
||||
where_list.append(self.model.ipaddr.like(f'%{ipaddr}%'))
|
||||
if where_list:
|
||||
se = se.where(and_(*where_list))
|
||||
return se
|
||||
|
||||
async def create(self, db: AsyncSession, obj_in: CreateLoginLog) -> NoReturn:
|
||||
await self.create_(db, obj_in)
|
||||
|
@ -70,6 +70,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
|
||||
async def get_all(self, username: str = None, phone: str = None, status: int = None) -> Select:
|
||||
se = (
|
||||
select(self.model)
|
||||
.options(selectinload(self.model.dept))
|
||||
.options(selectinload(self.model.roles).selectinload(Role.menus))
|
||||
.order_by(desc(self.model.time_joined))
|
||||
)
|
||||
|
@ -1,15 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from starlette.authentication import AuthenticationBackend
|
||||
from fastapi import Request
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.authentication import AuthenticationBackend, AuthenticationError
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from backend.app.common import jwt
|
||||
from backend.app.common.exception.errors import TokenError
|
||||
from backend.app.core.conf import settings
|
||||
from backend.app.database.db_mysql import async_db_session
|
||||
|
||||
|
||||
class _AuthenticationError(AuthenticationError):
|
||||
"""重写内部认证错误类"""
|
||||
|
||||
def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None):
|
||||
self.code = code
|
||||
self.msg = msg
|
||||
self.headers = headers
|
||||
|
||||
|
||||
class JwtAuthMiddleware(AuthenticationBackend):
|
||||
"""JWT 认证中间件"""
|
||||
|
||||
@staticmethod
|
||||
def auth_exception_handler(conn: HTTPConnection, exc: Exception) -> Response:
|
||||
"""覆盖内部认证错误处理"""
|
||||
code = getattr(exc, 'code', 500)
|
||||
msg = getattr(exc, 'msg', 'Internal Server Error')
|
||||
return JSONResponse(content={'code': code, 'msg': msg, 'data': None}, status_code=code)
|
||||
|
||||
async def authenticate(self, request: Request):
|
||||
auth = request.headers.get('Authorization')
|
||||
if not auth:
|
||||
@ -19,9 +41,15 @@ class JwtAuthMiddleware(AuthenticationBackend):
|
||||
if scheme.lower() != 'bearer':
|
||||
return
|
||||
|
||||
try:
|
||||
sub = await jwt.jwt_authentication(token)
|
||||
|
||||
async with async_db_session() as db:
|
||||
user = await jwt.get_current_user(db, data=sub)
|
||||
except TokenError as exc:
|
||||
raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
raise _AuthenticationError(msg=traceback.format_exc() if settings.ENVIRONMENT == 'dev' else None)
|
||||
|
||||
return auth, user
|
||||
|
@ -16,7 +16,7 @@ class LoginLog(DataClassBase):
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
user_uuid: Mapped[str] = mapped_column(String(50), nullable=False, comment='用户UUID')
|
||||
username: Mapped[str] = mapped_column(String(20), nullable=False, comment='用户名')
|
||||
status: Mapped[int] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
|
||||
status: Mapped[bool] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
|
||||
ipaddr: Mapped[str] = mapped_column(String(50), nullable=False, comment='登录IP地址')
|
||||
location: Mapped[str] = mapped_column(String(255), nullable=False, comment='归属地')
|
||||
browser: Mapped[str] = mapped_column(String(255), nullable=False, comment='浏览器')
|
||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel
|
||||
class LoginLogBase(BaseModel):
|
||||
user_uuid: str
|
||||
username: str
|
||||
status: int
|
||||
status: bool
|
||||
ipaddr: str
|
||||
location: str
|
||||
browser: str
|
||||
|
@ -71,7 +71,7 @@ class AuthService:
|
||||
request=request,
|
||||
user=current_user,
|
||||
login_time=self.login_time,
|
||||
status=LoginLogStatus.fail,
|
||||
status=LoginLogStatus.fail.value,
|
||||
msg=e.msg,
|
||||
)
|
||||
task = BackgroundTask(LoginLogService.create, **err_log_info)
|
||||
@ -84,7 +84,7 @@ class AuthService:
|
||||
request=request,
|
||||
user=user,
|
||||
login_time=self.login_time,
|
||||
status=LoginLogStatus.success,
|
||||
status=LoginLogStatus.success.value,
|
||||
msg='登录成功',
|
||||
)
|
||||
background_tasks.add_task(LoginLogService.create, **log_info)
|
||||
|
@ -19,8 +19,8 @@ from backend.app.utils import request_parse
|
||||
|
||||
class LoginLogService:
|
||||
@staticmethod
|
||||
async def get_select() -> Select:
|
||||
return await LoginLogDao.get_all()
|
||||
async def get_select(*, username: str, status: bool, ipaddr: str) -> Select:
|
||||
return await LoginLogDao.get_all(username=username, status=status, ipaddr=ipaddr)
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
|
Reference in New Issue
Block a user