Improve typing and make User pydantic models dynamic

This commit is contained in:
François Voron
2019-10-10 13:37:52 +02:00
parent d9e6e93f08
commit 0112e700ac
12 changed files with 153 additions and 76 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional
import pytest
from fastapi import HTTPException
from starlette import status
@ -5,16 +7,16 @@ from starlette.responses import Response
from fastapi_users.authentication import BaseAuthentication
from fastapi_users.db import BaseUserDatabase
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
from fastapi_users.password import get_password_hash
active_user_data = UserDB(
active_user_data = BaseUserDB(
id="aaa",
email="king.arthur@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
inactive_user_data = UserDB(
inactive_user_data = BaseUserDB(
id="bbb",
email="percival@camelot.bt",
hashed_password=get_password_hash("angharad"),
@ -23,31 +25,31 @@ inactive_user_data = UserDB(
@pytest.fixture
def user() -> UserDB:
def user() -> BaseUserDB:
return active_user_data
@pytest.fixture
def inactive_user() -> UserDB:
def inactive_user() -> BaseUserDB:
return inactive_user_data
class MockUserDatabase(BaseUserDatabase):
async def get(self, id: str) -> UserDB:
async def get(self, id: str) -> Optional[BaseUserDB]:
if id == active_user_data.id:
return active_user_data
elif id == inactive_user_data.id:
return inactive_user_data
return None
async def get_by_email(self, email: str) -> UserDB:
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
if email == active_user_data.email:
return active_user_data
elif email == inactive_user_data.email:
return inactive_user_data
return None
async def create(self, user: UserDB) -> UserDB:
async def create(self, user: BaseUserDB) -> BaseUserDB:
return user
@ -57,10 +59,10 @@ def mock_user_db() -> MockUserDatabase:
class MockAuthentication(BaseAuthentication):
async def get_login_response(self, user: UserDB, response: Response):
async def get_login_response(self, user: BaseUserDB, response: Response):
return {"token": user.id}
async def authenticate(self, token: str) -> UserDB:
async def authenticate(self, token: str) -> BaseUserDB:
user = await self.user_db.get(token)
if user is None or not user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)

View File

@ -6,7 +6,7 @@ from starlette.responses import Response
from starlette.testclient import TestClient
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
from fastapi_users.models import UserDB
from fastapi_users.models import BaseUserDB
SECRET = "SECRET"
ALGORITHM = "HS256"
@ -33,7 +33,7 @@ def test_auth_client(jwt_authentication):
@app.get("/test-auth")
def test_auth(
user: UserDB = Depends(jwt_authentication.get_authentication_method())
user: BaseUserDB = Depends(jwt_authentication.get_authentication_method())
):
return user

View File

@ -1,18 +1,19 @@
import sqlite3
from typing import AsyncGenerator
import pytest
import sqlalchemy
from databases import Database
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase
from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase
@pytest.fixture
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
Base = declarative_base()
async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
Base: DeclarativeMeta = declarative_base()
class User(BaseUser, Base):
class User(BaseUserTable, Base):
pass
DATABASE_URL = "sqlite:///./test.db"

View File

@ -3,13 +3,16 @@ from fastapi import FastAPI
from starlette import status
from starlette.testclient import TestClient
from fastapi_users.models import UserDB
from fastapi_users.router import UserRouter
from fastapi_users.models import BaseUser, BaseUserDB
from fastapi_users.router import get_user_router
@pytest.fixture
def test_app_client(mock_user_db, mock_authentication) -> TestClient:
userRouter = UserRouter(mock_user_db, mock_authentication)
class User(BaseUser):
pass
userRouter = get_user_router(mock_user_db, User, mock_authentication)
app = FastAPI()
app.include_router(userRouter)
@ -68,7 +71,7 @@ class TestLogin:
response = test_app_client.post("/login", data=data)
assert response.status_code == status.HTTP_400_BAD_REQUEST
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
response = test_app_client.post("/login", data=data)
assert response.status_code == status.HTTP_200_OK