Add sync to async decorator support (#96)

* Add sync to async decorator support

* Update ASyncTranslator to asgiref
This commit is contained in:
Wu Clan
2023-06-06 20:13:33 +08:00
committed by GitHub
parent 9e833db5f5
commit abcc9d2308
20 changed files with 115 additions and 92 deletions

View File

@ -18,7 +18,7 @@ router = APIRouter()
@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth]) @router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth])
async def get_api(pk: int): async def get_api(pk: int):
api = await ApiService.get(pk=pk) api = await ApiService.get(pk=pk)
return response_base.success(data=api) return await response_base.success(data=api)
@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends]) @router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends])
@ -30,26 +30,26 @@ async def get_all_apis(
): ):
api_select = await ApiService.get_select(name=name, method=method, path=path) api_select = await ApiService.get_select(name=name, method=method, path=path)
page_data = await paging_data(db, api_select, GetAllApi) page_data = await paging_data(db, api_select, GetAllApi)
return response_base.success(data=page_data) return await response_base.success(data=page_data)
@router.post('', summary='创建接口', dependencies=[DependsRBAC]) @router.post('', summary='创建接口', dependencies=[DependsRBAC])
async def create_api(request: Request, obj: CreateApi): async def create_api(request: Request, obj: CreateApi):
await ApiService.create(obj=obj, user_id=request.user.id) await ApiService.create(obj=obj, user_id=request.user.id)
return response_base.success() return await response_base.success()
@router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC]) @router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC])
async def update_api(request: Request, pk: int, obj: UpdateApi): async def update_api(request: Request, pk: int, obj: UpdateApi):
count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id) count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.delete('', summary='(批量)删除接口', dependencies=[DependsRBAC]) @router.delete('', summary='(批量)删除接口', dependencies=[DependsRBAC])
async def delete_api(pk: Annotated[list[int], Query(...)]): async def delete_api(pk: Annotated[list[int], Query(...)]):
count = await ApiService.delete(pk=pk) count = await ApiService.delete(pk=pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()

View File

@ -39,17 +39,17 @@ async def user_login(request: Request, obj: Auth, background_tasks: BackgroundTa
refresh_token_expire_time=refresh_expire, refresh_token_expire_time=refresh_expire,
user=user, user=user,
) )
return response_base.success(data=data) return await response_base.success(data=data)
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth]) @router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
async def create_new_token(refresh_token: Annotated[str, Query(...)]): async def create_new_token(refresh_token: Annotated[str, Query(...)]):
access_token, access_expire = await AuthService.new_token(refresh_token) access_token, access_expire = await AuthService.new_token(refresh_token)
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire) data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire)
return response_base.success(data=data) return await response_base.success(data=data)
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth]) @router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
async def user_logout(request: Request): async def user_logout(request: Request):
await AuthService.logout(request) await AuthService.logout(request)
return response_base.success() return await response_base.success()

View File

@ -12,7 +12,7 @@ router = APIRouter()
@router.get('', summary='获取系统配置', dependencies=[DependsRBAC]) @router.get('', summary='获取系统配置', dependencies=[DependsRBAC])
async def get_sys_config(): async def get_sys_config():
return response_base.success( return await response_base.success(
data={ data={
'title': settings.TITLE, 'title': settings.TITLE,
'version': settings.VERSION, 'version': settings.VERSION,
@ -59,4 +59,4 @@ async def get_all_route(request: Request):
for route in request.app.routes: for route in request.app.routes:
if isinstance(route, APIRoute): if isinstance(route, APIRoute):
data.append({'path': route.path, 'name': route.name, 'summary': route.summary, 'methods': route.methods}) data.append({'path': route.path, 'name': route.name, 'summary': route.summary, 'methods': route.methods})
return response_base.success(data={'route_list': data}) return await response_base.success(data={'route_list': data})

View File

@ -24,20 +24,20 @@ async def get_all_login_logs(
): ):
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr) log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
page_data = await paging_data(db, log_select, GetAllLoginLog) page_data = await paging_data(db, log_select, GetAllLoginLog)
return response_base.success(data=page_data) return await response_base.success(data=page_data)
@router.delete('', summary='(批量)删除登录日志', dependencies=[DependsRBAC]) @router.delete('', summary='(批量)删除登录日志', dependencies=[DependsRBAC])
async def delete_login_log(pk: Annotated[list[int], Query(...)]): async def delete_login_log(pk: Annotated[list[int], Query(...)]):
count = await LoginLogService.delete(pk) count = await LoginLogService.delete(pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.delete('/all', summary='清空登录日志', dependencies=[DependsRBAC]) @router.delete('/all', summary='清空登录日志', dependencies=[DependsRBAC])
async def delete_all_login_logs(): async def delete_all_login_logs():
count = await LoginLogService.delete_all() count = await LoginLogService.delete_all()
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()

View File

@ -17,27 +17,27 @@ router = APIRouter()
@router.get('', summary='(模糊条件)分页获取操作日志', dependencies=[DependsJwtAuth, PageDepends]) @router.get('', summary='(模糊条件)分页获取操作日志', dependencies=[DependsJwtAuth, PageDepends])
async def get_all_opera_logs( async def get_all_opera_logs(
db: CurrentSession, db: CurrentSession,
username: Annotated[str | None, Query()] = None, username: Annotated[str | None, Query()] = None,
status: Annotated[bool | None, Query()] = None, status: Annotated[bool | None, Query()] = None,
ipaddr: Annotated[str | None, Query()] = None, ipaddr: Annotated[str | None, Query()] = None,
): ):
log_select = await OperaLogService.get_select(username=username, status=status, ipaddr=ipaddr) log_select = await OperaLogService.get_select(username=username, status=status, ipaddr=ipaddr)
page_data = await paging_data(db, log_select, GetAllOperaLog) page_data = await paging_data(db, log_select, GetAllOperaLog)
return response_base.success(data=page_data) return await response_base.success(data=page_data)
@router.delete('', summary='(批量)删除操作日志', dependencies=[DependsRBAC]) @router.delete('', summary='(批量)删除操作日志', dependencies=[DependsRBAC])
async def delete_opera_log(pk: Annotated[list[int], Query(...)]): async def delete_opera_log(pk: Annotated[list[int], Query(...)]):
count = await OperaLogService.delete(pk) count = await OperaLogService.delete(pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.delete('/all', summary='清空操作日志', dependencies=[DependsRBAC]) @router.delete('/all', summary='清空操作日志', dependencies=[DependsRBAC])
async def delete_all_opera_logs(): async def delete_all_opera_logs():
count = await OperaLogService.delete_all() count = await OperaLogService.delete_all()
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()

View File

@ -20,7 +20,7 @@ router = APIRouter()
async def get_role(pk: int): async def get_role(pk: int):
role = await RoleService.get(pk=pk) role = await RoleService.get(pk=pk)
data = GetAllRole(**select_to_json(role)) data = GetAllRole(**select_to_json(role))
return response_base.success(data=data) return await response_base.success(data=data)
@router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsJwtAuth, PageDepends]) @router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsJwtAuth, PageDepends])
@ -31,26 +31,26 @@ async def get_all_roles(
): ):
role_select = await RoleService.get_select(name=name, data_scope=data_scope) role_select = await RoleService.get_select(name=name, data_scope=data_scope)
page_data = await paging_data(db, role_select, GetAllRole) page_data = await paging_data(db, role_select, GetAllRole)
return response_base.success(data=page_data) return await response_base.success(data=page_data)
@router.post('', summary='创建角色', dependencies=[DependsRBAC]) @router.post('', summary='创建角色', dependencies=[DependsRBAC])
async def create_role(request: Request, obj: CreateRole): async def create_role(request: Request, obj: CreateRole):
await RoleService.create(obj=obj, user_id=request.user.id) await RoleService.create(obj=obj, user_id=request.user.id)
return response_base.success() return await response_base.success()
@router.put('/{pk}', summary='更新角色', dependencies=[DependsRBAC]) @router.put('/{pk}', summary='更新角色', dependencies=[DependsRBAC])
async def update_role(request: Request, pk: int, obj: UpdateRole): async def update_role(request: Request, pk: int, obj: UpdateRole):
count = await RoleService.update(pk=pk, obj=obj, user_id=request.user.id) count = await RoleService.update(pk=pk, obj=obj, user_id=request.user.id)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.delete('', summary='(批量)删除角色', dependencies=[DependsRBAC]) @router.delete('', summary='(批量)删除角色', dependencies=[DependsRBAC])
async def delete_role(pk: Annotated[list[int], Query(...)]): async def delete_role(pk: Annotated[list[int], Query(...)]):
count = await RoleService.delete(pk=pk) count = await RoleService.delete(pk=pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()

View File

@ -18,38 +18,38 @@ router = APIRouter()
@router.post('/register', summary='用户注册') @router.post('/register', summary='用户注册')
async def user_register(obj: CreateUser): async def user_register(obj: CreateUser):
await UserService.register(obj) await UserService.register(obj)
return response_base.success() return await response_base.success()
@router.post('/password/reset', summary='密码重置') @router.post('/password/reset', summary='密码重置')
async def password_reset(obj: ResetPassword): async def password_reset(obj: ResetPassword):
count = await UserService.pwd_reset(obj) count = await UserService.pwd_reset(obj)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth]) @router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
async def get_user(username: str): async def get_user(username: str):
current_user = await UserService.get_userinfo(username) current_user = await UserService.get_userinfo(username)
data = GetAllUserInfo(**select_to_json(current_user)) data = GetAllUserInfo(**select_to_json(current_user))
return response_base.success(data=data) return await response_base.success(data=data)
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth]) @router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
async def update_userinfo(request: Request, username: str, obj: UpdateUser): async def update_userinfo(request: Request, username: str, obj: UpdateUser):
count = await UserService.update(request=request, username=username, obj=obj) count = await UserService.update(request=request, username=username, obj=obj)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth]) @router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
async def update_avatar(request: Request, username: str, avatar: Avatar): async def update_avatar(request: Request, username: str, avatar: Avatar):
count = await UserService.update_avatar(request=request, username=username, avatar=avatar) count = await UserService.update_avatar(request=request, username=username, avatar=avatar)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends]) @router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends])
@ -61,31 +61,31 @@ async def get_all_users(
): ):
user_select = await UserService.get_select(username=username, phone=phone, status=status) user_select = await UserService.get_select(username=username, phone=phone, status=status)
page_data = await paging_data(db, user_select, GetAllUserInfo) page_data = await paging_data(db, user_select, GetAllUserInfo)
return response_base.success(data=page_data) return await response_base.success(data=page_data)
@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth]) @router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth])
async def super_set(request: Request, pk: int): async def super_set(request: Request, pk: int):
count = await UserService.update_permission(request=request, pk=pk) count = await UserService.update_permission(request=request, pk=pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth]) @router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth])
async def active_set(request: Request, pk: int): async def active_set(request: Request, pk: int):
count = await UserService.update_active(request=request, pk=pk) count = await UserService.update_active(request=request, pk=pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth]) @router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth])
async def multi_set(request: Request, pk: int): async def multi_set(request: Request, pk: int):
count = await UserService.update_multi_login(request=request, pk=pk) count = await UserService.update_multi_login(request=request, pk=pk)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()
@router.delete( @router.delete(
@ -97,5 +97,5 @@ async def multi_set(request: Request, pk: int):
async def delete_user(request: Request, username: str): async def delete_user(request: Request, username: str):
count = await UserService.delete(request=request, username=username) count = await UserService.delete(request=request, username=username)
if count > 0: if count > 0:
return response_base.success() return await response_base.success()
return response_base.fail() return await response_base.fail()

View File

@ -37,7 +37,7 @@ def _get_exception_code(status_code):
def register_exception(app: FastAPI): def register_exception(app: FastAPI):
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
def http_exception_handler(request: Request, exc: HTTPException): async def http_exception_handler(request: Request, exc: HTTPException):
""" """
全局HTTP异常处理 全局HTTP异常处理
@ -47,12 +47,12 @@ def register_exception(app: FastAPI):
""" """
return JSONResponse( return JSONResponse(
status_code=_get_exception_code(exc.status_code), status_code=_get_exception_code(exc.status_code),
content=response_base.fail(code=exc.status_code, msg=exc.detail), content=await response_base.fail(code=exc.status_code, msg=exc.detail),
headers=exc.headers, headers=exc.headers,
) )
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
""" """
数据验证异常处理 数据验证异常处理
@ -84,7 +84,7 @@ def register_exception(app: FastAPI):
message += 'json解析失败' message += 'json解析失败'
return JSONResponse( return JSONResponse(
status_code=422, status_code=422,
content=response_base.fail( content=await response_base.fail(
code=422, code=422,
msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}', msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}',
data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None, data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None,
@ -92,7 +92,7 @@ def register_exception(app: FastAPI):
) )
@app.exception_handler(Exception) @app.exception_handler(Exception)
def all_exception_handler(request: Request, exc: Exception): async def all_exception_handler(request: Request, exc: Exception):
""" """
全局异常处理 全局异常处理
@ -103,14 +103,14 @@ def register_exception(app: FastAPI):
if isinstance(exc, BaseExceptionMixin): if isinstance(exc, BaseExceptionMixin):
return JSONResponse( return JSONResponse(
status_code=_get_exception_code(exc.code), status_code=_get_exception_code(exc.code),
content=response_base.fail(code=exc.code, msg=str(exc.msg), data=exc.data if exc.data else None), content=await response_base.fail(code=exc.code, msg=str(exc.msg), data=exc.data if exc.data else None),
background=exc.background, background=exc.background,
) )
elif isinstance(exc, AssertionError): elif isinstance(exc, AssertionError):
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=response_base.fail( content=await response_base.fail(
code=500, code=500,
msg=','.join(exc.args) msg=','.join(exc.args)
if exc.args if exc.args
@ -119,14 +119,14 @@ def register_exception(app: FastAPI):
else exc.__doc__, else exc.__doc__,
) )
if settings.ENVIRONMENT == 'dev' if settings.ENVIRONMENT == 'dev'
else response_base.fail(code=500, msg='Internal Server Error'), else await response_base.fail(code=500, msg='Internal Server Error'),
) )
else: else:
log.error(exc) log.error(exc)
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=response_base.fail(code=500, msg=str(exc)) content=await response_base.fail(code=500, msg=str(exc))
if settings.ENVIRONMENT == 'dev' if settings.ENVIRONMENT == 'dev'
else response_base.fail(code=500, msg='Internal Server Error'), else await response_base.fail(code=500, msg='Internal Server Error'),
) )

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from datetime import datetime, timedelta from datetime import datetime, timedelta
from asgiref.sync import sync_to_async
from fastapi import Depends, Request from fastapi import Depends, Request
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from fastapi.security.utils import get_authorization_scheme_param from fastapi.security.utils import get_authorization_scheme_param
@ -21,6 +22,7 @@ pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
oauth2_schema = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL_SWAGGER) oauth2_schema = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL_SWAGGER)
@sync_to_async
def get_hash_password(password: str) -> str: def get_hash_password(password: str) -> str:
""" """
Encrypt passwords using the hash algorithm Encrypt passwords using the hash algorithm
@ -31,6 +33,7 @@ def get_hash_password(password: str) -> str:
return pwd_context.hash(password) return pwd_context.hash(password)
@sync_to_async
def password_verify(plain_password: str, hashed_password: str) -> bool: def password_verify(plain_password: str, hashed_password: str) -> bool:
""" """
Password verification Password verification
@ -107,6 +110,7 @@ async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str,
return new_token, expire return new_token, expire
@sync_to_async
def get_token(request: Request) -> str: def get_token(request: Request) -> str:
""" """
Get token for request header Get token for request header
@ -120,6 +124,7 @@ def get_token(request: Request) -> str:
return token return token
@sync_to_async
def jwt_decode(token: str) -> tuple[int, list[int]]: def jwt_decode(token: str) -> tuple[int, list[int]]:
""" """
Decode token Decode token
@ -145,7 +150,7 @@ async def jwt_authentication(token: str) -> dict[str, int]:
:param token: :param token:
:return: :return:
""" """
user_id, _ = jwt_decode(token) user_id, _ = await jwt_decode(token)
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}' key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
token_verify = await redis_client.get(key) token_verify = await redis_client.get(key)
if not token_verify: if not token_verify:
@ -170,7 +175,8 @@ async def get_current_user(db: AsyncSession, data: dict) -> User:
return user return user
async def superuser_verify(request: Request) -> bool: @sync_to_async
def superuser_verify(request: Request) -> bool:
""" """
Verify the current user permissions through token Verify the current user permissions through token

View File

@ -3,6 +3,7 @@
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from asgiref.sync import sync_to_async
from pydantic import validate_arguments, BaseModel from pydantic import validate_arguments, BaseModel
from backend.app.utils.encoders import jsonable_encoder from backend.app.utils.encoders import jsonable_encoder
@ -51,19 +52,26 @@ class ResponseBase:
@router.get('/test') @router.get('/test')
def test(): def test():
return response_base.success(data={'test': 'test'}) return await response_base.success(data={'test': 'test'})
""" # noqa: E501 """ # noqa: E501
@staticmethod @staticmethod
@sync_to_async
def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs): def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs):
custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')} custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}
kwargs.update({'custom_encoder': custom_encoder}) kwargs.update({'custom_encoder': custom_encoder})
return jsonable_encoder(data, exclude=exclude, **kwargs) result = jsonable_encoder(data, exclude=exclude, **kwargs)
return result
@staticmethod
@validate_arguments @validate_arguments
def success( async def success(
*, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _ExcludeData | None = None, **kwargs self,
*,
code: int = 200,
msg: str = 'Success',
data: Any | None = None,
exclude: _ExcludeData | None = None,
**kwargs
) -> dict: ) -> dict:
""" """
请求成功返回通用方法 请求成功返回通用方法
@ -74,15 +82,20 @@ class ResponseBase:
:param exclude: 排除返回数据(data)字段 :param exclude: 排除返回数据(data)字段
:return: :return:
""" """
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs) data = data if data is None else await self.__json_encoder(data, exclude, **kwargs)
return {'code': code, 'msg': msg, 'data': data} return {'code': code, 'msg': msg, 'data': data}
@staticmethod
@validate_arguments @validate_arguments
def fail( async def fail(
*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _ExcludeData | None = None, **kwargs self,
*,
code: int = 400,
msg: str = 'Bad Request',
data: Any = None,
exclude: _ExcludeData | None = None,
**kwargs
) -> dict: ) -> dict:
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs) data = data if data is None else await self.__json_encoder(data, exclude, **kwargs)
return {'code': code, 'msg': msg, 'data': data} return {'code': code, 'msg': msg, 'data': data}

View File

@ -11,7 +11,9 @@ from backend.app.schemas.opera_log import CreateOperaLog, UpdateOperaLog
class CRUDOperaLogDao(CRUDBase[OperaLog, CreateOperaLog, UpdateOperaLog]): class CRUDOperaLogDao(CRUDBase[OperaLog, CreateOperaLog, UpdateOperaLog]):
async def get_all(self, username: str | None = None, status: bool | None = None, ipaddr: str | None = None) -> Select: async def get_all(
self, username: str | None = None, status: bool | None = None, ipaddr: str | None = None
) -> Select:
se = select(self.model).order_by(desc(self.model.create_time)) se = select(self.model).order_by(desc(self.model.create_time))
where_list = [] where_list = []
if username: if username:

View File

@ -28,7 +28,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
return user.rowcount return user.rowcount
async def create(self, db: AsyncSession, create: CreateUser) -> NoReturn: async def create(self, db: AsyncSession, create: CreateUser) -> NoReturn:
create.password = jwt.get_hash_password(create.password) create.password = await jwt.get_hash_password(create.password)
new_user = self.model(**create.dict(exclude={'roles'})) new_user = self.model(**create.dict(exclude={'roles'}))
role_list = [] role_list = []
for role_id in create.roles: for role_id in create.roles:
@ -63,7 +63,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int: async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int:
user = await db.execute( user = await db.execute(
update(self.model).where(self.model.id == pk).values(password=jwt.get_hash_password(password)) update(self.model).where(self.model.id == pk).values(password=await jwt.get_hash_password(password))
) )
return user.rowcount return user.rowcount

View File

@ -41,7 +41,7 @@ class InitTestData:
user_obj = User( user_obj = User(
username=username, username=username,
nickname=username, nickname=username,
password=get_hash_password(password), password=await get_hash_password(password),
email=email, email=email,
is_superuser=True, is_superuser=True,
dept_id=1, dept_id=1,
@ -70,7 +70,7 @@ class InitTestData:
user_obj = User( user_obj = User(
username=username, username=username,
nickname=username, nickname=username,
password=get_hash_password(password), password=await get_hash_password(password),
email=email, email=email,
is_superuser=True, is_superuser=True,
dept_id=1, dept_id=1,
@ -88,7 +88,7 @@ class InitTestData:
user_obj = User( user_obj = User(
username=username, username=username,
nickname=username, nickname=username,
password=get_hash_password(password), password=await get_hash_password(password),
email=email, email=email,
is_superuser=False, is_superuser=False,
dept_id=1, dept_id=1,
@ -106,7 +106,7 @@ class InitTestData:
user_obj = User( user_obj = User(
username=username, username=username,
nickname=username, nickname=username,
password=get_hash_password(password), password=await get_hash_password(password),
email=email, email=email,
is_active=False, is_active=False,
is_superuser=False, is_superuser=False,
@ -125,7 +125,7 @@ class InitTestData:
user_obj = User( user_obj = User(
username=username, username=username,
nickname=username, nickname=username,
password=get_hash_password(password), password=await get_hash_password(password),
email=email, email=email,
is_superuser=True, is_superuser=True,
dept_id=1, dept_id=1,
@ -143,7 +143,7 @@ class InitTestData:
user_obj = User( user_obj = User(
username=username, username=username,
nickname=username, nickname=username,
password=get_hash_password(password), password=await get_hash_password(password),
email=email, email=email,
is_active=False, is_active=False,
is_superuser=True, is_superuser=True,

View File

@ -43,7 +43,7 @@ class OperaLogMiddleware:
if settings.LOCATION_PARSE == 'online': if settings.LOCATION_PARSE == 'online':
location = await request_parse.get_location_online(ip, user_agent) location = await request_parse.get_location_online(ip, user_agent)
elif settings.LOCATION_PARSE == 'offline': elif settings.LOCATION_PARSE == 'offline':
location = request_parse.get_location_offline(ip) location = await request_parse.get_location_offline(ip)
else: else:
location = '未知' location = '未知'
try: try:

View File

@ -28,7 +28,7 @@ class AuthService:
current_user = await UserDao.get_by_username(db, form_data.username) current_user = await UserDao.get_by_username(db, form_data.username)
if not current_user: if not current_user:
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
elif not jwt.password_verify(form_data.password, current_user.password): elif not await jwt.password_verify(form_data.password, current_user.password):
raise errors.AuthorizationError(msg='密码错误') raise errors.AuthorizationError(msg='密码错误')
elif not current_user.is_active: elif not current_user.is_active:
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败') raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
@ -50,7 +50,7 @@ class AuthService:
current_user = await UserDao.get_by_username(db, obj.username) current_user = await UserDao.get_by_username(db, obj.username)
if not current_user: if not current_user:
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
elif not jwt.password_verify(obj.password, current_user.password): elif not await jwt.password_verify(obj.password, current_user.password):
raise errors.AuthorizationError(msg='密码错误') raise errors.AuthorizationError(msg='密码错误')
elif not current_user.is_active: elif not current_user.is_active:
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败') raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
@ -92,7 +92,7 @@ class AuthService:
@staticmethod @staticmethod
async def new_token(refresh_token: str) -> tuple[str, datetime]: async def new_token(refresh_token: str) -> tuple[str, datetime]:
user_id, role_ids = jwt.jwt_decode(refresh_token) user_id, role_ids = await jwt.jwt_decode(refresh_token)
async with async_db_session() as db: async with async_db_session() as db:
current_user = await UserDao.get(db, user_id) current_user = await UserDao.get(db, user_id)
if not current_user: if not current_user:
@ -106,7 +106,7 @@ class AuthService:
@staticmethod @staticmethod
async def logout(request: Request) -> NoReturn: async def logout(request: Request) -> NoReturn:
token = get_token(request) token = await get_token(request)
if request.user.is_multi_login: if request.user.is_multi_login:
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}' key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
await redis_client.delete(key) await redis_client.delete(key)

View File

@ -6,10 +6,8 @@ from typing import NoReturn
from fastapi import Request from fastapi import Request
from sqlalchemy import Select from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from user_agents import parse
from backend.app.common.log import log from backend.app.common.log import log
from backend.app.core.conf import settings
from backend.app.crud.crud_login_log import LoginLogDao from backend.app.crud.crud_login_log import LoginLogDao
from backend.app.database.db_mysql import async_db_session from backend.app.database.db_mysql import async_db_session
from backend.app.models import User from backend.app.models import User

View File

@ -135,8 +135,8 @@ class UserService:
raise errors.NotFoundError(msg='用户不存在') raise errors.NotFoundError(msg='用户不存在')
else: else:
count = await UserDao.set_multi_login(db, pk) count = await UserDao.set_multi_login(db, pk)
token = get_token(request) token = await get_token(request)
user_id, role_ids = jwt_decode(token) user_id, role_ids = await jwt_decode(token)
latest_multi_login = await UserDao.get_multi_login(db, pk) latest_multi_login = await UserDao.get_multi_login(db, pk)
# TODO: 删除用户 refresh token, 此操作需要传参,暂时不考虑实现 # TODO: 删除用户 refresh token, 此操作需要传参,暂时不考虑实现
# 当前用户修改自身时(普通/超级除当前token外其他token失效 # 当前用户修改自身时(普通/超级除当前token外其他token失效

View File

@ -4,7 +4,7 @@ import datetime
import uuid import uuid
def get_uuid() -> str: def get_uuid_str() -> str:
""" """
生成uuid 生成uuid

View File

@ -2,13 +2,15 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import httpx import httpx
from XdbSearchIP.xdbSearcher import XdbSearcher from XdbSearchIP.xdbSearcher import XdbSearcher
from asgiref.sync import sync_to_async
from httpx import HTTPError from httpx import HTTPError
from fastapi import Request from fastapi import Request
from backend.app.core.path_conf import IP2REGION_XDB from backend.app.core.path_conf import IP2REGION_XDB
async def get_request_ip(request: Request) -> str: @sync_to_async
def get_request_ip(request: Request) -> str:
"""获取请求的 ip 地址""" """获取请求的 ip 地址"""
real = request.headers.get('X-Real-IP') real = request.headers.get('X-Real-IP')
if real: if real:
@ -46,6 +48,7 @@ async def get_location_online(ipaddr: str, user_agent: str) -> str:
return city or '未知' if city != '' else '未知' return city or '未知' if city != '' else '未知'
@sync_to_async
def get_location_offline(ipaddr: str) -> str: def get_location_offline(ipaddr: str) -> str:
""" """
离线获取 ip 地址属地无法保证准确率100%可用 离线获取 ip 地址属地无法保证准确率100%可用

View File

@ -3,6 +3,7 @@ aioredis==2.0.1
aiosmtplib==1.1.6 aiosmtplib==1.1.6
alembic==1.7.4 alembic==1.7.4
APScheduler==3.8.1 APScheduler==3.8.1
asgiref==3.7.2
asynccasbin==1.1.8 asynccasbin==1.1.8
asyncmy==0.2.5 asyncmy==0.2.5
bcrypt==3.2.2 bcrypt==3.2.2