mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-19 07:21:31 +08:00
Fix the merge issues (#87)
* Fix the login log status value. * Fix config information interface constants * Add fuzzy paging query for login logs * Fix fuzzy paging query for query user interface * Fix jwt middleware internal exception not caught
This commit is contained in:
@ -44,7 +44,8 @@ async def get_sys_config():
|
|||||||
'token_algorithm': settings.TOKEN_ALGORITHM,
|
'token_algorithm': settings.TOKEN_ALGORITHM,
|
||||||
'token_expire_seconds': settings.TOKEN_EXPIRE_SECONDS,
|
'token_expire_seconds': settings.TOKEN_EXPIRE_SECONDS,
|
||||||
'token_swagger_url': settings.TOKEN_URL_SWAGGER,
|
'token_swagger_url': settings.TOKEN_URL_SWAGGER,
|
||||||
'log_file_name': settings.LOG_FILE_NAME,
|
'access_log_filename': settings.LOG_STDOUT_FILENAME,
|
||||||
|
'error_log_filename': settings.LOG_STDERR_FILENAME,
|
||||||
'middleware_cors': settings.MIDDLEWARE_CORS,
|
'middleware_cors': settings.MIDDLEWARE_CORS,
|
||||||
'middleware_gzip': settings.MIDDLEWARE_GZIP,
|
'middleware_gzip': settings.MIDDLEWARE_GZIP,
|
||||||
'middleware_access': settings.MIDDLEWARE_ACCESS,
|
'middleware_access': settings.MIDDLEWARE_ACCESS,
|
||||||
|
@ -15,9 +15,14 @@ from backend.app.services.login_log_service import LoginLogService
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get('', summary='获取所有登录日志', dependencies=[DependsJwtAuth, PageDepends])
|
@router.get('', summary='(模糊条件)分页获取登录日志', dependencies=[DependsJwtAuth, PageDepends])
|
||||||
async def get_all_login_logs(db: CurrentSession):
|
async def get_all_login_logs(
|
||||||
log_select = await LoginLogService.get_select()
|
db: CurrentSession,
|
||||||
|
username: Annotated[str | None, Query()] = None,
|
||||||
|
status: Annotated[bool | None, Query()] = None,
|
||||||
|
ipaddr: Annotated[str | None, Query()] = None,
|
||||||
|
):
|
||||||
|
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
||||||
page_data = await paging_data(db, log_select, GetAllLoginLog)
|
page_data = await paging_data(db, log_select, GetAllLoginLog)
|
||||||
return response_base.success(data=page_data)
|
return response_base.success(data=page_data)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ async def get_all_users(
|
|||||||
db: CurrentSession,
|
db: CurrentSession,
|
||||||
username: Annotated[str | None, Query()] = None,
|
username: Annotated[str | None, Query()] = None,
|
||||||
phone: Annotated[str | None, Query()] = None,
|
phone: Annotated[str | None, Query()] = None,
|
||||||
status: Annotated[int | None, Query()] = None,
|
status: Annotated[bool | None, Query()] = None,
|
||||||
):
|
):
|
||||||
user_select = await UserService.get_select(username=username, phone=phone, status=status)
|
user_select = await UserService.get_select(username=username, phone=phone, status=status)
|
||||||
page_data = await paging_data(db, user_select, GetAllUserInfo)
|
page_data = await paging_data(db, user_select, GetAllUserInfo)
|
||||||
|
@ -91,7 +91,6 @@ class Settings(BaseSettings):
|
|||||||
# Middleware
|
# Middleware
|
||||||
MIDDLEWARE_CORS: bool = True
|
MIDDLEWARE_CORS: bool = True
|
||||||
MIDDLEWARE_GZIP: bool = True
|
MIDDLEWARE_GZIP: bool = True
|
||||||
MIDDLEWARE_JWT_AUTH: bool = True
|
|
||||||
MIDDLEWARE_ACCESS: bool = False
|
MIDDLEWARE_ACCESS: bool = False
|
||||||
|
|
||||||
# Casbin
|
# Casbin
|
||||||
|
@ -13,6 +13,7 @@ from backend.app.common.redis import redis_client
|
|||||||
from backend.app.common.task import scheduler
|
from backend.app.common.task import scheduler
|
||||||
from backend.app.core.conf import settings
|
from backend.app.core.conf import settings
|
||||||
from backend.app.database.db_mysql import create_table
|
from backend.app.database.db_mysql import create_table
|
||||||
|
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
|
||||||
from backend.app.utils.health_check import ensure_unique_route_names
|
from backend.app.utils.health_check import ensure_unique_route_names
|
||||||
from backend.app.utils.openapi import simplify_operation_ids
|
from backend.app.utils.openapi import simplify_operation_ids
|
||||||
|
|
||||||
@ -90,7 +91,21 @@ def register_static_file(app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
def register_middleware(app: FastAPI):
|
def register_middleware(app: FastAPI):
|
||||||
# CORS
|
# Gzip
|
||||||
|
if settings.MIDDLEWARE_GZIP:
|
||||||
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
|
app.add_middleware(GZipMiddleware)
|
||||||
|
# Api access logs
|
||||||
|
if settings.MIDDLEWARE_ACCESS:
|
||||||
|
from backend.app.middleware.access_middleware import AccessMiddleware
|
||||||
|
|
||||||
|
app.add_middleware(AccessMiddleware)
|
||||||
|
# JWT auth: Always open
|
||||||
|
app.add_middleware(
|
||||||
|
AuthenticationMiddleware, backend=JwtAuthMiddleware(), on_error=JwtAuthMiddleware.auth_exception_handler
|
||||||
|
)
|
||||||
|
# CORS: Always at the end
|
||||||
if settings.MIDDLEWARE_CORS:
|
if settings.MIDDLEWARE_CORS:
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
@ -101,21 +116,6 @@ def register_middleware(app: FastAPI):
|
|||||||
allow_methods=['*'],
|
allow_methods=['*'],
|
||||||
allow_headers=['*'],
|
allow_headers=['*'],
|
||||||
)
|
)
|
||||||
# Gzip
|
|
||||||
if settings.MIDDLEWARE_GZIP:
|
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
|
||||||
|
|
||||||
app.add_middleware(GZipMiddleware)
|
|
||||||
# JWT auth
|
|
||||||
if settings.MIDDLEWARE_JWT_AUTH:
|
|
||||||
from backend.app.middleware.jwt_auth_middleware import JwtAuthMiddleware
|
|
||||||
|
|
||||||
app.add_middleware(AuthenticationMiddleware, backend=JwtAuthMiddleware())
|
|
||||||
# Api access logs
|
|
||||||
if settings.MIDDLEWARE_ACCESS:
|
|
||||||
from backend.app.middleware.access_middleware import AccessMiddleware
|
|
||||||
|
|
||||||
app.add_middleware(AccessMiddleware)
|
|
||||||
|
|
||||||
|
|
||||||
def register_router(app: FastAPI):
|
def register_router(app: FastAPI):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from typing import NoReturn
|
from typing import NoReturn
|
||||||
|
|
||||||
from sqlalchemy import Select, select, desc, delete
|
from sqlalchemy import Select, select, desc, delete, and_
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from backend.app.crud.base import CRUDBase
|
from backend.app.crud.base import CRUDBase
|
||||||
@ -11,8 +11,18 @@ from backend.app.schemas.login_log import CreateLoginLog, UpdateLoginLog
|
|||||||
|
|
||||||
|
|
||||||
class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]):
|
class CRUDLoginLog(CRUDBase[LoginLog, CreateLoginLog, UpdateLoginLog]):
|
||||||
async def get_all(self) -> Select:
|
async def get_all(self, username: str = None, status: bool = None, ipaddr: str = None) -> Select:
|
||||||
return select(self.model).order_by(desc(self.model.create_time))
|
se = select(self.model).order_by(desc(self.model.create_time))
|
||||||
|
where_list = []
|
||||||
|
if username:
|
||||||
|
where_list.append(self.model.username.like(f'%{username}%'))
|
||||||
|
if status is not None:
|
||||||
|
where_list.append(self.model.status == status)
|
||||||
|
if ipaddr:
|
||||||
|
where_list.append(self.model.ipaddr.like(f'%{ipaddr}%'))
|
||||||
|
if where_list:
|
||||||
|
se = se.where(and_(*where_list))
|
||||||
|
return se
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, obj_in: CreateLoginLog) -> NoReturn:
|
async def create(self, db: AsyncSession, obj_in: CreateLoginLog) -> NoReturn:
|
||||||
await self.create_(db, obj_in)
|
await self.create_(db, obj_in)
|
||||||
|
@ -70,6 +70,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
|
|||||||
async def get_all(self, username: str = None, phone: str = None, status: int = None) -> Select:
|
async def get_all(self, username: str = None, phone: str = None, status: int = None) -> Select:
|
||||||
se = (
|
se = (
|
||||||
select(self.model)
|
select(self.model)
|
||||||
|
.options(selectinload(self.model.dept))
|
||||||
.options(selectinload(self.model.roles).selectinload(Role.menus))
|
.options(selectinload(self.model.roles).selectinload(Role.menus))
|
||||||
.order_by(desc(self.model.time_joined))
|
.order_by(desc(self.model.time_joined))
|
||||||
)
|
)
|
||||||
|
@ -1,15 +1,37 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from starlette.authentication import AuthenticationBackend
|
from typing import Any
|
||||||
from fastapi import Request
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.authentication import AuthenticationBackend, AuthenticationError
|
||||||
|
from starlette.requests import HTTPConnection
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from backend.app.common import jwt
|
from backend.app.common import jwt
|
||||||
|
from backend.app.common.exception.errors import TokenError
|
||||||
|
from backend.app.core.conf import settings
|
||||||
from backend.app.database.db_mysql import async_db_session
|
from backend.app.database.db_mysql import async_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class _AuthenticationError(AuthenticationError):
|
||||||
|
"""重写内部认证错误类"""
|
||||||
|
|
||||||
|
def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None):
|
||||||
|
self.code = code
|
||||||
|
self.msg = msg
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
|
|
||||||
class JwtAuthMiddleware(AuthenticationBackend):
|
class JwtAuthMiddleware(AuthenticationBackend):
|
||||||
"""JWT 认证中间件"""
|
"""JWT 认证中间件"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auth_exception_handler(conn: HTTPConnection, exc: Exception) -> Response:
|
||||||
|
"""覆盖内部认证错误处理"""
|
||||||
|
code = getattr(exc, 'code', 500)
|
||||||
|
msg = getattr(exc, 'msg', 'Internal Server Error')
|
||||||
|
return JSONResponse(content={'code': code, 'msg': msg, 'data': None}, status_code=code)
|
||||||
|
|
||||||
async def authenticate(self, request: Request):
|
async def authenticate(self, request: Request):
|
||||||
auth = request.headers.get('Authorization')
|
auth = request.headers.get('Authorization')
|
||||||
if not auth:
|
if not auth:
|
||||||
@ -19,9 +41,15 @@ class JwtAuthMiddleware(AuthenticationBackend):
|
|||||||
if scheme.lower() != 'bearer':
|
if scheme.lower() != 'bearer':
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
sub = await jwt.jwt_authentication(token)
|
sub = await jwt.jwt_authentication(token)
|
||||||
|
|
||||||
async with async_db_session() as db:
|
async with async_db_session() as db:
|
||||||
user = await jwt.get_current_user(db, data=sub)
|
user = await jwt.get_current_user(db, data=sub)
|
||||||
|
except TokenError as exc:
|
||||||
|
raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
raise _AuthenticationError(msg=traceback.format_exc() if settings.ENVIRONMENT == 'dev' else None)
|
||||||
|
|
||||||
return auth, user
|
return auth, user
|
||||||
|
@ -16,7 +16,7 @@ class LoginLog(DataClassBase):
|
|||||||
id: Mapped[id_key] = mapped_column(init=False)
|
id: Mapped[id_key] = mapped_column(init=False)
|
||||||
user_uuid: Mapped[str] = mapped_column(String(50), nullable=False, comment='用户UUID')
|
user_uuid: Mapped[str] = mapped_column(String(50), nullable=False, comment='用户UUID')
|
||||||
username: Mapped[str] = mapped_column(String(20), nullable=False, comment='用户名')
|
username: Mapped[str] = mapped_column(String(20), nullable=False, comment='用户名')
|
||||||
status: Mapped[int] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
|
status: Mapped[bool] = mapped_column(insert_default=0, comment='登录状态(0失败 1成功)')
|
||||||
ipaddr: Mapped[str] = mapped_column(String(50), nullable=False, comment='登录IP地址')
|
ipaddr: Mapped[str] = mapped_column(String(50), nullable=False, comment='登录IP地址')
|
||||||
location: Mapped[str] = mapped_column(String(255), nullable=False, comment='归属地')
|
location: Mapped[str] = mapped_column(String(255), nullable=False, comment='归属地')
|
||||||
browser: Mapped[str] = mapped_column(String(255), nullable=False, comment='浏览器')
|
browser: Mapped[str] = mapped_column(String(255), nullable=False, comment='浏览器')
|
||||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel
|
|||||||
class LoginLogBase(BaseModel):
|
class LoginLogBase(BaseModel):
|
||||||
user_uuid: str
|
user_uuid: str
|
||||||
username: str
|
username: str
|
||||||
status: int
|
status: bool
|
||||||
ipaddr: str
|
ipaddr: str
|
||||||
location: str
|
location: str
|
||||||
browser: str
|
browser: str
|
||||||
|
@ -71,7 +71,7 @@ class AuthService:
|
|||||||
request=request,
|
request=request,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
login_time=self.login_time,
|
login_time=self.login_time,
|
||||||
status=LoginLogStatus.fail,
|
status=LoginLogStatus.fail.value,
|
||||||
msg=e.msg,
|
msg=e.msg,
|
||||||
)
|
)
|
||||||
task = BackgroundTask(LoginLogService.create, **err_log_info)
|
task = BackgroundTask(LoginLogService.create, **err_log_info)
|
||||||
@ -84,7 +84,7 @@ class AuthService:
|
|||||||
request=request,
|
request=request,
|
||||||
user=user,
|
user=user,
|
||||||
login_time=self.login_time,
|
login_time=self.login_time,
|
||||||
status=LoginLogStatus.success,
|
status=LoginLogStatus.success.value,
|
||||||
msg='登录成功',
|
msg='登录成功',
|
||||||
)
|
)
|
||||||
background_tasks.add_task(LoginLogService.create, **log_info)
|
background_tasks.add_task(LoginLogService.create, **log_info)
|
||||||
|
@ -19,8 +19,8 @@ from backend.app.utils import request_parse
|
|||||||
|
|
||||||
class LoginLogService:
|
class LoginLogService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_select() -> Select:
|
async def get_select(*, username: str, status: bool, ipaddr: str) -> Select:
|
||||||
return await LoginLogDao.get_all()
|
return await LoginLogDao.get_all(username=username, status=status, ipaddr=ipaddr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create(
|
async def create(
|
||||||
|
Reference in New Issue
Block a user