mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-18 23:11:48 +08:00
Add sync to async decorator support (#96)
* Add sync to async decorator support * Update ASyncTranslator to asgiref
This commit is contained in:
@ -18,7 +18,7 @@ router = APIRouter()
|
|||||||
@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth])
|
@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth])
|
||||||
async def get_api(pk: int):
|
async def get_api(pk: int):
|
||||||
api = await ApiService.get(pk=pk)
|
api = await ApiService.get(pk=pk)
|
||||||
return response_base.success(data=api)
|
return await response_base.success(data=api)
|
||||||
|
|
||||||
|
|
||||||
@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends])
|
@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends])
|
||||||
@ -30,26 +30,26 @@ async def get_all_apis(
|
|||||||
):
|
):
|
||||||
api_select = await ApiService.get_select(name=name, method=method, path=path)
|
api_select = await ApiService.get_select(name=name, method=method, path=path)
|
||||||
page_data = await paging_data(db, api_select, GetAllApi)
|
page_data = await paging_data(db, api_select, GetAllApi)
|
||||||
return response_base.success(data=page_data)
|
return await response_base.success(data=page_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post('', summary='创建接口', dependencies=[DependsRBAC])
|
@router.post('', summary='创建接口', dependencies=[DependsRBAC])
|
||||||
async def create_api(request: Request, obj: CreateApi):
|
async def create_api(request: Request, obj: CreateApi):
|
||||||
await ApiService.create(obj=obj, user_id=request.user.id)
|
await ApiService.create(obj=obj, user_id=request.user.id)
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
|
|
||||||
|
|
||||||
@router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC])
|
@router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC])
|
||||||
async def update_api(request: Request, pk: int, obj: UpdateApi):
|
async def update_api(request: Request, pk: int, obj: UpdateApi):
|
||||||
count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id)
|
count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.delete('', summary='(批量)删除接口', dependencies=[DependsRBAC])
|
@router.delete('', summary='(批量)删除接口', dependencies=[DependsRBAC])
|
||||||
async def delete_api(pk: Annotated[list[int], Query(...)]):
|
async def delete_api(pk: Annotated[list[int], Query(...)]):
|
||||||
count = await ApiService.delete(pk=pk)
|
count = await ApiService.delete(pk=pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
@ -39,17 +39,17 @@ async def user_login(request: Request, obj: Auth, background_tasks: BackgroundTa
|
|||||||
refresh_token_expire_time=refresh_expire,
|
refresh_token_expire_time=refresh_expire,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return response_base.success(data=data)
|
return await response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
|
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
|
||||||
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
|
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
|
||||||
access_token, access_expire = await AuthService.new_token(refresh_token)
|
access_token, access_expire = await AuthService.new_token(refresh_token)
|
||||||
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire)
|
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire)
|
||||||
return response_base.success(data=data)
|
return await response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
|
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
|
||||||
async def user_logout(request: Request):
|
async def user_logout(request: Request):
|
||||||
await AuthService.logout(request)
|
await AuthService.logout(request)
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
|
@ -12,7 +12,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
@router.get('', summary='获取系统配置', dependencies=[DependsRBAC])
|
@router.get('', summary='获取系统配置', dependencies=[DependsRBAC])
|
||||||
async def get_sys_config():
|
async def get_sys_config():
|
||||||
return response_base.success(
|
return await response_base.success(
|
||||||
data={
|
data={
|
||||||
'title': settings.TITLE,
|
'title': settings.TITLE,
|
||||||
'version': settings.VERSION,
|
'version': settings.VERSION,
|
||||||
@ -59,4 +59,4 @@ async def get_all_route(request: Request):
|
|||||||
for route in request.app.routes:
|
for route in request.app.routes:
|
||||||
if isinstance(route, APIRoute):
|
if isinstance(route, APIRoute):
|
||||||
data.append({'path': route.path, 'name': route.name, 'summary': route.summary, 'methods': route.methods})
|
data.append({'path': route.path, 'name': route.name, 'summary': route.summary, 'methods': route.methods})
|
||||||
return response_base.success(data={'route_list': data})
|
return await response_base.success(data={'route_list': data})
|
||||||
|
@ -24,20 +24,20 @@ async def get_all_login_logs(
|
|||||||
):
|
):
|
||||||
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
||||||
page_data = await paging_data(db, log_select, GetAllLoginLog)
|
page_data = await paging_data(db, log_select, GetAllLoginLog)
|
||||||
return response_base.success(data=page_data)
|
return await response_base.success(data=page_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete('', summary='(批量)删除登录日志', dependencies=[DependsRBAC])
|
@router.delete('', summary='(批量)删除登录日志', dependencies=[DependsRBAC])
|
||||||
async def delete_login_log(pk: Annotated[list[int], Query(...)]):
|
async def delete_login_log(pk: Annotated[list[int], Query(...)]):
|
||||||
count = await LoginLogService.delete(pk)
|
count = await LoginLogService.delete(pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.delete('/all', summary='清空登录日志', dependencies=[DependsRBAC])
|
@router.delete('/all', summary='清空登录日志', dependencies=[DependsRBAC])
|
||||||
async def delete_all_login_logs():
|
async def delete_all_login_logs():
|
||||||
count = await LoginLogService.delete_all()
|
count = await LoginLogService.delete_all()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
@ -17,27 +17,27 @@ router = APIRouter()
|
|||||||
|
|
||||||
@router.get('', summary='(模糊条件)分页获取操作日志', dependencies=[DependsJwtAuth, PageDepends])
|
@router.get('', summary='(模糊条件)分页获取操作日志', dependencies=[DependsJwtAuth, PageDepends])
|
||||||
async def get_all_opera_logs(
|
async def get_all_opera_logs(
|
||||||
db: CurrentSession,
|
db: CurrentSession,
|
||||||
username: Annotated[str | None, Query()] = None,
|
username: Annotated[str | None, Query()] = None,
|
||||||
status: Annotated[bool | None, Query()] = None,
|
status: Annotated[bool | None, Query()] = None,
|
||||||
ipaddr: Annotated[str | None, Query()] = None,
|
ipaddr: Annotated[str | None, Query()] = None,
|
||||||
):
|
):
|
||||||
log_select = await OperaLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
log_select = await OperaLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
||||||
page_data = await paging_data(db, log_select, GetAllOperaLog)
|
page_data = await paging_data(db, log_select, GetAllOperaLog)
|
||||||
return response_base.success(data=page_data)
|
return await response_base.success(data=page_data)
|
||||||
|
|
||||||
|
|
||||||
@router.delete('', summary='(批量)删除操作日志', dependencies=[DependsRBAC])
|
@router.delete('', summary='(批量)删除操作日志', dependencies=[DependsRBAC])
|
||||||
async def delete_opera_log(pk: Annotated[list[int], Query(...)]):
|
async def delete_opera_log(pk: Annotated[list[int], Query(...)]):
|
||||||
count = await OperaLogService.delete(pk)
|
count = await OperaLogService.delete(pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.delete('/all', summary='清空操作日志', dependencies=[DependsRBAC])
|
@router.delete('/all', summary='清空操作日志', dependencies=[DependsRBAC])
|
||||||
async def delete_all_opera_logs():
|
async def delete_all_opera_logs():
|
||||||
count = await OperaLogService.delete_all()
|
count = await OperaLogService.delete_all()
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
@ -20,7 +20,7 @@ router = APIRouter()
|
|||||||
async def get_role(pk: int):
|
async def get_role(pk: int):
|
||||||
role = await RoleService.get(pk=pk)
|
role = await RoleService.get(pk=pk)
|
||||||
data = GetAllRole(**select_to_json(role))
|
data = GetAllRole(**select_to_json(role))
|
||||||
return response_base.success(data=data)
|
return await response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
@router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsJwtAuth, PageDepends])
|
@router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsJwtAuth, PageDepends])
|
||||||
@ -31,26 +31,26 @@ async def get_all_roles(
|
|||||||
):
|
):
|
||||||
role_select = await RoleService.get_select(name=name, data_scope=data_scope)
|
role_select = await RoleService.get_select(name=name, data_scope=data_scope)
|
||||||
page_data = await paging_data(db, role_select, GetAllRole)
|
page_data = await paging_data(db, role_select, GetAllRole)
|
||||||
return response_base.success(data=page_data)
|
return await response_base.success(data=page_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post('', summary='创建角色', dependencies=[DependsRBAC])
|
@router.post('', summary='创建角色', dependencies=[DependsRBAC])
|
||||||
async def create_role(request: Request, obj: CreateRole):
|
async def create_role(request: Request, obj: CreateRole):
|
||||||
await RoleService.create(obj=obj, user_id=request.user.id)
|
await RoleService.create(obj=obj, user_id=request.user.id)
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
|
|
||||||
|
|
||||||
@router.put('/{pk}', summary='更新角色', dependencies=[DependsRBAC])
|
@router.put('/{pk}', summary='更新角色', dependencies=[DependsRBAC])
|
||||||
async def update_role(request: Request, pk: int, obj: UpdateRole):
|
async def update_role(request: Request, pk: int, obj: UpdateRole):
|
||||||
count = await RoleService.update(pk=pk, obj=obj, user_id=request.user.id)
|
count = await RoleService.update(pk=pk, obj=obj, user_id=request.user.id)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.delete('', summary='(批量)删除角色', dependencies=[DependsRBAC])
|
@router.delete('', summary='(批量)删除角色', dependencies=[DependsRBAC])
|
||||||
async def delete_role(pk: Annotated[list[int], Query(...)]):
|
async def delete_role(pk: Annotated[list[int], Query(...)]):
|
||||||
count = await RoleService.delete(pk=pk)
|
count = await RoleService.delete(pk=pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
@ -18,38 +18,38 @@ router = APIRouter()
|
|||||||
@router.post('/register', summary='用户注册')
|
@router.post('/register', summary='用户注册')
|
||||||
async def user_register(obj: CreateUser):
|
async def user_register(obj: CreateUser):
|
||||||
await UserService.register(obj)
|
await UserService.register(obj)
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
|
|
||||||
|
|
||||||
@router.post('/password/reset', summary='密码重置')
|
@router.post('/password/reset', summary='密码重置')
|
||||||
async def password_reset(obj: ResetPassword):
|
async def password_reset(obj: ResetPassword):
|
||||||
count = await UserService.pwd_reset(obj)
|
count = await UserService.pwd_reset(obj)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
|
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
|
||||||
async def get_user(username: str):
|
async def get_user(username: str):
|
||||||
current_user = await UserService.get_userinfo(username)
|
current_user = await UserService.get_userinfo(username)
|
||||||
data = GetAllUserInfo(**select_to_json(current_user))
|
data = GetAllUserInfo(**select_to_json(current_user))
|
||||||
return response_base.success(data=data)
|
return await response_base.success(data=data)
|
||||||
|
|
||||||
|
|
||||||
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
|
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
|
||||||
async def update_userinfo(request: Request, username: str, obj: UpdateUser):
|
async def update_userinfo(request: Request, username: str, obj: UpdateUser):
|
||||||
count = await UserService.update(request=request, username=username, obj=obj)
|
count = await UserService.update(request=request, username=username, obj=obj)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
|
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
|
||||||
async def update_avatar(request: Request, username: str, avatar: Avatar):
|
async def update_avatar(request: Request, username: str, avatar: Avatar):
|
||||||
count = await UserService.update_avatar(request=request, username=username, avatar=avatar)
|
count = await UserService.update_avatar(request=request, username=username, avatar=avatar)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends])
|
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends])
|
||||||
@ -61,31 +61,31 @@ async def get_all_users(
|
|||||||
):
|
):
|
||||||
user_select = await UserService.get_select(username=username, phone=phone, status=status)
|
user_select = await UserService.get_select(username=username, phone=phone, status=status)
|
||||||
page_data = await paging_data(db, user_select, GetAllUserInfo)
|
page_data = await paging_data(db, user_select, GetAllUserInfo)
|
||||||
return response_base.success(data=page_data)
|
return await response_base.success(data=page_data)
|
||||||
|
|
||||||
|
|
||||||
@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth])
|
@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth])
|
||||||
async def super_set(request: Request, pk: int):
|
async def super_set(request: Request, pk: int):
|
||||||
count = await UserService.update_permission(request=request, pk=pk)
|
count = await UserService.update_permission(request=request, pk=pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth])
|
@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth])
|
||||||
async def active_set(request: Request, pk: int):
|
async def active_set(request: Request, pk: int):
|
||||||
count = await UserService.update_active(request=request, pk=pk)
|
count = await UserService.update_active(request=request, pk=pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth])
|
@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth])
|
||||||
async def multi_set(request: Request, pk: int):
|
async def multi_set(request: Request, pk: int):
|
||||||
count = await UserService.update_multi_login(request=request, pk=pk)
|
count = await UserService.update_multi_login(request=request, pk=pk)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
@ -97,5 +97,5 @@ async def multi_set(request: Request, pk: int):
|
|||||||
async def delete_user(request: Request, username: str):
|
async def delete_user(request: Request, username: str):
|
||||||
count = await UserService.delete(request=request, username=username)
|
count = await UserService.delete(request=request, username=username)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return response_base.success()
|
return await response_base.success()
|
||||||
return response_base.fail()
|
return await response_base.fail()
|
||||||
|
@ -37,7 +37,7 @@ def _get_exception_code(status_code):
|
|||||||
|
|
||||||
def register_exception(app: FastAPI):
|
def register_exception(app: FastAPI):
|
||||||
@app.exception_handler(HTTPException)
|
@app.exception_handler(HTTPException)
|
||||||
def http_exception_handler(request: Request, exc: HTTPException):
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||||
"""
|
"""
|
||||||
全局HTTP异常处理
|
全局HTTP异常处理
|
||||||
|
|
||||||
@ -47,12 +47,12 @@ def register_exception(app: FastAPI):
|
|||||||
"""
|
"""
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=_get_exception_code(exc.status_code),
|
status_code=_get_exception_code(exc.status_code),
|
||||||
content=response_base.fail(code=exc.status_code, msg=exc.detail),
|
content=await response_base.fail(code=exc.status_code, msg=exc.detail),
|
||||||
headers=exc.headers,
|
headers=exc.headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
"""
|
"""
|
||||||
数据验证异常处理
|
数据验证异常处理
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ def register_exception(app: FastAPI):
|
|||||||
message += 'json解析失败'
|
message += 'json解析失败'
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
content=response_base.fail(
|
content=await response_base.fail(
|
||||||
code=422,
|
code=422,
|
||||||
msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}',
|
msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}',
|
||||||
data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None,
|
data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None,
|
||||||
@ -92,7 +92,7 @@ def register_exception(app: FastAPI):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
def all_exception_handler(request: Request, exc: Exception):
|
async def all_exception_handler(request: Request, exc: Exception):
|
||||||
"""
|
"""
|
||||||
全局异常处理
|
全局异常处理
|
||||||
|
|
||||||
@ -103,14 +103,14 @@ def register_exception(app: FastAPI):
|
|||||||
if isinstance(exc, BaseExceptionMixin):
|
if isinstance(exc, BaseExceptionMixin):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=_get_exception_code(exc.code),
|
status_code=_get_exception_code(exc.code),
|
||||||
content=response_base.fail(code=exc.code, msg=str(exc.msg), data=exc.data if exc.data else None),
|
content=await response_base.fail(code=exc.code, msg=str(exc.msg), data=exc.data if exc.data else None),
|
||||||
background=exc.background,
|
background=exc.background,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(exc, AssertionError):
|
elif isinstance(exc, AssertionError):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content=response_base.fail(
|
content=await response_base.fail(
|
||||||
code=500,
|
code=500,
|
||||||
msg=','.join(exc.args)
|
msg=','.join(exc.args)
|
||||||
if exc.args
|
if exc.args
|
||||||
@ -119,14 +119,14 @@ def register_exception(app: FastAPI):
|
|||||||
else exc.__doc__,
|
else exc.__doc__,
|
||||||
)
|
)
|
||||||
if settings.ENVIRONMENT == 'dev'
|
if settings.ENVIRONMENT == 'dev'
|
||||||
else response_base.fail(code=500, msg='Internal Server Error'),
|
else await response_base.fail(code=500, msg='Internal Server Error'),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.error(exc)
|
log.error(exc)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content=response_base.fail(code=500, msg=str(exc))
|
content=await response_base.fail(code=500, msg=str(exc))
|
||||||
if settings.ENVIRONMENT == 'dev'
|
if settings.ENVIRONMENT == 'dev'
|
||||||
else response_base.fail(code=500, msg='Internal Server Error'),
|
else await response_base.fail(code=500, msg='Internal Server Error'),
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import Depends, Request
|
from fastapi import Depends, Request
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.security.utils import get_authorization_scheme_param
|
from fastapi.security.utils import get_authorization_scheme_param
|
||||||
@ -21,6 +22,7 @@ pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
|||||||
oauth2_schema = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL_SWAGGER)
|
oauth2_schema = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL_SWAGGER)
|
||||||
|
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
def get_hash_password(password: str) -> str:
|
def get_hash_password(password: str) -> str:
|
||||||
"""
|
"""
|
||||||
Encrypt passwords using the hash algorithm
|
Encrypt passwords using the hash algorithm
|
||||||
@ -31,6 +33,7 @@ def get_hash_password(password: str) -> str:
|
|||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
def password_verify(plain_password: str, hashed_password: str) -> bool:
|
def password_verify(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Password verification
|
Password verification
|
||||||
@ -107,6 +110,7 @@ async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str,
|
|||||||
return new_token, expire
|
return new_token, expire
|
||||||
|
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
def get_token(request: Request) -> str:
|
def get_token(request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
Get token for request header
|
Get token for request header
|
||||||
@ -120,6 +124,7 @@ def get_token(request: Request) -> str:
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
def jwt_decode(token: str) -> tuple[int, list[int]]:
|
def jwt_decode(token: str) -> tuple[int, list[int]]:
|
||||||
"""
|
"""
|
||||||
Decode token
|
Decode token
|
||||||
@ -145,7 +150,7 @@ async def jwt_authentication(token: str) -> dict[str, int]:
|
|||||||
:param token:
|
:param token:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
user_id, _ = jwt_decode(token)
|
user_id, _ = await jwt_decode(token)
|
||||||
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
|
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
|
||||||
token_verify = await redis_client.get(key)
|
token_verify = await redis_client.get(key)
|
||||||
if not token_verify:
|
if not token_verify:
|
||||||
@ -170,7 +175,8 @@ async def get_current_user(db: AsyncSession, data: dict) -> User:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def superuser_verify(request: Request) -> bool:
|
@sync_to_async
|
||||||
|
def superuser_verify(request: Request) -> bool:
|
||||||
"""
|
"""
|
||||||
Verify the current user permissions through token
|
Verify the current user permissions through token
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from pydantic import validate_arguments, BaseModel
|
from pydantic import validate_arguments, BaseModel
|
||||||
|
|
||||||
from backend.app.utils.encoders import jsonable_encoder
|
from backend.app.utils.encoders import jsonable_encoder
|
||||||
@ -51,19 +52,26 @@ class ResponseBase:
|
|||||||
|
|
||||||
@router.get('/test')
|
@router.get('/test')
|
||||||
def test():
|
def test():
|
||||||
return response_base.success(data={'test': 'test'})
|
return await response_base.success(data={'test': 'test'})
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@sync_to_async
|
||||||
def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs):
|
def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs):
|
||||||
custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}
|
custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}
|
||||||
kwargs.update({'custom_encoder': custom_encoder})
|
kwargs.update({'custom_encoder': custom_encoder})
|
||||||
return jsonable_encoder(data, exclude=exclude, **kwargs)
|
result = jsonable_encoder(data, exclude=exclude, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@validate_arguments
|
@validate_arguments
|
||||||
def success(
|
async def success(
|
||||||
*, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _ExcludeData | None = None, **kwargs
|
self,
|
||||||
|
*,
|
||||||
|
code: int = 200,
|
||||||
|
msg: str = 'Success',
|
||||||
|
data: Any | None = None,
|
||||||
|
exclude: _ExcludeData | None = None,
|
||||||
|
**kwargs
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
请求成功返回通用方法
|
请求成功返回通用方法
|
||||||
@ -74,15 +82,20 @@ class ResponseBase:
|
|||||||
:param exclude: 排除返回数据(data)字段
|
:param exclude: 排除返回数据(data)字段
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs)
|
data = data if data is None else await self.__json_encoder(data, exclude, **kwargs)
|
||||||
return {'code': code, 'msg': msg, 'data': data}
|
return {'code': code, 'msg': msg, 'data': data}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@validate_arguments
|
@validate_arguments
|
||||||
def fail(
|
async def fail(
|
||||||
*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _ExcludeData | None = None, **kwargs
|
self,
|
||||||
|
*,
|
||||||
|
code: int = 400,
|
||||||
|
msg: str = 'Bad Request',
|
||||||
|
data: Any = None,
|
||||||
|
exclude: _ExcludeData | None = None,
|
||||||
|
**kwargs
|
||||||
) -> dict:
|
) -> dict:
|
||||||
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs)
|
data = data if data is None else await self.__json_encoder(data, exclude, **kwargs)
|
||||||
return {'code': code, 'msg': msg, 'data': data}
|
return {'code': code, 'msg': msg, 'data': data}
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +11,9 @@ from backend.app.schemas.opera_log import CreateOperaLog, UpdateOperaLog
|
|||||||
|
|
||||||
|
|
||||||
class CRUDOperaLogDao(CRUDBase[OperaLog, CreateOperaLog, UpdateOperaLog]):
|
class CRUDOperaLogDao(CRUDBase[OperaLog, CreateOperaLog, UpdateOperaLog]):
|
||||||
async def get_all(self, username: str | None = None, status: bool | None = None, ipaddr: str | None = None) -> Select:
|
async def get_all(
|
||||||
|
self, username: str | None = None, status: bool | None = None, ipaddr: str | None = None
|
||||||
|
) -> Select:
|
||||||
se = select(self.model).order_by(desc(self.model.create_time))
|
se = select(self.model).order_by(desc(self.model.create_time))
|
||||||
where_list = []
|
where_list = []
|
||||||
if username:
|
if username:
|
||||||
|
@ -28,7 +28,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
|
|||||||
return user.rowcount
|
return user.rowcount
|
||||||
|
|
||||||
async def create(self, db: AsyncSession, create: CreateUser) -> NoReturn:
|
async def create(self, db: AsyncSession, create: CreateUser) -> NoReturn:
|
||||||
create.password = jwt.get_hash_password(create.password)
|
create.password = await jwt.get_hash_password(create.password)
|
||||||
new_user = self.model(**create.dict(exclude={'roles'}))
|
new_user = self.model(**create.dict(exclude={'roles'}))
|
||||||
role_list = []
|
role_list = []
|
||||||
for role_id in create.roles:
|
for role_id in create.roles:
|
||||||
@ -63,7 +63,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
|
|||||||
|
|
||||||
async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int:
|
async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int:
|
||||||
user = await db.execute(
|
user = await db.execute(
|
||||||
update(self.model).where(self.model.id == pk).values(password=jwt.get_hash_password(password))
|
update(self.model).where(self.model.id == pk).values(password=await jwt.get_hash_password(password))
|
||||||
)
|
)
|
||||||
return user.rowcount
|
return user.rowcount
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class InitTestData:
|
|||||||
user_obj = User(
|
user_obj = User(
|
||||||
username=username,
|
username=username,
|
||||||
nickname=username,
|
nickname=username,
|
||||||
password=get_hash_password(password),
|
password=await get_hash_password(password),
|
||||||
email=email,
|
email=email,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
dept_id=1,
|
dept_id=1,
|
||||||
@ -70,7 +70,7 @@ class InitTestData:
|
|||||||
user_obj = User(
|
user_obj = User(
|
||||||
username=username,
|
username=username,
|
||||||
nickname=username,
|
nickname=username,
|
||||||
password=get_hash_password(password),
|
password=await get_hash_password(password),
|
||||||
email=email,
|
email=email,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
dept_id=1,
|
dept_id=1,
|
||||||
@ -88,7 +88,7 @@ class InitTestData:
|
|||||||
user_obj = User(
|
user_obj = User(
|
||||||
username=username,
|
username=username,
|
||||||
nickname=username,
|
nickname=username,
|
||||||
password=get_hash_password(password),
|
password=await get_hash_password(password),
|
||||||
email=email,
|
email=email,
|
||||||
is_superuser=False,
|
is_superuser=False,
|
||||||
dept_id=1,
|
dept_id=1,
|
||||||
@ -106,7 +106,7 @@ class InitTestData:
|
|||||||
user_obj = User(
|
user_obj = User(
|
||||||
username=username,
|
username=username,
|
||||||
nickname=username,
|
nickname=username,
|
||||||
password=get_hash_password(password),
|
password=await get_hash_password(password),
|
||||||
email=email,
|
email=email,
|
||||||
is_active=False,
|
is_active=False,
|
||||||
is_superuser=False,
|
is_superuser=False,
|
||||||
@ -125,7 +125,7 @@ class InitTestData:
|
|||||||
user_obj = User(
|
user_obj = User(
|
||||||
username=username,
|
username=username,
|
||||||
nickname=username,
|
nickname=username,
|
||||||
password=get_hash_password(password),
|
password=await get_hash_password(password),
|
||||||
email=email,
|
email=email,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
dept_id=1,
|
dept_id=1,
|
||||||
@ -143,7 +143,7 @@ class InitTestData:
|
|||||||
user_obj = User(
|
user_obj = User(
|
||||||
username=username,
|
username=username,
|
||||||
nickname=username,
|
nickname=username,
|
||||||
password=get_hash_password(password),
|
password=await get_hash_password(password),
|
||||||
email=email,
|
email=email,
|
||||||
is_active=False,
|
is_active=False,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
|
@ -43,7 +43,7 @@ class OperaLogMiddleware:
|
|||||||
if settings.LOCATION_PARSE == 'online':
|
if settings.LOCATION_PARSE == 'online':
|
||||||
location = await request_parse.get_location_online(ip, user_agent)
|
location = await request_parse.get_location_online(ip, user_agent)
|
||||||
elif settings.LOCATION_PARSE == 'offline':
|
elif settings.LOCATION_PARSE == 'offline':
|
||||||
location = request_parse.get_location_offline(ip)
|
location = await request_parse.get_location_offline(ip)
|
||||||
else:
|
else:
|
||||||
location = '未知'
|
location = '未知'
|
||||||
try:
|
try:
|
||||||
|
@ -28,7 +28,7 @@ class AuthService:
|
|||||||
current_user = await UserDao.get_by_username(db, form_data.username)
|
current_user = await UserDao.get_by_username(db, form_data.username)
|
||||||
if not current_user:
|
if not current_user:
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
elif not jwt.password_verify(form_data.password, current_user.password):
|
elif not await jwt.password_verify(form_data.password, current_user.password):
|
||||||
raise errors.AuthorizationError(msg='密码错误')
|
raise errors.AuthorizationError(msg='密码错误')
|
||||||
elif not current_user.is_active:
|
elif not current_user.is_active:
|
||||||
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
||||||
@ -50,7 +50,7 @@ class AuthService:
|
|||||||
current_user = await UserDao.get_by_username(db, obj.username)
|
current_user = await UserDao.get_by_username(db, obj.username)
|
||||||
if not current_user:
|
if not current_user:
|
||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
elif not jwt.password_verify(obj.password, current_user.password):
|
elif not await jwt.password_verify(obj.password, current_user.password):
|
||||||
raise errors.AuthorizationError(msg='密码错误')
|
raise errors.AuthorizationError(msg='密码错误')
|
||||||
elif not current_user.is_active:
|
elif not current_user.is_active:
|
||||||
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
||||||
@ -92,7 +92,7 @@ class AuthService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def new_token(refresh_token: str) -> tuple[str, datetime]:
|
async def new_token(refresh_token: str) -> tuple[str, datetime]:
|
||||||
user_id, role_ids = jwt.jwt_decode(refresh_token)
|
user_id, role_ids = await jwt.jwt_decode(refresh_token)
|
||||||
async with async_db_session() as db:
|
async with async_db_session() as db:
|
||||||
current_user = await UserDao.get(db, user_id)
|
current_user = await UserDao.get(db, user_id)
|
||||||
if not current_user:
|
if not current_user:
|
||||||
@ -106,7 +106,7 @@ class AuthService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def logout(request: Request) -> NoReturn:
|
async def logout(request: Request) -> NoReturn:
|
||||||
token = get_token(request)
|
token = await get_token(request)
|
||||||
if request.user.is_multi_login:
|
if request.user.is_multi_login:
|
||||||
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
|
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
|
||||||
await redis_client.delete(key)
|
await redis_client.delete(key)
|
||||||
|
@ -6,10 +6,8 @@ from typing import NoReturn
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy import Select
|
from sqlalchemy import Select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from user_agents import parse
|
|
||||||
|
|
||||||
from backend.app.common.log import log
|
from backend.app.common.log import log
|
||||||
from backend.app.core.conf import settings
|
|
||||||
from backend.app.crud.crud_login_log import LoginLogDao
|
from backend.app.crud.crud_login_log import LoginLogDao
|
||||||
from backend.app.database.db_mysql import async_db_session
|
from backend.app.database.db_mysql import async_db_session
|
||||||
from backend.app.models import User
|
from backend.app.models import User
|
||||||
|
@ -135,8 +135,8 @@ class UserService:
|
|||||||
raise errors.NotFoundError(msg='用户不存在')
|
raise errors.NotFoundError(msg='用户不存在')
|
||||||
else:
|
else:
|
||||||
count = await UserDao.set_multi_login(db, pk)
|
count = await UserDao.set_multi_login(db, pk)
|
||||||
token = get_token(request)
|
token = await get_token(request)
|
||||||
user_id, role_ids = jwt_decode(token)
|
user_id, role_ids = await jwt_decode(token)
|
||||||
latest_multi_login = await UserDao.get_multi_login(db, pk)
|
latest_multi_login = await UserDao.get_multi_login(db, pk)
|
||||||
# TODO: 删除用户 refresh token, 此操作需要传参,暂时不考虑实现
|
# TODO: 删除用户 refresh token, 此操作需要传参,暂时不考虑实现
|
||||||
# 当前用户修改自身时(普通/超级),除当前token外,其他token失效
|
# 当前用户修改自身时(普通/超级),除当前token外,其他token失效
|
||||||
|
@ -4,7 +4,7 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
def get_uuid() -> str:
|
def get_uuid_str() -> str:
|
||||||
"""
|
"""
|
||||||
生成uuid
|
生成uuid
|
||||||
|
|
@ -2,13 +2,15 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import httpx
|
import httpx
|
||||||
from XdbSearchIP.xdbSearcher import XdbSearcher
|
from XdbSearchIP.xdbSearcher import XdbSearcher
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from backend.app.core.path_conf import IP2REGION_XDB
|
from backend.app.core.path_conf import IP2REGION_XDB
|
||||||
|
|
||||||
|
|
||||||
async def get_request_ip(request: Request) -> str:
|
@sync_to_async
|
||||||
|
def get_request_ip(request: Request) -> str:
|
||||||
"""获取请求的 ip 地址"""
|
"""获取请求的 ip 地址"""
|
||||||
real = request.headers.get('X-Real-IP')
|
real = request.headers.get('X-Real-IP')
|
||||||
if real:
|
if real:
|
||||||
@ -46,6 +48,7 @@ async def get_location_online(ipaddr: str, user_agent: str) -> str:
|
|||||||
return city or '未知' if city != '' else '未知'
|
return city or '未知' if city != '' else '未知'
|
||||||
|
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
def get_location_offline(ipaddr: str) -> str:
|
def get_location_offline(ipaddr: str) -> str:
|
||||||
"""
|
"""
|
||||||
离线获取 ip 地址属地,无法保证准确率,100%可用
|
离线获取 ip 地址属地,无法保证准确率,100%可用
|
||||||
|
@ -3,6 +3,7 @@ aioredis==2.0.1
|
|||||||
aiosmtplib==1.1.6
|
aiosmtplib==1.1.6
|
||||||
alembic==1.7.4
|
alembic==1.7.4
|
||||||
APScheduler==3.8.1
|
APScheduler==3.8.1
|
||||||
|
asgiref==3.7.2
|
||||||
asynccasbin==1.1.8
|
asynccasbin==1.1.8
|
||||||
asyncmy==0.2.5
|
asyncmy==0.2.5
|
||||||
bcrypt==3.2.2
|
bcrypt==3.2.2
|
||||||
|
Reference in New Issue
Block a user