mirror of
https://github.com/fastapi-practices/fastapi_best_architecture.git
synced 2025-08-26 04:33:09 +08:00
Add OAuth 2.0 authorization login (#293)
* [WIP] Add OAuth 2.0 authorization login * Add social user relationship table * Update social user relationship table back_populates * Add OAuth 2.0 related interface * Automatically redirect authorization addresses * Update OAuth2 authorization to GitHub * Add implementation code * fix the callback interface return * fix typo * fix the api return * fix imports * Fix logic for creating system users and social tables * Fix user information storage * Add OAuth2 source link * remove unnecessary db refresh * remove the front end docker-compose annotation
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@ __pycache__/
|
||||
.idea/
|
||||
.env
|
||||
venv/
|
||||
.venv/
|
||||
.mypy_cache/
|
||||
backend/app/log/
|
||||
backend/app/alembic/versions/
|
||||
|
@ -173,9 +173,7 @@ Initialize the test data using the `backend/sql/init_test_data.sql` file
|
||||
|
||||
## Development process
|
||||
|
||||
For reference only
|
||||
|
||||
### BackEnd
|
||||
(For reference only)
|
||||
|
||||
1. Define the database model (model) and remember to perform database migration for each change
|
||||
2. Define the data validation model (schema)
|
||||
@ -183,10 +181,6 @@ For reference only
|
||||
4. Define the business logic (service)
|
||||
5. Write database operations (crud)
|
||||
|
||||
### Front
|
||||
|
||||
Go to [fastapi_best_architecture_ui](https://github.com/fastapi-practices/fastapi_best_architecture_ui) for details
|
||||
|
||||
## Test
|
||||
|
||||
Execute unittests via pytest
|
||||
|
@ -166,9 +166,7 @@ mvc 架构作为常规设计模式,在 python web 中也很常见,但是三
|
||||
|
||||
## 开发流程
|
||||
|
||||
仅供参考
|
||||
|
||||
### 后端:
|
||||
(仅供参考)
|
||||
|
||||
1. 定义数据库模型(model),每次变化记得执行数据库迁移
|
||||
2. 定义数据验证模型(schema)
|
||||
@ -176,10 +174,6 @@ mvc 架构作为常规设计模式,在 python web 中也很常见,但是三
|
||||
4. 定义业务逻辑(service)
|
||||
5. 编写数据库操作(crud)
|
||||
|
||||
### 前端
|
||||
|
||||
跳转 [fastapi_best_architecture_ui](https://github.com/fastapi-practices/fastapi_best_architecture_ui) 查看详情
|
||||
|
||||
## 测试
|
||||
|
||||
通过 pytest 执行单元测试
|
||||
|
@ -25,3 +25,6 @@ RABBITMQ_PASSWORD='guest'
|
||||
TOKEN_SECRET_KEY='1VkVF75nsNABBjK_7-qz7GtzNy3AMvktc9TCPwKczCk'
|
||||
# Opera Log
|
||||
OPERA_LOG_ENCRYPT_SECRET_KEY='d77b25790a804c2b4a339dd0207941e4cefa5751935a33735bc73bb7071a005b'
|
||||
# OAuth2
|
||||
OAUTH2_GITHUB_CLIENT_ID='test'
|
||||
OAUTH2_GITHUB_CLIENT_SECRET='test'
|
||||
|
@ -4,8 +4,10 @@ from fastapi import APIRouter
|
||||
|
||||
from backend.app.api.v1.auth.auth import router as auth_router
|
||||
from backend.app.api.v1.auth.captcha import router as captcha_router
|
||||
from backend.app.api.v1.auth.github import router as github_router
|
||||
|
||||
router = APIRouter(prefix='/auth', tags=['授权管理'])
|
||||
|
||||
router.include_router(auth_router)
|
||||
router.include_router(captcha_router)
|
||||
router.include_router(github_router)
|
||||
|
@ -3,27 +3,22 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.security import HTTPBasicCredentials
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from backend.app.common.jwt import DependsJwtAuth
|
||||
from backend.app.common.response.response_schema import ResponseModel, response_base
|
||||
from backend.app.schemas.token import GetLoginToken, GetNewToken, GetSwaggerToken
|
||||
from backend.app.schemas.token import GetSwaggerToken
|
||||
from backend.app.schemas.user import AuthLoginParam
|
||||
from backend.app.services.auth_service import auth_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
'/swagger_login',
|
||||
summary='swagger 表单登录',
|
||||
description='form 格式登录,用于 swagger 文档调试以及获取 JWT Auth',
|
||||
deprecated=True,
|
||||
)
|
||||
async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> GetSwaggerToken:
|
||||
token, user = await auth_service.swagger_login(form_data=form_data)
|
||||
@router.post('/login/swagger', summary='swagger 调试专用', description='用于快捷获取 token 进行 swagger 认证')
|
||||
async def swagger_user_login(obj: Annotated[HTTPBasicCredentials, Depends()]) -> GetSwaggerToken:
|
||||
token, user = await auth_service.swagger_login(obj=obj)
|
||||
return GetSwaggerToken(access_token=token, user=user) # type: ignore
|
||||
|
||||
|
||||
@ -34,33 +29,13 @@ async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -
|
||||
dependencies=[Depends(RateLimiter(times=5, minutes=1))],
|
||||
)
|
||||
async def user_login(request: Request, obj: AuthLoginParam, background_tasks: BackgroundTasks) -> ResponseModel:
|
||||
access_token, refresh_token, access_expire, refresh_expire, user = await auth_service.login(
|
||||
request=request, obj=obj, background_tasks=background_tasks
|
||||
)
|
||||
data = GetLoginToken(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
access_token_expire_time=access_expire,
|
||||
refresh_token_expire_time=refresh_expire,
|
||||
user=user, # type: ignore
|
||||
)
|
||||
data = await auth_service.login(request=request, obj=obj, background_tasks=background_tasks)
|
||||
return await response_base.success(data=data)
|
||||
|
||||
|
||||
@router.post('/new_token', summary='创建新 token', dependencies=[DependsJwtAuth])
|
||||
async def create_new_token(request: Request, refresh_token: Annotated[str, Query(...)]) -> ResponseModel:
|
||||
(
|
||||
new_access_token,
|
||||
new_refresh_token,
|
||||
new_access_token_expire_time,
|
||||
new_refresh_token_expire_time,
|
||||
) = await auth_service.new_token(request=request, refresh_token=refresh_token)
|
||||
data = GetNewToken(
|
||||
access_token=new_access_token,
|
||||
access_token_expire_time=new_access_token_expire_time,
|
||||
refresh_token=new_refresh_token,
|
||||
refresh_token_expire_time=new_refresh_token_expire_time,
|
||||
)
|
||||
data = await auth_service.new_token(request=request, refresh_token=refresh_token)
|
||||
return await response_base.success(data=data)
|
||||
|
||||
|
||||
|
35
backend/app/api/v1/auth/github.py
Normal file
35
backend/app/api/v1/auth/github.py
Normal file
@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Request
|
||||
from fastapi_oauth20 import FastAPIOAuth20, GitHubOAuth20
|
||||
|
||||
from backend.app.common.response.response_schema import ResponseModel, response_base
|
||||
from backend.app.core.conf import settings
|
||||
from backend.app.services.github_service import github_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
github_client = GitHubOAuth20(settings.OAUTH2_GITHUB_CLIENT_ID, settings.OAUTH2_GITHUB_CLIENT_SECRET)
|
||||
github_oauth2 = FastAPIOAuth20(github_client, settings.OAUTH2_GITHUB_REDIRECT_URI)
|
||||
|
||||
|
||||
@router.get('/github', summary='获取 Github 授权链接')
|
||||
async def auth_github() -> ResponseModel:
|
||||
auth_url = await github_client.get_authorization_url(redirect_uri=settings.OAUTH2_GITHUB_REDIRECT_URI)
|
||||
return await response_base.success(data=auth_url)
|
||||
|
||||
|
||||
@router.get(
|
||||
'/github/callback',
|
||||
summary='Github 授权重定向',
|
||||
description='Github 授权后,自动重定向到当前地址并获取用户信息,通过用户信息自动创建系统用户',
|
||||
response_model=None,
|
||||
)
|
||||
async def login_github(
|
||||
request: Request, background_tasks: BackgroundTasks, oauth: FastAPIOAuth20 = Depends(github_oauth2)
|
||||
) -> ResponseModel:
|
||||
token, state = oauth
|
||||
access_token = token['access_token']
|
||||
user = await github_client.get_userinfo(access_token)
|
||||
data = await github_service.add_with_login(request, background_tasks, user)
|
||||
return await response_base.success(data=data)
|
@ -81,3 +81,9 @@ class StatusType(IntEnum):
|
||||
|
||||
disable = 0
|
||||
enable = 1
|
||||
|
||||
|
||||
class UserSocialType(StrEnum):
|
||||
"""用户社交类型"""
|
||||
|
||||
github = 'GitHub'
|
||||
|
@ -4,7 +4,7 @@ from datetime import datetime, timedelta
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security import HTTPBearer, OAuth2PasswordBearer
|
||||
from fastapi.security import HTTPBearer
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
@ -19,8 +19,6 @@ from backend.app.utils.timezone import timezone
|
||||
|
||||
pwd_context = CryptContext(schemes=['bcrypt'], deprecated='auto')
|
||||
|
||||
# Deprecated, may be enabled when oauth2 is actually integrated
|
||||
oauth2_schema = OAuth2PasswordBearer(tokenUrl=settings.TOKEN_URL_SWAGGER)
|
||||
|
||||
# JWT authorizes dependency injection
|
||||
DependsJwtAuth = Depends(HTTPBearer())
|
||||
|
@ -45,6 +45,10 @@ class Settings(BaseSettings):
|
||||
# Env Opera Log
|
||||
OPERA_LOG_ENCRYPT_SECRET_KEY: str # 密钥 os.urandom(32), 需使用 bytes.hex() 方法转换为 str
|
||||
|
||||
# OAuth2:https://github.com/fastapi-practices/fastapi_oauth20
|
||||
OAUTH2_GITHUB_CLIENT_ID: str
|
||||
OAUTH2_GITHUB_CLIENT_SECRET: str
|
||||
|
||||
# FastAPI
|
||||
API_V1_STR: str = '/api/v1'
|
||||
TITLE: str = 'FastAPI'
|
||||
@ -70,6 +74,9 @@ class Settings(BaseSettings):
|
||||
('GET', f'{API_V1_STR}/auth/captcha'),
|
||||
}
|
||||
|
||||
# OAuth2
|
||||
OAUTH2_GITHUB_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/auth/github/callback'
|
||||
|
||||
# Uvicorn
|
||||
UVICORN_HOST: str = '127.0.0.1'
|
||||
UVICORN_PORT: int = 8000
|
||||
|
@ -1,7 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from fast_captcha import text_captcha
|
||||
from sqlalchemy import and_, desc, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@ -12,6 +10,7 @@ from backend.app.common import jwt
|
||||
from backend.app.crud.base import CRUDBase
|
||||
from backend.app.models import Role, User
|
||||
from backend.app.schemas.user import AddUserParam, AvatarParam, RegisterUserParam, UpdateUserParam, UpdateUserRoleParam
|
||||
from backend.app.utils.timezone import timezone
|
||||
|
||||
|
||||
class CRUDUser(CRUDBase[User, RegisterUserParam, UpdateUserParam]):
|
||||
@ -26,24 +25,27 @@ class CRUDUser(CRUDBase[User, RegisterUserParam, UpdateUserParam]):
|
||||
user = await db.execute(select(self.model).where(self.model.nickname == nickname))
|
||||
return user.scalars().first()
|
||||
|
||||
async def update_login_time(self, db: AsyncSession, username: str, login_time: datetime) -> int:
|
||||
async def update_login_time(self, db: AsyncSession, username: str) -> int:
|
||||
user = await db.execute(
|
||||
update(self.model).where(self.model.username == username).values(last_login_time=login_time)
|
||||
update(self.model).where(self.model.username == username).values(last_login_time=timezone.now())
|
||||
)
|
||||
await db.commit()
|
||||
return user.rowcount
|
||||
|
||||
async def create(self, db: AsyncSession, obj: RegisterUserParam) -> None:
|
||||
salt = text_captcha(5)
|
||||
obj.password = await jwt.get_hash_password(obj.password + salt)
|
||||
dict_obj = obj.model_dump()
|
||||
dict_obj.update({'salt': salt})
|
||||
async def create(self, db: AsyncSession, obj: RegisterUserParam, *, social: bool = False) -> None:
|
||||
if not social:
|
||||
salt = text_captcha(5)
|
||||
obj.password = await jwt.get_hash_password(f'{obj.password}{salt}')
|
||||
dict_obj = obj.model_dump()
|
||||
dict_obj.update({'salt': salt})
|
||||
else:
|
||||
dict_obj = obj.model_dump()
|
||||
dict_obj.update({'salt': None})
|
||||
new_user = self.model(**dict_obj)
|
||||
db.add(new_user)
|
||||
|
||||
async def add(self, db: AsyncSession, obj: AddUserParam) -> None:
|
||||
salt = text_captcha(5)
|
||||
obj.password = await jwt.get_hash_password(obj.password + salt)
|
||||
obj.password = await jwt.get_hash_password(f'{obj.password}{salt}')
|
||||
dict_obj = obj.model_dump(exclude={'roles'})
|
||||
dict_obj.update({'salt': salt})
|
||||
new_user = self.model(**dict_obj)
|
||||
|
25
backend/app/crud/crud_user_social.py
Normal file
25
backend/app/crud/crud_user_social.py
Normal file
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from backend.app.common.enums import UserSocialType
|
||||
from backend.app.crud.base import CRUDBase
|
||||
from backend.app.models import UserSocial
|
||||
from backend.app.schemas.user_social import CreateUserSocialParam, UpdateUserSocialParam
|
||||
|
||||
|
||||
class CRUDOUserSocial(CRUDBase[UserSocial, CreateUserSocialParam, UpdateUserSocialParam]):
|
||||
async def get(self, db: AsyncSession, pk: int, source: UserSocialType) -> UserSocial | None:
|
||||
se = select(self.model).where(and_(self.model.id == pk, self.model.source == source))
|
||||
user_social = await db.execute(se)
|
||||
return user_social.scalars().first()
|
||||
|
||||
async def create(self, db: AsyncSession, obj_in: CreateUserSocialParam) -> None:
|
||||
await self.create_(db, obj_in)
|
||||
|
||||
async def delete(self, db: AsyncSession, social_id: int) -> int:
|
||||
return await self.delete_(db, social_id)
|
||||
|
||||
|
||||
user_social_dao: CRUDOUserSocial = CRUDOUserSocial(UserSocial)
|
@ -15,3 +15,4 @@ from backend.app.models.sys_menu import Menu
|
||||
from backend.app.models.sys_opera_log import OperaLog
|
||||
from backend.app.models.sys_role import Role
|
||||
from backend.app.models.sys_user import User
|
||||
from backend.app.models.sys_user_social import UserSocial
|
||||
|
@ -21,8 +21,8 @@ class User(Base):
|
||||
uuid: Mapped[str] = mapped_column(String(50), init=False, default_factory=uuid4_str, unique=True)
|
||||
username: Mapped[str] = mapped_column(String(20), unique=True, index=True, comment='用户名')
|
||||
nickname: Mapped[str] = mapped_column(String(20), unique=True, comment='昵称')
|
||||
password: Mapped[str] = mapped_column(String(255), comment='密码')
|
||||
salt: Mapped[str] = mapped_column(String(5), comment='加密盐')
|
||||
password: Mapped[str | None] = mapped_column(String(255), comment='密码')
|
||||
salt: Mapped[str | None] = mapped_column(String(5), comment='加密盐')
|
||||
email: Mapped[str] = mapped_column(String(50), unique=True, index=True, comment='邮箱')
|
||||
is_superuser: Mapped[bool] = mapped_column(default=False, comment='超级权限(0否 1是)')
|
||||
is_staff: Mapped[bool] = mapped_column(default=False, comment='后台管理登陆(0否 1是)')
|
||||
@ -41,3 +41,5 @@ class User(Base):
|
||||
roles: Mapped[list['Role']] = relationship( # noqa: F821
|
||||
init=False, secondary=sys_user_role, back_populates='users'
|
||||
)
|
||||
# 用户 OAuth2 一对多
|
||||
socials: Mapped[list['UserSocial']] = relationship(init=False, back_populates='user') # noqa: F821
|
||||
|
27
backend/app/models/sys_user_social.py
Normal file
27
backend/app/models/sys_user_social.py
Normal file
@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from backend.app.models.base import Base, id_key
|
||||
|
||||
|
||||
class UserSocial(Base):
|
||||
"""用户社交表(OAuth2)"""
|
||||
|
||||
__tablename__ = 'sys_user_social'
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
source: Mapped[str] = mapped_column(String(20), comment='第三方用户来源')
|
||||
open_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户的 open id')
|
||||
uid: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户的 ID')
|
||||
union_id: Mapped[str | None] = mapped_column(String(20), default=None, comment='第三方用户的 union id')
|
||||
scope: Mapped[str | None] = mapped_column(String(120), default=None, comment='第三方用户授予的权限')
|
||||
code: Mapped[str | None] = mapped_column(String(50), default=None, comment='用户的授权 code')
|
||||
# 用户 OAuth2 一对多
|
||||
user_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey('sys_user.id', ondelete='SET NULL'), default=None, comment='用户关联ID'
|
||||
)
|
||||
user: Mapped[Union['User', None]] = relationship(init=False, back_populates='socials') # noqa: F821
|
@ -12,7 +12,7 @@ from backend.app.schemas.role import GetRoleListDetails
|
||||
|
||||
class AuthSchemaBase(SchemaBase):
|
||||
username: str
|
||||
password: str
|
||||
password: str | None
|
||||
|
||||
|
||||
class AuthLoginParam(AuthSchemaBase):
|
||||
|
21
backend/app/schemas/user_social.py
Normal file
21
backend/app/schemas/user_social.py
Normal file
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from backend.app.common.enums import UserSocialType
|
||||
from backend.app.schemas.base import SchemaBase
|
||||
|
||||
|
||||
class UserSocialSchemaBase(SchemaBase):
|
||||
source: UserSocialType
|
||||
open_id: str | None = None
|
||||
uid: str | None = None
|
||||
union_id: str | None = None
|
||||
scope: str | None = None
|
||||
code: str | None = None
|
||||
|
||||
|
||||
class CreateUserSocialParam(UserSocialSchemaBase):
|
||||
user_id: int
|
||||
|
||||
|
||||
class UpdateUserSocialParam(SchemaBase):
|
||||
pass
|
@ -1,50 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.security import HTTPBasicCredentials
|
||||
from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from backend.app.common import jwt
|
||||
from backend.app.common.enums import LoginLogStatusType
|
||||
from backend.app.common.exception import errors
|
||||
from backend.app.common.jwt import get_token
|
||||
from backend.app.common.redis import redis_client
|
||||
from backend.app.common.response.response_code import CustomErrorCode
|
||||
from backend.app.core.conf import settings
|
||||
from backend.app.crud.crud_user import user_dao
|
||||
from backend.app.database.db_mysql import async_db_session
|
||||
from backend.app.models import User
|
||||
from backend.app.schemas.token import GetLoginToken, GetNewToken
|
||||
from backend.app.schemas.user import AuthLoginParam
|
||||
from backend.app.services.login_log_service import LoginLogService
|
||||
from backend.app.utils.timezone import timezone
|
||||
|
||||
|
||||
class AuthService:
|
||||
login_time = timezone.now()
|
||||
|
||||
async def swagger_login(self, *, form_data: OAuth2PasswordRequestForm) -> tuple[str, User]:
|
||||
async with async_db_session() as db:
|
||||
current_user = await user_dao.get_by_username(db, form_data.username)
|
||||
@staticmethod
|
||||
async def swagger_login(obj: HTTPBasicCredentials) -> tuple[str, User]:
|
||||
async with async_db_session.begin() as db:
|
||||
current_user = await user_dao.get_by_username(db, obj.username)
|
||||
if not current_user:
|
||||
raise errors.NotFoundError(msg='用户不存在')
|
||||
elif not await jwt.password_verify(form_data.password + current_user.salt, current_user.password):
|
||||
elif not await jwt.password_verify(f'{obj.password}{current_user.salt}', current_user.password):
|
||||
raise errors.AuthorizationError(msg='密码错误')
|
||||
elif not current_user.status:
|
||||
raise errors.AuthorizationError(msg='用户已锁定, 登陆失败')
|
||||
# 更新登陆时间
|
||||
await user_dao.update_login_time(db, form_data.username, self.login_time)
|
||||
# 获取最新用户信息
|
||||
user = await user_dao.get(db, current_user.id)
|
||||
# 创建token
|
||||
access_token, _ = await jwt.create_access_token(str(user.id), multi_login=user.is_multi_login)
|
||||
return access_token, user
|
||||
access_token, _ = await jwt.create_access_token(
|
||||
str(current_user.id), multi_login=current_user.is_multi_login
|
||||
)
|
||||
await user_dao.update_login_time(db, obj.username)
|
||||
return access_token, current_user
|
||||
|
||||
async def login(
|
||||
self, *, request: Request, obj: AuthLoginParam, background_tasks: BackgroundTasks
|
||||
) -> tuple[str, str, datetime, datetime, User]:
|
||||
async with async_db_session() as db:
|
||||
@staticmethod
|
||||
async def login(*, request: Request, obj: AuthLoginParam, background_tasks: BackgroundTasks) -> GetLoginToken:
|
||||
async with async_db_session.begin() as db:
|
||||
try:
|
||||
current_user = await user_dao.get_by_username(db, obj.username)
|
||||
if not current_user:
|
||||
@ -58,14 +52,14 @@ class AuthService:
|
||||
raise errors.AuthorizationError(msg='验证码失效,请重新获取')
|
||||
if captcha_code.lower() != obj.captcha.lower():
|
||||
raise errors.CustomError(error=CustomErrorCode.CAPTCHA_ERROR)
|
||||
await user_dao.update_login_time(db, obj.username, self.login_time)
|
||||
user = await user_dao.get(db, current_user.id)
|
||||
access_token, access_token_expire_time = await jwt.create_access_token(
|
||||
str(user.id), multi_login=user.is_multi_login
|
||||
str(current_user.id), multi_login=current_user.is_multi_login
|
||||
)
|
||||
refresh_token, refresh_token_expire_time = await jwt.create_refresh_token(
|
||||
str(user.id), access_token_expire_time, multi_login=user.is_multi_login
|
||||
str(current_user.id), access_token_expire_time, multi_login=current_user.is_multi_login
|
||||
)
|
||||
await user_dao.update_login_time(db, obj.username)
|
||||
await db.refresh(current_user)
|
||||
except errors.NotFoundError as e:
|
||||
raise errors.NotFoundError(msg=e.msg)
|
||||
except (errors.AuthorizationError, errors.CustomError) as e:
|
||||
@ -73,7 +67,7 @@ class AuthService:
|
||||
db=db,
|
||||
request=request,
|
||||
user=current_user,
|
||||
login_time=self.login_time,
|
||||
login_time=timezone.now(),
|
||||
status=LoginLogStatusType.fail.value,
|
||||
msg=e.msg,
|
||||
)
|
||||
@ -82,20 +76,27 @@ class AuthService:
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
log_info = dict(
|
||||
login_log = dict(
|
||||
db=db,
|
||||
request=request,
|
||||
user=user,
|
||||
login_time=self.login_time,
|
||||
user=current_user,
|
||||
login_time=timezone.now(),
|
||||
status=LoginLogStatusType.success.value,
|
||||
msg='登录成功',
|
||||
)
|
||||
background_tasks.add_task(LoginLogService.create, **log_info)
|
||||
background_tasks.add_task(LoginLogService.create, **login_log)
|
||||
await redis_client.delete(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}')
|
||||
return access_token, refresh_token, access_token_expire_time, refresh_token_expire_time, user
|
||||
data = GetLoginToken(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
access_token_expire_time=access_token_expire_time,
|
||||
refresh_token_expire_time=refresh_token_expire_time,
|
||||
user=current_user, # type: ignore
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def new_token(*, request: Request, refresh_token: str) -> tuple[str, str, datetime, datetime]:
|
||||
async def new_token(*, request: Request, refresh_token: str) -> GetNewToken:
|
||||
user_id = await jwt.jwt_decode(refresh_token)
|
||||
if request.user.id != user_id:
|
||||
raise errors.TokenError(msg='刷新 token 无效')
|
||||
@ -105,7 +106,7 @@ class AuthService:
|
||||
raise errors.NotFoundError(msg='用户不存在')
|
||||
elif not current_user.status:
|
||||
raise errors.AuthorizationError(msg='用户已锁定,操作失败')
|
||||
current_token = await get_token(request)
|
||||
current_token = await jwt.get_token(request)
|
||||
(
|
||||
new_access_token,
|
||||
new_refresh_token,
|
||||
@ -114,11 +115,17 @@ class AuthService:
|
||||
) = await jwt.create_new_token(
|
||||
str(current_user.id), current_token, refresh_token, multi_login=current_user.is_multi_login
|
||||
)
|
||||
return new_access_token, new_refresh_token, new_access_token_expire_time, new_refresh_token_expire_time
|
||||
data = GetNewToken(
|
||||
access_token=new_access_token,
|
||||
access_token_expire_time=new_access_token_expire_time,
|
||||
refresh_token=new_refresh_token,
|
||||
refresh_token_expire_time=new_refresh_token_expire_time,
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def logout(*, request: Request) -> None:
|
||||
token = await get_token(request)
|
||||
token = await jwt.get_token(request)
|
||||
if request.user.is_multi_login:
|
||||
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
|
||||
await redis_client.delete(key)
|
||||
|
82
backend/app/services/github_service.py
Normal file
82
backend/app/services/github_service.py
Normal file
@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from fast_captcha import text_captcha
|
||||
from fastapi import BackgroundTasks, Request
|
||||
|
||||
from backend.app.common import jwt
|
||||
from backend.app.common.enums import LoginLogStatusType, UserSocialType
|
||||
from backend.app.common.exception.errors import AuthorizationError
|
||||
from backend.app.common.redis import redis_client
|
||||
from backend.app.core.conf import settings
|
||||
from backend.app.crud.crud_user import user_dao
|
||||
from backend.app.crud.crud_user_social import user_social_dao
|
||||
from backend.app.database.db_mysql import async_db_session
|
||||
from backend.app.schemas.token import GetLoginToken
|
||||
from backend.app.schemas.user import RegisterUserParam
|
||||
from backend.app.schemas.user_social import CreateUserSocialParam
|
||||
from backend.app.services.login_log_service import LoginLogService
|
||||
from backend.app.utils.timezone import timezone
|
||||
|
||||
|
||||
class GithubService:
|
||||
@staticmethod
|
||||
async def add_with_login(request: Request, background_tasks: BackgroundTasks, user: dict) -> GetLoginToken | None:
|
||||
async with async_db_session.begin() as db:
|
||||
github_email = user['email']
|
||||
if not github_email:
|
||||
raise AuthorizationError(msg='授权失败,GitHub 账户未绑定邮箱')
|
||||
github_id = user['id']
|
||||
github_username = user['login']
|
||||
github_nickname = user['name']
|
||||
sys_user = await user_dao.check_email(db, github_email)
|
||||
if not sys_user:
|
||||
# 创建系统用户
|
||||
sys_user = await user_dao.get_by_username(db, github_username)
|
||||
if sys_user:
|
||||
github_username = f'{github_username}{text_captcha(5)}'
|
||||
sys_user = await user_dao.get_by_nickname(db, github_nickname)
|
||||
if sys_user:
|
||||
github_nickname = f'{github_nickname}{text_captcha(5)}'
|
||||
new_sys_user = RegisterUserParam(
|
||||
username=github_username, password=None, nickname=github_nickname, email=github_email
|
||||
)
|
||||
await user_dao.create(db, new_sys_user, social=True)
|
||||
await db.flush()
|
||||
sys_user = await user_dao.check_email(db, github_email)
|
||||
# 绑定社交用户
|
||||
user_social = await user_social_dao.get(db, sys_user.id, UserSocialType.github)
|
||||
if not user_social:
|
||||
new_user_social = CreateUserSocialParam(
|
||||
source=UserSocialType.github, uid=str(github_id), user_id=sys_user.id
|
||||
)
|
||||
await user_social_dao.create(db, new_user_social)
|
||||
# 创建 token
|
||||
access_token, access_token_expire_time = await jwt.create_access_token(
|
||||
str(sys_user.id), multi_login=sys_user.is_multi_login
|
||||
)
|
||||
refresh_token, refresh_token_expire_time = await jwt.create_refresh_token(
|
||||
str(sys_user.id), access_token_expire_time, multi_login=sys_user.is_multi_login
|
||||
)
|
||||
await user_dao.update_login_time(db, sys_user.username)
|
||||
await db.refresh(sys_user)
|
||||
login_log = dict(
|
||||
db=db,
|
||||
request=request,
|
||||
user=sys_user,
|
||||
login_time=timezone.now(),
|
||||
status=LoginLogStatusType.success.value,
|
||||
msg='登录成功(OAuth2)',
|
||||
)
|
||||
background_tasks.add_task(LoginLogService.create, **login_log)
|
||||
await redis_client.delete(f'{settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}')
|
||||
data = GetLoginToken(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
access_token_expire_time=access_token_expire_time,
|
||||
refresh_token_expire_time=refresh_token_expire_time,
|
||||
user=sys_user, # type: ignore
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
github_service: GithubService = GithubService()
|
@ -28,6 +28,8 @@ class UserService:
|
||||
@staticmethod
|
||||
async def register(*, obj: RegisterUserParam) -> None:
|
||||
async with async_db_session.begin() as db:
|
||||
if not obj.password:
|
||||
raise errors.ForbiddenError(msg='密码为空')
|
||||
username = await user_dao.get_by_username(db, obj.username)
|
||||
if username:
|
||||
raise errors.ForbiddenError(msg='该用户名已注册')
|
||||
@ -51,6 +53,11 @@ class UserService:
|
||||
nickname = await user_dao.get_by_nickname(db, obj.nickname)
|
||||
if nickname:
|
||||
raise errors.ForbiddenError(msg='昵称已注册')
|
||||
if not obj.password:
|
||||
raise errors.ForbiddenError(msg='密码为空')
|
||||
email = await user_dao.check_email(db, obj.email)
|
||||
if email:
|
||||
raise errors.ForbiddenError(msg='该邮箱已注册')
|
||||
dept = await dept_dao.get(db, obj.dept_id)
|
||||
if not dept:
|
||||
raise errors.NotFoundError(msg='部门不存在')
|
||||
@ -58,16 +65,12 @@ class UserService:
|
||||
role = await role_dao.get(db, role_id)
|
||||
if not role:
|
||||
raise errors.NotFoundError(msg='角色不存在')
|
||||
email = await user_dao.check_email(db, obj.email)
|
||||
if email:
|
||||
raise errors.ForbiddenError(msg='该邮箱已注册')
|
||||
await user_dao.add(db, obj)
|
||||
|
||||
@staticmethod
|
||||
async def pwd_reset(*, request: Request, obj: ResetPasswordParam) -> int:
|
||||
async with async_db_session.begin() as db:
|
||||
op = obj.old_password
|
||||
if not await password_verify(op + request.user.salt, request.user.password):
|
||||
if not await password_verify(f'{obj.old_password}{request.user.salt}', request.user.password):
|
||||
raise errors.ForbiddenError(msg='旧密码错误')
|
||||
np1 = obj.new_password
|
||||
np2 = obj.confirm_password
|
||||
|
@ -76,28 +76,8 @@ services:
|
||||
networks:
|
||||
- fba_network
|
||||
|
||||
# # For use fastapi_best_architecture_ui, or online/pro env
|
||||
# fba_ui:
|
||||
# build:
|
||||
# context: ⚠️ your fba_ui folder directory ⚠️
|
||||
# dockerfile: Dockerfile
|
||||
# ports:
|
||||
# - "80:80"
|
||||
# - "443:443"
|
||||
# container_name: fba_ui
|
||||
# restart: always
|
||||
# depends_on:
|
||||
# - fba_server
|
||||
# command:
|
||||
# - nginx
|
||||
# - -g
|
||||
# - daemon off;
|
||||
# volumes:
|
||||
# - fba_static:/www/fba_server/backend/app/static
|
||||
# networks:
|
||||
# - fba_network
|
||||
|
||||
# Good for server dev env
|
||||
# The backend is dedicated, which conflicts with fba_ui,If you choose to use fba_ui,
|
||||
# you should stop using fba_nginx container
|
||||
fba_nginx:
|
||||
image: nginx
|
||||
ports:
|
||||
|
17
pdm.lock
generated
17
pdm.lock
generated
@ -5,7 +5,7 @@
|
||||
groups = ["default", "deploy"]
|
||||
strategy = ["cross_platform", "inherit_metadata"]
|
||||
lock_version = "4.4.1"
|
||||
content_hash = "sha256:4dec29a5afb6315a1fb1ec1782576c9e54e27104b7e883525cc65cd097a39430"
|
||||
content_hash = "sha256:06e039c331218934838e6c6865b0c78b8cb14142c3f90f23f7d36844f9a20a09"
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@ -533,6 +533,21 @@ files = [
|
||||
{file = "fastapi_limiter-0.1.6.tar.gz", hash = "sha256:6f5fde8efebe12eb33861bdffb91009f699369a3c2862cdc7c1d9acf912ff443"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi-oauth20"
|
||||
version = "0.0.1a1"
|
||||
requires_python = ">=3.10"
|
||||
summary = "在 FastAPI 中异步授权 OAuth2 客户端"
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"fastapi>=0.100.0",
|
||||
"httpx>=0.18.0",
|
||||
]
|
||||
files = [
|
||||
{file = "fastapi_oauth20-0.0.1a1-py3-none-any.whl", hash = "sha256:02247a49f1c9ffc364d13857dc29abf49783f36bbefa4632e196aca778f58ede"},
|
||||
{file = "fastapi_oauth20-0.0.1a1.tar.gz", hash = "sha256:f3d2eda24c10fdfe81735859f0346a8bdb687246193134bc285c8110daf1c221"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi-pagination"
|
||||
version = "0.12.13"
|
||||
|
@ -46,6 +46,7 @@ dependencies = [
|
||||
"user-agents==2.2.0",
|
||||
"uvicorn[standard]==0.24.0",
|
||||
"XdbSearchIP==1.0.2",
|
||||
"fastapi-oauth20>=0.0.1a1",
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
readme = "README.md"
|
||||
|
@ -33,6 +33,7 @@ exceptiongroup==1.2.0; python_version < "3.11"
|
||||
fast-captcha==0.2.1
|
||||
fastapi==0.108.0
|
||||
fastapi-limiter==0.1.6
|
||||
fastapi-oauth20==0.0.1a1
|
||||
fastapi-pagination==0.12.13
|
||||
filelock==3.13.1
|
||||
greenlet==3.0.3; platform_machine == "win32" or platform_machine == "WIN32" or platform_machine == "AMD64" or platform_machine == "amd64" or platform_machine == "x86_64" or platform_machine == "ppc64le" or platform_machine == "aarch64"
|
||||
|
Reference in New Issue
Block a user