Files
Wu Clan e1edcade21 Add RBAC authorisation and some tools or optimisations (#41)
* WIP: add rbac authorization

* Perform pre-commit fixes

* add rbac route whitelist

* add init test data user role associations

* Restore database table id naming to fix generic crud base

* Add database section value uniqueness settings

* Update the test directory to tests

* Update route_name file name to health_check

* Split user auth and user action interfaces

* Fix conflict between merge and current branch

* Add pymysql dependencies

* Fix RBAC authentication method

* Add the select serialisation tool

* Fix missing return messages due to global exception handler slicing

* Update the user interface with associated relationships

* Add items to be completed

* Perform pre-commit fixes

* Add pre-made routers

* Paging data return structure optimisation

* Split user auth and user interface tests

* Fix user register test data structure error

* Fix duplicate named test classes
2023-05-17 22:13:37 +08:00

120 lines
4.9 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import NoReturn
from sqlalchemy import func, select, update, desc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.sql import Select
from backend.app.common import jwt
from backend.app.crud.base import CRUDBase
from backend.app.models import User, Role
from backend.app.schemas.user import CreateUser, UpdateUser, Avatar
class CRUDUser(CRUDBase[User, CreateUser, UpdateUser]):
async def get_user_by_id(self, db: AsyncSession, user_id: int) -> User | None:
return await self.get(db, user_id)
async def get_user_by_username(self, db: AsyncSession, username: str) -> User | None:
user = await db.execute(select(self.model).where(self.model.username == username))
return user.scalars().first()
async def update_user_login_time(self, db: AsyncSession, username: str) -> int:
user = await db.execute(update(self.model).where(self.model.username == username).values(last_login=func.now()))
return user.rowcount
async def create_user(self, db: AsyncSession, create: CreateUser) -> NoReturn:
create.password = jwt.get_hash_password(create.password)
new_user = self.model(**create.dict(exclude={'roles'}))
role_list = []
for role_id in create.roles:
role_list.append(await db.get(Role, role_id))
new_user.roles.append(*role_list)
db.add(new_user)
async def update_userinfo(self, db: AsyncSession, input_user: User, obj: UpdateUser) -> int:
user = await db.execute(
update(self.model).where(self.model.id == input_user.id).values(**obj.dict(exclude={'roles'}))
)
# 删除用户所有角色
for i in list(input_user.roles):
input_user.roles.remove(i)
# 添加用户角色
role_list = []
for role_id in obj.roles:
role_list.append(await db.get(Role, role_id))
input_user.roles.append(*role_list)
return user.rowcount
async def update_avatar(self, db: AsyncSession, current_user: User, avatar: Avatar) -> int:
user = await db.execute(update(self.model).where(self.model.id == current_user.id).values(avatar=avatar))
return user.rowcount
async def delete_user(self, db: AsyncSession, user_id: int) -> int:
return await self.delete(db, user_id)
async def check_email(self, db: AsyncSession, email: str) -> User:
mail = await db.execute(select(self.model).where(self.model.email == email))
return mail.scalars().first()
async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int:
user = await db.execute(
update(self.model).where(self.model.id == pk).values(password=jwt.get_hash_password(password))
)
return user.rowcount
def get_users(self) -> Select:
return (
select(self.model)
.options(selectinload(self.model.roles).selectinload(Role.menus))
.order_by(desc(self.model.time_joined))
)
async def get_user_is_super(self, db: AsyncSession, user_id: int) -> bool:
user = await self.get_user_by_id(db, user_id)
return user.is_superuser
async def get_user_is_active(self, db: AsyncSession, user_id: int) -> bool:
user = await self.get_user_by_id(db, user_id)
return user.is_active
async def super_set(self, db: AsyncSession, user_id: int) -> int:
super_status = await self.get_user_is_super(db, user_id)
user = await db.execute(
update(self.model).where(self.model.id == user_id).values(is_superuser=False if super_status else True)
)
return user.rowcount
async def active_set(self, db: AsyncSession, user_id: int) -> int:
active_status = await self.get_user_is_active(db, user_id)
user = await db.execute(
update(self.model).where(self.model.id == user_id).values(is_active=False if active_status else True)
)
return user.rowcount
async def get_user_role_ids(self, db: AsyncSession, user_id: int) -> list[int]:
user = await db.execute(
select(self.model).where(self.model.id == user_id).options(selectinload(self.model.roles))
)
roles_id = [role.id for role in user.scalars().first().roles]
return roles_id
async def get_user_with_relation(self, db: AsyncSession, *, user_id: int = None, username: str = None) -> User:
where = 'condition'
if user_id:
where = 'self.model.id == user_id'
if username:
where = 'self.model.username == username'
user = await db.execute(
select(self.model)
.where(eval(where))
.options(selectinload(self.model.dept))
.options(selectinload(self.model.roles).joinedload(Role.menus))
)
return user.scalars().first()
UserDao: CRUDUser = CRUDUser(User)