Fix example

This commit is contained in:
long2ice
2021-04-28 22:51:53 +08:00
parent 5e230eca90
commit 8bb34f2032
9 changed files with 168 additions and 54 deletions

View File

@ -1,31 +1,37 @@
import uuid
from gettext import gettext as _
from typing import Callable, Type
import bcrypt
from aioredis import Redis
from fastapi import Depends
from pydantic import EmailStr
from starlette.requests import Request
from starlette.responses import RedirectResponse
from starlette.status import HTTP_303_SEE_OTHER
from starlette.status import HTTP_303_SEE_OTHER, HTTP_401_UNAUTHORIZED
from tortoise import Model, fields
from fastapi_admin import constants
from fastapi_admin.depends import get_redis
from fastapi_admin.template import templates
class LoginProvider:
login_path = "/login"
logout_path = "/logout"
template = "login.html"
def __init__(
self, login_path="/login", logout_path="/logout", template="login.html"
):
self.template = template
self.logout_path = logout_path
self.login_path = login_path
@classmethod
async def get(
cls,
self,
request: Request,
):
return templates.TemplateResponse(cls.template, context={"request": request})
return templates.TemplateResponse(self.template, context={"request": request})
@classmethod
async def post(
cls,
self,
request: Request,
):
"""
@ -34,21 +40,22 @@ class LoginProvider:
:return:
"""
@classmethod
async def authenticate(
cls,
self,
request: Request,
call_next: Callable,
):
response = await call_next(request)
return response
@classmethod
async def logout(cls, request: Request):
def redirect_login(self, request: Request):
return RedirectResponse(
url=request.app.admin_path + cls.login_path, status_code=HTTP_303_SEE_OTHER
url=request.app.admin_path + self.login_path, status_code=HTTP_303_SEE_OTHER
)
async def logout(self, request: Request):
return self.redirect_login(request)
class UserMixin(Model):
username = fields.CharField(max_length=50, unique=True)
@ -60,44 +67,128 @@ class UserMixin(Model):
class UsernamePasswordProvider(LoginProvider):
model: Type[UserMixin]
access_token = "access_token"
@classmethod
async def post(
cls,
request: Request,
def __init__(
self,
user_model: Type[UserMixin],
enable_captcha: bool = False,
login_path="/login",
logout_path="/logout",
template="login.html",
):
super().__init__(login_path, logout_path, template)
self.user_model = user_model
self.enable_captcha = enable_captcha
async def captcha(
self,
request: Request,
width: int = 160,
height: int = 60,
redis: Redis = Depends(get_redis),
):
if not self.enable_captcha:
raise ConfigurationError(error="Should enable captcha first")
captcha = ImageCaptcha(width=width, height=height)
code = utils.generate_random_str(4)
captcha_id = uuid.uuid4().hex
captcha_key = constants.CAPTCHA_ID.format(captcha_id=captcha_id)
image = captcha.generate(code)
response = StreamingResponse(content=image, media_type="image/png")
await redis.set(captcha_key, code, expire=60)
response.set_cookie(
"captcha_id",
captcha_id,
max_age=60,
path=request.app.admin_path,
httponly=True,
)
return response
async def post(self, request: Request, redis: Redis = Depends(get_redis)):
form = await request.form()
username = form.get("username")
password = form.get("password")
user = await cls.model.get_or_none(username=username)
if not user:
return templates.TemplateResponse(
cls.template, context={"request": request, "error": _("no_such_user")}
)
if not cls.check_password(user, password):
return templates.TemplateResponse(
cls.template, context={"request": request, "error": _("password_error")}
)
return RedirectResponse(url=request.app.admin_path, status_code=HTTP_303_SEE_OTHER)
remember_me = form.get("remember_me")
@classmethod
def check_password(cls, user: UserMixin, password: str):
user = await self.user_model.get_or_none(username=username)
if not user or not self.check_password(user, password):
return templates.TemplateResponse(
self.template,
status_code=HTTP_401_UNAUTHORIZED,
context={"request": request, "error": _("login_failed")},
)
response = RedirectResponse(
url=request.app.admin_path, status_code=HTTP_303_SEE_OTHER
)
if remember_me == "on":
expire = 3600 * 24 * 30
response.set_cookie("remember_me", "on")
else:
expire = 3600
response.delete_cookie("remember_me")
token = uuid.uuid4().hex
response.set_cookie(
self.access_token,
token,
expires=expire,
path=request.app.admin_path,
httponly=True,
)
await redis.set(
constants.LOGIN_USER.format(token=token), user.pk, expire=expire
)
return response
async def logout(self, request: Request, redis: Redis = Depends(get_redis)):
response = await super(UsernamePasswordProvider, self).logout(request)
response.delete_cookie(self.access_token)
token = request.cookies.get(self.access_token)
await redis.delete(constants.LOGIN_USER.format(token=token))
return response
async def authenticate(
self,
request: Request,
call_next: Callable,
):
redis = request.app.redis # type:Redis
token = request.cookies.get(self.access_token)
path = request.scope["path"]
token_key = constants.LOGIN_USER.format(token=token)
user_id = await redis.get(token_key)
if not user_id and path != self.login_path:
return self.redirect_login(request)
user = await self.user_model.get_or_none(pk=user_id)
if not user:
if path != self.login_path:
response = self.redirect_login(request)
response.delete_cookie(self.access_token)
return response
else:
if path == self.login_path:
return RedirectResponse(
url=request.app.admin_path, status_code=HTTP_303_SEE_OTHER
)
request.state.user = user
response = await call_next(request)
return response
def check_password(self, user: UserMixin, password: str):
return bcrypt.checkpw(password.encode(), user.password.encode())
@classmethod
def hash_password(cls, password: str):
def hash_password(self, password: str):
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
@classmethod
async def create_user(cls, username: str, password: str, email: EmailStr):
return await cls.model.create(
async def create_user(self, username: str, password: str, email: EmailStr):
return await self.user_model.create(
username=username,
password=cls.hash_password(password),
password=self.hash_password(password),
email=email,
)
@classmethod
async def update_password(cls, user: UserMixin, password: str):
user.password = cls.hash_password(password)
async def update_password(self, user: UserMixin, password: str):
user.password = self.hash_password(password)
await user.save(update_fields=["password"])