mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-18 15:00:46 +08:00
Add sync to async decorator support (#96)
* Add sync to async decorator support * Update ASyncTranslator to asgiref
This commit is contained in:
@ -18,7 +18,7 @@ router = APIRouter()
|
||||
@router.get('/{pk}', summary='获取接口详情', dependencies=[DependsJwtAuth])
|
||||
async def get_api(pk: int):
|
||||
api = await ApiService.get(pk=pk)
|
||||
return response_base.success(data=api)
|
||||
return await response_base.success(data=api)
|
||||
|
||||
|
||||
@router.get('', summary='(模糊条件)分页获取所有接口', dependencies=[DependsJwtAuth, PageDepends])
|
||||
@ -30,26 +30,26 @@ async def get_all_apis(
|
||||
):
|
||||
api_select = await ApiService.get_select(name=name, method=method, path=path)
|
||||
page_data = await paging_data(db, api_select, GetAllApi)
|
||||
return response_base.success(data=page_data)
|
||||
return await response_base.success(data=page_data)
|
||||
|
||||
|
||||
@router.post('', summary='创建接口', dependencies=[DependsRBAC])
|
||||
async def create_api(request: Request, obj: CreateApi):
|
||||
await ApiService.create(obj=obj, user_id=request.user.id)
|
||||
return response_base.success()
|
||||
return await response_base.success()
|
||||
|
||||
|
||||
@router.put('/{pk}', summary='更新接口', dependencies=[DependsRBAC])
|
||||
async def update_api(request: Request, pk: int, obj: UpdateApi):
|
||||
count = await ApiService.update(pk=pk, obj=obj, user_id=request.user.id)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.delete('', summary='(批量)删除接口', dependencies=[DependsRBAC])
|
||||
async def delete_api(pk: Annotated[list[int], Query(...)]):
|
||||
count = await ApiService.delete(pk=pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
@ -39,17 +39,17 @@ async def user_login(request: Request, obj: Auth, background_tasks: BackgroundTa
|
||||
refresh_token_expire_time=refresh_expire,
|
||||
user=user,
|
||||
)
|
||||
return response_base.success(data=data)
|
||||
return await response_base.success(data=data)
|
||||
|
||||
|
||||
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
|
||||
async def create_new_token(refresh_token: Annotated[str, Query(...)]):
|
||||
access_token, access_expire = await AuthService.new_token(refresh_token)
|
||||
data = GetNewToken(access_token=access_token, access_token_expire_time=access_expire)
|
||||
return response_base.success(data=data)
|
||||
return await response_base.success(data=data)
|
||||
|
||||
|
||||
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
|
||||
async def user_logout(request: Request):
|
||||
await AuthService.logout(request)
|
||||
return response_base.success()
|
||||
return await response_base.success()
|
||||
|
@ -12,7 +12,7 @@ router = APIRouter()
|
||||
|
||||
@router.get('', summary='获取系统配置', dependencies=[DependsRBAC])
|
||||
async def get_sys_config():
|
||||
return response_base.success(
|
||||
return await response_base.success(
|
||||
data={
|
||||
'title': settings.TITLE,
|
||||
'version': settings.VERSION,
|
||||
@ -59,4 +59,4 @@ async def get_all_route(request: Request):
|
||||
for route in request.app.routes:
|
||||
if isinstance(route, APIRoute):
|
||||
data.append({'path': route.path, 'name': route.name, 'summary': route.summary, 'methods': route.methods})
|
||||
return response_base.success(data={'route_list': data})
|
||||
return await response_base.success(data={'route_list': data})
|
||||
|
@ -24,20 +24,20 @@ async def get_all_login_logs(
|
||||
):
|
||||
log_select = await LoginLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
||||
page_data = await paging_data(db, log_select, GetAllLoginLog)
|
||||
return response_base.success(data=page_data)
|
||||
return await response_base.success(data=page_data)
|
||||
|
||||
|
||||
@router.delete('', summary='(批量)删除登录日志', dependencies=[DependsRBAC])
|
||||
async def delete_login_log(pk: Annotated[list[int], Query(...)]):
|
||||
count = await LoginLogService.delete(pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.delete('/all', summary='清空登录日志', dependencies=[DependsRBAC])
|
||||
async def delete_all_login_logs():
|
||||
count = await LoginLogService.delete_all()
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
@ -24,20 +24,20 @@ async def get_all_opera_logs(
|
||||
):
|
||||
log_select = await OperaLogService.get_select(username=username, status=status, ipaddr=ipaddr)
|
||||
page_data = await paging_data(db, log_select, GetAllOperaLog)
|
||||
return response_base.success(data=page_data)
|
||||
return await response_base.success(data=page_data)
|
||||
|
||||
|
||||
@router.delete('', summary='(批量)删除操作日志', dependencies=[DependsRBAC])
|
||||
async def delete_opera_log(pk: Annotated[list[int], Query(...)]):
|
||||
count = await OperaLogService.delete(pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.delete('/all', summary='清空操作日志', dependencies=[DependsRBAC])
|
||||
async def delete_all_opera_logs():
|
||||
count = await OperaLogService.delete_all()
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
@ -20,7 +20,7 @@ router = APIRouter()
|
||||
async def get_role(pk: int):
|
||||
role = await RoleService.get(pk=pk)
|
||||
data = GetAllRole(**select_to_json(role))
|
||||
return response_base.success(data=data)
|
||||
return await response_base.success(data=data)
|
||||
|
||||
|
||||
@router.get('', summary='(模糊条件)分页获取所有角色', dependencies=[DependsJwtAuth, PageDepends])
|
||||
@ -31,26 +31,26 @@ async def get_all_roles(
|
||||
):
|
||||
role_select = await RoleService.get_select(name=name, data_scope=data_scope)
|
||||
page_data = await paging_data(db, role_select, GetAllRole)
|
||||
return response_base.success(data=page_data)
|
||||
return await response_base.success(data=page_data)
|
||||
|
||||
|
||||
@router.post('', summary='创建角色', dependencies=[DependsRBAC])
|
||||
async def create_role(request: Request, obj: CreateRole):
|
||||
await RoleService.create(obj=obj, user_id=request.user.id)
|
||||
return response_base.success()
|
||||
return await response_base.success()
|
||||
|
||||
|
||||
@router.put('/{pk}', summary='更新角色', dependencies=[DependsRBAC])
|
||||
async def update_role(request: Request, pk: int, obj: UpdateRole):
|
||||
count = await RoleService.update(pk=pk, obj=obj, user_id=request.user.id)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.delete('', summary='(批量)删除角色', dependencies=[DependsRBAC])
|
||||
async def delete_role(pk: Annotated[list[int], Query(...)]):
|
||||
count = await RoleService.delete(pk=pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
@ -18,38 +18,38 @@ router = APIRouter()
|
||||
@router.post('/register', summary='用户注册')
|
||||
async def user_register(obj: CreateUser):
|
||||
await UserService.register(obj)
|
||||
return response_base.success()
|
||||
return await response_base.success()
|
||||
|
||||
|
||||
@router.post('/password/reset', summary='密码重置')
|
||||
async def password_reset(obj: ResetPassword):
|
||||
count = await UserService.pwd_reset(obj)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.get('/{username}', summary='查看用户信息', dependencies=[DependsJwtAuth])
|
||||
async def get_user(username: str):
|
||||
current_user = await UserService.get_userinfo(username)
|
||||
data = GetAllUserInfo(**select_to_json(current_user))
|
||||
return response_base.success(data=data)
|
||||
return await response_base.success(data=data)
|
||||
|
||||
|
||||
@router.put('/{username}', summary='更新用户信息', dependencies=[DependsJwtAuth])
|
||||
async def update_userinfo(request: Request, username: str, obj: UpdateUser):
|
||||
count = await UserService.update(request=request, username=username, obj=obj)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.put('/{username}/avatar', summary='更新头像', dependencies=[DependsJwtAuth])
|
||||
async def update_avatar(request: Request, username: str, avatar: Avatar):
|
||||
count = await UserService.update_avatar(request=request, username=username, avatar=avatar)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.get('', summary='(模糊条件)分页获取所有用户', dependencies=[DependsJwtAuth, PageDepends])
|
||||
@ -61,31 +61,31 @@ async def get_all_users(
|
||||
):
|
||||
user_select = await UserService.get_select(username=username, phone=phone, status=status)
|
||||
page_data = await paging_data(db, user_select, GetAllUserInfo)
|
||||
return response_base.success(data=page_data)
|
||||
return await response_base.success(data=page_data)
|
||||
|
||||
|
||||
@router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsJwtAuth])
|
||||
async def super_set(request: Request, pk: int):
|
||||
count = await UserService.update_permission(request=request, pk=pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.post('/{pk}/action', summary='修改用户状态', dependencies=[DependsJwtAuth])
|
||||
async def active_set(request: Request, pk: int):
|
||||
count = await UserService.update_active(request=request, pk=pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.post('/{pk}/multi', summary='修改用户多点登录状态', dependencies=[DependsJwtAuth])
|
||||
async def multi_set(request: Request, pk: int):
|
||||
count = await UserService.update_multi_login(request=request, pk=pk)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
||||
|
||||
@router.delete(
|
||||
@ -97,5 +97,5 @@ async def multi_set(request: Request, pk: int):
|
||||
async def delete_user(request: Request, username: str):
|
||||
count = await UserService.delete(request=request, username=username)
|
||||
if count > 0:
|
||||
return response_base.success()
|
||||
return response_base.fail()
|
||||
return await response_base.success()
|
||||
return await response_base.fail()
|
||||
|
@ -37,7 +37,7 @@ def _get_exception_code(status_code):
|
||||
|
||||
def register_exception(app: FastAPI):
|
||||
@app.exception_handler(HTTPException)
|
||||
def http_exception_handler(request: Request, exc: HTTPException):
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""
|
||||
全局HTTP异常处理
|
||||
|
||||
@ -47,12 +47,12 @@ def register_exception(app: FastAPI):
|
||||
"""
|
||||
return JSONResponse(
|
||||
status_code=_get_exception_code(exc.status_code),
|
||||
content=response_base.fail(code=exc.status_code, msg=exc.detail),
|
||||
content=await response_base.fail(code=exc.status_code, msg=exc.detail),
|
||||
headers=exc.headers,
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""
|
||||
数据验证异常处理
|
||||
|
||||
@ -84,7 +84,7 @@ def register_exception(app: FastAPI):
|
||||
message += 'json解析失败'
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content=response_base.fail(
|
||||
content=await response_base.fail(
|
||||
code=422,
|
||||
msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}',
|
||||
data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None,
|
||||
@ -92,7 +92,7 @@ def register_exception(app: FastAPI):
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
def all_exception_handler(request: Request, exc: Exception):
|
||||
async def all_exception_handler(request: Request, exc: Exception):
|
||||
"""
|
||||
全局异常处理
|
||||
|
||||
@ -103,14 +103,14 @@ def register_exception(app: FastAPI):
|
||||
if isinstance(exc, BaseExceptionMixin):
|
||||
return JSONResponse(
|
||||
status_code=_get_exception_code(exc.code),
|
||||
content=response_base.fail(code=exc.code, msg=str(exc.msg), data=exc.data if exc.data else None),
|
||||
content=await response_base.fail(code=exc.code, msg=str(exc.msg), data=exc.data if exc.data else None),
|
||||
background=exc.background,
|
||||
)
|
||||
|
||||
elif isinstance(exc, AssertionError):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=response_base.fail(
|
||||
content=await response_base.fail(
|
||||
code=500,
|
||||
msg=','.join(exc.args)
|
||||
if exc.args
|
||||
@ -119,14 +119,14 @@ def register_exception(app: FastAPI):
|
||||
else exc.__doc__,
|
||||
)
|
||||
if settings.ENVIRONMENT == 'dev'
|
||||
else response_base.fail(code=500, msg='Internal Server Error'),
|
||||
else await response_base.fail(code=500, msg='Internal Server Error'),
|
||||
)
|
||||
|
||||
else:
|
||||
log.error(exc)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=response_base.fail(code=500, msg=str(exc))
|
||||
content=await response_base.fail(code=500, msg=str(exc))
|
||||
if settings.ENVIRONMENT == 'dev'
|
||||
else response_base.fail(code=500, msg='Internal Server Error'),
|
||||
else await response_base.fail(code=500, msg='Internal Server Error'),
|
||||
)
|
||||
|
@ -2,6 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
@ -21,6 +22,7 @@ pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
||||
oauth2_schema = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL_SWAGGER)
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def get_hash_password(password: str) -> str:
|
||||
"""
|
||||
Encrypt passwords using the hash algorithm
|
||||
@ -31,6 +33,7 @@ def get_hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def password_verify(plain_password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
Password verification
|
||||
@ -107,6 +110,7 @@ async def create_new_token(sub: str, refresh_token: str, **kwargs) -> tuple[str,
|
||||
return new_token, expire
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def get_token(request: Request) -> str:
|
||||
"""
|
||||
Get token for request header
|
||||
@ -120,6 +124,7 @@ def get_token(request: Request) -> str:
|
||||
return token
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def jwt_decode(token: str) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Decode token
|
||||
@ -145,7 +150,7 @@ async def jwt_authentication(token: str) -> dict[str, int]:
|
||||
:param token:
|
||||
:return:
|
||||
"""
|
||||
user_id, _ = jwt_decode(token)
|
||||
user_id, _ = await jwt_decode(token)
|
||||
key = f'{settings.TOKEN_REDIS_PREFIX}:{user_id}:{token}'
|
||||
token_verify = await redis_client.get(key)
|
||||
if not token_verify:
|
||||
@ -170,7 +175,8 @@ async def get_current_user(db: AsyncSession, data: dict) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def superuser_verify(request: Request) -> bool:
|
||||
@sync_to_async
|
||||
def superuser_verify(request: Request) -> bool:
|
||||
"""
|
||||
Verify the current user permissions through token
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from pydantic import validate_arguments, BaseModel
|
||||
|
||||
from backend.app.utils.encoders import jsonable_encoder
|
||||
@ -51,19 +52,26 @@ class ResponseBase:
|
||||
|
||||
@router.get('/test')
|
||||
def test():
|
||||
return response_base.success(data={'test': 'test'})
|
||||
return await response_base.success(data={'test': 'test'})
|
||||
""" # noqa: E501
|
||||
|
||||
@staticmethod
|
||||
@sync_to_async
|
||||
def __json_encoder(data: Any, exclude: _ExcludeData | None = None, **kwargs):
|
||||
custom_encoder = {datetime: lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}
|
||||
kwargs.update({'custom_encoder': custom_encoder})
|
||||
return jsonable_encoder(data, exclude=exclude, **kwargs)
|
||||
result = jsonable_encoder(data, exclude=exclude, **kwargs)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@validate_arguments
|
||||
def success(
|
||||
*, code: int = 200, msg: str = 'Success', data: Any | None = None, exclude: _ExcludeData | None = None, **kwargs
|
||||
async def success(
|
||||
self,
|
||||
*,
|
||||
code: int = 200,
|
||||
msg: str = 'Success',
|
||||
data: Any | None = None,
|
||||
exclude: _ExcludeData | None = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
请求成功返回通用方法
|
||||
@ -74,15 +82,20 @@ class ResponseBase:
|
||||
:param exclude: 排除返回数据(data)字段
|
||||
:return:
|
||||
"""
|
||||
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs)
|
||||
data = data if data is None else await self.__json_encoder(data, exclude, **kwargs)
|
||||
return {'code': code, 'msg': msg, 'data': data}
|
||||
|
||||
@staticmethod
|
||||
@validate_arguments
|
||||
def fail(
|
||||
*, code: int = 400, msg: str = 'Bad Request', data: Any = None, exclude: _ExcludeData | None = None, **kwargs
|
||||
async def fail(
|
||||
self,
|
||||
*,
|
||||
code: int = 400,
|
||||
msg: str = 'Bad Request',
|
||||
data: Any = None,
|
||||
exclude: _ExcludeData | None = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
data = data if data is None else ResponseBase.__json_encoder(data, exclude, **kwargs)
|
||||
data = data if data is None else await self.__json_encoder(data, exclude, **kwargs)
|
||||
return {'code': code, 'msg': msg, 'data': data}
|
||||
|
||||
|
||||
|
@ -11,7 +11,9 @@ from backend.app.schemas.opera_log import CreateOperaLog, UpdateOperaLog
|
||||
|
||||
|
||||
class CRUDOperaLogDao(CRUDBase[OperaLog, CreateOperaLog, UpdateOperaLog]):
|
||||
async def get_all(self, username: str | None = None, status: bool | None = None, ipaddr: str | None = None) -> Select:
|
||||
async def get_all(
|
||||
self, username: str | None = None, status: bool | None = None, ipaddr: str | None = None
|
||||
) -> Select:
|
||||
se = select(self.model).order_by(desc(self.model.create_time))
|
||||
where_list = []
|
||||
if username:
|
||||
|
@ -28,7 +28,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
|
||||
return user.rowcount
|
||||
|
||||
async def create(self, db: AsyncSession, create: CreateUser) -> NoReturn:
|
||||
create.password = jwt.get_hash_password(create.password)
|
||||
create.password = await jwt.get_hash_password(create.password)
|
||||
new_user = self.model(**create.dict(exclude={'roles'}))
|
||||
role_list = []
|
||||
for role_id in create.roles:
|
||||
@ -63,7 +63,7 @@ class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
|
||||
|
||||
async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int:
|
||||
user = await db.execute(
|
||||
update(self.model).where(self.model.id == pk).values(password=jwt.get_hash_password(password))
|
||||
update(self.model).where(self.model.id == pk).values(password=await jwt.get_hash_password(password))
|
||||
)
|
||||
return user.rowcount
|
||||
|
||||
|
@ -41,7 +41,7 @@ class InitTestData:
|
||||
user_obj = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
password=get_hash_password(password),
|
||||
password=await get_hash_password(password),
|
||||
email=email,
|
||||
is_superuser=True,
|
||||
dept_id=1,
|
||||
@ -70,7 +70,7 @@ class InitTestData:
|
||||
user_obj = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
password=get_hash_password(password),
|
||||
password=await get_hash_password(password),
|
||||
email=email,
|
||||
is_superuser=True,
|
||||
dept_id=1,
|
||||
@ -88,7 +88,7 @@ class InitTestData:
|
||||
user_obj = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
password=get_hash_password(password),
|
||||
password=await get_hash_password(password),
|
||||
email=email,
|
||||
is_superuser=False,
|
||||
dept_id=1,
|
||||
@ -106,7 +106,7 @@ class InitTestData:
|
||||
user_obj = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
password=get_hash_password(password),
|
||||
password=await get_hash_password(password),
|
||||
email=email,
|
||||
is_active=False,
|
||||
is_superuser=False,
|
||||
@ -125,7 +125,7 @@ class InitTestData:
|
||||
user_obj = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
password=get_hash_password(password),
|
||||
password=await get_hash_password(password),
|
||||
email=email,
|
||||
is_superuser=True,
|
||||
dept_id=1,
|
||||
@ -143,7 +143,7 @@ class InitTestData:
|
||||
user_obj = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
password=get_hash_password(password),
|
||||
password=await get_hash_password(password),
|
||||
email=email,
|
||||
is_active=False,
|
||||
is_superuser=True,
|
||||
|
@ -43,7 +43,7 @@ class OperaLogMiddleware:
|
||||
if settings.LOCATION_PARSE == 'online':
|
||||
location = await request_parse.get_location_online(ip, user_agent)
|
||||
elif settings.LOCATION_PARSE == 'offline':
|
||||
location = request_parse.get_location_offline(ip)
|
||||
location = await request_parse.get_location_offline(ip)
|
||||
else:
|
||||
location = '未知'
|
||||
try:
|
||||
|
@ -28,7 +28,7 @@ class AuthService:
|
||||
current_user = await UserDao.get_by_username(db, form_data.username)
|
||||
if not current_user:
|
||||
raise errors.NotFoundError(msg='用户不存在')
|
||||
elif not jwt.password_verify(form_data.password, current_user.password):
|
||||
elif not await jwt.password_verify(form_data.password, current_user.password):
|
||||
raise errors.AuthorizationError(msg='密码错误')
|
||||
elif not current_user.is_active:
|
||||
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
||||
@ -50,7 +50,7 @@ class AuthService:
|
||||
current_user = await UserDao.get_by_username(db, obj.username)
|
||||
if not current_user:
|
||||
raise errors.NotFoundError(msg='用户不存在')
|
||||
elif not jwt.password_verify(obj.password, current_user.password):
|
||||
elif not await jwt.password_verify(obj.password, current_user.password):
|
||||
raise errors.AuthorizationError(msg='密码错误')
|
||||
elif not current_user.is_active:
|
||||
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
||||
@ -92,7 +92,7 @@ class AuthService:
|
||||
|
||||
@staticmethod
|
||||
async def new_token(refresh_token: str) -> tuple[str, datetime]:
|
||||
user_id, role_ids = jwt.jwt_decode(refresh_token)
|
||||
user_id, role_ids = await jwt.jwt_decode(refresh_token)
|
||||
async with async_db_session() as db:
|
||||
current_user = await UserDao.get(db, user_id)
|
||||
if not current_user:
|
||||
@ -106,7 +106,7 @@ class AuthService:
|
||||
|
||||
@staticmethod
|
||||
async def logout(request: Request) -> NoReturn:
|
||||
token = get_token(request)
|
||||
token = await get_token(request)
|
||||
if request.user.is_multi_login:
|
||||
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
|
||||
await redis_client.delete(key)
|
||||
|
@ -6,10 +6,8 @@ from typing import NoReturn
|
||||
from fastapi import Request
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from user_agents import parse
|
||||
|
||||
from backend.app.common.log import log
|
||||
from backend.app.core.conf import settings
|
||||
from backend.app.crud.crud_login_log import LoginLogDao
|
||||
from backend.app.database.db_mysql import async_db_session
|
||||
from backend.app.models import User
|
||||
|
@ -135,8 +135,8 @@ class UserService:
|
||||
raise errors.NotFoundError(msg='用户不存在')
|
||||
else:
|
||||
count = await UserDao.set_multi_login(db, pk)
|
||||
token = get_token(request)
|
||||
user_id, role_ids = jwt_decode(token)
|
||||
token = await get_token(request)
|
||||
user_id, role_ids = await jwt_decode(token)
|
||||
latest_multi_login = await UserDao.get_multi_login(db, pk)
|
||||
# TODO: 删除用户 refresh token, 此操作需要传参,暂时不考虑实现
|
||||
# 当前用户修改自身时(普通/超级),除当前token外,其他token失效
|
||||
|
@ -4,7 +4,7 @@ import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
def get_uuid() -> str:
|
||||
def get_uuid_str() -> str:
|
||||
"""
|
||||
生成uuid
|
||||
|
@ -2,13 +2,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import httpx
|
||||
from XdbSearchIP.xdbSearcher import XdbSearcher
|
||||
from asgiref.sync import sync_to_async
|
||||
from httpx import HTTPError
|
||||
from fastapi import Request
|
||||
|
||||
from backend.app.core.path_conf import IP2REGION_XDB
|
||||
|
||||
|
||||
async def get_request_ip(request: Request) -> str:
|
||||
@sync_to_async
|
||||
def get_request_ip(request: Request) -> str:
|
||||
"""获取请求的 ip 地址"""
|
||||
real = request.headers.get('X-Real-IP')
|
||||
if real:
|
||||
@ -46,6 +48,7 @@ async def get_location_online(ipaddr: str, user_agent: str) -> str:
|
||||
return city or '未知' if city != '' else '未知'
|
||||
|
||||
|
||||
@sync_to_async
|
||||
def get_location_offline(ipaddr: str) -> str:
|
||||
"""
|
||||
离线获取 ip 地址属地,无法保证准确率,100%可用
|
||||
|
@ -3,6 +3,7 @@ aioredis==2.0.1
|
||||
aiosmtplib==1.1.6
|
||||
alembic==1.7.4
|
||||
APScheduler==3.8.1
|
||||
asgiref==3.7.2
|
||||
asynccasbin==1.1.8
|
||||
asyncmy==0.2.5
|
||||
bcrypt==3.2.2
|
||||
|
Reference in New Issue
Block a user