fix that the data validation global exception handler does not work (#40)

* fix that the data validation global exception handler does not work

* update login api test

* update the JSON login method to create the user ID of the token
This commit is contained in:
Wu Clan
2023-05-13 01:29:15 +08:00
committed by GitHub
parent 768a13f014
commit 24024d4bf0
7 changed files with 66 additions and 49 deletions

View File

@ -26,9 +26,11 @@ class UserService:
raise errors.AuthorizationError(msg='该用户已被锁定,无法登录') raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
# 更新登陆时间 # 更新登陆时间
await UserDao.update_user_login_time(db, form_data.username) await UserDao.update_user_login_time(db, form_data.username)
# 获取最新用户信息
user = await UserDao.get_user_by_id(db, current_user.id)
# 创建token # 创建token
access_token = jwt.create_access_token(current_user.id) access_token = jwt.create_access_token(user.id)
return access_token, current_user.is_superuser return access_token, user
# @staticmethod # @staticmethod
# async def login(obj: Auth): # async def login(obj: Auth):
@ -42,9 +44,11 @@ class UserService:
# raise errors.AuthorizationError(msg='该用户已被锁定,无法登录') # raise errors.AuthorizationError(msg='该用户已被锁定,无法登录')
# # 更新登陆时间 # # 更新登陆时间
# await UserDao.update_user_login_time(db, obj.username) # await UserDao.update_user_login_time(db, obj.username)
# # 获取最新用户信息
# user = await UserDao.get_user_by_id(db, current_user.id)
# # 创建token # # 创建token
# access_token = jwt.create_access_token(current_user.id) # access_token = jwt.create_access_token(user.id)
# return access_token, current_user.is_superuser # return access_token, user
@staticmethod @staticmethod
async def register(obj: CreateUser): async def register(obj: CreateUser):

View File

@ -13,17 +13,18 @@ from backend.app.schemas.user import CreateUser, GetUserInfo, ResetPassword, Upd
router = APIRouter() router = APIRouter()
@router.post('/login', summary='表单登录', response_model=Token, description='form 格式登录支持直接在 api 文档调试接口') @router.post('/login', summary='表单登录', description='form 格式登录支持直接在 api 文档调试接口')
async def user_login(form_data: OAuth2PasswordRequestForm = Depends()): async def user_login(form_data: OAuth2PasswordRequestForm = Depends()):
token, is_super = await UserService.login(form_data) token, user = await UserService.login(form_data)
return Token(access_token=token, is_superuser=is_super) data = Token(access_token=token, user=user)
return response_base.response_200(data=data)
# @router.post('/login', summary='用户登录', response_model=Token, # @router.post('/login', summary='用户登录', description='json 格式登录, 仅支持在第三方api工具调试接口, 例如: postman')
# description='json 格式登录, 不支持api文档接口调试, 需使用第三方api工具, 例如: postman')
# async def user_login(obj: Auth): # async def user_login(obj: Auth):
# token, is_super = await UserService.login(obj) # token, user = await UserService.login(obj)
# return Token(access_token=token, is_superuser=is_super) # data = Token(access_token=token, user=user)
# return response_base.response_200(data=data)
@router.post('/register', summary='用户注册') @router.post('/register', summary='用户注册')

View File

@ -50,17 +50,15 @@ def register_exception(app: FastAPI):
headers=exc.headers, headers=exc.headers,
) )
@app.exception_handler(Exception) @app.exception_handler(RequestValidationError)
def all_exception_handler(request: Request, exc): def validation_exception_handler(request: Request, exc: RequestValidationError):
""" """
全局异常处理 数据验证异常处理
:param request: :param request:
:param exc: :param exc:
:return: :return:
""" """
# 常规
if isinstance(exc, RequestValidationError):
message = '' message = ''
data = {} data = {}
for raw_error in exc.raw_errors: for raw_error in exc.raw_errors:
@ -71,21 +69,36 @@ def register_exception(app: FastAPI):
for field_key in fields.keys(): for field_key in fields.keys():
field_title = fields.get(field_key).field_info.title field_title = fields.get(field_key).field_info.title
data[field_key] = field_title if field_title else field_key data[field_key] = field_title if field_title else field_key
errors_len = len(exc.errors())
for error in exc.errors(): for error in exc.errors():
field = str(error.get('loc')[-1]) field = str(error.get('loc')[-1])
_msg = error.get('msg') _msg = error.get('msg')
message += f'{data.get(field, field)} {_msg},' errors_len = errors_len - 1
message += (
f'{data.get(field, field)} {_msg}' + ', '
if errors_len > 0
else f'{data.get(field, field)} {_msg}'
)
elif isinstance(raw_error.exc, json.JSONDecodeError): elif isinstance(raw_error.exc, json.JSONDecodeError):
message += 'json解析失败' message += 'json解析失败'
return JSONResponse( return JSONResponse(
status_code=422, status_code=422,
content=response_base.fail( content=response_base.fail(
msg='请求参数非法' if len(message) == 0 else f'请求参数非法:{message[:-1]}', code=422,
msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message[:-1]}',
data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None, data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None,
), ),
) )
# 自定义 @app.exception_handler(Exception)
def all_exception_handler(request: Request, exc: Exception):
"""
全局异常处理
:param request:
:param exc:
:return:
"""
if isinstance(exc, BaseExceptionMixin): if isinstance(exc, BaseExceptionMixin):
return JSONResponse( return JSONResponse(
status_code=_get_exception_code(exc.code), status_code=_get_exception_code(exc.code),

View File

@ -1,12 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from pydantic import BaseModel from pydantic import BaseModel
from backend.app.schemas.user import GetUserInfo
class Token(BaseModel): class Token(BaseModel):
code: int = 200
msg: str = 'Success'
access_token: str access_token: str
token_type: str = 'Bearer' token_type: str = 'Bearer'
is_superuser: bool | None = None user: GetUserInfo

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime from datetime import datetime
from pydantic import BaseModel, Field, HttpUrl from pydantic import BaseModel, Field, HttpUrl
@ -28,10 +28,10 @@ class GetUserInfo(UpdateUser):
id: int id: int
uid: str uid: str
avatar: str | None = None avatar: str | None = None
time_joined: datetime.datetime = None
last_login: datetime.datetime | None = None
is_superuser: bool
is_active: bool is_active: bool
is_superuser: bool
time_joined: datetime = None
last_login: datetime | None = None
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -25,7 +25,7 @@ async def function_fixture(anyio_backend):
} }
async with AsyncClient() as client: async with AsyncClient() as client:
response = await client.post(**auth_data) response = await client.post(**auth_data)
token = response.json()['access_token'] token = response.json()['data']['access_token']
test_token = await redis_client.get('test_token') test_token = await redis_client.get('test_token')
if not test_token: if not test_token:
await redis_client.set('test_token', token, ex=86400) await redis_client.set('test_token', token, ex=86400)

View File

@ -31,7 +31,7 @@ class TestAuth:
url=f'{self.users_api_base_url}/login', data={'username': '1', 'password': '1'} url=f'{self.users_api_base_url}/login', data={'username': '1', 'password': '1'}
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()['token_type'] == 'Bearer' assert response.json()['data']['token_type'] == 'Bearer'
async def test_register(self): async def test_register(self):
async with AsyncClient( async with AsyncClient(