mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Use real UUID for User id. and OAuthAccount id. (#198)
* Use UUID for user id and oauth account id * Update documentation for UUID * Tweak GUID definition of SQLAlchemy to match Tortoise ORM one * Write migration doc
This commit is contained in:
2
Makefile
2
Makefile
@ -13,7 +13,7 @@ format: isort-src isort-docs
|
|||||||
test:
|
test:
|
||||||
docker stop $(MONGODB_CONTAINER_NAME) || true
|
docker stop $(MONGODB_CONTAINER_NAME) || true
|
||||||
docker run -d --rm --name $(MONGODB_CONTAINER_NAME) -p 27017:27017 mvertes/alpine-mongo
|
docker run -d --rm --name $(MONGODB_CONTAINER_NAME) -p 27017:27017 mvertes/alpine-mongo
|
||||||
$(PIPENV_RUN) pytest --cov=fastapi_users/
|
$(PIPENV_RUN) pytest --cov=fastapi_users/ --cov-report=term-missing
|
||||||
docker stop $(MONGODB_CONTAINER_NAME)
|
docker stop $(MONGODB_CONTAINER_NAME)
|
||||||
|
|
||||||
docs-serve:
|
docs-serve:
|
||||||
|
@ -12,11 +12,14 @@ Let's create a MongoDB connection and instantiate a collection.
|
|||||||
|
|
||||||
You can choose any name for the database and the collection.
|
You can choose any name for the database and the collection.
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
You may have noticed the `uuidRepresentation` parameter. It controls how the UUID values will be encoded in the database. By default, it's set to `pythonLegacy` but new applications should consider setting this to `standard` for cross language compatibility. [Read more about this](https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient).
|
||||||
|
|
||||||
## Create the database adapter
|
## Create the database adapter
|
||||||
|
|
||||||
The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. Create it like this.
|
The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. Create it like this.
|
||||||
|
|
||||||
```py hl_lines="32"
|
```py hl_lines="34"
|
||||||
{!./src/db_mongodb.py!}
|
{!./src/db_mongodb.py!}
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -26,7 +29,7 @@ Notice that we pass a reference to your [`UserDB` model](../model.md).
|
|||||||
The database adapter will automatically create a [unique index](https://docs.mongodb.com/manual/core/index-unique/) on `id` and `email`.
|
The database adapter will automatically create a [unique index](https://docs.mongodb.com/manual/core/index-unique/) on `id` and `email`.
|
||||||
|
|
||||||
!!! warning
|
!!! warning
|
||||||
**FastAPI Users** will use its defined [`id` UUID-string](../model.md) as unique identifier for the user, rather than the builtin MongoDB `_id`.
|
**FastAPI Users** will use its defined [`id` UUID](../model.md) as unique identifier for the user, rather than the builtin MongoDB `_id`.
|
||||||
|
|
||||||
## Next steps
|
## Next steps
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
**FastAPI Users** defines a minimal User model for authentication purposes. It is structured like this:
|
**FastAPI Users** defines a minimal User model for authentication purposes. It is structured like this:
|
||||||
|
|
||||||
* `id` (`str`) – Unique identifier of the user. Default to a **UUID4**.
|
* `id` (`UUID4`) – Unique identifier of the user. Default to a **UUID4**.
|
||||||
* `email` (`str`) – Email of the user. Validated by [`email-validator`](https://github.com/JoshData/python-email-validator).
|
* `email` (`str`) – Email of the user. Validated by [`email-validator`](https://github.com/JoshData/python-email-validator).
|
||||||
* `is_active` (`bool`) – Whether or not the user is active. If not, login and forgot password requests will be denied. Default to `True`.
|
* `is_active` (`bool`) – Whether or not the user is active. If not, login and forgot password requests will be denied. Default to `True`.
|
||||||
* `is_superuser` (`bool`) – Whether or not the user is a superuser. Useful to implement administration logic. Default to `False`.
|
* `is_superuser` (`bool`) – Whether or not the user is a superuser. Useful to implement administration logic. Default to `False`.
|
||||||
|
@ -56,7 +56,7 @@ class UserDB(User, models.BaseUserDB):
|
|||||||
|
|
||||||
Notice that we inherit from the `BaseOAuthAccountMixin`, which adds a `List` of `BaseOAuthAccount` objects. This object is structured like this:
|
Notice that we inherit from the `BaseOAuthAccountMixin`, which adds a `List` of `BaseOAuthAccount` objects. This object is structured like this:
|
||||||
|
|
||||||
* `id` (`str`) – Unique identifier of the user. Default to a **UUID4**.
|
* `id` (`UUID4`) – Unique identifier of the OAuth account information. Default to a **UUID4**.
|
||||||
* `oauth_name` (`str`) – Name of the OAuth service. It corresponds to the `name` property of the OAuth client.
|
* `oauth_name` (`str`) – Name of the OAuth service. It corresponds to the `name` property of the OAuth client.
|
||||||
* `access_token` (`str`) – Access token.
|
* `access_token` (`str`) – Access token.
|
||||||
* `expires_at` (`int`) - Timestamp at which the access token is expired.
|
* `expires_at` (`int`) - Timestamp at which the access token is expired.
|
||||||
@ -100,7 +100,7 @@ class OAuthAccount(TortoiseBaseOAuthAccountModel):
|
|||||||
```
|
```
|
||||||
|
|
||||||
!!! warning
|
!!! warning
|
||||||
Note that you shouls define the foreign key yourself, so that you can point it the user model in your namespace.
|
Note that you should define the foreign key yourself, so that you can point it the user model in your namespace.
|
||||||
|
|
||||||
Then, you should declare it on the database adapter:
|
Then, you should declare it on the database adapter:
|
||||||
|
|
||||||
|
82
docs/migration/08_to_10.md
Normal file
82
docs/migration/08_to_10.md
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# 0.8.x ➡️ 1.0.x
|
||||||
|
|
||||||
|
1.0 version introduces major breaking changes that need you to update some of your code and migrate your data.
|
||||||
|
|
||||||
|
## Id. are UUID
|
||||||
|
|
||||||
|
Users and OAuth accounts id. are now represented as real UUID objects instead of plain strings. This change was introduced to leverage efficient storage and indexing for DBMS that supports UUID (especially PostgreSQL and Mongo).
|
||||||
|
|
||||||
|
### In Python code
|
||||||
|
|
||||||
|
If you were doing comparison betwen a user id. and a string (in unit tests for example), you should now cast the id. to string:
|
||||||
|
|
||||||
|
```py
|
||||||
|
# Before
|
||||||
|
assert "d35d213e-f3d8-4f08-954a-7e0d1bea286f" == user.id
|
||||||
|
|
||||||
|
# Now
|
||||||
|
assert "d35d213e-f3d8-4f08-954a-7e0d1bea286f" == str(user.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
If you were refering to user id. in your Pydantic models, the field should now be of `UUID4` type instead of `str`:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from pydantic import BaseModel, UUID4
|
||||||
|
|
||||||
|
# Before
|
||||||
|
class Model(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
# After
|
||||||
|
class Model(BaseModel):
|
||||||
|
user_id: UUID4
|
||||||
|
```
|
||||||
|
|
||||||
|
### In database
|
||||||
|
|
||||||
|
Id. were before stored as strings in the database. You should make a migration to convert string data to UUID data.
|
||||||
|
|
||||||
|
!!! danger
|
||||||
|
Scripts below are provided as guidelines. Please **review them carefully**, **adapt them** and check that they are working on a test database before applying them to production. **BE CAREFUL. THEY CAN DESTROY YOUR DATA.**.
|
||||||
|
|
||||||
|
#### PostgreSQL
|
||||||
|
|
||||||
|
PostgreSQL supports UUID type. If not already, you should enable the `uuid-ossp` extension:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
```
|
||||||
|
|
||||||
|
To convert the existing id. string column, we can:
|
||||||
|
|
||||||
|
1. Create a new column with UUID type.
|
||||||
|
2. Fill it with the id. converted to UUID.
|
||||||
|
3. Drop the original id. column.
|
||||||
|
4. Make the new column a primary key and rename it.
|
||||||
|
|
||||||
|
```sql
|
||||||
|
ALTER TABLE "user" ADD uuid_id UUID;
|
||||||
|
UPDATE "user" SET uuid_id = uuid(id);
|
||||||
|
ALTER TABLE "user" DROP id;
|
||||||
|
ALTER TABLE "user" ADD PRIMARY KEY (uuid_id);
|
||||||
|
ALTER TABLE "user" RENAME COLUMN uuid_id TO id;
|
||||||
|
```
|
||||||
|
|
||||||
|
#### MySQL
|
||||||
|
|
||||||
|
MySQL doesn't support UUID type. We'll just convert the column to `CHAR(36)` type:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
ALTER TABLE "user" MODIFY id CHAR(36);
|
||||||
|
```
|
||||||
|
|
||||||
|
#### MongoDB
|
||||||
|
|
||||||
|
For MongoDB, we can use a `forEach` iterator to convert the id. for each document:
|
||||||
|
|
||||||
|
```js
|
||||||
|
db.getCollection('users').find().forEach(function(user) {
|
||||||
|
var uuid = UUID(user.id);
|
||||||
|
db.getCollection('users').update({_id: user._id}, [{$set: {id: uuid}}]);
|
||||||
|
});
|
||||||
|
```
|
@ -21,7 +21,9 @@ class UserDB(User, models.BaseUserDB):
|
|||||||
|
|
||||||
|
|
||||||
DATABASE_URL = "mongodb://localhost:27017"
|
DATABASE_URL = "mongodb://localhost:27017"
|
||||||
client = motor.motor_asyncio.AsyncIOMotorClient(DATABASE_URL)
|
client = motor.motor_asyncio.AsyncIOMotorClient(
|
||||||
|
DATABASE_URL, uuidRepresentation="standard"
|
||||||
|
)
|
||||||
db = client["database_name"]
|
db = client["database_name"]
|
||||||
collection = db["users"]
|
collection = db["users"]
|
||||||
|
|
||||||
|
@ -25,7 +25,9 @@ class UserDB(User, models.BaseUserDB):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
client = motor.motor_asyncio.AsyncIOMotorClient(DATABASE_URL)
|
client = motor.motor_asyncio.AsyncIOMotorClient(
|
||||||
|
DATABASE_URL, uuidRepresentation="standard"
|
||||||
|
)
|
||||||
db = client["database_name"]
|
db = client["database_name"]
|
||||||
collection = db["users"]
|
collection = db["users"]
|
||||||
user_db = MongoDBUserDatabase(UserDB, collection)
|
user_db = MongoDBUserDatabase(UserDB, collection)
|
||||||
|
@ -29,7 +29,9 @@ class UserDB(User, models.BaseUserDB):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
client = motor.motor_asyncio.AsyncIOMotorClient(DATABASE_URL)
|
client = motor.motor_asyncio.AsyncIOMotorClient(
|
||||||
|
DATABASE_URL, uuidRepresentation="standard"
|
||||||
|
)
|
||||||
db = client["database_name"]
|
db = client["database_name"]
|
||||||
collection = db["users"]
|
collection = db["users"]
|
||||||
user_db = MongoDBUserDatabase(UserDB, collection)
|
user_db = MongoDBUserDatabase(UserDB, collection)
|
||||||
|
@ -2,6 +2,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import UUID4
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
@ -56,7 +57,12 @@ class JWTAuthentication(BaseAuthentication):
|
|||||||
return None
|
return None
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
return None
|
return None
|
||||||
return await user_db.get(user_id)
|
|
||||||
|
try:
|
||||||
|
user_uiid = UUID4(user_id)
|
||||||
|
return await user_db.get(user_uiid)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
async def get_login_response(self, user: BaseUserDB, response: Response) -> Any:
|
||||||
token = await self._generate_token(user)
|
token = await self._generate_token(user)
|
||||||
@ -66,5 +72,5 @@ class JWTAuthentication(BaseAuthentication):
|
|||||||
return await self.scheme.__call__(request)
|
return await self.scheme.__call__(request)
|
||||||
|
|
||||||
async def _generate_token(self, user: BaseUserDB) -> str:
|
async def _generate_token(self, user: BaseUserDB) -> str:
|
||||||
data = {"user_id": user.id, "aud": self.token_audience}
|
data = {"user_id": str(user.id), "aud": self.token_audience}
|
||||||
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
|
return generate_jwt(data, self.lifetime_seconds, self.secret, JWT_ALGORITHM)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Generic, Optional, Type
|
from typing import Generic, Optional, Type
|
||||||
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from pydantic import UUID4
|
||||||
|
|
||||||
from fastapi_users import password
|
from fastapi_users import password
|
||||||
from fastapi_users.models import UD
|
from fastapi_users.models import UD
|
||||||
@ -18,7 +19,7 @@ class BaseUserDatabase(Generic[UD]):
|
|||||||
def __init__(self, user_db_model: Type[UD]):
|
def __init__(self, user_db_model: Type[UD]):
|
||||||
self.user_db_model = user_db_model
|
self.user_db_model = user_db_model
|
||||||
|
|
||||||
async def get(self, id: str) -> Optional[UD]:
|
async def get(self, id: UUID4) -> Optional[UD]:
|
||||||
"""Get a single user by id."""
|
"""Get a single user by id."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorCollection
|
from motor.motor_asyncio import AsyncIOMotorCollection
|
||||||
|
from pydantic import UUID4
|
||||||
|
|
||||||
from fastapi_users.db.base import BaseUserDatabase
|
from fastapi_users.db.base import BaseUserDatabase
|
||||||
from fastapi_users.models import UD
|
from fastapi_users.models import UD
|
||||||
@ -22,7 +23,7 @@ class MongoDBUserDatabase(BaseUserDatabase[UD]):
|
|||||||
self.collection.create_index("id", unique=True)
|
self.collection.create_index("id", unique=True)
|
||||||
self.collection.create_index("email", unique=True)
|
self.collection.create_index("email", unique=True)
|
||||||
|
|
||||||
async def get(self, id: str) -> Optional[UD]:
|
async def get(self, id: UUID4) -> Optional[UD]:
|
||||||
user = await self.collection.find_one({"id": id})
|
user = await self.collection.find_one({"id": id})
|
||||||
return self.user_db_model(**user) if user else None
|
return self.user_db_model(**user) if user else None
|
||||||
|
|
||||||
|
@ -1,19 +1,58 @@
|
|||||||
|
import uuid
|
||||||
from typing import Mapping, Optional, Type
|
from typing import Mapping, Optional, Type
|
||||||
|
|
||||||
from databases import Database
|
from databases import Database
|
||||||
|
from pydantic import UUID4
|
||||||
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select
|
from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Table, select
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.ext.declarative import declared_attr
|
from sqlalchemy.ext.declarative import declared_attr
|
||||||
|
from sqlalchemy.types import CHAR, TypeDecorator
|
||||||
|
|
||||||
from fastapi_users.db.base import BaseUserDatabase
|
from fastapi_users.db.base import BaseUserDatabase
|
||||||
from fastapi_users.models import UD
|
from fastapi_users.models import UD
|
||||||
|
|
||||||
|
|
||||||
|
class GUID(TypeDecorator): # pragma: no cover
|
||||||
|
"""Platform-independent GUID type.
|
||||||
|
|
||||||
|
Uses PostgreSQL's UUID type, otherwise uses
|
||||||
|
CHAR(36), storing as regular strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
impl = CHAR
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect):
|
||||||
|
if dialect.name == "postgresql":
|
||||||
|
return dialect.type_descriptor(UUID())
|
||||||
|
else:
|
||||||
|
return dialect.type_descriptor(CHAR(36))
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
elif dialect.name == "postgresql":
|
||||||
|
return str(value)
|
||||||
|
else:
|
||||||
|
if not isinstance(value, uuid.UUID):
|
||||||
|
return str(uuid.UUID(value))
|
||||||
|
else:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
if not isinstance(value, uuid.UUID):
|
||||||
|
value = uuid.UUID(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class SQLAlchemyBaseUserTable:
|
class SQLAlchemyBaseUserTable:
|
||||||
"""Base SQLAlchemy users table definition."""
|
"""Base SQLAlchemy users table definition."""
|
||||||
|
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(GUID, primary_key=True)
|
||||||
email = Column(String, unique=True, index=True, nullable=False)
|
email = Column(String, unique=True, index=True, nullable=False)
|
||||||
hashed_password = Column(String, nullable=False)
|
hashed_password = Column(String, nullable=False)
|
||||||
is_active = Column(Boolean, default=True, nullable=False)
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
@ -25,7 +64,7 @@ class SQLAlchemyBaseOAuthAccountTable:
|
|||||||
|
|
||||||
__tablename__ = "oauth_account"
|
__tablename__ = "oauth_account"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(GUID, primary_key=True)
|
||||||
oauth_name = Column(String, index=True, nullable=False)
|
oauth_name = Column(String, index=True, nullable=False)
|
||||||
access_token = Column(String, nullable=False)
|
access_token = Column(String, nullable=False)
|
||||||
expires_at = Column(Integer, nullable=False)
|
expires_at = Column(Integer, nullable=False)
|
||||||
@ -35,7 +74,7 @@ class SQLAlchemyBaseOAuthAccountTable:
|
|||||||
|
|
||||||
@declared_attr
|
@declared_attr
|
||||||
def user_id(cls):
|
def user_id(cls):
|
||||||
return Column(String, ForeignKey("user.id", ondelete="cascade"), nullable=False)
|
return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class NotSetOAuthAccountTableError(Exception):
|
class NotSetOAuthAccountTableError(Exception):
|
||||||
@ -75,7 +114,7 @@ class SQLAlchemyUserDatabase(BaseUserDatabase[UD]):
|
|||||||
self.users = users
|
self.users = users
|
||||||
self.oauth_accounts = oauth_accounts
|
self.oauth_accounts = oauth_accounts
|
||||||
|
|
||||||
async def get(self, id: str) -> Optional[UD]:
|
async def get(self, id: UUID4) -> Optional[UD]:
|
||||||
query = self.users.select().where(self.users.c.id == id)
|
query = self.users.select().where(self.users.c.id == id)
|
||||||
user = await self.database.fetch_one(query)
|
user = await self.database.fetch_one(query)
|
||||||
return await self._make_user(user) if user else None
|
return await self._make_user(user) if user else None
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from pydantic import UUID4
|
||||||
from tortoise import fields, models
|
from tortoise import fields, models
|
||||||
from tortoise.exceptions import DoesNotExist
|
from tortoise.exceptions import DoesNotExist
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ from fastapi_users.models import UD
|
|||||||
|
|
||||||
|
|
||||||
class TortoiseBaseUserModel(models.Model):
|
class TortoiseBaseUserModel(models.Model):
|
||||||
id = fields.CharField(pk=True, generated=False, max_length=255)
|
id = fields.UUIDField(pk=True, generated=False)
|
||||||
email = fields.CharField(index=True, unique=True, null=False, max_length=255)
|
email = fields.CharField(index=True, unique=True, null=False, max_length=255)
|
||||||
hashed_password = fields.CharField(null=False, max_length=255)
|
hashed_password = fields.CharField(null=False, max_length=255)
|
||||||
is_active = fields.BooleanField(default=True, null=False)
|
is_active = fields.BooleanField(default=True, null=False)
|
||||||
@ -27,7 +28,7 @@ class TortoiseBaseUserModel(models.Model):
|
|||||||
|
|
||||||
|
|
||||||
class TortoiseBaseOAuthAccountModel(models.Model):
|
class TortoiseBaseOAuthAccountModel(models.Model):
|
||||||
id = fields.CharField(pk=True, generated=False, max_length=255)
|
id = fields.UUIDField(pk=True, generated=False, max_length=255)
|
||||||
oauth_name = fields.CharField(null=False, max_length=255)
|
oauth_name = fields.CharField(null=False, max_length=255)
|
||||||
access_token = fields.CharField(null=False, max_length=255)
|
access_token = fields.CharField(null=False, max_length=255)
|
||||||
expires_at = fields.IntField(null=False)
|
expires_at = fields.IntField(null=False)
|
||||||
@ -61,7 +62,7 @@ class TortoiseUserDatabase(BaseUserDatabase[UD]):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.oauth_account_model = oauth_account_model
|
self.oauth_account_model = oauth_account_model
|
||||||
|
|
||||||
async def get(self, id: str) -> Optional[UD]:
|
async def get(self, id: UUID4) -> Optional[UD]:
|
||||||
try:
|
try:
|
||||||
query = self.model.get(id=id)
|
query = self.model.get(id=id)
|
||||||
|
|
||||||
|
@ -1,21 +1,20 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional, TypeVar
|
from typing import List, Optional, TypeVar
|
||||||
|
|
||||||
import pydantic
|
from pydantic import UUID4, BaseModel, EmailStr, validator
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
|
|
||||||
|
|
||||||
class BaseUser(BaseModel):
|
class BaseUser(BaseModel):
|
||||||
"""Base User model."""
|
"""Base User model."""
|
||||||
|
|
||||||
id: Optional[str] = None
|
id: Optional[UUID4] = None
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
is_active: Optional[bool] = True
|
is_active: Optional[bool] = True
|
||||||
is_superuser: Optional[bool] = False
|
is_superuser: Optional[bool] = False
|
||||||
|
|
||||||
@pydantic.validator("id", pre=True, always=True)
|
@validator("id", pre=True, always=True)
|
||||||
def default_id(cls, v):
|
def default_id(cls, v):
|
||||||
return v or str(uuid.uuid4())
|
return v or uuid.uuid4()
|
||||||
|
|
||||||
def create_update_dict(self):
|
def create_update_dict(self):
|
||||||
return self.dict(
|
return self.dict(
|
||||||
@ -37,7 +36,7 @@ class BaseUserUpdate(BaseUser):
|
|||||||
|
|
||||||
|
|
||||||
class BaseUserDB(BaseUser):
|
class BaseUserDB(BaseUser):
|
||||||
id: str
|
id: UUID4
|
||||||
hashed_password: str
|
hashed_password: str
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -50,7 +49,7 @@ UD = TypeVar("UD", bound=BaseUserDB)
|
|||||||
class BaseOAuthAccount(BaseModel):
|
class BaseOAuthAccount(BaseModel):
|
||||||
"""Base OAuth account model."""
|
"""Base OAuth account model."""
|
||||||
|
|
||||||
id: Optional[str] = None
|
id: Optional[UUID4] = None
|
||||||
oauth_name: str
|
oauth_name: str
|
||||||
access_token: str
|
access_token: str
|
||||||
expires_at: int
|
expires_at: int
|
||||||
@ -58,9 +57,9 @@ class BaseOAuthAccount(BaseModel):
|
|||||||
account_id: str
|
account_id: str
|
||||||
account_email: str
|
account_email: str
|
||||||
|
|
||||||
@pydantic.validator("id", pre=True, always=True)
|
@validator("id", pre=True, always=True)
|
||||||
def default_id(cls, v):
|
def default_id(cls, v):
|
||||||
return v or str(uuid.uuid4())
|
return v or uuid.uuid4()
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, Type, cast
|
|||||||
import jwt
|
import jwt
|
||||||
from fastapi import Body, Depends, HTTPException
|
from fastapi import Body, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from pydantic import EmailStr
|
from pydantic import UUID4, EmailStr
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
@ -69,7 +69,7 @@ def get_user_router(
|
|||||||
get_current_active_user = authenticator.get_current_active_user
|
get_current_active_user = authenticator.get_current_active_user
|
||||||
get_current_superuser = authenticator.get_current_superuser
|
get_current_superuser = authenticator.get_current_superuser
|
||||||
|
|
||||||
async def _get_or_404(id: str) -> models.BaseUserDB:
|
async def _get_or_404(id: UUID4) -> models.BaseUserDB:
|
||||||
user = await user_db.get(id)
|
user = await user_db.get(id)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||||
@ -118,7 +118,7 @@ def get_user_router(
|
|||||||
user = await user_db.get_by_email(email)
|
user = await user_db.get_by_email(email)
|
||||||
|
|
||||||
if user is not None and user.is_active:
|
if user is not None and user.is_active:
|
||||||
token_data = {"user_id": user.id, "aud": reset_password_token_audience}
|
token_data = {"user_id": str(user.id), "aud": reset_password_token_audience}
|
||||||
token = generate_jwt(
|
token = generate_jwt(
|
||||||
token_data,
|
token_data,
|
||||||
reset_password_token_lifetime_seconds,
|
reset_password_token_lifetime_seconds,
|
||||||
@ -146,7 +146,15 @@ def get_user_router(
|
|||||||
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = await user_db.get(user_id)
|
try:
|
||||||
|
user_uiid = UUID4(user_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=ErrorCode.RESET_PASSWORD_BAD_TOKEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await user_db.get(user_uiid)
|
||||||
if user is None or not user.is_active:
|
if user is None or not user.is_active:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@ -190,7 +198,7 @@ def get_user_router(
|
|||||||
response_model=user_model,
|
response_model=user_model,
|
||||||
dependencies=[Depends(get_current_superuser)],
|
dependencies=[Depends(get_current_superuser)],
|
||||||
)
|
)
|
||||||
async def get_user(id: str,):
|
async def get_user(id: UUID4):
|
||||||
return await _get_or_404(id)
|
return await _get_or_404(id)
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
@ -199,7 +207,7 @@ def get_user_router(
|
|||||||
dependencies=[Depends(get_current_superuser)],
|
dependencies=[Depends(get_current_superuser)],
|
||||||
)
|
)
|
||||||
async def update_user(
|
async def update_user(
|
||||||
id: str, updated_user: user_update_model, # type: ignore
|
id: UUID4, updated_user: user_update_model, # type: ignore
|
||||||
):
|
):
|
||||||
updated_user = cast(
|
updated_user = cast(
|
||||||
models.BaseUserUpdate, updated_user,
|
models.BaseUserUpdate, updated_user,
|
||||||
@ -213,7 +221,7 @@ def get_user_router(
|
|||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
dependencies=[Depends(get_current_superuser)],
|
dependencies=[Depends(get_current_superuser)],
|
||||||
)
|
)
|
||||||
async def delete_user(id: str):
|
async def delete_user(id: UUID4):
|
||||||
user = await _get_or_404(id)
|
user = await _get_or_404(id)
|
||||||
await user_db.delete(user)
|
await user_db.delete(user)
|
||||||
return None
|
return None
|
||||||
|
@ -45,3 +45,5 @@ nav:
|
|||||||
- usage/flow.md
|
- usage/flow.md
|
||||||
- usage/routes.md
|
- usage/routes.md
|
||||||
- usage/dependency-callables.md
|
- usage/dependency-callables.md
|
||||||
|
- Migration:
|
||||||
|
- migration/08_to_10.md
|
||||||
|
@ -8,6 +8,7 @@ from asgi_lifespan import LifespanManager
|
|||||||
from fastapi import Depends, FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from httpx_oauth.oauth2 import OAuth2
|
from httpx_oauth.oauth2 import OAuth2
|
||||||
|
from pydantic import UUID4
|
||||||
from starlette.applications import ASGIApp
|
from starlette.applications import ASGIApp
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
@ -58,16 +59,13 @@ def event_loop():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user() -> UserDB:
|
def user() -> UserDB:
|
||||||
return UserDB(
|
return UserDB(
|
||||||
id="aaa",
|
email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash,
|
||||||
email="king.arthur@camelot.bt",
|
|
||||||
hashed_password=guinevere_password_hash,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth:
|
def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth:
|
||||||
return UserDBOAuth(
|
return UserDBOAuth(
|
||||||
id="aaa",
|
|
||||||
email="king.arthur@camelot.bt",
|
email="king.arthur@camelot.bt",
|
||||||
hashed_password=guinevere_password_hash,
|
hashed_password=guinevere_password_hash,
|
||||||
oauth_accounts=[oauth_account1, oauth_account2],
|
oauth_accounts=[oauth_account1, oauth_account2],
|
||||||
@ -77,7 +75,6 @@ def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def inactive_user() -> UserDB:
|
def inactive_user() -> UserDB:
|
||||||
return UserDB(
|
return UserDB(
|
||||||
id="bbb",
|
|
||||||
email="percival@camelot.bt",
|
email="percival@camelot.bt",
|
||||||
hashed_password=angharad_password_hash,
|
hashed_password=angharad_password_hash,
|
||||||
is_active=False,
|
is_active=False,
|
||||||
@ -87,7 +84,6 @@ def inactive_user() -> UserDB:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def inactive_user_oauth(oauth_account3) -> UserDBOAuth:
|
def inactive_user_oauth(oauth_account3) -> UserDBOAuth:
|
||||||
return UserDBOAuth(
|
return UserDBOAuth(
|
||||||
id="bbb",
|
|
||||||
email="percival@camelot.bt",
|
email="percival@camelot.bt",
|
||||||
hashed_password=angharad_password_hash,
|
hashed_password=angharad_password_hash,
|
||||||
is_active=False,
|
is_active=False,
|
||||||
@ -98,7 +94,6 @@ def inactive_user_oauth(oauth_account3) -> UserDBOAuth:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def superuser() -> UserDB:
|
def superuser() -> UserDB:
|
||||||
return UserDB(
|
return UserDB(
|
||||||
id="ccc",
|
|
||||||
email="merlin@camelot.bt",
|
email="merlin@camelot.bt",
|
||||||
hashed_password=viviane_password_hash,
|
hashed_password=viviane_password_hash,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
@ -108,7 +103,6 @@ def superuser() -> UserDB:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def superuser_oauth() -> UserDBOAuth:
|
def superuser_oauth() -> UserDBOAuth:
|
||||||
return UserDBOAuth(
|
return UserDBOAuth(
|
||||||
id="ccc",
|
|
||||||
email="merlin@camelot.bt",
|
email="merlin@camelot.bt",
|
||||||
hashed_password=viviane_password_hash,
|
hashed_password=viviane_password_hash,
|
||||||
is_superuser=True,
|
is_superuser=True,
|
||||||
@ -119,7 +113,6 @@ def superuser_oauth() -> UserDBOAuth:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth_account1() -> BaseOAuthAccount:
|
def oauth_account1() -> BaseOAuthAccount:
|
||||||
return BaseOAuthAccount(
|
return BaseOAuthAccount(
|
||||||
id="aaa",
|
|
||||||
oauth_name="service1",
|
oauth_name="service1",
|
||||||
access_token="TOKEN",
|
access_token="TOKEN",
|
||||||
expires_at=1579000751,
|
expires_at=1579000751,
|
||||||
@ -131,7 +124,6 @@ def oauth_account1() -> BaseOAuthAccount:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth_account2() -> BaseOAuthAccount:
|
def oauth_account2() -> BaseOAuthAccount:
|
||||||
return BaseOAuthAccount(
|
return BaseOAuthAccount(
|
||||||
id="bbb",
|
|
||||||
oauth_name="service2",
|
oauth_name="service2",
|
||||||
access_token="TOKEN",
|
access_token="TOKEN",
|
||||||
expires_at=1579000751,
|
expires_at=1579000751,
|
||||||
@ -143,7 +135,6 @@ def oauth_account2() -> BaseOAuthAccount:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def oauth_account3() -> BaseOAuthAccount:
|
def oauth_account3() -> BaseOAuthAccount:
|
||||||
return BaseOAuthAccount(
|
return BaseOAuthAccount(
|
||||||
id="ccc",
|
|
||||||
oauth_name="service3",
|
oauth_name="service3",
|
||||||
access_token="TOKEN",
|
access_token="TOKEN",
|
||||||
expires_at=1579000751,
|
expires_at=1579000751,
|
||||||
@ -155,7 +146,7 @@ def oauth_account3() -> BaseOAuthAccount:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
|
def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase:
|
||||||
class MockUserDatabase(BaseUserDatabase[UserDB]):
|
class MockUserDatabase(BaseUserDatabase[UserDB]):
|
||||||
async def get(self, id: str) -> Optional[UserDB]:
|
async def get(self, id: UUID4) -> Optional[UserDB]:
|
||||||
if id == user.id:
|
if id == user.id:
|
||||||
return user
|
return user
|
||||||
if id == inactive_user.id:
|
if id == inactive_user.id:
|
||||||
@ -190,7 +181,7 @@ def mock_user_db_oauth(
|
|||||||
user_oauth, inactive_user_oauth, superuser_oauth
|
user_oauth, inactive_user_oauth, superuser_oauth
|
||||||
) -> BaseUserDatabase:
|
) -> BaseUserDatabase:
|
||||||
class MockUserDatabase(BaseUserDatabase[UserDBOAuth]):
|
class MockUserDatabase(BaseUserDatabase[UserDBOAuth]):
|
||||||
async def get(self, id: str) -> Optional[UserDBOAuth]:
|
async def get(self, id: UUID4) -> Optional[UserDBOAuth]:
|
||||||
if id == user_oauth.id:
|
if id == user_oauth.id:
|
||||||
return user_oauth
|
return user_oauth
|
||||||
if id == inactive_user_oauth.id:
|
if id == inactive_user_oauth.id:
|
||||||
@ -246,7 +237,11 @@ class MockAuthentication(BaseAuthentication):
|
|||||||
async def __call__(self, request: Request, user_db: BaseUserDatabase):
|
async def __call__(self, request: Request, user_db: BaseUserDatabase):
|
||||||
token = await self.scheme.__call__(request)
|
token = await self.scheme.__call__(request)
|
||||||
if token is not None:
|
if token is not None:
|
||||||
return await user_db.get(token)
|
try:
|
||||||
|
token_uuid = UUID4(token)
|
||||||
|
return await user_db.get(token_uuid)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||||
|
@ -31,7 +31,7 @@ def token():
|
|||||||
def _token(user=None, lifetime=LIFETIME):
|
def _token(user=None, lifetime=LIFETIME):
|
||||||
data = {"aud": "fastapi-users:auth"}
|
data = {"aud": "fastapi-users:auth"}
|
||||||
if user is not None:
|
if user is not None:
|
||||||
data["user_id"] = user.id
|
data["user_id"] = str(user.id)
|
||||||
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
||||||
|
|
||||||
return _token
|
return _token
|
||||||
@ -131,7 +131,7 @@ async def test_get_login_response(
|
|||||||
decoded = jwt.decode(
|
decoded = jwt.decode(
|
||||||
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
cookie_value, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
||||||
)
|
)
|
||||||
assert decoded["user_id"] == user.id
|
assert decoded["user_id"] == str(user.id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.authentication
|
@pytest.mark.authentication
|
||||||
@ -149,4 +149,4 @@ async def test_get_logout_response(user):
|
|||||||
|
|
||||||
cookie = cookies[0][1].decode("latin-1")
|
cookie = cookies[0][1].decode("latin-1")
|
||||||
|
|
||||||
assert f"Max-Age=0" in cookie
|
assert "Max-Age=0" in cookie
|
||||||
|
@ -17,10 +17,10 @@ def jwt_authentication():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def token():
|
def token():
|
||||||
def _token(user=None, lifetime=LIFETIME):
|
def _token(user_id=None, lifetime=LIFETIME):
|
||||||
data = {"aud": "fastapi-users:auth"}
|
data = {"aud": "fastapi-users:auth"}
|
||||||
if user is not None:
|
if user_id is not None:
|
||||||
data["user_id"] = user.id
|
data["user_id"] = str(user_id)
|
||||||
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
||||||
|
|
||||||
return _token
|
return _token
|
||||||
@ -57,11 +57,19 @@ class TestAuthenticate:
|
|||||||
authenticated_user = await jwt_authentication(request, mock_user_db)
|
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||||
assert authenticated_user is None
|
assert authenticated_user is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_token_invalid_uuid(
|
||||||
|
self, jwt_authentication, mock_user_db, request_builder, token
|
||||||
|
):
|
||||||
|
request = request_builder(headers={"Authorization": f"Bearer {token('foo')}"})
|
||||||
|
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||||
|
assert authenticated_user is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_valid_token(
|
async def test_valid_token(
|
||||||
self, jwt_authentication, mock_user_db, request_builder, token, user
|
self, jwt_authentication, mock_user_db, request_builder, token, user
|
||||||
):
|
):
|
||||||
request = request_builder(headers={"Authorization": f"Bearer {token(user)}"})
|
request = request_builder(headers={"Authorization": f"Bearer {token(user.id)}"})
|
||||||
authenticated_user = await jwt_authentication(request, mock_user_db)
|
authenticated_user = await jwt_authentication(request, mock_user_db)
|
||||||
assert authenticated_user.id == user.id
|
assert authenticated_user.id == user.id
|
||||||
|
|
||||||
@ -77,7 +85,7 @@ async def test_get_login_response(jwt_authentication, user):
|
|||||||
decoded = jwt.decode(
|
decoded = jwt.decode(
|
||||||
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
token, SECRET, audience="fastapi-users:auth", algorithms=[JWT_ALGORITHM]
|
||||||
)
|
)
|
||||||
assert decoded["user_id"] == user.id
|
assert decoded["user_id"] == str(user.id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.authentication
|
@pytest.mark.authentication
|
||||||
|
@ -15,7 +15,9 @@ def get_mongodb_user_db():
|
|||||||
user_model,
|
user_model,
|
||||||
) -> AsyncGenerator[MongoDBUserDatabase, None]:
|
) -> AsyncGenerator[MongoDBUserDatabase, None]:
|
||||||
client = motor.motor_asyncio.AsyncIOMotorClient(
|
client = motor.motor_asyncio.AsyncIOMotorClient(
|
||||||
"mongodb://localhost:27017", serverSelectionTimeoutMS=100
|
"mongodb://localhost:27017",
|
||||||
|
serverSelectionTimeoutMS=100,
|
||||||
|
uuidRepresentation="standard",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -50,9 +52,7 @@ async def mongodb_user_db_oauth(get_mongodb_user_db):
|
|||||||
@pytest.mark.db
|
@pytest.mark.db
|
||||||
async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
||||||
user = UserDB(
|
user = UserDB(
|
||||||
id="111",
|
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||||
email="lancelot@camelot.bt",
|
|
||||||
hashed_password=get_password_hash("guinevere"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create
|
# Create
|
||||||
@ -96,7 +96,6 @@ async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
|||||||
async def test_queries_custom_fields(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
async def test_queries_custom_fields(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
||||||
"""It should output custom fields in query result."""
|
"""It should output custom fields in query result."""
|
||||||
user = UserDB(
|
user = UserDB(
|
||||||
id="111",
|
|
||||||
email="lancelot@camelot.bt",
|
email="lancelot@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
first_name="Lancelot",
|
first_name="Lancelot",
|
||||||
@ -117,7 +116,6 @@ async def test_queries_oauth(
|
|||||||
oauth_account2,
|
oauth_account2,
|
||||||
):
|
):
|
||||||
user = UserDBOAuth(
|
user = UserDBOAuth(
|
||||||
id="111",
|
|
||||||
email="lancelot@camelot.bt",
|
email="lancelot@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
oauth_accounts=[oauth_account1, oauth_account2],
|
oauth_accounts=[oauth_account1, oauth_account2],
|
||||||
|
@ -70,9 +70,7 @@ async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, N
|
|||||||
@pytest.mark.db
|
@pytest.mark.db
|
||||||
async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
|
async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
|
||||||
user = UserDB(
|
user = UserDB(
|
||||||
id="111",
|
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||||
email="lancelot@camelot.bt",
|
|
||||||
hashed_password=get_password_hash("guinevere"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create
|
# Create
|
||||||
@ -103,7 +101,7 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
|
|||||||
|
|
||||||
# Exception when inserting non-nullable fields
|
# Exception when inserting non-nullable fields
|
||||||
with pytest.raises(sqlite3.IntegrityError):
|
with pytest.raises(sqlite3.IntegrityError):
|
||||||
wrong_user = UserDB(id="222", hashed_password="aaa")
|
wrong_user = UserDB(hashed_password="aaa")
|
||||||
await sqlalchemy_user_db.create(wrong_user)
|
await sqlalchemy_user_db.create(wrong_user)
|
||||||
|
|
||||||
# Unknown user
|
# Unknown user
|
||||||
@ -117,9 +115,7 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
|
|||||||
|
|
||||||
# Exception when creating/updating a OAuth user
|
# Exception when creating/updating a OAuth user
|
||||||
user_oauth = UserDBOAuth(
|
user_oauth = UserDBOAuth(
|
||||||
id="222",
|
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||||
email="lancelot@camelot.bt",
|
|
||||||
hashed_password=get_password_hash("guinevere"),
|
|
||||||
)
|
)
|
||||||
with pytest.raises(NotSetOAuthAccountTableError):
|
with pytest.raises(NotSetOAuthAccountTableError):
|
||||||
await sqlalchemy_user_db.create(user_oauth)
|
await sqlalchemy_user_db.create(user_oauth)
|
||||||
@ -138,7 +134,6 @@ async def test_queries_custom_fields(
|
|||||||
):
|
):
|
||||||
"""It should output custom fields in query result."""
|
"""It should output custom fields in query result."""
|
||||||
user = UserDB(
|
user = UserDB(
|
||||||
id="111",
|
|
||||||
email="lancelot@camelot.bt",
|
email="lancelot@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
first_name="Lancelot",
|
first_name="Lancelot",
|
||||||
@ -159,7 +154,6 @@ async def test_queries_oauth(
|
|||||||
oauth_account2,
|
oauth_account2,
|
||||||
):
|
):
|
||||||
user = UserDBOAuth(
|
user = UserDBOAuth(
|
||||||
id="111",
|
|
||||||
email="lancelot@camelot.bt",
|
email="lancelot@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
oauth_accounts=[oauth_account1, oauth_account2],
|
oauth_accounts=[oauth_account1, oauth_account2],
|
||||||
|
@ -55,9 +55,7 @@ async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None]
|
|||||||
@pytest.mark.db
|
@pytest.mark.db
|
||||||
async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
||||||
user = UserDB(
|
user = UserDB(
|
||||||
id="111",
|
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||||
email="lancelot@camelot.bt",
|
|
||||||
hashed_password=get_password_hash("guinevere"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create
|
# Create
|
||||||
@ -88,7 +86,7 @@ async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
|||||||
|
|
||||||
# Exception when inserting non-nullable fields
|
# Exception when inserting non-nullable fields
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
wrong_user = UserDB(id="222", hashed_password="aaa")
|
wrong_user = UserDB(hashed_password="aaa")
|
||||||
await tortoise_user_db.create(wrong_user)
|
await tortoise_user_db.create(wrong_user)
|
||||||
|
|
||||||
# Unknown user
|
# Unknown user
|
||||||
@ -106,7 +104,6 @@ async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
|||||||
async def test_queries_custom_fields(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
async def test_queries_custom_fields(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
||||||
"""It should output custom fields in query result."""
|
"""It should output custom fields in query result."""
|
||||||
user = UserDB(
|
user = UserDB(
|
||||||
id="111",
|
|
||||||
email="lancelot@camelot.bt",
|
email="lancelot@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
first_name="Lancelot",
|
first_name="Lancelot",
|
||||||
@ -127,7 +124,6 @@ async def test_queries_oauth(
|
|||||||
oauth_account2,
|
oauth_account2,
|
||||||
):
|
):
|
||||||
user = UserDBOAuth(
|
user = UserDBOAuth(
|
||||||
id="111",
|
|
||||||
email="lancelot@camelot.bt",
|
email="lancelot@camelot.bt",
|
||||||
hashed_password=get_password_hash("guinevere"),
|
hashed_password=get_password_hash("guinevere"),
|
||||||
oauth_accounts=[oauth_account1, oauth_account2],
|
oauth_accounts=[oauth_account1, oauth_account2],
|
||||||
|
@ -204,7 +204,7 @@ class TestCallback:
|
|||||||
user_update_mock.assert_awaited_once()
|
user_update_mock.assert_awaited_once()
|
||||||
data = cast(Dict[str, Any], response.json())
|
data = cast(Dict[str, Any], response.json())
|
||||||
|
|
||||||
assert data["token"] == user_oauth.id
|
assert data["token"] == str(user_oauth.id)
|
||||||
|
|
||||||
assert event_handler.called is False
|
assert event_handler.called is False
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ class TestCallback:
|
|||||||
user_update_mock.assert_awaited_once()
|
user_update_mock.assert_awaited_once()
|
||||||
data = cast(Dict[str, Any], response.json())
|
data = cast(Dict[str, Any], response.json())
|
||||||
|
|
||||||
assert data["token"] == superuser_oauth.id
|
assert data["token"] == str(superuser_oauth.id)
|
||||||
|
|
||||||
assert event_handler.called is False
|
assert event_handler.called is False
|
||||||
|
|
||||||
@ -283,7 +283,7 @@ class TestCallback:
|
|||||||
|
|
||||||
assert event_handler.called is True
|
assert event_handler.called is True
|
||||||
actual_user = event_handler.call_args[0][0]
|
actual_user = event_handler.call_args[0][0]
|
||||||
assert actual_user.id == data["token"]
|
assert str(actual_user.id) == data["token"]
|
||||||
request = event_handler.call_args[0][1]
|
request = event_handler.call_args[0][1]
|
||||||
assert isinstance(request, Request)
|
assert isinstance(request, Request)
|
||||||
|
|
||||||
@ -348,4 +348,4 @@ class TestCallback:
|
|||||||
)
|
)
|
||||||
data = cast(Dict[str, Any], response.json())
|
data = cast(Dict[str, Any], response.json())
|
||||||
|
|
||||||
assert data["token"] == user_oauth.id
|
assert data["token"] == str(user_oauth.id)
|
||||||
|
@ -23,7 +23,7 @@ def forgot_password_token():
|
|||||||
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
|
def _forgot_password_token(user_id=None, lifetime=LIFETIME):
|
||||||
data = {"aud": "fastapi-users:reset"}
|
data = {"aud": "fastapi-users:reset"}
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
data["user_id"] = user_id
|
data["user_id"] = str(user_id)
|
||||||
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM)
|
||||||
|
|
||||||
return _forgot_password_token
|
return _forgot_password_token
|
||||||
@ -117,7 +117,7 @@ class TestRegister:
|
|||||||
assert data["id"] is not None
|
assert data["id"] is not None
|
||||||
|
|
||||||
actual_user = event_handler.call_args[0][0]
|
actual_user = event_handler.call_args[0][0]
|
||||||
assert actual_user.id == data["id"]
|
assert str(actual_user.id) == data["id"]
|
||||||
request = event_handler.call_args[0][1]
|
request = event_handler.call_args[0][1]
|
||||||
assert isinstance(request, Request)
|
assert isinstance(request, Request)
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class TestLogin:
|
|||||||
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
data = {"username": "king.arthur@camelot.bt", "password": "guinevere"}
|
||||||
response = await test_app_client.post(path, data=data)
|
response = await test_app_client.post(path, data=data)
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert response.json() == {"token": user.id}
|
assert response.json() == {"token": str(user.id)}
|
||||||
|
|
||||||
async def test_inactive_user(self, path, test_app_client: httpx.AsyncClient):
|
async def test_inactive_user(self, path, test_app_client: httpx.AsyncClient):
|
||||||
data = {"username": "percival@camelot.bt", "password": "angharad"}
|
data = {"username": "percival@camelot.bt", "password": "angharad"}
|
||||||
@ -261,7 +261,7 @@ class TestForgotPassword:
|
|||||||
audience="fastapi-users:reset",
|
audience="fastapi-users:reset",
|
||||||
algorithms=[JWT_ALGORITHM],
|
algorithms=[JWT_ALGORITHM],
|
||||||
)
|
)
|
||||||
assert decoded_token["user_id"] == user.id
|
assert decoded_token["user_id"] == str(user.id)
|
||||||
request = event_handler.call_args[0][2]
|
request = event_handler.call_args[0][2]
|
||||||
assert isinstance(request, Request)
|
assert isinstance(request, Request)
|
||||||
|
|
||||||
@ -306,6 +306,22 @@ class TestResetPassword:
|
|||||||
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||||
assert mock_user_db.update.called is False
|
assert mock_user_db.update.called is False
|
||||||
|
|
||||||
|
async def test_valid_token_invalid_uuid(
|
||||||
|
self,
|
||||||
|
mocker,
|
||||||
|
mock_user_db,
|
||||||
|
test_app_client: httpx.AsyncClient,
|
||||||
|
forgot_password_token,
|
||||||
|
):
|
||||||
|
mocker.spy(mock_user_db, "update")
|
||||||
|
|
||||||
|
json = {"token": forgot_password_token("foo"), "password": "holygrail"}
|
||||||
|
response = await test_app_client.post("/reset-password", json=json)
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
data = cast(Dict[str, Any], response.json())
|
||||||
|
assert data["detail"] == ErrorCode.RESET_PASSWORD_BAD_TOKEN
|
||||||
|
assert mock_user_db.update.called is False
|
||||||
|
|
||||||
async def test_inactive_user(
|
async def test_inactive_user(
|
||||||
self,
|
self,
|
||||||
mocker,
|
mocker,
|
||||||
@ -368,7 +384,7 @@ class TestMe:
|
|||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
data = cast(Dict[str, Any], response.json())
|
data = cast(Dict[str, Any], response.json())
|
||||||
assert data["id"] == user.id
|
assert data["id"] == str(user.id)
|
||||||
assert data["email"] == user.email
|
assert data["email"] == user.email
|
||||||
|
|
||||||
|
|
||||||
@ -504,12 +520,13 @@ class TestUpdateMe:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestGetUser:
|
class TestGetUser:
|
||||||
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
|
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
|
||||||
response = await test_app_client.get("/000")
|
response = await test_app_client.get("/d35d213e-f3d8-4f08-954a-7e0d1bea286f")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
|
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
|
||||||
response = await test_app_client.get(
|
response = await test_app_client.get(
|
||||||
"/000", headers={"Authorization": f"Bearer {user.id}"}
|
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
|
||||||
|
headers={"Authorization": f"Bearer {user.id}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
@ -517,7 +534,8 @@ class TestGetUser:
|
|||||||
self, test_app_client: httpx.AsyncClient, superuser: UserDB
|
self, test_app_client: httpx.AsyncClient, superuser: UserDB
|
||||||
):
|
):
|
||||||
response = await test_app_client.get(
|
response = await test_app_client.get(
|
||||||
"/000", headers={"Authorization": f"Bearer {superuser.id}"}
|
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
|
||||||
|
headers={"Authorization": f"Bearer {superuser.id}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
@ -530,7 +548,7 @@ class TestGetUser:
|
|||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
data = cast(Dict[str, Any], response.json())
|
data = cast(Dict[str, Any], response.json())
|
||||||
assert data["id"] == user.id
|
assert data["id"] == str(user.id)
|
||||||
assert "hashed_password" not in data
|
assert "hashed_password" not in data
|
||||||
|
|
||||||
|
|
||||||
@ -538,12 +556,13 @@ class TestGetUser:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestUpdateUser:
|
class TestUpdateUser:
|
||||||
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
|
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
|
||||||
response = await test_app_client.patch("/000")
|
response = await test_app_client.patch("/d35d213e-f3d8-4f08-954a-7e0d1bea286f")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
|
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
|
||||||
response = await test_app_client.patch(
|
response = await test_app_client.patch(
|
||||||
"/000", headers={"Authorization": f"Bearer {user.id}"}
|
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
|
||||||
|
headers={"Authorization": f"Bearer {user.id}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
@ -551,7 +570,9 @@ class TestUpdateUser:
|
|||||||
self, test_app_client: httpx.AsyncClient, superuser: UserDB
|
self, test_app_client: httpx.AsyncClient, superuser: UserDB
|
||||||
):
|
):
|
||||||
response = await test_app_client.patch(
|
response = await test_app_client.patch(
|
||||||
"/000", json={}, headers={"Authorization": f"Bearer {superuser.id}"}
|
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
|
||||||
|
json={},
|
||||||
|
headers={"Authorization": f"Bearer {superuser.id}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
@ -636,12 +657,13 @@ class TestUpdateUser:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestDeleteUser:
|
class TestDeleteUser:
|
||||||
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
|
async def test_missing_token(self, test_app_client: httpx.AsyncClient):
|
||||||
response = await test_app_client.delete("/000")
|
response = await test_app_client.delete("/d35d213e-f3d8-4f08-954a-7e0d1bea286f")
|
||||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||||
|
|
||||||
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
|
async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB):
|
||||||
response = await test_app_client.delete(
|
response = await test_app_client.delete(
|
||||||
"/000", headers={"Authorization": f"Bearer {user.id}"}
|
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
|
||||||
|
headers={"Authorization": f"Bearer {user.id}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||||
|
|
||||||
@ -649,7 +671,8 @@ class TestDeleteUser:
|
|||||||
self, test_app_client: httpx.AsyncClient, superuser: UserDB
|
self, test_app_client: httpx.AsyncClient, superuser: UserDB
|
||||||
):
|
):
|
||||||
response = await test_app_client.delete(
|
response = await test_app_client.delete(
|
||||||
"/000", headers={"Authorization": f"Bearer {superuser.id}"}
|
"/d35d213e-f3d8-4f08-954a-7e0d1bea286f",
|
||||||
|
headers={"Authorization": f"Bearer {superuser.id}"},
|
||||||
)
|
)
|
||||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user