mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-11-02 12:21:53 +08:00
Improve typing and make User pydantic models dynamic
This commit is contained in:
1
Pipfile
1
Pipfile
@ -14,6 +14,7 @@ flake8-docstrings = "*"
|
|||||||
mkdocs = "*"
|
mkdocs = "*"
|
||||||
mkdocs-material = "*"
|
mkdocs-material = "*"
|
||||||
black = "*"
|
black = "*"
|
||||||
|
mypy = "*"
|
||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
fastapi = "*"
|
fastapi = "*"
|
||||||
|
|||||||
53
Pipfile.lock
generated
53
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "4687ef95ee5576f1882e551641586bdbeda40a663bb6d9b9ff95d4259e4cd023"
|
"sha256": "c1aa33adc0a5d3c81741b012a2cab88f01697fa74b7259bfbdaaec18ebe60036"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -383,6 +383,29 @@
|
|||||||
],
|
],
|
||||||
"version": "==7.2.0"
|
"version": "==7.2.0"
|
||||||
},
|
},
|
||||||
|
"mypy": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
|
||||||
|
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
|
||||||
|
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
|
||||||
|
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
|
||||||
|
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
|
||||||
|
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
|
||||||
|
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
|
||||||
|
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
|
||||||
|
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
|
||||||
|
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
|
||||||
|
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==0.730"
|
||||||
|
},
|
||||||
|
"mypy-extensions": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
|
||||||
|
],
|
||||||
|
"version": "==0.4.2"
|
||||||
|
},
|
||||||
"packaging": {
|
"packaging": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47",
|
"sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47",
|
||||||
@ -535,6 +558,34 @@
|
|||||||
],
|
],
|
||||||
"version": "==6.0.3"
|
"version": "==6.0.3"
|
||||||
},
|
},
|
||||||
|
"typed-ast": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
|
||||||
|
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
|
||||||
|
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
|
||||||
|
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
|
||||||
|
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
|
||||||
|
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
|
||||||
|
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
|
||||||
|
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
|
||||||
|
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
|
||||||
|
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
|
||||||
|
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
|
||||||
|
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
|
||||||
|
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
|
||||||
|
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
|
||||||
|
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
|
||||||
|
],
|
||||||
|
"version": "==1.4.0"
|
||||||
|
},
|
||||||
|
"typing-extensions": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95",
|
||||||
|
"sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87",
|
||||||
|
"sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed"
|
||||||
|
],
|
||||||
|
"version": "==3.7.4"
|
||||||
|
},
|
||||||
"urllib3": {
|
"urllib3": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398",
|
"sha256:3de946ffbed6e6746608990594d08faac602528ac7015ac28d33cee6a45b7398",
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Callable
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
|
|
||||||
|
|
||||||
class BaseAuthentication:
|
class BaseAuthentication:
|
||||||
@ -13,8 +13,8 @@ class BaseAuthentication:
|
|||||||
def __init__(self, user_db: BaseUserDatabase):
|
def __init__(self, user_db: BaseUserDatabase):
|
||||||
self.user_db = user_db
|
self.user_db = user_db
|
||||||
|
|
||||||
async def get_login_response(self, user: UserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_authentication_method(self) -> Callable[..., UserDB]:
|
def get_authentication_method(self) -> Callable[..., BaseUserDB]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from starlette import status
|
|||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
from fastapi_users.authentication import BaseAuthentication
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class JWTAuthentication(BaseAuthentication):
|
|||||||
self.secret = secret
|
self.secret = secret
|
||||||
self.lifetime_seconds = lifetime_seconds
|
self.lifetime_seconds = lifetime_seconds
|
||||||
|
|
||||||
async def get_login_response(self, user: UserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
data = {"user_id": user.id}
|
data = {"user_id": user.id}
|
||||||
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
token = generate_jwt(data, self.lifetime_seconds, self.secret, self.algorithm)
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ class JWTAuthentication(BaseAuthentication):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
data = jwt.decode(token, self.secret, algorithms=[self.algorithm])
|
||||||
user_id: str = data.get("user_id")
|
user_id = data.get("user_id")
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
from fastapi_users.password import get_password_hash, verify_and_update_password
|
from fastapi_users.password import get_password_hash, verify_and_update_password
|
||||||
|
|
||||||
|
|
||||||
@ -12,22 +12,24 @@ class BaseUserDatabase:
|
|||||||
the database.
|
the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def list(self) -> List[UserDB]:
|
async def list(self) -> List[BaseUserDB]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def get(self, id: str) -> UserDB:
|
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> UserDB:
|
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def create(self, user: UserDB) -> UserDB:
|
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def update(self, user: UserDB) -> UserDB:
|
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def authenticate(self, credentials: OAuth2PasswordRequestForm) -> UserDB:
|
async def authenticate(
|
||||||
|
self, credentials: OAuth2PasswordRequestForm
|
||||||
|
) -> Optional[BaseUserDB]:
|
||||||
user = await self.get_by_email(credentials.username)
|
user = await self.get_by_email(credentials.username)
|
||||||
|
|
||||||
# Always run the hasher to mitigate timing attack
|
# Always run the hasher to mitigate timing attack
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
from typing import List
|
from typing import List, cast
|
||||||
|
|
||||||
from databases import Database
|
from databases import Database
|
||||||
from sqlalchemy import Boolean, Column, String, Table
|
from sqlalchemy import Boolean, Column, String, Table
|
||||||
|
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
|
|
||||||
|
|
||||||
class BaseUser:
|
class BaseUserTable:
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
@ -26,24 +26,24 @@ class SQLAlchemyUserDatabase(BaseUserDatabase):
|
|||||||
self.database = database
|
self.database = database
|
||||||
self.users = users
|
self.users = users
|
||||||
|
|
||||||
async def list(self) -> List[UserDB]:
|
async def list(self) -> List[BaseUserDB]:
|
||||||
query = self.users.select()
|
query = self.users.select()
|
||||||
return await self.database.fetch_all(query)
|
return cast(List[BaseUserDB], await self.database.fetch_all(query))
|
||||||
|
|
||||||
async def get(self, id: str) -> UserDB:
|
async def get(self, id: str) -> BaseUserDB:
|
||||||
query = self.users.select().where(self.users.c.id == id)
|
query = self.users.select().where(self.users.c.id == id)
|
||||||
return await self.database.fetch_one(query)
|
return cast(BaseUserDB, await self.database.fetch_one(query))
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> UserDB:
|
async def get_by_email(self, email: str) -> BaseUserDB:
|
||||||
query = self.users.select().where(self.users.c.email == email)
|
query = self.users.select().where(self.users.c.email == email)
|
||||||
return await self.database.fetch_one(query)
|
return cast(BaseUserDB, await self.database.fetch_one(query))
|
||||||
|
|
||||||
async def create(self, user: UserDB) -> UserDB:
|
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
query = self.users.insert().values(**user.dict())
|
query = self.users.insert().values(**user.dict())
|
||||||
await self.database.execute(query)
|
await self.database.execute(query)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def update(self, user: UserDB) -> UserDB:
|
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
query = (
|
query = (
|
||||||
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
self.users.update().where(self.users.c.id == user.id).values(**user.dict())
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.types import EmailStr
|
from pydantic.types import EmailStr
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class BaseUser(BaseModel):
|
||||||
id: str = None
|
id: Optional[str] = None
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
is_active: Optional[bool] = True
|
is_active: Optional[bool] = True
|
||||||
is_superuser: Optional[bool] = False
|
is_superuser: Optional[bool] = False
|
||||||
@ -17,18 +17,31 @@ class UserBase(BaseModel):
|
|||||||
return v or str(uuid.uuid4())
|
return v or str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
class BaseUserCreate(BaseUser):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
class UserUpdate(UserBase):
|
class BaseUserUpdate(BaseUser):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UserDB(UserBase):
|
class BaseUserDB(BaseUser):
|
||||||
hashed_password: str
|
hashed_password: str
|
||||||
|
|
||||||
|
|
||||||
class User(UserBase):
|
class Models:
|
||||||
pass
|
def __init__(self, user_model: Type[BaseUser]):
|
||||||
|
class UserCreate(user_model, BaseUserCreate): # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserUpdate(user_model, BaseUserUpdate): # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserDB(user_model, BaseUserDB): # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.User = user_model
|
||||||
|
self.UserCreate = UserCreate
|
||||||
|
self.UserUpdate = UserUpdate
|
||||||
|
self.UserDB = UserDB
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from starlette import status
|
from starlette import status
|
||||||
@ -5,32 +7,34 @@ from starlette.responses import Response
|
|||||||
|
|
||||||
from fastapi_users.authentication import BaseAuthentication
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import User, UserCreate, UserDB
|
from fastapi_users.models import BaseUser, Models
|
||||||
from fastapi_users.password import get_password_hash
|
from fastapi_users.password import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
class UserRouter:
|
def get_user_router(
|
||||||
def __new__(cls, user_db: BaseUserDatabase, auth: BaseAuthentication) -> APIRouter:
|
user_db: BaseUserDatabase, user_model: Type[BaseUser], auth: BaseAuthentication
|
||||||
router = APIRouter()
|
) -> APIRouter:
|
||||||
|
router = APIRouter()
|
||||||
|
models = Models(user_model)
|
||||||
|
|
||||||
@router.post("/register", response_model=User)
|
@router.post("/register", response_model=models.User)
|
||||||
async def register(user: UserCreate):
|
async def register(user: models.UserCreate): # type: ignore
|
||||||
hashed_password = get_password_hash(user.password)
|
hashed_password = get_password_hash(user.password)
|
||||||
db_user = UserDB(**user.dict(), hashed_password=hashed_password)
|
db_user = models.UserDB(**user.dict(), hashed_password=hashed_password)
|
||||||
created_user = await user_db.create(db_user)
|
created_user = await user_db.create(db_user)
|
||||||
return created_user
|
return created_user
|
||||||
|
|
||||||
@router.post("/login")
|
@router.post("/login")
|
||||||
async def login(
|
async def login(
|
||||||
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
response: Response, credentials: OAuth2PasswordRequestForm = Depends()
|
||||||
):
|
):
|
||||||
user = await user_db.authenticate(credentials)
|
user = await user_db.authenticate(credentials)
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
elif not user.is_active:
|
elif not user.is_active:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
return await auth.get_login_response(user, response)
|
return await auth.get_login_response(user, response)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from starlette import status
|
from starlette import status
|
||||||
@ -5,16 +7,16 @@ from starlette.responses import Response
|
|||||||
|
|
||||||
from fastapi_users.authentication import BaseAuthentication
|
from fastapi_users.authentication import BaseAuthentication
|
||||||
from fastapi_users.db import BaseUserDatabase
|
from fastapi_users.db import BaseUserDatabase
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
from fastapi_users.password import get_password_hash
|
from fastapi_users.password import get_password_hash
|
||||||
|
|
||||||
active_user_data = UserDB(
|
active_user_data = BaseUserDB(
|
||||||
id="aaa",
|
id="aaa",
|
||||||
email="king.arthur@camelot.bt",
|
email="king.arthur@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
)
|
)
|
||||||
|
|
||||||
inactive_user_data = UserDB(
|
inactive_user_data = BaseUserDB(
|
||||||
id="bbb",
|
id="bbb",
|
||||||
email="percival@camelot.bt",
|
email="percival@camelot.bt",
|
||||||
hashed_password=get_password_hash("angharad"),
|
hashed_password=get_password_hash("angharad"),
|
||||||
@ -23,31 +25,31 @@ inactive_user_data = UserDB(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user() -> UserDB:
|
def user() -> BaseUserDB:
|
||||||
return active_user_data
|
return active_user_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def inactive_user() -> UserDB:
|
def inactive_user() -> BaseUserDB:
|
||||||
return inactive_user_data
|
return inactive_user_data
|
||||||
|
|
||||||
|
|
||||||
class MockUserDatabase(BaseUserDatabase):
|
class MockUserDatabase(BaseUserDatabase):
|
||||||
async def get(self, id: str) -> UserDB:
|
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||||
if id == active_user_data.id:
|
if id == active_user_data.id:
|
||||||
return active_user_data
|
return active_user_data
|
||||||
elif id == inactive_user_data.id:
|
elif id == inactive_user_data.id:
|
||||||
return inactive_user_data
|
return inactive_user_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> UserDB:
|
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||||
if email == active_user_data.email:
|
if email == active_user_data.email:
|
||||||
return active_user_data
|
return active_user_data
|
||||||
elif email == inactive_user_data.email:
|
elif email == inactive_user_data.email:
|
||||||
return inactive_user_data
|
return inactive_user_data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def create(self, user: UserDB) -> UserDB:
|
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@ -57,10 +59,10 @@ def mock_user_db() -> MockUserDatabase:
|
|||||||
|
|
||||||
|
|
||||||
class MockAuthentication(BaseAuthentication):
|
class MockAuthentication(BaseAuthentication):
|
||||||
async def get_login_response(self, user: UserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
return {"token": user.id}
|
return {"token": user.id}
|
||||||
|
|
||||||
async def authenticate(self, token: str) -> UserDB:
|
async def authenticate(self, token: str) -> BaseUserDB:
|
||||||
user = await self.user_db.get(token)
|
user = await self.user_db.get(token)
|
||||||
if user is None or not user.is_active:
|
if user is None or not user.is_active:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from starlette.responses import Response
|
|||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
|
from fastapi_users.authentication.jwt import JWTAuthentication, generate_jwt
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUserDB
|
||||||
|
|
||||||
SECRET = "SECRET"
|
SECRET = "SECRET"
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
@ -33,7 +33,7 @@ def test_auth_client(jwt_authentication):
|
|||||||
|
|
||||||
@app.get("/test-auth")
|
@app.get("/test-auth")
|
||||||
def test_auth(
|
def test_auth(
|
||||||
user: UserDB = Depends(jwt_authentication.get_authentication_method())
|
user: BaseUserDB = Depends(jwt_authentication.get_authentication_method())
|
||||||
):
|
):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|||||||
@ -1,18 +1,19 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from databases import Database
|
from databases import Database
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
|
||||||
|
|
||||||
from fastapi_users.db.sqlalchemy import BaseUser, SQLAlchemyUserDatabase
|
from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def sqlalchemy_user_db() -> SQLAlchemyUserDatabase:
|
async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
|
||||||
Base = declarative_base()
|
Base: DeclarativeMeta = declarative_base()
|
||||||
|
|
||||||
class User(BaseUser, Base):
|
class User(BaseUserTable, Base):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
DATABASE_URL = "sqlite:///./test.db"
|
DATABASE_URL = "sqlite:///./test.db"
|
||||||
|
|||||||
@ -3,13 +3,16 @@ from fastapi import FastAPI
|
|||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from fastapi_users.models import UserDB
|
from fastapi_users.models import BaseUser, BaseUserDB
|
||||||
from fastapi_users.router import UserRouter
|
from fastapi_users.router import get_user_router
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_app_client(mock_user_db, mock_authentication) -> TestClient:
|
def test_app_client(mock_user_db, mock_authentication) -> TestClient:
|
||||||
userRouter = UserRouter(mock_user_db, mock_authentication)
|
class User(BaseUser):
|
||||||
|
pass
|
||||||
|
|
||||||
|
userRouter = get_user_router(mock_user_db, User, mock_authentication)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(userRouter)
|
app.include_router(userRouter)
|
||||||
@ -68,7 +71,7 @@ class TestLogin:
|
|||||||
response = test_app_client.post("/login", data=data)
|
response = test_app_client.post("/login", data=data)
|
||||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
|
||||||
def test_valid_credentials(self, test_app_client: TestClient, user: UserDB):
|
def test_valid_credentials(self, test_app_client: TestClient, user: BaseUserDB):
|
||||||
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
||||||
response = test_app_client.post("/login", data=data)
|
response = test_app_client.post("/login", data=data)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|||||||
Reference in New Issue
Block a user