diff --git a/README.md b/README.md index 6a8979f..8771475 100644 --- a/README.md +++ b/README.md @@ -84,15 +84,21 @@ Execute the `backend/app/init_test_data.py` file Perform tests via pytest -**Tip**: Before the test starts, please execute init the test data first, also, the fastapi service needs to be started +1. Create a database `fba_test`, choose utf8mb4 encode -1. First, go to the app directory +2. First, go to the app directory ```shell cd backend/app/ ``` + +3. Init the test data -2. Execute the test command + ```shell + python tests/init_test_data.py + ``` + +4. Execute the test command ```shell pytest -vs --disable-warnings diff --git a/README.zh-CN.md b/README.zh-CN.md index 9bb1d13..6fac221 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -77,15 +77,21 @@ git clone https://github.com/wu-clan/fastapi_best_architecture.git 通过 pytest 进行测试 -**提示**: 在测试开始前,请先执行初始化测试数据,同时,需要启动 fastapi 服务。 +1. 创建一个数据库`fba_test`,选择 utf8mb4 编码 -1. 首先,进入app目录 +2. 首先,进入app目录 ```shell cd backend/app/ ``` + +3. 初始化测试数据 -2. 执行测试命令 + ```shell + python tests/init_test_data.py + ``` + +4. 执行测试命令 ```shell pytest -vs --disable-warnings diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000..56fafa5 --- /dev/null +++ b/backend/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/backend/app/api/v1/auth/__init__.py b/backend/app/api/v1/auth/__init__.py index 2ceb01e..e4d7a75 100644 --- a/backend/app/api/v1/auth/__init__.py +++ b/backend/app/api/v1/auth/__init__.py @@ -5,4 +5,4 @@ from backend.app.api.v1.auth.auth import router as auth_router router = APIRouter(prefix='/auth', tags=['认证']) -router.include_router(auth_router, prefix='/users') +router.include_router(auth_router) diff --git a/backend/app/core/conf.py b/backend/app/core/conf.py index b1e602b..91ad89b 100644 --- a/backend/app/core/conf.py +++ b/backend/app/core/conf.py @@ -32,12 +32,13 @@ class Settings(BaseSettings): TOKEN_WHITE_LIST: list[str] # 白名单用户ID,可多点登录 # FastAPI + API_V1_STR: str = '/v1' TITLE: str = 'FastAPI' VERSION: str = '0.0.1' DESCRIPTION: str = 'FastAPI Best Architecture' - DOCS_URL: str | None = '/v1/docs' - REDOCS_URL: str | None = '/v1/redocs' - OPENAPI_URL: str | None = '/v1/openapi' + DOCS_URL: str | None = f'{API_V1_STR}/docs' + REDOCS_URL: str | None = f'{API_V1_STR}/redocs' + OPENAPI_URL: str | None = f'{API_V1_STR}/openapi' @root_validator def validator_api_url(cls, values): @@ -54,7 +55,7 @@ class Settings(BaseSettings): STATIC_FILES: bool = False # MySQL - DB_ECHO: bool = True + DB_ECHO: bool = False DB_DATABASE: str = 'fba' DB_CHARSET: str = 'utf8mb4' @@ -72,7 +73,7 @@ class Settings(BaseSettings): # Token TOKEN_ALGORITHM: str = 'HS256' # 算法 TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒 - TOKEN_URL_SWAGGER: str = '/v1/auth/users/swagger_login' + TOKEN_URL_SWAGGER: str = f'{API_V1_STR}/auth/swagger_login' TOKEN_REDIS_PREFIX: str = 'fba_token' # Log @@ -86,10 +87,10 @@ class Settings(BaseSettings): # Casbin CASBIN_RBAC_MODEL_NAME: str = 'rbac_model.conf' CASBIN_EXCLUDE: list[dict[str, str], dict[str, str]] = [ - {'method': 'POST', 'path': '/api/v1/auth/users/swagger_login'}, - {'method': 'POST', 'path': '/api/v1/auth/users/login'}, - {'method': 'POST', 'path': '/api/v1/auth/users/register'}, - {'method': 'POST', 'path': '/api/v1/auth/users/password/reset'}, + {'method': 'POST', 'path': '/api/v1/auth/swagger_login'}, + {'method': 'POST', 'path': '/api/v1/auth/login'}, + {'method': 'POST', 'path': '/api/v1/auth/register'}, + {'method': 'POST', 'path': '/api/v1/auth/password/reset'}, ] class Config: diff --git a/backend/app/database/db_mysql.py b/backend/app/database/db_mysql.py index ecc2a63..4313626 100644 --- a/backend/app/database/db_mysql.py +++ b/backend/app/database/db_mysql.py @@ -3,6 +3,7 @@ import sys from fastapi import Depends +from sqlalchemy import URL from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from typing_extensions import Annotated @@ -14,20 +15,26 @@ from backend.app.database.base_class import MappedBase 说明:SqlAlchemy """ + +def create_engine_and_session(url: str | URL): + try: + # 数据库引擎 + engine = create_async_engine(url, echo=settings.DB_ECHO, future=True, pool_pre_ping=True) + # log.success('数据库连接成功') + except Exception as e: + log.error('❌ 数据库链接失败 {}', e) + sys.exit() + else: + db_session = async_sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) + return engine, db_session + + SQLALCHEMY_DATABASE_URL = ( f'mysql+asyncmy://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:' f'{settings.DB_PORT}/{settings.DB_DATABASE}?charset={settings.DB_CHARSET}' ) -try: - # 数据库引擎 - async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=settings.DB_ECHO, future=True, pool_pre_ping=True) - # log.success('数据库连接成功') -except Exception as e: - log.error('❌ 数据库链接失败 {}', e) - sys.exit() -else: - async_db_session = async_sessionmaker(bind=async_engine, autoflush=False, expire_on_commit=False) +async_engine, async_db_session = create_engine_and_session(SQLALCHEMY_DATABASE_URL) async def get_db() -> AsyncSession: diff --git a/backend/app/init_test_data.py b/backend/app/init_test_data.py index 3d9bab0..1802715 100644 --- a/backend/app/init_test_data.py +++ b/backend/app/init_test_data.py @@ -14,28 +14,26 @@ from backend.app.models import User, Role, Menu, Dept class InitTestData: """初始化测试数据""" - def __init__(self): + def __init__(self, session): self.fake = Faker('zh_CN') + self.session = session - @staticmethod - async def create_dept(): + async def create_dept(self): """自动创建部门""" - async with async_db_session.begin() as db: + async with self.session.begin() as db: department_obj = Dept(name='test', create_user=1) db.add(department_obj) log.info('部门 test 创建成功') - @staticmethod - async def create_role(): + async def create_role(self): """自动创建角色""" - async with async_db_session.begin() as db: + async with self.session.begin() as db: role_obj = Role(name='test', create_user=1) role_obj.menus.append(Menu(name='test', create_user=1)) db.add(role_obj) log.info('角色 test 创建成功') - @staticmethod - async def create_test_user(): + async def create_test_user(self): """创建测试用户""" username = 'test' password = 'test' @@ -48,13 +46,12 @@ class InitTestData: is_superuser=True, dept_id=1, ) - async with async_db_session.begin() as db: + async with self.session.begin() as db: user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'测试用户创建成功,账号:{username},密码:{password}') - @staticmethod - async def create_superuser_by_yourself(): + async def create_superuser_by_yourself(self): """手动创建管理员账户""" log.info('开始创建自定义管理员用户') print('请输入用户名:') @@ -78,7 +75,7 @@ class InitTestData: is_superuser=True, dept_id=1, ) - async with async_db_session.begin() as db: + async with self.session.begin() as db: user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'自定义管理员用户创建成功,账号:{username},密码:{password}') @@ -96,7 +93,7 @@ class InitTestData: is_superuser=False, dept_id=1, ) - async with async_db_session.begin() as db: + async with self.session.begin() as db: user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'普通用户创建成功,账号:{username},密码:{password}') @@ -115,7 +112,7 @@ class InitTestData: is_superuser=False, dept_id=1, ) - async with async_db_session.begin() as db: + async with self.session.begin() as db: user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'普通锁定用户创建成功,账号:{username},密码:{password}') @@ -133,7 +130,7 @@ class InitTestData: is_superuser=True, dept_id=1, ) - async with async_db_session.begin() as db: + async with self.session.begin() as db: user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'管理员用户创建成功,账号:{username},密码:{password}') @@ -152,7 +149,7 @@ class InitTestData: is_superuser=True, dept_id=1, ) - async with async_db_session.begin() as db: + async with self.session.begin() as db: user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'管理员锁定用户创建成功,账号:{username},密码:{password}') @@ -172,6 +169,6 @@ class InitTestData: if __name__ == '__main__': - init = InitTestData() + init = InitTestData(session=async_db_session) loop = asyncio.get_event_loop() loop.run_until_complete(init.init_data()) diff --git a/backend/app/tests/api_v1/__init__.py b/backend/app/tests/api_v1/__init__.py new file mode 100644 index 0000000..56fafa5 --- /dev/null +++ b/backend/app/tests/api_v1/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/backend/app/tests/api_v1/test_auth.py b/backend/app/tests/api_v1/test_auth.py new file mode 100644 index 0000000..3b24633 --- /dev/null +++ b/backend/app/tests/api_v1/test_auth.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from starlette.testclient import TestClient + +from backend.app.core.conf import settings + + +def test_login(client: TestClient) -> None: + data = { + 'username': 'test', + 'password': 'test', + } + response = client.post(f'{settings.API_V1_STR}/auth/login', json=data) + assert response.status_code == 200 + assert response.json()['data']['access_token_type'] == 'Bearer' + + +def test_logout(client: TestClient, token_headers: dict[str, str]) -> None: + response = client.post(f'{settings.API_V1_STR}/auth/logout', headers=token_headers) + assert response.status_code == 200 + assert response.json()['code'] == 200 diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 1113405..f412605 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -2,30 +2,27 @@ # -*- coding: utf-8 -*- import sys -import pytest -from httpx import AsyncClient - sys.path.append('../../') -from backend.app.common.redis import redis_client # noqa: E402 -from backend.app.core.conf import settings # noqa: E402 +import pytest +from typing import Generator, Dict + +from starlette.testclient import TestClient + +from backend.app.main import app +from backend.app.tests.utils.get_headers import get_token_headers +from backend.app.database.db_mysql import get_db +from backend.app.tests.utils.db_mysql import override_get_db + +app.dependency_overrides[get_db] = override_get_db -@pytest.fixture(scope='session') -def anyio_backend(): - return 'asyncio' +@pytest.fixture(scope='module') +def client() -> Generator: + with TestClient(app) as c: + yield c -@pytest.fixture(scope='package', autouse=True) -async def function_fixture(anyio_backend): - auth_data = { - 'url': f'http://{settings.UVICORN_HOST}:{settings.UVICORN_PORT}/v1/auth/users/login', - 'headers': {'accept': 'application/json', 'Content-Type': 'application/json'}, - 'json': {'username': 'test', 'password': 'test'}, - } - async with AsyncClient() as client: - response = await client.post(**auth_data) - token = response.json()['data']['access_token'] - test_token = await redis_client.get('test_token') - if not test_token: - await redis_client.set('test_token', token, ex=86400) +@pytest.fixture(scope='module') +def token_headers(client: TestClient) -> Dict[str, str]: + return get_token_headers(client=client, username='test', password='test') diff --git a/backend/app/tests/init_test_data.py b/backend/app/tests/init_test_data.py new file mode 100644 index 0000000..fb7ae73 --- /dev/null +++ b/backend/app/tests/init_test_data.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys + +sys.path.append('../../') + +import asyncio + +from backend.app.init_test_data import InitTestData +from backend.app.tests.utils.db_mysql import async_db_session, create_table + +if __name__ == '__main__': + init = InitTestData(session=async_db_session) + loop = asyncio.get_event_loop() + loop.run_until_complete(create_table()) + loop.run_until_complete(init.init_data()) diff --git a/backend/app/tests/test_auth.py b/backend/app/tests/test_auth.py deleted file mode 100644 index 148ab5c..0000000 --- a/backend/app/tests/test_auth.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -import sys - -import pytest -from httpx import AsyncClient - -sys.path.append('../../') - -from backend.app.core.conf import settings # noqa: E402 -from backend.app.main import app # noqa: E402 - - -class TestAuth: - pytestmark = pytest.mark.anyio - - async def test_login(self): - async with AsyncClient( - app=app, headers={'accept': 'application/json', 'Content-Type': 'application/json'} - ) as client: - response = await client.post( - url=f'http://{settings.UVICORN_HOST}:{settings.UVICORN_PORT}/v1/auth/users/login', - json={'username': 'test', 'password': 'test'}, - ) - assert response.status_code == 200 - assert response.json()['data']['token_type'] == 'Bearer' diff --git a/backend/app/tests/test_user.py b/backend/app/tests/test_user.py deleted file mode 100644 index d930ab3..0000000 --- a/backend/app/tests/test_user.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -import sys - -import pytest -from faker import Faker -from httpx import AsyncClient - -sys.path.append('../../') - -from backend.app.core.conf import settings # noqa: E402 -from backend.app.main import app # noqa: E402 -from backend.app.common.redis import redis_client # noqa: E402 - - -class TestUser: - pytestmark = pytest.mark.anyio - faker = Faker(locale='zh_CN') - users_api_base_url = f'http://{settings.UVICORN_HOST}:{settings.UVICORN_PORT}/v1/users' - - @property - async def get_token(self): - token = await redis_client.get('test_token') - return token - - async def test_register(self): - async with AsyncClient( - app=app, headers={'accept': 'application/json', 'Content-Type': 'application/json'} - ) as client: - response = await client.post( - url=f'{self.users_api_base_url}/register', - json={ - 'username': f'{self.faker.user_name()}', - 'nickname': f'{self.faker.name()}', - 'password': f'{self.faker.password()}', - 'email': f'{self.faker.email()}', - 'dept_id': 1, - 'roles': [1], - }, - ) - assert response.status_code == 200 - r_json = response.json() - assert r_json['code'] == 200 - assert r_json['msg'] == 'Success' - - async def test_get_userinfo(self): - async with AsyncClient( - app=app, headers={'accept': 'application/json', 'Authorization': f'Bearer {await self.get_token}'} - ) as client: - response = await client.get(url=f'{self.users_api_base_url}/1') - assert response.status_code == 200 - r_json = response.json() - assert r_json['code'] == 200 - assert r_json['msg'] == 'Success' - - async def test_get_all_users(self): - async with AsyncClient( - app=app, headers={'accept': 'application/json', 'Authorization': f'Bearer {await self.get_token}'} - ) as client: - response = await client.get(url=f'{self.users_api_base_url}?page=1&size=20') - assert response.status_code == 200 - r_json = response.json() - assert isinstance(r_json['data']['items'], list) - assert isinstance(r_json['data']['links'], dict) - assert isinstance(r_json['data']['links']['self'], str) diff --git a/backend/app/tests/utils/__init__.py b/backend/app/tests/utils/__init__.py new file mode 100644 index 0000000..56fafa5 --- /dev/null +++ b/backend/app/tests/utils/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- diff --git a/backend/app/tests/utils/db_mysql.py b/backend/app/tests/utils/db_mysql.py new file mode 100644 index 0000000..d349b11 --- /dev/null +++ b/backend/app/tests/utils/db_mysql.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from sqlalchemy.ext.asyncio import AsyncSession + + +from backend.app.core.conf import settings +from backend.app.database.base_class import MappedBase +from backend.app.database.db_mysql import create_engine_and_session + +TEST_DB_DATABASE = settings.DB_DATABASE + '_test' + +SQLALCHEMY_DATABASE_URL = ( + f'mysql+asyncmy://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:' + f'{settings.DB_PORT}/{TEST_DB_DATABASE}?charset={settings.DB_CHARSET}' +) + +async_engine, async_db_session = create_engine_and_session(SQLALCHEMY_DATABASE_URL) + + +async def override_get_db() -> AsyncSession: + """ + session 生成器 + + :return: + """ + session = async_db_session() + try: + yield session + except Exception as se: + await session.rollback() + raise se + finally: + await session.close() + + +async def create_table(): + """ + 创建数据库表 + """ + async with async_engine.begin() as coon: + await coon.run_sync(MappedBase.metadata.create_all) diff --git a/backend/app/tests/utils/get_headers.py b/backend/app/tests/utils/get_headers.py new file mode 100644 index 0000000..8ad1368 --- /dev/null +++ b/backend/app/tests/utils/get_headers.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import Dict + +from starlette.testclient import TestClient + +from backend.app.core.conf import settings + + +def get_token_headers(client: TestClient, username: str, password: str) -> Dict[str, str]: + data = { + 'username': username, + 'password': password, + } + response = client.post(f'{settings.API_V1_STR}/auth/login', json=data) + token_type = response.json()['data']['access_token_type'] + access_token = response.json()['data']['access_token'] + headers = {'Authorization': f'{token_type} {access_token}'} + return headers