diff --git a/README.md b/README.md index 9778c10d..7bf5d796 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,8 @@ Add quickly a registration and authentication system to your [FastAPI](https://f * [X] Dependency callables to inject current user in route * [X] Pluggable password validation * [X] Customizable database backend - * [X] [SQLAlchemy ORM async](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) backend included - * [X] MongoDB async backend included thanks to [mongodb/motor](https://github.com/mongodb/motor) - * [X] [Tortoise ORM](https://tortoise-orm.readthedocs.io/en/latest/) backend included - * [X] [ormar](https://collerek.github.io/ormar/) backend included + * [X] [SQLAlchemy ORM async](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html) included + * [X] [MongoDB with Beanie ODM](https://github.com/roman-right/beanie/) included * [X] Multiple customizable authentication backends * [X] Transports: Authorization header, Cookie * [X] Strategies: JWT, Database, Redis diff --git a/docs/configuration/authentication/strategies/database.md b/docs/configuration/authentication/strategies/database.md index 58bd5972..dca79964 100644 --- a/docs/configuration/authentication/strategies/database.md +++ b/docs/configuration/authentication/strategies/database.md @@ -4,68 +4,72 @@ The most natural way for storing tokens is of course the very same database you' ## Configuration -The configuration of this strategy is a bit more complex than the others as it requires you to configure models and a database adapter, [exactly like we did for users](../../overview.md#database-adapters). +The configuration of this strategy is a bit more complex than the others as it requires you to configure models and a database adapter, [exactly like we did for users](../../overview.md#user-model-and-database-adapters). -### Model +### Database adapters -You should define an `AccessToken` Pydantic model inheriting from `BaseAccessToken`. - -```py -from fastapi_users.authentication.strategy.db import BaseAccessToken - - -class AccessToken(BaseAccessToken): - pass -``` - -It is structured like this: +An access token will be structured like this in your database: * `token` (`str`) – Unique identifier of the token. It's generated automatically upon login by the strategy. -* `user_id` (`UUID4`) – User id. of the user associated to this token. +* `user_id` (`ID`) – User id. of the user associated to this token. * `created_at` (`datetime`) – Date and time of creation of the token. It's used to determine if the token is expired or not. -### Database adapter +We are providing a base model with those fields for each database we are supporting. -=== "SQLAlchemy" +#### SQLAlchemy - ```py hl_lines="5-8 13 23-24 45-46" - --8<-- "docs/src/db_sqlalchemy_access_tokens.py" +We'll expand from the basic SQLAlchemy configuration. + +```py hl_lines="5-8 21-22 43-46" +--8<-- "docs/src/db_sqlalchemy_access_tokens.py" +``` + +1. We define an `AccessToken` ORM model inheriting from `SQLAlchemyBaseAccessTokenTableUUID`. + +2. We define a dependency to instantiate the `SQLAlchemyAccessTokenDatabase` class. Just like the user database adapter, it expects a fresh SQLAlchemy session and the `AccessToken` model class we defined above. + +!!! tip "`user_id` foreign key is defined as UUID" + By default, we use UUID as a primary key ID for your user, so we follow the same convention to define the foreign key pointing to the user. + + If you want to use another type, like an auto-incremented integer, you can use `SQLAlchemyBaseAccessTokenTable` as base class and define your own `user_id` column. + + ```py + class AccessToken(SQLAlchemyBaseAccessTokenTable[int], Base): + @declared_attr + def user_id(cls): + return Column(Integer, ForeignKey("user.id", ondelete="cascade"), nullable=False) ``` -=== "Tortoise ORM" + Notice that `SQLAlchemyBaseAccessTokenTable` expects a generic type to define the actual type of ID you use. - With Tortoise ORM, you need to define a proper Tortoise model for `AccessToken` and manually specify the user foreign key. Besides, you need to modify the Pydantic model a bit so that it works well with this Tortoise model. +#### Beanie - === "model.py" - ```py hl_lines="2 4 31-38" - --8<-- "docs/src/db_tortoise_access_tokens_model.py" - ``` +We'll expand from the basic Beanie configuration. - === "adapter.py" - ```py hl_lines="2 4 13-14" - --8<-- "docs/src/db_tortoise_access_tokens_adapter.py" - ``` +```py hl_lines="4-7 20-21 28-29" +--8<-- "docs/src/db_beanie_access_tokens.py" +``` -=== "MongoDB" +1. We define an `AccessToken` ODM model inheriting from `BeanieBaseAccessToken`. Notice that we set a generic type to define the type of the `user_id` reference. By default, it's a standard MongoDB ObjectID. - ```py hl_lines="3 5 13 20-21" - --8<-- "docs/src/db_mongodb_access_tokens.py" - ``` +2. We define a dependency to instantiate the `BeanieAccessTokenDatabase` class. Just like the user database adapter, it expects the `AccessToken` model class we defined above. ### Strategy ```py +import uuid + from fastapi import Depends from fastapi_users.authentication.strategy.db import AccessTokenDatabase, DatabaseStrategy -from .models import AccessToken, UserCreate, UserDB +from .db import AccessToken, User def get_database_strategy( access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db), -) -> DatabaseStrategy[UserCreate, UserDB, AccessToken]: +) -> DatabaseStrategy: return DatabaseStrategy(access_token_db, lifetime_seconds=3600) ``` diff --git a/docs/configuration/databases/beanie.md b/docs/configuration/databases/beanie.md new file mode 100644 index 00000000..fac5086c --- /dev/null +++ b/docs/configuration/databases/beanie.md @@ -0,0 +1,75 @@ +# Beanie + +**FastAPI Users** provides the necessary tools to work with MongoDB databases using the [Beanie ODM](https://github.com/roman-right/beanie). + +## Setup database connection and collection + +The first thing to do is to create a MongoDB connection using [mongodb/motor](https://github.com/mongodb/motor) (automatically installed with Beanie). + +```py hl_lines="5-9" +--8<-- "docs/src/db_beanie.py" +``` + +You can choose any name for the database. + +## Create the User model + +As for any Beanie ODM model, we'll create a `User` model. + +```py hl_lines="12-13" +--8<-- "docs/src/db_beanie.py" +``` + +As you can see, **FastAPI Users** provides a base class that will include base fields for our `User` table. You can of course add you own fields there to fit to your needs! + +!!! tip "Document ID is a MongoDB ObjectID" + Beanie [automatically manages document ID](https://roman-right.github.io/beanie/tutorial/defining-a-document/#id) by encoding/decoding MongoDB ObjectID. + + If you want to use another type, like UUID, you can override the `id` field: + + ```py + import uuid + + from pydantic import Field + + + class User(BeanieBaseUser[uuid.UUID]): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + ``` + + Notice that `BeanieBaseUser` expects a generic type to define the actual type of ID you use. + +!!! info + The base class is configured to automatically create a [unique index](https://roman-right.github.io/beanie/tutorial/defining-a-document/#indexes) on `id` and `email`. + +## Create the database adapter + +The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. It should be generated by a FastAPI dependency. + +```py hl_lines="16-17" +--8<-- "docs/src/db_beanie.py" +``` + +Notice that we pass a reference to the `User` model we defined above. + +## Initialize Beanie + +When initializing your FastAPI app, it's important that you [**initialize Beanie**](https://roman-right.github.io/beanie/tutorial/initialization/) so it can discover your models. We can achieve this using a startup event handler on the FastAPI app: + +```py +from beanie import init_beanie + + +@app.on_event("startup") +async def on_startup(): + await init_beanie( + database=db, # (1)! + document_models=[ + User, # (2)! + ], + ) +``` + +1. This is the `db` Motor database instance we defined above. + +2. This is the Beanie `User` model we defined above. Don't forget to also add your very own models! diff --git a/docs/configuration/databases/mongodb.md b/docs/configuration/databases/mongodb.md deleted file mode 100644 index d4bfe4a1..00000000 --- a/docs/configuration/databases/mongodb.md +++ /dev/null @@ -1,32 +0,0 @@ -# MongoDB - -**FastAPI Users** provides the necessary tools to work with MongoDB databases thanks to [mongodb/motor](https://github.com/mongodb/motor) package for full async support. - -## Setup database connection and collection - -Let's create a MongoDB connection and instantiate a collection. - -```py hl_lines="6 7 8 9 10 11" ---8<-- "docs/src/db_mongodb.py" -``` - -You can choose any name for the database and the collection. - -!!! warning - You may have noticed the `uuidRepresentation` parameter. It controls how the UUID values will be encoded in the database. By default, it's set to `pythonLegacy` but new applications should consider setting this to `standard` for cross language compatibility. [Read more about this](https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient). - -## Create the database adapter - -The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. It should be generated by a FastAPI dependency. - -```py hl_lines="14 15" ---8<-- "docs/src/db_mongodb.py" -``` - -Notice that we pass a reference to your [`UserDB` model](../models.md). - -!!! info - The database adapter will automatically create a [unique index](https://docs.mongodb.com/manual/core/index-unique/) on `id` and `email`. - -!!! warning - **FastAPI Users** will use its defined [`id` UUID](../models.md) as unique identifier for the user, rather than the builtin MongoDB `_id`. diff --git a/docs/configuration/databases/ormar.md b/docs/configuration/databases/ormar.md deleted file mode 100644 index 4538395c..00000000 --- a/docs/configuration/databases/ormar.md +++ /dev/null @@ -1,49 +0,0 @@ -# Ormar - -**FastAPI Users** provides the necessary tools to work with ormar. - -## Installation - -Install the database driver that corresponds to your DBMS: - -```sh -pip install asyncpg psycopg2 -``` - -```sh -pip install aiomysql pymysql -``` - -```sh -pip install aiosqlite -``` - -For the sake of this tutorial from now on, we'll use a simple SQLite database. - -## Setup User table - -Let's declare our User ORM model. - -```py hl_lines="12-16" ---8<-- "docs/src/db_ormar.py" -``` - -As you can see, **FastAPI Users** provides an abstract model that will -include base fields for our User table. You can of course add you own fields -there to fit to your needs! - -## Create the database adapter - -The database adapter of **FastAPI Users** makes the link between your -database configuration and the users logic. It should be generated by a FastAPI dependency. - -```py hl_lines="23-24" ---8<-- "docs/src/db_ormar.py" -``` - -Notice that we pass a reference to your [`UserDB` model](../models.md). - -!!! warning - In production, it's strongly recommended to setup a migration system to - update your SQL schemas. See - [Alembic](https://alembic.sqlalchemy.org/en/latest/). diff --git a/docs/configuration/databases/sqlalchemy.md b/docs/configuration/databases/sqlalchemy.md index ccfc6d36..bdf18728 100644 --- a/docs/configuration/databases/sqlalchemy.md +++ b/docs/configuration/databases/sqlalchemy.md @@ -2,9 +2,6 @@ **FastAPI Users** provides the necessary tools to work with SQL databases thanks to [SQLAlchemy ORM with asyncio](https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html). -!!! warning - The previous adapter using `encode/databases` is now deprecated but can still be installed using `fastapi-users[sqlalchemy]`. - ## Asynchronous driver To work with your DBMS, you'll need to install the corresponding asyncio driver. The common choices are: @@ -14,21 +11,31 @@ To work with your DBMS, you'll need to install the corresponding asyncio driver. For the sake of this tutorial from now on, we'll use a simple SQLite databse. -## Setup User table +## Create the User model -Let's declare our SQLAlchemy `User` table. +As for any SQLAlchemy ORM model, we'll create a `User` model. -```py hl_lines="15 16" +```py hl_lines="13-14" --8<-- "docs/src/db_sqlalchemy.py" ``` -As you can see, **FastAPI Users** provides a mixin that will include base fields for our `User` table. You can of course add you own fields there to fit to your needs! +As you can see, **FastAPI Users** provides a base class that will include base fields for our `User` table. You can of course add you own fields there to fit to your needs! + +!!! tip "Primary key is defined as UUID" + By default, we use UUID as a primary key ID for your user. If you want to use another type, like an auto-incremented integer, you can use `SQLAlchemyBaseUserTable` as base class and define your own `id` column. + + ```py + class User(SQLAlchemyBaseUserTable[int], Base): + id = Column(Integer, primary_key=True) + ``` + + Notice that `SQLAlchemyBaseUserTable` expects a generic type to define the actual type of ID you use. ## Implement a function to create the tables We'll now create an utility function to create all the defined tables. -```py hl_lines="23-25" +```py hl_lines="21-23" --8<-- "docs/src/db_sqlalchemy.py" ``` @@ -41,14 +48,13 @@ This function can be called, for example, during the initialization of your Fast The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. It should be generated by a FastAPI dependency. -```py hl_lines="28-34" +```py hl_lines="26-33" --8<-- "docs/src/db_sqlalchemy.py" ``` Notice that we define first a `get_async_session` dependency returning us a fresh SQLAlchemy session to interact with the database. -It's then used inside the `get_user_db` dependency to generate our adapter. Notice that we pass it three things: +It's then used inside the `get_user_db` dependency to generate our adapter. Notice that we pass it two things: -* A reference to your [`UserDB` model](../models.md). * The `session` instance we just injected. -* The `UserTable` variable, which is the actual SQLAlchemy model. +* The `User` class, which is the actual SQLAlchemy model. diff --git a/docs/configuration/databases/tortoise.md b/docs/configuration/databases/tortoise.md deleted file mode 100644 index ef318ca6..00000000 --- a/docs/configuration/databases/tortoise.md +++ /dev/null @@ -1,69 +0,0 @@ -# Tortoise ORM - -**FastAPI Users** provides the necessary tools to work with Tortoise ORM. - -## Installation - -Install the database driver that corresponds to your DBMS: - -```sh -pip install asyncpg -``` - -```sh -pip install aiomysql -``` - -```sh -pip install aiosqlite -``` - -For the sake of this tutorial from now on, we'll use a simple SQLite databse. - -## Setup User table and model - -Let's declare our User ORM model. - -```py hl_lines="18 19" ---8<-- "docs/src/db_tortoise_model.py" -``` - -As you can see, **FastAPI Users** provides an abstract model that will include base fields for our User table. You can of course add you own fields there to fit to your needs! - -In order to make the Pydantic model and the Tortoise ORM model working well together, you'll have to add a mixin and some configuration options to your `UserDB` model. Tortoise ORM provides [utilities to ease the integration with Pydantic](https://tortoise-orm.readthedocs.io/en/latest/contrib/pydantic.html) and we'll use them here. - -```py hl_lines="22 23 24 25 26" ---8<-- "docs/src/db_tortoise_model.py" -``` - -The `PydanticModel` mixin adds methods used internally by Tortoise ORM to the Pydantic model so that it can easily transform it back to an ORM model. It expects then that you provide the property `orig_model` which should point to the **User ORM model we defined just above**. - -## Create the database adapter - -The database adapter of **FastAPI Users** makes the link between your database configuration and the users logic. It should be generated by a FastAPI dependency. - -```py hl_lines="8 9" ---8<-- "docs/src/db_tortoise_adapter.py" -``` - -Notice that we pass a reference to your [`UserDB` model](../models.md). - -## Register Tortoise - -For using Tortoise ORM we must register our models and database. - -Tortoise ORM supports integration with FastAPI out-of-the-box. It will automatically bind startup and shutdown events. - -```py -from tortoise.contrib.fastapi import register_tortoise - -register_tortoise( - app, - db_url=DATABASE_URL, - modules={"models": ["models"]}, - generate_schemas=True, -) -``` - -!!! warning - In production, it's strongly recommended to setup a migration system to update your SQL schemas. See [https://tortoise-orm.readthedocs.io/en/latest/migration.html](https://tortoise-orm.readthedocs.io/en/latest/migration.html). diff --git a/docs/configuration/full-example.md b/docs/configuration/full-example.md index a5cb56b6..133f5cb7 100644 --- a/docs/configuration/full-example.md +++ b/docs/configuration/full-example.md @@ -34,10 +34,10 @@ Here is a full working example with JWT authentication to help get you started. --8<-- "examples/sqlalchemy/app/db.py" ``` -=== "app/models.py" +=== "app/schemas.py" ```py - --8<-- "examples/sqlalchemy/app/models.py" + --8<-- "examples/sqlalchemy/app/schemas.py" ``` === "app/users.py" @@ -46,84 +46,44 @@ Here is a full working example with JWT authentication to help get you started. --8<-- "examples/sqlalchemy/app/users.py" ``` -## MongoDB +## Beanie -[Open :material-open-in-new:](https://github.com/fastapi-users/fastapi-users/tree/master/examples/mongodb) +[Open :material-open-in-new:](https://github.com/fastapi-users/fastapi-users/tree/master/examples/beanie) === "requirements.txt" ``` - --8<-- "examples/mongodb/requirements.txt" + --8<-- "examples/beanie/requirements.txt" ``` === "main.py" ```py - --8<-- "examples/mongodb/main.py" + --8<-- "examples/beanie/main.py" ``` === "app/app.py" ```py - --8<-- "examples/mongodb/app/app.py" + --8<-- "examples/beanie/app/app.py" ``` === "app/db.py" ```py - --8<-- "examples/mongodb/app/db.py" + --8<-- "examples/beanie/app/db.py" ``` -=== "app/models.py" +=== "app/schemas.py" ```py - --8<-- "examples/mongodb/app/models.py" + --8<-- "examples/beanie/app/schemas.py" ``` === "app/users.py" ```py - --8<-- "examples/mongodb/app/users.py" - ``` - -## Tortoise ORM - -[Open :material-open-in-new:](https://github.com/fastapi-users/fastapi-users/tree/master/examples/tortoise) - -=== "requirements.txt" - - ``` - --8<-- "examples/tortoise/requirements.txt" - ``` - -=== "main.py" - - ```py - --8<-- "examples/tortoise/main.py" - ``` - -=== "app/app.py" - - ```py - --8<-- "examples/tortoise/app/app.py" - ``` - -=== "app/db.py" - - ```py - --8<-- "examples/tortoise/app/db.py" - ``` - -=== "app/models.py" - - ```py - --8<-- "examples/tortoise/app/models.py" - ``` - -=== "app/users.py" - - ```py - --8<-- "examples/tortoise/app/users.py" + --8<-- "examples/beanie/app/users.py" ``` ## What now? diff --git a/docs/configuration/models.md b/docs/configuration/models.md deleted file mode 100644 index 9cfdea6c..00000000 --- a/docs/configuration/models.md +++ /dev/null @@ -1,69 +0,0 @@ -# Models - -**FastAPI Users** defines a minimal User model for authentication purposes. It is structured like this: - -* `id` (`UUID4`) – Unique identifier of the user. Defaults to a **UUID4**. -* `email` (`str`) – Email of the user. Validated by [`email-validator`](https://github.com/JoshData/python-email-validator). -* `is_active` (`bool`) – Whether or not the user is active. If not, login and forgot password requests will be denied. Defaults to `True`. -* `is_verified` (`bool`) – Whether or not the user is verified. Optional but helpful with the [`verify` router](./routers/verify.md) logic. Defaults to `False`. -* `is_superuser` (`bool`) – Whether or not the user is a superuser. Useful to implement administration logic. Defaults to `False`. - -## Define your models - -There are four Pydantic models variations provided as mixins: - -* `BaseUser`, which provides the basic fields and validation ; -* `BaseCreateUser`, dedicated to user registration, which consists of compulsory `email` and `password` fields ; -* `BaseUpdateUser`, dedicated to user profile update, which adds an optional `password` field ; -* `BaseUserDB`, which is a representation of the user in database, adding a `hashed_password` field. - -You should define each of those variations, inheriting from each mixin: - -```py -from fastapi_users import models - - -class User(models.BaseUser): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass -``` - -### Adding your own fields - -You can of course add your own properties there to fit to your needs. In the example below, we add a required string property, `first_name`, and an optional date property, `birthdate`. - -```py -import datetime - -from fastapi_users import models - - -class User(models.BaseUser): - first_name: str - birthdate: Optional[datetime.date] - - -class UserCreate(models.BaseUserCreate): - first_name: str - birthdate: Optional[datetime.date] - - -class UserUpdate(models.BaseUserUpdate): - first_name: Optional[str] - birthdate: Optional[datetime.date] - - -class UserDB(User, models.BaseUserDB): - pass -``` diff --git a/docs/configuration/oauth.md b/docs/configuration/oauth.md index d10f9b98..5adf4504 100644 --- a/docs/configuration/oauth.md +++ b/docs/configuration/oauth.md @@ -7,15 +7,11 @@ FastAPI Users provides an optional OAuth2 authentication support. It relies on [ You should install the library with the optional dependencies for OAuth: ```sh -pip install 'fastapi-users[sqlalchemy2,oauth]' +pip install 'fastapi-users[sqlalchemy,oauth]' ``` ```sh -pip install 'fastapi-users[mongodb,oauth]' -``` - -```sh -pip install 'fastapi-users[tortoise-orm,oauth]' +pip install 'fastapi-users[beanie,oauth]' ``` ## Configuration @@ -30,78 +26,44 @@ from httpx_oauth.clients.google import GoogleOAuth2 google_oauth_client = GoogleOAuth2("CLIENT_ID", "CLIENT_SECRET") ``` -### Setup the models - -The user models differ a bit from the standard one as we have to have a way to store the OAuth information (access tokens, account ids...). - -```py -from fastapi_users import models - - -class User(models.BaseUser, models.BaseOAuthAccountMixin): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass -``` - -Notice that we inherit from the `BaseOAuthAccountMixin`, which adds a `List` of `BaseOAuthAccount` objects. This object is structured like this: - -* `id` (`UUID4`) – Unique identifier of the OAuth account information. Defaults to a **UUID4**. -* `oauth_name` (`str`) – Name of the OAuth service. It corresponds to the `name` property of the OAuth client. -* `access_token` (`str`) – Access token. -* `expires_at` (`Optional[int]`) - Timestamp at which the access token is expired. -* `refresh_token` (`Optional[str]`) – On services that support it, a token to get a fresh access token. -* `account_id` (`str`) - Identifier of the OAuth account on the corresponding service. -* `account_email` (`str`) - Email address of the OAuth account on the corresponding service. - ### Setup the database adapter #### SQLAlchemy You'll need to define the SQLAlchemy model for storing OAuth accounts. We provide a base one for this: -```py hl_lines="19-24" +```py hl_lines="5 17-18 22 39-40" --8<-- "docs/src/db_sqlalchemy_oauth.py" ``` Notice that we also manually added a `relationship` on the `UserTable` so that SQLAlchemy can properly retrieve the OAuth accounts of the user. -When instantiating the database adapter, you should pass this SQLAlchemy model: +Besides, when instantiating the database adapter, we need pass this SQLAlchemy model as third argument. -```py hl_lines="41-42" ---8<-- "docs/src/db_sqlalchemy_oauth.py" +!!! tip "Primary key is defined as UUID" + By default, we use UUID as a primary key ID for your user. If you want to use another type, like an auto-incremented integer, you can use `SQLAlchemyBaseOAuthAccountTable` as base class and define your own `id` and `user_id` column. + + ```py + class OAuthAccount(SQLAlchemyBaseOAuthAccountTable[int], Base): + id = Column(Integer, primary_key=True) + + @declared_attr + def user_id(cls): + return Column(Integer, ForeignKey("user.id", ondelete="cascade"), nullable=False) + + ``` + + Notice that `SQLAlchemyBaseOAuthAccountTable` expects a generic type to define the actual type of ID you use. + +#### Beanie + +The advantage of MongoDB is that you can easily embed sub-objects in a single document. That's why the configuration for Beanie is quite simple. All we need to do is to define another class to structure an OAuth account object. + +```py hl_lines="5 15-16 20" +--8<-- "docs/src/db_beanie_oauth.py" ``` -#### MongoDB - -Nothing to do, the [basic configuration](./databases/mongodb.md) is enough. - -#### Tortoise ORM - -You'll need to define the Tortoise model for storing the OAuth account model. We provide a base one for this: - -```py hl_lines="29 30" ---8<-- "docs/src/db_tortoise_oauth_model.py" -``` - -!!! warning - Note that you should define the foreign key yourself, so that you can point it the user model in your namespace. - -Then, you should declare it on the database adapter: - -```py hl_lines="8 9" ---8<-- "docs/src/db_tortoise_oauth_adapter.py" -``` +It's worth to note that `OAuthAccount` is **not a Beanie document** but a Pydantic model that we'll embed inside the `User` document, through the `oauth_accounts` array. ### Generate a router @@ -109,9 +71,9 @@ Once you have a `FastAPIUsers` instance, you can make it generate a single OAuth ```py app.include_router( - fastapi_users.get_oauth_router(google_oauth_client, auth_backend, "SECRET"), - prefix="/auth/google", - tags=["auth"], + fastapi_users.get_oauth_router(google_oauth_client, auth_backend, "SECRET"), + prefix="/auth/google", + tags=["auth"], ) ``` @@ -149,10 +111,10 @@ app.include_router( --8<-- "examples/sqlalchemy-oauth/app/db.py" ``` -=== "app/models.py" +=== "app/schemas.py" ```py - --8<-- "examples/sqlalchemy-oauth/app/models.py" + --8<-- "examples/sqlalchemy-oauth/app/schemas.py" ``` === "app/users.py" @@ -161,82 +123,42 @@ app.include_router( --8<-- "examples/sqlalchemy-oauth/app/users.py" ``` -#### MongoDB +#### Beanie -[Open :material-open-in-new:](https://github.com/fastapi-users/fastapi-users/tree/master/examples/mongodb-oauth) +[Open :material-open-in-new:](https://github.com/fastapi-users/fastapi-users/tree/master/examples/beanie-oauth) === "requirements.txt" ``` - --8<-- "examples/mongodb-oauth/requirements.txt" + --8<-- "examples/beanie-oauth/requirements.txt" ``` === "main.py" ```py - --8<-- "examples/mongodb-oauth/main.py" + --8<-- "examples/beanie-oauth/main.py" ``` === "app/app.py" ```py - --8<-- "examples/mongodb-oauth/app/app.py" + --8<-- "examples/beanie-oauth/app/app.py" ``` === "app/db.py" ```py - --8<-- "examples/mongodb-oauth/app/db.py" + --8<-- "examples/beanie-oauth/app/db.py" ``` -=== "app/models.py" +=== "app/schemas.py" ```py - --8<-- "examples/mongodb-oauth/app/models.py" + --8<-- "examples/beanie-oauth/app/schemas.py" ``` === "app/users.py" ```py - --8<-- "examples/mongodb-oauth/app/users.py" - ``` - -#### Tortoise ORM - -[Open :material-open-in-new:](https://github.com/fastapi-users/fastapi-users/tree/master/examples/tortoise-oauth) - -=== "requirements.txt" - - ``` - --8<-- "examples/tortoise-oauth/requirements.txt" - ``` - -=== "main.py" - - ```py - --8<-- "examples/tortoise-oauth/main.py" - ``` - -=== "app/app.py" - - ```py - --8<-- "examples/tortoise-oauth/app/app.py" - ``` - -=== "app/db.py" - - ```py - --8<-- "examples/tortoise-oauth/app/db.py" - ``` - -=== "app/models.py" - - ```py - --8<-- "examples/tortoise-oauth/app/models.py" - ``` - -=== "app/users.py" - - ```py - --8<-- "examples/tortoise-oauth/app/users.py" + --8<-- "examples/beanie-oauth/app/users.py" ``` diff --git a/docs/configuration/overview.md b/docs/configuration/overview.md index 44b1a611..a14392e6 100644 --- a/docs/configuration/overview.md +++ b/docs/configuration/overview.md @@ -4,28 +4,23 @@ The schema below shows you how the library is structured and how each part fit t ```mermaid -flowchart LR +flowchart TB FASTAPI_USERS{FastAPIUsers} USER_MANAGER{UserManager} + USER_MODEL{User model} DATABASE_DEPENDENCY[[get_user_db]] USER_MANAGER_DEPENDENCY[[get_user_manager]] CURRENT_USER[[current_user]] - subgraph MODELS[Models] - direction RL + subgraph SCHEMAS[Schemas] USER[User] USER_CREATE[UserCreate] USER_UPDATE[UserUpdate] - USER_DB[UserDB] end subgraph DATABASE[Database adapters] - direction RL SQLALCHEMY[SQLAlchemy] - MONGODB[MongoDB] - TORTOISE[Tortoise ORM] - ORMAR[Ormar] + BEANIE[Beanie] end subgraph ROUTERS[Routers] - direction RL AUTH[[get_auth_router]] OAUTH[[get_oauth_router]] REGISTER[[get_register_router]] @@ -34,24 +29,21 @@ flowchart LR USERS[[get_users_router]] end subgraph AUTH_BACKENDS[Authentication] - direction RL subgraph TRANSPORTS[Transports] - direction RL COOKIE[CookieTransport] BEARER[BearerTransport] end subgraph STRATEGIES[Strategies] - direction RL + DB[DatabaseStrategy] JWT[JWTStrategy] + REDIS[RedisStrategy] end AUTH_BACKEND{AuthenticationBackend} end DATABASE --> DATABASE_DEPENDENCY + USER_MODEL --> DATABASE_DEPENDENCY DATABASE_DEPENDENCY --> USER_MANAGER - MODELS --> USER_MANAGER - MODELS --> FASTAPI_USERS - USER_MANAGER --> USER_MANAGER_DEPENDENCY USER_MANAGER_DEPENDENCY --> FASTAPI_USERS @@ -60,30 +52,21 @@ flowchart LR TRANSPORTS --> AUTH_BACKEND STRATEGIES --> AUTH_BACKEND - AUTH_BACKEND --> AUTH - AUTH_BACKEND --> OAUTH + AUTH_BACKEND --> ROUTERS AUTH_BACKEND --> FASTAPI_USERS FASTAPI_USERS --> CURRENT_USER + + SCHEMAS --> ROUTERS ``` -## Models +## User model and database adapters -Pydantic models representing the data structure of a user. Base classes are provided with the required fields to make authentication work. You should sub-class each of them and add your own fields there. - -➡️ [Configure the models](./models.md) - -## Database adapters - -FastAPI Users is compatible with various databases and ORM. To build the interface between those database tools and the library, we provide database adapters classes that you need to instantiate and configure. +FastAPI Users is compatible with various **databases and ORM**. To build the interface between those database tools and the library, we provide database adapters classes that you need to instantiate and configure. ➡️ [I'm using SQLAlchemy](databases/sqlalchemy.md) -➡️ [I'm using MongoDB](databases/mongodb.md) - -➡️ [I'm using Tortoise ORM](databases/tortoise.md) - -➡️ [I'm using ormar](databases/ormar.md) +➡️ [I'm using Beanie](databases/beanie.md) ## Authentication backends @@ -101,6 +84,12 @@ This `UserManager` object should be provided through a FastAPI dependency, `get_ ➡️ [Configure `UserManager`](./user-manager.md) +## Schemas + +FastAPI is heavily using [Pydantic models](https://pydantic-docs.helpmanual.io/) to validate request payloads and serialize responses. **FastAPI Users** is no exception and will expect you to provide Pydantic schemas representing a user when it's read, created and updated. + +➡️ [Configure schemas](./schemas.md) + ## `FastAPIUsers` and routers Finally, `FastAPIUsers` object is the main class from which you'll be able to generate routers for classic routes like registration or login, but also get the `current_user` dependency factory to inject the authenticated user in your own routes. diff --git a/docs/configuration/routers/auth.md b/docs/configuration/routers/auth.md index c15616ce..fb9828b3 100644 --- a/docs/configuration/routers/auth.md +++ b/docs/configuration/routers/auth.md @@ -7,16 +7,16 @@ Check the [routes usage](../../usage/routes.md) to learn how to use them. ## Setup ```py +import uuid + from fastapi import FastAPI from fastapi_users import FastAPIUsers -fastapi_users = FastAPIUsers( +from .db import User + +fastapi_users = FastAPIUsers[User, uuid.UUID]( get_user_manager, [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, ) app = FastAPI() diff --git a/docs/configuration/routers/index.md b/docs/configuration/routers/index.md index c1799992..7b0aecd9 100644 --- a/docs/configuration/routers/index.md +++ b/docs/configuration/routers/index.md @@ -4,29 +4,33 @@ We're almost there! The last step is to configure the `FastAPIUsers` object that ## Configure `FastAPIUsers` -Configure `FastAPIUsers` object with all the elements we defined before. More precisely: +Configure `FastAPIUsers` object with the elements we defined before. More precisely: * `get_user_manager`: Dependency callable getter to inject the user manager class instance. See [UserManager](../user-manager.md). * `auth_backends`: List of authentication backends. See [Authentication](../authentication/index.md). -* `user_model`: Pydantic model of a user. -* `user_create_model`: Pydantic model for creating a user. -* `user_update_model`: Pydantic model for updating a user. -* `user_db_model`: Pydantic model of a DB representation of a user. ```py +import uuid + from fastapi_users import FastAPIUsers -fastapi_users = FastAPIUsers( +from .db import User + +fastapi_users = FastAPIUsers[User, uuid.UUID]( get_user_manager, [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, ) ``` +!!! note "Typing: User and ID generic types are expected" + You can see that we define two generic types when instantiating: + + * `User`, which is the user model we defined in the database part + * The ID, which should correspond to the type of ID you use on your model. Here, we chose UUID, but it can be anything, like an integer or a MongoDB ObjectID. + + It'll help you to have **good type-checking and auto-completion**. + ## Available routers This helper class will let you generate useful routers to setup the authentication system. Each of them is **optional**, so you can pick only the one that you are interested in! Here are the routers provided: diff --git a/docs/configuration/routers/register.md b/docs/configuration/routers/register.md index 8d243bfe..fff34b8a 100644 --- a/docs/configuration/routers/register.md +++ b/docs/configuration/routers/register.md @@ -7,23 +7,22 @@ Check the [routes usage](../../usage/routes.md) to learn how to use them. ## Setup ```py +import uuid + from fastapi import FastAPI from fastapi_users import FastAPIUsers -SECRET = "SECRET" +from .db import User +from .schemas import UserCreate, UserRead -fastapi_users = FastAPIUsers( +fastapi_users = FastAPIUsers[User, uuid.UUID]( get_user_manager, [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, ) app = FastAPI() app.include_router( - fastapi_users.get_register_router(), + fastapi_users.get_register_router(UserRead, UserCreate), prefix="/auth", tags=["auth"], ) diff --git a/docs/configuration/routers/reset.md b/docs/configuration/routers/reset.md index d91a4882..ac99c5d7 100644 --- a/docs/configuration/routers/reset.md +++ b/docs/configuration/routers/reset.md @@ -7,18 +7,16 @@ Check the [routes usage](../../usage/routes.md) to learn how to use them. ## Setup ```py +import uuid + from fastapi import FastAPI from fastapi_users import FastAPIUsers -SECRET = "SECRET" +from .db import User -fastapi_users = FastAPIUsers( +fastapi_users = FastAPIUsers[User, uuid.UUID]( get_user_manager, [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, ) app = FastAPI() diff --git a/docs/configuration/routers/users.md b/docs/configuration/routers/users.md index 61edc4ef..d055138b 100644 --- a/docs/configuration/routers/users.md +++ b/docs/configuration/routers/users.md @@ -5,23 +5,22 @@ This router provides routes to manage users. Check the [routes usage](../../usag ## Setup ```py +import uuid + from fastapi import FastAPI from fastapi_users import FastAPIUsers -SECRET = "SECRET" +from .db import User +from .schemas import UserRead, UserUpdate -fastapi_users = FastAPIUsers( +fastapi_users = FastAPIUsers[User, uuid.UUID]( get_user_manager, [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, ) app = FastAPI() app.include_router( - fastapi_users.get_users_router(), + fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", tags=["users"], ) @@ -33,7 +32,7 @@ You can require the user to be **verified** (i.e. `is_verified` property set to ```py app.include_router( - fastapi_users.get_users_router(requires_verification=True), + fastapi_users.get_users_router(UserRead, UserUpdate, requires_verification=True), prefix="/users", tags=["users"], ) diff --git a/docs/configuration/routers/verify.md b/docs/configuration/routers/verify.md index 49f99235..2cc048ba 100644 --- a/docs/configuration/routers/verify.md +++ b/docs/configuration/routers/verify.md @@ -8,23 +8,22 @@ This router provides routes to manage user email verification. Check the [routes ## Setup ```py +import uuid + from fastapi import FastAPI from fastapi_users import FastAPIUsers -SECRET = "SECRET" +from .db import User +from .schemas import UserRead -fastapi_users = FastAPIUsers( +fastapi_users = FastAPIUsers[User, uuid.UUID]( get_user_manager, [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, ) app = FastAPI() app.include_router( - fastapi_users.get_verify_router(), + fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"], ) diff --git a/docs/configuration/schemas.md b/docs/configuration/schemas.md new file mode 100644 index 00000000..3cb98e7a --- /dev/null +++ b/docs/configuration/schemas.md @@ -0,0 +1,73 @@ +# Schemas + +FastAPI is heavily using [Pydantic models](https://pydantic-docs.helpmanual.io/) to validate request payloads and serialize responses. **FastAPI Users** is no exception and will expect you to provide Pydantic schemas representing a user when it's read, created and updated. + +It's **different from your `User` model**, which is an object that actually interacts with the database. Those schemas on the other hand are here to validate data and serialize correct it in the API. + +**FastAPI Users** provides a base structure to cover its needs. It is structured like this: + +* `id` (`ID`) – Unique identifier of the user. It matches the type of your ID, like UUID or integer. +* `email` (`str`) – Email of the user. Validated by [`email-validator`](https://github.com/JoshData/python-email-validator). +* `is_active` (`bool`) – Whether or not the user is active. If not, login and forgot password requests will be denied. Defaults to `True`. +* `is_verified` (`bool`) – Whether or not the user is verified. Optional but helpful with the [`verify` router](./routers/verify.md) logic. Defaults to `False`. +* `is_superuser` (`bool`) – Whether or not the user is a superuser. Useful to implement administration logic. Defaults to `False`. + +## Define your schemas + +There are four Pydantic models variations provided as mixins: + +* `BaseUser`, which provides the basic fields and validation; +* `BaseCreateUser`, dedicated to user registration, which consists of compulsory `email` and `password` fields; +* `BaseUpdateUser`, dedicated to user profile update, which adds an optional `password` field; + +You should define each of those variations, inheriting from each mixin: + +```py +import uuid + +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[uuid.UUID]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass +``` + +!!! note "Typing: ID generic type is expected" + You can see that we define a generic type when extending the `BaseUser` class. It should correspond to the type of ID you use on your model. Here, we chose UUID, but it can be anything, like an integer or a MongoDB ObjectID. + +### Adding your own fields + +You can of course add your own properties there to fit to your needs. In the example below, we add a required string property, `first_name`, and an optional date property, `birthdate`. + +```py +import datetime +import uuid + +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[uuid.UUID]): + first_name: str + birthdate: Optional[datetime.date] + + +class UserCreate(schemas.BaseUserCreate): + first_name: str + birthdate: Optional[datetime.date] + + +class UserUpdate(schemas.BaseUserUpdate): + first_name: Optional[str] + birthdate: Optional[datetime.date] +``` + +!!! warning "Make sure to mirror this in your database model" + The `User` model you defined earlier for your specific database will be the central object that will actually store the data. Therefore, you need to define the very same fields in it so the data can be actually stored. diff --git a/docs/configuration/user-manager.md b/docs/configuration/user-manager.md index 6d5f753b..5e33b6cb 100644 --- a/docs/configuration/user-manager.md +++ b/docs/configuration/user-manager.md @@ -2,23 +2,62 @@ The `UserManager` class is the core logic of FastAPI Users. We provide the `BaseUserManager` class which you should extend to set some parameters and define logic, for example when a user just registered or forgot its password. -It's designed to be easily extensible and customizable so that you can integrate less generic logic. +It's designed to be easily extensible and customizable so that you can integrate your very own logic. ## Create your `UserManager` class You should define your own version of the `UserManager` class to set various parameters. -```py hl_lines="12-28" +```py hl_lines="12-27" --8<-- "docs/src/user_manager.py" ``` As you can see, you have to define here various attributes and methods. You can find the complete list of those below. +!!! note "Typing: User and ID generic types are expected" + You can see that we define two generic types when extending the base class: + + * `User`, which is the user model we defined in the database part + * The ID, which should correspond to the type of ID you use on your model. Here, we chose UUID, but it can be anything, like an integer or a MongoDB ObjectID. + + It'll help you to have **good type-checking and auto-completion** when implementing the custom methods. + +### The ID parser mixin + +Since the user ID is fully generic, we need a way to **parse it reliably when it'll come from API requests**, typically as URL path attributes. + +That's why we added the `UUIDIDMixin` in the example above. It implements the `parse_id` method, ensuring UUID are valid and correctly parsed. + +Of course, it's important that this logic **matches the type of your ID**. To help you with this, we provide mixins for the most common cases: + +* `UUIDIDMixin`, for UUID ID. +* `IntegerIDMixin`, for integer ID. +* `ObjectIDIDMixin` (provided by `fastapi_users_db_beanie`), for MongoDB ObjectID. + +!!! tip "Inheritance order matters" + Notice in your example that **the mixin comes first in our `UserManager` inheritance**. Because of the Method-Resolution-Order (MRO) of Python, the left-most element takes precedence. + +If you need another type of ID, you can simply overload the `parse_id` method on your `UserManager` class: + +```py +from fastapi_users import BaseUserManager, InvalidID + + +class UserManager(BaseUserManager[User, MyCustomID]): + def parse_id(self, value: Any) -> MyCustomID: + try: + return MyCustomID(value) + except ValueError as e: + raise InvalidID() from e # (1)! +``` + +1. If the ID can't be parsed into the desired type, you'll need to raise an `InvalidID` exception. + ## Create `get_user_manager` dependency The `UserManager` class will be injected at runtime using a FastAPI dependency. This way, you can run it in a database session or swap it with a mock during testing. -```py hl_lines="31-32" +```py hl_lines="30-31" --8<-- "docs/src/user_manager.py" ``` @@ -28,7 +67,6 @@ Notice that we use the `get_user_db` dependency we defined earlier to inject the ### Attributes -* `user_db_model`: Pydantic model of a DB representation of a user. * `reset_password_token_secret`: Secret to encode reset password token. **Use a strong passphrase and keep it secure.** * `reset_password_token_lifetime_seconds`: Lifetime of reset password token. Defaults to 3600. * `reset_password_token_audience`: JWT audience of reset password token. Defaults to `fastapi-users:reset`. @@ -54,15 +92,15 @@ This function should return `None` if the password is valid or raise `InvalidPas **Example** ```py -from fastapi_users import BaseUserManager, InvalidPasswordException +from fastapi_users import BaseUserManager, InvalidPasswordException, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... async def validate_password( self, password: str, - user: Union[UserCreate, UserDB], + user: Union[UserCreate, User], ) -> None: if len(password) < 8: raise InvalidPasswordException( @@ -82,18 +120,18 @@ Typically, you'll want to **send a welcome e-mail** or add it to your marketing **Arguments** -* `user` (`UserDB`): the registered user. +* `user` (`User`): the registered user. * `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. **Example** ```py -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") ``` @@ -105,21 +143,21 @@ It may be useful, for example, if you wish to update your user in a data analyti **Arguments** -* `user` (`UserDB`): the updated user. +* `user` (`User`): the updated user. * `update_dict` (`Dict[str, Any]`): dictionary with the updated user fields. * `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. **Example** ```py -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... async def on_after_update( self, - user: UserDB, + user: User, update_dict: Dict[str, Any], request: Optional[Request] = None, ): @@ -134,20 +172,20 @@ Typically, you'll want to **send an e-mail** with the link (and the token) that **Arguments** -* `user` (`UserDB`): the user to verify. +* `user` (`User`): the user to verify. * `token` (`str`): the verification token. * `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. **Example** ```py -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") ``` @@ -160,19 +198,19 @@ This may be useful if you wish to send another e-mail or store this information **Arguments** -* `user` (`UserDB`): the verified user. +* `user` (`User`): the verified user. * `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. **Example** ```py -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... async def on_after_verify( - self, user: UserDB, request: Optional[Request] = None + self, user: User, request: Optional[Request] = None ): print(f"User {user.id} has been verified") ``` @@ -185,20 +223,20 @@ Typically, you'll want to **send an e-mail** with the link (and the token) that **Arguments** -* `user` (`UserDB`): the user that forgot its password. +* `user` (`User`): the user that forgot its password. * `token` (`str`): the forgot password token * `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. **Example** ```py -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") ``` @@ -211,17 +249,17 @@ For example, you may want to **send an e-mail** to the concerned user to warn hi **Arguments** -* `user` (`UserDB`): the user that reset its password. +* `user` (`User`): the user that reset its password. * `request` (`Optional[Request]`): optional FastAPI request object that triggered the operation. Defaults to None. **Example** ```py -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -class UserManager(BaseUserManager[UserCreate, UserDB]): +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # ... - async def on_after_reset_password(self, user: UserDB, request: Optional[Request] = None): + async def on_after_reset_password(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has reset their password.") ``` diff --git a/docs/installation.md b/docs/installation.md index e7e41ae0..29ff1c68 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -5,27 +5,15 @@ You can add **FastAPI Users** to your FastAPI project in a few easy steps. First ## With SQLAlchemy support ```sh -pip install 'fastapi-users[sqlalchemy2]' +pip install 'fastapi-users[sqlalchemy]' ``` -## With MongoDB support +## With Beanie support ```sh pip install 'fastapi-users[mongodb]' ``` -## With Tortoise ORM support - -```sh -pip install 'fastapi-users[tortoise-orm]' -``` - -## With ormar support - -```sh -pip install 'fastapi-users[ormar]' -``` - --- -That's it! Now, let's have a look at our [User model](./configuration/models.md). +That's it! In the next section, we'll have an [overview](./configuration/overview.md) of how things work. diff --git a/docs/migration/9x_to_10x.md b/docs/migration/9x_to_10x.md new file mode 100644 index 00000000..4f031b79 --- /dev/null +++ b/docs/migration/9x_to_10x.md @@ -0,0 +1,361 @@ +# 9.x.x ➡️ 10.x.x + +Version 10 marks important changes in how we manage User models and their ID. + +Before, we were relying only on Pydantic models to work with users. In particular the [`current_user` dependency](../usage/current-user.md) would return you an instance of `UserDB`, a Pydantic model. This proved to be quite problematic with some ORM if you ever needed to **retrieve relationship data** or make specific requests. + +Now, FastAPI Users is designed to always return you a **native object for your ORM model**, whether it's an SQLAlchemy model or a Beanie document. Pydantic models are now only used for validation and serialization inside the API. + +Before, we were forcing the use of UUID as primary key ID; a consequence of the design above. This proved to be quite problematic on some databases, like MongoDB which uses a special ObjectID format by default. Some SQL folks also prefer to use traditional auto-increment integers. + +Now, FastAPI Users is designed to use **generic ID type**. It means that you can use any type you want for your user's ID. By default, SQLAlchemy adapter still use UUID; but you can quite easily switch to another thing, like an integer. Beanie adapter for MongoDB will use native ObjectID by default, but it also can be overriden. + +As you may have guessed, those changes imply quite a lot of **breaking changes**. + +## User models and database adapter + +### SQLAlchemy ORM + +We've removed the old SQLAlchemy dependency support, so the dependency is now `fastapi-users[sqlalchemy]`. + +=== "Before" + + ```txt + fastapi + fastapi-users[sqlalchemy2] + uvicorn[standard] + aiosqlite + ``` + +=== "After" + + + ```txt + fastapi + fastapi-users[sqlalchemy] + uvicorn[standard] + aiosqlite + ``` + +The User model base class for SQLAlchemy slightly changed to support UUID by default. + +We changed the name of the class from `UserTable` to `User`: it's not a compulsory change, but since there is no risk of confusion with Pydantic models anymore, it's probably a more idiomatic naming. + +=== "Before" + + ```py + class UserTable(Base, SQLAlchemyBaseUserTable): + pass + ``` + +=== "After" + + ```py + class User(SQLAlchemyBaseUserTableUUID, Base): + pass + ``` + +Instantiating the `SQLAlchemyUserDatabase` adapter now only expects this `User` model. `UserDB` is removed. + +=== "Before" + + ```py + async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(UserDB, session, UserTable) + ``` + +=== "After" + + ```py + async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(session, User) + ``` + +### MongoDB + +MongoDB support is now only provided through [Beanie ODM](https://github.com/roman-right/beanie/). Even if you don't use it for the rest of your project, it's a very light addition that shouldn't interfere much. + +=== "Before" + + ```txt + fastapi + fastapi-users[mongodb] + uvicorn[standard] + aiosqlite + ``` + +=== "After" + + + ```txt + fastapi + fastapi-users[beanie] + uvicorn[standard] + aiosqlite + ``` + +You now need to define a proper User model using Beanie. + +=== "Before" + + ```py + import os + + import motor.motor_asyncio + from fastapi_users.db import MongoDBUserDatabase + + from app.models import UserDB + + DATABASE_URL = os.environ["DATABASE_URL"] + client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" + ) + db = client["database_name"] + collection = db["users"] + + + async def get_user_db(): + yield MongoDBUserDatabase(UserDB, collection) + ``` + +=== "After" + + ```py + import motor.motor_asyncio + from beanie import PydanticObjectId + from fastapi_users.db import BeanieBaseUser, BeanieUserDatabase + + DATABASE_URL = "mongodb://localhost:27017" + client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" + ) + db = client["database_name"] + + + class User(BeanieBaseUser[PydanticObjectId]): + pass + + + async def get_user_db(): + yield BeanieUserDatabase(User) + ``` + +!!! danger "ID are now ObjectID by default" + By default, User ID will now be native MongoDB ObjectID. If you don't want to make the transition and keep UUID you can do so by overriding the `id` field: + + ```py + import uuid + + from pydantic import Field + + + class User(BeanieBaseUser[uuid.UUID]): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + ``` + +Beanie also needs to be initialized in a startup event handler of your FastAPI app: + +```py +from beanie import init_beanie + + +@app.on_event("startup") +async def on_startup(): + await init_beanie( + database=db, + document_models=[ + User, + ], + ) +``` + +### Tortoise ORM and ormar + +Unfortunately, we sometimes need to make difficult choices to keep things sustainable. That's why we decided to **not support Tortoise ORM and ormar** anymore. It appeared they were not widely used. + +You can still add support for those ORM yourself by implementing the necessary adapter. You can take inspiration from [the SQLAlchemy one](https://github.com/fastapi-users/fastapi-users-db-sqlalchemy). + +## `UserManager` + +There is some slight changes on the `UserManager` class. In particular, it now needs a `parse_id` method that can be provided through built-in mixins. + +Generic typing now expects your **native User model class** and the **type of ID**. + +The `user_db_model` class property is **removed**. + +=== "Before" + + ```py + class UserManager(BaseUserManager[UserCreate, UserDB]): + user_db_model = UserDB + reset_password_token_secret = SECRET + verification_token_secret = SECRET + + async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + + async def on_after_forgot_password( + self, user: UserDB, token: str, request: Optional[Request] = None + ): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify( + self, user: UserDB, token: str, request: Optional[Request] = None + ): + print(f"Verification requested for user {user.id}. Verification token: {token}") + ``` + +=== "After" + + ```py + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + reset_password_token_secret = SECRET + verification_token_secret = SECRET + + async def on_after_register(self, user: User, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + + async def on_after_forgot_password( + self, user: User, token: str, request: Optional[Request] = None + ): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify( + self, user: User, token: str, request: Optional[Request] = None + ): + print(f"Verification requested for user {user.id}. Verification token: {token}") + ``` + +If you need to support other types of ID, you can read more about it [in the dedicated section](../configuration/user-manager.md#the-id-parser-mixin). + +## Pydantic models + +To better distinguish them from the ORM models, Pydantic models are now called **schemas**. + +**`UserDB` has been removed** in favor of native models. + +We changed the name of `User` to `UserRead`: it's not a compulsory change, but since there is a **risk of confusion** with the native model, it's highly recommended. + +Besides, the `BaseUser` schema now accepts a generic type to specify the type of ID you use. + +=== "Before" + + ```py + from fastapi_users import models + + + class User(models.BaseUser): + pass + + + class UserCreate(models.BaseUserCreate): + pass + + + class UserUpdate(models.BaseUserUpdate): + pass + + + class UserDB(User, models.BaseUserDB): + pass + ``` + +=== "After" + + ```py + import uuid + + from fastapi_users import schemas + + + class UserRead(schemas.BaseUser[uuid.UUID]): + pass + + + class UserCreate(schemas.BaseUserCreate): + pass + + + class UserUpdate(schemas.BaseUserUpdate): + pass + ``` + +## FastAPI Users and routers + +Pydantic schemas are now way less important in this new design. As such, you don't need to pass them when initializing the `FastAPIUsers` class: + +=== "Before" + + ```py + fastapi_users = FastAPIUsers( + get_user_manager, + [auth_backend], + User, + UserCreate, + UserUpdate, + UserDB, + ) + ``` + +=== "After" + + ```py + fastapi_users = FastAPIUsers[User, uuid.UUID]( + get_user_manager, + [auth_backend], + ) + ``` + +As a consequence, those schemas need to be passed when initializing the router that needs them: `get_register_router`, `get_verify_router` and `get_users_router`. + +=== "Before" + + ```py + app.include_router( + fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] + ) + app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) + app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], + ) + app.include_router( + fastapi_users.get_verify_router(), + prefix="/auth", + tags=["auth"], + ) + app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) + ``` + +=== "After" + + ```py + app.include_router( + fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] + ) + app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + ) + app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], + ) + app.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], + ) + app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], + ) + ``` + +## Lost? + +If you're unsure or a bit lost, make sure to check the [full working examples](../configuration/full-example.md). diff --git a/docs/src/cookbook_create_user_programmatically.py b/docs/src/cookbook_create_user_programmatically.py index 449a81ad..70724c06 100644 --- a/docs/src/cookbook_create_user_programmatically.py +++ b/docs/src/cookbook_create_user_programmatically.py @@ -1,7 +1,7 @@ import contextlib from app.db import get_async_session, get_user_db -from app.models import UserCreate +from app.schemas import UserCreate from app.users import get_user_manager from fastapi_users.manager import UserAlreadyExists diff --git a/docs/src/db_mongodb.py b/docs/src/db_beanie.py similarity index 54% rename from docs/src/db_mongodb.py rename to docs/src/db_beanie.py index 12b026f8..58accc17 100644 --- a/docs/src/db_mongodb.py +++ b/docs/src/db_beanie.py @@ -1,15 +1,17 @@ import motor.motor_asyncio -from fastapi_users.db import MongoDBUserDatabase - -from .models import UserDB +from beanie import PydanticObjectId +from fastapi_users.db import BeanieBaseUser, BeanieUserDatabase DATABASE_URL = "mongodb://localhost:27017" client = motor.motor_asyncio.AsyncIOMotorClient( DATABASE_URL, uuidRepresentation="standard" ) db = client["database_name"] -collection = db["users"] + + +class User(BeanieBaseUser[PydanticObjectId]): + pass async def get_user_db(): - yield MongoDBUserDatabase(UserDB, collection) + yield BeanieUserDatabase(User) diff --git a/docs/src/db_beanie_access_tokens.py b/docs/src/db_beanie_access_tokens.py new file mode 100644 index 00000000..098f7432 --- /dev/null +++ b/docs/src/db_beanie_access_tokens.py @@ -0,0 +1,29 @@ +import motor.motor_asyncio +from beanie import PydanticObjectId +from fastapi_users.db import BeanieBaseUser, BeanieUserDatabase +from fastapi_users_db_beanie.access_token import ( + BeanieAccessTokenDatabase, + BeanieBaseAccessToken, +) + +DATABASE_URL = "mongodb://localhost:27017" +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) +db = client["database_name"] + + +class User(BeanieBaseUser): + pass + + +class AccessToken(BeanieBaseAccessToken[PydanticObjectId]): # (1)! + pass + + +async def get_user_db(): + yield BeanieUserDatabase(User) + + +async def get_access_token_db(): # (2)! + yield BeanieAccessTokenDatabase(AccessToken) diff --git a/docs/src/db_beanie_oauth.py b/docs/src/db_beanie_oauth.py new file mode 100644 index 00000000..835ddd21 --- /dev/null +++ b/docs/src/db_beanie_oauth.py @@ -0,0 +1,24 @@ +from typing import List + +import motor.motor_asyncio +from beanie import PydanticObjectId +from fastapi_users.db import BaseOAuthAccount, BeanieBaseUser, BeanieUserDatabase +from pydantic import Field + +DATABASE_URL = "mongodb://localhost:27017" +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) +db = client["database_name"] + + +class OAuthAccount(BaseOAuthAccount): + pass + + +class User(BeanieBaseUser[PydanticObjectId]): + oauth_accounts: List[OAuthAccount] = Field(default_factory=list) + + +async def get_user_db(): + yield BeanieUserDatabase(User) diff --git a/docs/src/db_mongodb_access_tokens.py b/docs/src/db_mongodb_access_tokens.py deleted file mode 100644 index fded3797..00000000 --- a/docs/src/db_mongodb_access_tokens.py +++ /dev/null @@ -1,21 +0,0 @@ -import motor.motor_asyncio -from fastapi_users.db import MongoDBUserDatabase -from fastapi_users_db_mongodb.access_token import MongoDBAccessTokenDatabase - -from .models import AccessToken, UserDB - -DATABASE_URL = "mongodb://localhost:27017" -client = motor.motor_asyncio.AsyncIOMotorClient( - DATABASE_URL, uuidRepresentation="standard" -) -db = client["database_name"] -users_collection = db["users"] -access_tokens_collection = db["access_tokens"] - - -async def get_user_db(): - yield MongoDBUserDatabase(UserDB, users_collection) - - -async def get_access_token_db(): - yield MongoDBAccessTokenDatabase(AccessToken, access_tokens_collection) diff --git a/docs/src/db_ormar.py b/docs/src/db_ormar.py deleted file mode 100644 index 430f3f8d..00000000 --- a/docs/src/db_ormar.py +++ /dev/null @@ -1,24 +0,0 @@ -import databases -import sqlalchemy -from fastapi_users.db import OrmarBaseUserModel, OrmarUserDatabase - -from .models import UserDB - -DATABASE_URL = "sqlite:///test.db" -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL) - - -class UserModel(OrmarBaseUserModel): - class Meta: - tablename = "users" - metadata = metadata - database = database - - -engine = sqlalchemy.create_engine(DATABASE_URL) -metadata.create_all(engine) - - -async def get_user_db(): - yield OrmarUserDatabase(UserDB, UserModel) diff --git a/docs/src/db_sqlalchemy.py b/docs/src/db_sqlalchemy.py index fc2baa6d..05d0f38b 100644 --- a/docs/src/db_sqlalchemy.py +++ b/docs/src/db_sqlalchemy.py @@ -1,18 +1,16 @@ from typing import AsyncGenerator from fastapi import Depends -from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase +from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker -from .models import UserDB - DATABASE_URL = "sqlite+aiosqlite:///./test.db" Base: DeclarativeMeta = declarative_base() -class UserTable(Base, SQLAlchemyBaseUserTable): +class User(SQLAlchemyBaseUserTableUUID, Base): pass @@ -31,4 +29,4 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(UserDB, session, UserTable) + yield SQLAlchemyUserDatabase(session, User) diff --git a/docs/src/db_sqlalchemy_access_tokens.py b/docs/src/db_sqlalchemy_access_tokens.py index cc83c4e7..257fd831 100644 --- a/docs/src/db_sqlalchemy_access_tokens.py +++ b/docs/src/db_sqlalchemy_access_tokens.py @@ -1,26 +1,24 @@ from typing import AsyncGenerator from fastapi import Depends -from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase +from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from fastapi_users_db_sqlalchemy.access_token import ( SQLAlchemyAccessTokenDatabase, - SQLAlchemyBaseAccessTokenTable, + SQLAlchemyBaseAccessTokenTableUUID, ) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker -from .models import AccessToken, UserDB - DATABASE_URL = "sqlite+aiosqlite:///./test.db" Base: DeclarativeMeta = declarative_base() -class UserTable(Base, SQLAlchemyBaseUserTable): +class User(SQLAlchemyBaseUserTableUUID, Base): pass -class AccessTokenTable(SQLAlchemyBaseAccessTokenTable, Base): +class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): # (1)! pass @@ -39,8 +37,10 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(UserDB, session, UserTable) + yield SQLAlchemyUserDatabase(session, User) -async def get_access_token_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyAccessTokenDatabase(AccessToken, session, AccessTokenTable) +async def get_access_token_db( + session: AsyncSession = Depends(get_async_session), +): # (2)! + yield SQLAlchemyAccessTokenDatabase(session, AccessToken) diff --git a/docs/src/db_sqlalchemy_oauth.py b/docs/src/db_sqlalchemy_oauth.py index 7ed34081..87d6fbbd 100644 --- a/docs/src/db_sqlalchemy_oauth.py +++ b/docs/src/db_sqlalchemy_oauth.py @@ -1,29 +1,27 @@ -from typing import AsyncGenerator +from typing import AsyncGenerator, List from fastapi import Depends from fastapi_users.db import ( - SQLAlchemyBaseOAuthAccountTable, - SQLAlchemyBaseUserTable, + SQLAlchemyBaseOAuthAccountTableUUID, + SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase, ) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import relationship, sessionmaker -from .models import UserDB - DATABASE_URL = "sqlite+aiosqlite:///./test.db" Base: DeclarativeMeta = declarative_base() -class UserTable(Base, SQLAlchemyBaseUserTable): - oauth_accounts = relationship("OAuthAccountTable") - - -class OAuthAccountTable(SQLAlchemyBaseOAuthAccountTable, Base): +class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): pass +class User(SQLAlchemyBaseUserTableUUID, Base): + oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined") + + engine = create_async_engine(DATABASE_URL) async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @@ -39,4 +37,4 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(UserDB, session, UserTable, OAuthAccountTable) + yield SQLAlchemyUserDatabase(session, User, OAuthAccount) diff --git a/docs/src/db_tortoise_access_tokens_adapter.py b/docs/src/db_tortoise_access_tokens_adapter.py deleted file mode 100644 index 1d46242f..00000000 --- a/docs/src/db_tortoise_access_tokens_adapter.py +++ /dev/null @@ -1,14 +0,0 @@ -from fastapi_users.db import TortoiseUserDatabase -from fastapi_users_db_tortoise.access_token import TortoiseAccessTokenDatabase - -from .models import AccessToken, AccessTokenModel, UserDB, UserModel - -DATABASE_URL = "sqlite://./test.db" - - -async def get_user_db(): - yield TortoiseUserDatabase(UserDB, UserModel) - - -async def get_access_token_db(): - yield TortoiseAccessTokenDatabase(AccessToken, AccessTokenModel) diff --git a/docs/src/db_tortoise_access_tokens_model.py b/docs/src/db_tortoise_access_tokens_model.py deleted file mode 100644 index 0c5176be..00000000 --- a/docs/src/db_tortoise_access_tokens_model.py +++ /dev/null @@ -1,38 +0,0 @@ -from fastapi_users import models -from fastapi_users.authentication.strategy.db.models import BaseAccessToken -from fastapi_users.db import TortoiseBaseUserModel -from fastapi_users_db_tortoise.access_token import TortoiseBaseAccessTokenModel -from tortoise import fields -from tortoise.contrib.pydantic import PydanticModel - - -class User(models.BaseUser): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserModel(TortoiseBaseUserModel): - pass - - -class UserDB(User, models.BaseUserDB, PydanticModel): - class Config: - orm_mode = True - orig_model = UserModel - - -class AccessTokenModel(TortoiseBaseAccessTokenModel): - user = fields.ForeignKeyField("models.UserModel", related_name="access_tokens") - - -class AccessToken(BaseAccessToken, PydanticModel): - class Config: - orm_mode = True - orig_model = AccessTokenModel diff --git a/docs/src/db_tortoise_adapter.py b/docs/src/db_tortoise_adapter.py deleted file mode 100644 index 7506c886..00000000 --- a/docs/src/db_tortoise_adapter.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastapi_users.db import TortoiseUserDatabase - -from .models import UserDB, UserModel - -DATABASE_URL = "sqlite://./test.db" - - -async def get_user_db(): - yield TortoiseUserDatabase(UserDB, UserModel) diff --git a/docs/src/db_tortoise_model.py b/docs/src/db_tortoise_model.py deleted file mode 100644 index e55ecc42..00000000 --- a/docs/src/db_tortoise_model.py +++ /dev/null @@ -1,25 +0,0 @@ -from fastapi_users import models -from fastapi_users.db import TortoiseBaseUserModel -from tortoise.contrib.pydantic import PydanticModel - - -class User(models.BaseUser): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserModel(TortoiseBaseUserModel): - pass - - -class UserDB(User, models.BaseUserDB, PydanticModel): - class Config: - orm_mode = True - orig_model = UserModel diff --git a/docs/src/db_tortoise_oauth_adapter.py b/docs/src/db_tortoise_oauth_adapter.py deleted file mode 100644 index bf0b4af7..00000000 --- a/docs/src/db_tortoise_oauth_adapter.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastapi_users.db import TortoiseUserDatabase - -from .models import OAuthAccount, UserDB, UserModel - -DATABASE_URL = "sqlite://./test.db" - - -async def get_user_db(): - yield TortoiseUserDatabase(UserDB, UserModel, OAuthAccount) diff --git a/docs/src/db_tortoise_oauth_model.py b/docs/src/db_tortoise_oauth_model.py deleted file mode 100644 index 9d470192..00000000 --- a/docs/src/db_tortoise_oauth_model.py +++ /dev/null @@ -1,30 +0,0 @@ -from fastapi_users import models -from fastapi_users.db import TortoiseBaseOAuthAccountModel, TortoiseBaseUserModel -from tortoise import fields -from tortoise.contrib.pydantic import PydanticModel - - -class User(models.BaseUser, models.BaseOAuthAccountMixin): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserModel(TortoiseBaseUserModel): - pass - - -class UserDB(User, models.BaseUserDB, PydanticModel): - class Config: - orm_mode = True - orig_model = UserModel - - -class OAuthAccount(TortoiseBaseOAuthAccountModel): - user = fields.ForeignKeyField("models.UserModel", related_name="oauth_accounts") diff --git a/docs/src/user_manager.py b/docs/src/user_manager.py index 2fd9f5b5..55f435fd 100644 --- a/docs/src/user_manager.py +++ b/docs/src/user_manager.py @@ -1,29 +1,28 @@ +import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users import BaseUserManager +from fastapi_users import BaseUserManager, UUIDIDMixin -from .db import get_user_db -from .models import UserCreate, UserDB +from .db import User, get_user_db SECRET = "SECRET" -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") diff --git a/examples/mongodb-oauth/app/__init__.py b/examples/beanie-oauth/app/__init__.py similarity index 100% rename from examples/mongodb-oauth/app/__init__.py rename to examples/beanie-oauth/app/__init__.py diff --git a/examples/tortoise-oauth/app/app.py b/examples/beanie-oauth/app/app.py similarity index 51% rename from examples/tortoise-oauth/app/app.py rename to examples/beanie-oauth/app/app.py index 0991be46..a939f8fc 100644 --- a/examples/tortoise-oauth/app/app.py +++ b/examples/beanie-oauth/app/app.py @@ -1,8 +1,8 @@ +from beanie import init_beanie from fastapi import Depends, FastAPI -from tortoise.contrib.fastapi import register_tortoise -from app.db import DATABASE_URL -from app.models import UserDB +from app.db import User, db +from app.schemas import UserCreate, UserRead, UserUpdate from app.users import ( auth_backend, current_active_user, @@ -15,18 +15,26 @@ app = FastAPI() app.include_router( fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] ) -app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) app.include_router( fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"], ) app.include_router( - fastapi_users.get_verify_router(), + fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"], ) -app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) app.include_router( fastapi_users.get_oauth_router(google_oauth_client, auth_backend, "SECRET"), prefix="/auth/google", @@ -35,13 +43,15 @@ app.include_router( @app.get("/authenticated-route") -async def authenticated_route(user: UserDB = Depends(current_active_user)): +async def authenticated_route(user: User = Depends(current_active_user)): return {"message": f"Hello {user.email}!"} -register_tortoise( - app, - db_url=DATABASE_URL, - modules={"models": ["app.models"]}, - generate_schemas=True, -) +@app.on_event("startup") +async def on_startup(): + await init_beanie( + database=db, + document_models=[ + User, + ], + ) diff --git a/examples/beanie-oauth/app/db.py b/examples/beanie-oauth/app/db.py new file mode 100644 index 00000000..835ddd21 --- /dev/null +++ b/examples/beanie-oauth/app/db.py @@ -0,0 +1,24 @@ +from typing import List + +import motor.motor_asyncio +from beanie import PydanticObjectId +from fastapi_users.db import BaseOAuthAccount, BeanieBaseUser, BeanieUserDatabase +from pydantic import Field + +DATABASE_URL = "mongodb://localhost:27017" +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) +db = client["database_name"] + + +class OAuthAccount(BaseOAuthAccount): + pass + + +class User(BeanieBaseUser[PydanticObjectId]): + oauth_accounts: List[OAuthAccount] = Field(default_factory=list) + + +async def get_user_db(): + yield BeanieUserDatabase(User) diff --git a/examples/beanie-oauth/app/schemas.py b/examples/beanie-oauth/app/schemas.py new file mode 100644 index 00000000..f9b2b9a1 --- /dev/null +++ b/examples/beanie-oauth/app/schemas.py @@ -0,0 +1,14 @@ +from beanie import PydanticObjectId +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[PydanticObjectId]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass diff --git a/examples/mongodb-oauth/app/users.py b/examples/beanie-oauth/app/users.py similarity index 60% rename from examples/mongodb-oauth/app/users.py rename to examples/beanie-oauth/app/users.py index 2ba395ed..81ad5ca4 100644 --- a/examples/mongodb-oauth/app/users.py +++ b/examples/beanie-oauth/app/users.py @@ -1,6 +1,7 @@ import os from typing import Optional +from beanie import PydanticObjectId from fastapi import Depends, Request from fastapi_users import BaseUserManager, FastAPIUsers from fastapi_users.authentication import ( @@ -8,41 +9,38 @@ from fastapi_users.authentication import ( BearerTransport, JWTStrategy, ) -from fastapi_users.db import MongoDBUserDatabase +from fastapi_users.db import BeanieUserDatabase, ObjectIDIDMixin from httpx_oauth.clients.google import GoogleOAuth2 -from app.db import get_user_db -from app.models import User, UserCreate, UserDB, UserUpdate +from app.db import User, get_user_db SECRET = "SECRET" - google_oauth_client = GoogleOAuth2( - os.environ["GOOGLE_OAUTH_CLIENT_ID"], - os.environ["GOOGLE_OAUTH_CLIENT_SECRET"], + os.getenv("GOOGLE_OAUTH_CLIENT_ID", ""), + os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", ""), ) -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(ObjectIDIDMixin, BaseUserManager[User, PydanticObjectId]): reset_password_token_secret = SECRET verification_token_secret = SECRET - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") -async def get_user_manager(user_db: MongoDBUserDatabase = Depends(get_user_db)): +async def get_user_manager(user_db: BeanieUserDatabase = Depends(get_user_db)): yield UserManager(user_db) @@ -58,13 +56,7 @@ auth_backend = AuthenticationBackend( transport=bearer_transport, get_strategy=get_jwt_strategy, ) -fastapi_users = FastAPIUsers( - get_user_manager, - [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, -) + +fastapi_users = FastAPIUsers[User, PydanticObjectId](get_user_manager, [auth_backend]) current_active_user = fastapi_users.current_user(active=True) diff --git a/examples/beanie-oauth/main.py b/examples/beanie-oauth/main.py new file mode 100644 index 00000000..eb1edd12 --- /dev/null +++ b/examples/beanie-oauth/main.py @@ -0,0 +1,4 @@ +import uvicorn + +if __name__ == "__main__": + uvicorn.run("app.app:app", host="0.0.0.0", log_level="info") diff --git a/examples/mongodb/requirements.txt b/examples/beanie-oauth/requirements.txt similarity index 53% rename from examples/mongodb/requirements.txt rename to examples/beanie-oauth/requirements.txt index 745d009c..ea30f98f 100644 --- a/examples/mongodb/requirements.txt +++ b/examples/beanie-oauth/requirements.txt @@ -1,3 +1,3 @@ fastapi -fastapi-users[mongodb] +fastapi-users[beanie] uvicorn[standard] diff --git a/examples/mongodb/app/__init__.py b/examples/beanie/app/__init__.py similarity index 100% rename from examples/mongodb/app/__init__.py rename to examples/beanie/app/__init__.py diff --git a/examples/beanie/app/app.py b/examples/beanie/app/app.py new file mode 100644 index 00000000..5550983a --- /dev/null +++ b/examples/beanie/app/app.py @@ -0,0 +1,47 @@ +from beanie import init_beanie +from fastapi import Depends, FastAPI + +from app.db import User, db +from app.schemas import UserCreate, UserRead, UserUpdate +from app.users import auth_backend, current_active_user, fastapi_users + +app = FastAPI() + +app.include_router( + fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] +) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], +) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) + + +@app.get("/authenticated-route") +async def authenticated_route(user: User = Depends(current_active_user)): + return {"message": f"Hello {user.email}!"} + + +@app.on_event("startup") +async def on_startup(): + await init_beanie( + database=db, + document_models=[ + User, + ], + ) diff --git a/examples/beanie/app/db.py b/examples/beanie/app/db.py new file mode 100644 index 00000000..58accc17 --- /dev/null +++ b/examples/beanie/app/db.py @@ -0,0 +1,17 @@ +import motor.motor_asyncio +from beanie import PydanticObjectId +from fastapi_users.db import BeanieBaseUser, BeanieUserDatabase + +DATABASE_URL = "mongodb://localhost:27017" +client = motor.motor_asyncio.AsyncIOMotorClient( + DATABASE_URL, uuidRepresentation="standard" +) +db = client["database_name"] + + +class User(BeanieBaseUser[PydanticObjectId]): + pass + + +async def get_user_db(): + yield BeanieUserDatabase(User) diff --git a/examples/beanie/app/schemas.py b/examples/beanie/app/schemas.py new file mode 100644 index 00000000..f9b2b9a1 --- /dev/null +++ b/examples/beanie/app/schemas.py @@ -0,0 +1,14 @@ +from beanie import PydanticObjectId +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[PydanticObjectId]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass diff --git a/examples/mongodb/app/users.py b/examples/beanie/app/users.py similarity index 60% rename from examples/mongodb/app/users.py rename to examples/beanie/app/users.py index c38242ad..a96772e6 100644 --- a/examples/mongodb/app/users.py +++ b/examples/beanie/app/users.py @@ -1,5 +1,6 @@ from typing import Optional +from beanie import PydanticObjectId from fastapi import Depends, Request from fastapi_users import BaseUserManager, FastAPIUsers from fastapi_users.authentication import ( @@ -7,34 +8,32 @@ from fastapi_users.authentication import ( BearerTransport, JWTStrategy, ) -from fastapi_users.db import MongoDBUserDatabase +from fastapi_users.db import BeanieUserDatabase, ObjectIDIDMixin -from app.db import get_user_db -from app.models import User, UserCreate, UserDB, UserUpdate +from app.db import User, get_user_db SECRET = "SECRET" -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(ObjectIDIDMixin, BaseUserManager[User, PydanticObjectId]): reset_password_token_secret = SECRET verification_token_secret = SECRET - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") -async def get_user_manager(user_db: MongoDBUserDatabase = Depends(get_user_db)): +async def get_user_manager(user_db: BeanieUserDatabase = Depends(get_user_db)): yield UserManager(user_db) @@ -50,13 +49,7 @@ auth_backend = AuthenticationBackend( transport=bearer_transport, get_strategy=get_jwt_strategy, ) -fastapi_users = FastAPIUsers( - get_user_manager, - [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, -) + +fastapi_users = FastAPIUsers[User, PydanticObjectId](get_user_manager, [auth_backend]) current_active_user = fastapi_users.current_user(active=True) diff --git a/examples/beanie/main.py b/examples/beanie/main.py new file mode 100644 index 00000000..eb1edd12 --- /dev/null +++ b/examples/beanie/main.py @@ -0,0 +1,4 @@ +import uvicorn + +if __name__ == "__main__": + uvicorn.run("app.app:app", host="0.0.0.0", log_level="info") diff --git a/examples/beanie/requirements.txt b/examples/beanie/requirements.txt new file mode 100644 index 00000000..ea30f98f --- /dev/null +++ b/examples/beanie/requirements.txt @@ -0,0 +1,3 @@ +fastapi +fastapi-users[beanie] +uvicorn[standard] diff --git a/examples/mongodb-oauth/app/app.py b/examples/mongodb-oauth/app/app.py deleted file mode 100644 index 2d73c0fa..00000000 --- a/examples/mongodb-oauth/app/app.py +++ /dev/null @@ -1,37 +0,0 @@ -from fastapi import Depends, FastAPI - -from app.models import UserDB -from app.users import ( - auth_backend, - current_active_user, - fastapi_users, - google_oauth_client, -) - -app = FastAPI() - -app.include_router( - fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] -) -app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) -app.include_router( - fastapi_users.get_reset_password_router(), - prefix="/auth", - tags=["auth"], -) -app.include_router( - fastapi_users.get_verify_router(), - prefix="/auth", - tags=["auth"], -) -app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) -app.include_router( - fastapi_users.get_oauth_router(google_oauth_client, auth_backend, "SECRET"), - prefix="/auth/google", - tags=["auth"], -) - - -@app.get("/authenticated-route") -async def authenticated_route(user: UserDB = Depends(current_active_user)): - return {"message": f"Hello {user.email}!"} diff --git a/examples/mongodb-oauth/app/db.py b/examples/mongodb-oauth/app/db.py deleted file mode 100644 index e9e3ce77..00000000 --- a/examples/mongodb-oauth/app/db.py +++ /dev/null @@ -1,17 +0,0 @@ -import os - -import motor.motor_asyncio -from fastapi_users.db import MongoDBUserDatabase - -from app.models import UserDB - -DATABASE_URL = os.environ["DATABASE_URL"] -client = motor.motor_asyncio.AsyncIOMotorClient( - DATABASE_URL, uuidRepresentation="standard" -) -db = client["database_name"] -collection = db["users"] - - -async def get_user_db(): - yield MongoDBUserDatabase(UserDB, collection) diff --git a/examples/mongodb-oauth/app/models.py b/examples/mongodb-oauth/app/models.py deleted file mode 100644 index 04050f8a..00000000 --- a/examples/mongodb-oauth/app/models.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi_users import models - - -class User(models.BaseUser, models.BaseOAuthAccountMixin): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass diff --git a/examples/mongodb-oauth/main.py b/examples/mongodb-oauth/main.py deleted file mode 100644 index 1b47137b..00000000 --- a/examples/mongodb-oauth/main.py +++ /dev/null @@ -1,4 +0,0 @@ -import uvicorn - -if __name__ == "__main__": - uvicorn.run("app.app:app", host="0.0.0.0", port=5000, log_level="info") diff --git a/examples/mongodb-oauth/requirements.txt b/examples/mongodb-oauth/requirements.txt deleted file mode 100644 index 8f0ff9d3..00000000 --- a/examples/mongodb-oauth/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -fastapi -fastapi-users[mongodb,oauth] -uvicorn[standard] diff --git a/examples/mongodb/app/app.py b/examples/mongodb/app/app.py deleted file mode 100644 index a2171027..00000000 --- a/examples/mongodb/app/app.py +++ /dev/null @@ -1,27 +0,0 @@ -from fastapi import Depends, FastAPI - -from app.models import UserDB -from app.users import auth_backend, current_active_user, fastapi_users - -app = FastAPI() - -app.include_router( - fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] -) -app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) -app.include_router( - fastapi_users.get_reset_password_router(), - prefix="/auth", - tags=["auth"], -) -app.include_router( - fastapi_users.get_verify_router(), - prefix="/auth", - tags=["auth"], -) -app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) - - -@app.get("/authenticated-route") -async def authenticated_route(user: UserDB = Depends(current_active_user)): - return {"message": f"Hello {user.email}!"} diff --git a/examples/mongodb/app/db.py b/examples/mongodb/app/db.py deleted file mode 100644 index e9e3ce77..00000000 --- a/examples/mongodb/app/db.py +++ /dev/null @@ -1,17 +0,0 @@ -import os - -import motor.motor_asyncio -from fastapi_users.db import MongoDBUserDatabase - -from app.models import UserDB - -DATABASE_URL = os.environ["DATABASE_URL"] -client = motor.motor_asyncio.AsyncIOMotorClient( - DATABASE_URL, uuidRepresentation="standard" -) -db = client["database_name"] -collection = db["users"] - - -async def get_user_db(): - yield MongoDBUserDatabase(UserDB, collection) diff --git a/examples/mongodb/app/models.py b/examples/mongodb/app/models.py deleted file mode 100644 index e2392e25..00000000 --- a/examples/mongodb/app/models.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi_users import models - - -class User(models.BaseUser): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass diff --git a/examples/mongodb/main.py b/examples/mongodb/main.py deleted file mode 100644 index 1b47137b..00000000 --- a/examples/mongodb/main.py +++ /dev/null @@ -1,4 +0,0 @@ -import uvicorn - -if __name__ == "__main__": - uvicorn.run("app.app:app", host="0.0.0.0", port=5000, log_level="info") diff --git a/examples/sqlalchemy-oauth/app/app.py b/examples/sqlalchemy-oauth/app/app.py index d9ffac76..ccd061fc 100644 --- a/examples/sqlalchemy-oauth/app/app.py +++ b/examples/sqlalchemy-oauth/app/app.py @@ -1,7 +1,7 @@ from fastapi import Depends, FastAPI -from app.db import create_db_and_tables -from app.models import UserDB +from app.db import User, create_db_and_tables +from app.schemas import UserCreate, UserRead, UserUpdate from app.users import ( auth_backend, current_active_user, @@ -14,18 +14,26 @@ app = FastAPI() app.include_router( fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] ) -app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) app.include_router( fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"], ) app.include_router( - fastapi_users.get_verify_router(), + fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"], ) -app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) app.include_router( fastapi_users.get_oauth_router(google_oauth_client, auth_backend, "SECRET"), prefix="/auth/google", @@ -34,7 +42,7 @@ app.include_router( @app.get("/authenticated-route") -async def authenticated_route(user: UserDB = Depends(current_active_user)): +async def authenticated_route(user: User = Depends(current_active_user)): return {"message": f"Hello {user.email}!"} diff --git a/examples/sqlalchemy-oauth/app/db.py b/examples/sqlalchemy-oauth/app/db.py index ea56bcb2..87d6fbbd 100644 --- a/examples/sqlalchemy-oauth/app/db.py +++ b/examples/sqlalchemy-oauth/app/db.py @@ -1,29 +1,27 @@ -from typing import AsyncGenerator +from typing import AsyncGenerator, List from fastapi import Depends from fastapi_users.db import ( - SQLAlchemyBaseOAuthAccountTable, - SQLAlchemyBaseUserTable, + SQLAlchemyBaseOAuthAccountTableUUID, + SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase, ) from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import relationship, sessionmaker -from app.models import UserDB - DATABASE_URL = "sqlite+aiosqlite:///./test.db" Base: DeclarativeMeta = declarative_base() -class UserTable(Base, SQLAlchemyBaseUserTable): - oauth_accounts = relationship("OAuthAccountTable") - - -class OAuthAccountTable(SQLAlchemyBaseOAuthAccountTable, Base): +class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): pass +class User(SQLAlchemyBaseUserTableUUID, Base): + oauth_accounts: List[OAuthAccount] = relationship("OAuthAccount", lazy="joined") + + engine = create_async_engine(DATABASE_URL) async_session_maker = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @@ -39,4 +37,4 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(UserDB, session, UserTable, OAuthAccountTable) + yield SQLAlchemyUserDatabase(session, User, OAuthAccount) diff --git a/examples/sqlalchemy-oauth/app/models.py b/examples/sqlalchemy-oauth/app/models.py deleted file mode 100644 index 04050f8a..00000000 --- a/examples/sqlalchemy-oauth/app/models.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi_users import models - - -class User(models.BaseUser, models.BaseOAuthAccountMixin): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass diff --git a/examples/sqlalchemy-oauth/app/schemas.py b/examples/sqlalchemy-oauth/app/schemas.py new file mode 100644 index 00000000..de1169e4 --- /dev/null +++ b/examples/sqlalchemy-oauth/app/schemas.py @@ -0,0 +1,15 @@ +import uuid + +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[uuid.UUID]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass diff --git a/examples/sqlalchemy-oauth/app/users.py b/examples/sqlalchemy-oauth/app/users.py index d615d32b..0b61b2a5 100644 --- a/examples/sqlalchemy-oauth/app/users.py +++ b/examples/sqlalchemy-oauth/app/users.py @@ -1,8 +1,9 @@ import os +import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -11,33 +12,30 @@ from fastapi_users.authentication import ( from fastapi_users.db import SQLAlchemyUserDatabase from httpx_oauth.clients.google import GoogleOAuth2 -from app.db import get_user_db -from app.models import User, UserCreate, UserDB, UserUpdate +from app.db import User, get_user_db SECRET = "SECRET" - google_oauth_client = GoogleOAuth2( os.getenv("GOOGLE_OAUTH_CLIENT_ID", ""), os.getenv("GOOGLE_OAUTH_CLIENT_SECRET", ""), ) -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") @@ -58,13 +56,7 @@ auth_backend = AuthenticationBackend( transport=bearer_transport, get_strategy=get_jwt_strategy, ) -fastapi_users = FastAPIUsers( - get_user_manager, - [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, -) + +fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) current_active_user = fastapi_users.current_user(active=True) diff --git a/examples/sqlalchemy-oauth/requirements.txt b/examples/sqlalchemy-oauth/requirements.txt index 226c8ced..6407e81e 100644 --- a/examples/sqlalchemy-oauth/requirements.txt +++ b/examples/sqlalchemy-oauth/requirements.txt @@ -1,4 +1,4 @@ fastapi -fastapi-users[sqlalchemy2,oauth] +fastapi-users[sqlalchemy] uvicorn[standard] aiosqlite diff --git a/examples/sqlalchemy/app/app.py b/examples/sqlalchemy/app/app.py index c898d816..034089cb 100644 --- a/examples/sqlalchemy/app/app.py +++ b/examples/sqlalchemy/app/app.py @@ -1,7 +1,7 @@ from fastapi import Depends, FastAPI -from app.db import create_db_and_tables -from app.models import UserDB +from app.db import User, create_db_and_tables +from app.schemas import UserCreate, UserRead, UserUpdate from app.users import auth_backend, current_active_user, fastapi_users app = FastAPI() @@ -9,22 +9,30 @@ app = FastAPI() app.include_router( fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] ) -app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) +app.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], +) app.include_router( fastapi_users.get_reset_password_router(), prefix="/auth", tags=["auth"], ) app.include_router( - fastapi_users.get_verify_router(), + fastapi_users.get_verify_router(UserRead), prefix="/auth", tags=["auth"], ) -app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) +app.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], +) @app.get("/authenticated-route") -async def authenticated_route(user: UserDB = Depends(current_active_user)): +async def authenticated_route(user: User = Depends(current_active_user)): return {"message": f"Hello {user.email}!"} diff --git a/examples/sqlalchemy/app/db.py b/examples/sqlalchemy/app/db.py index 58ac1822..05d0f38b 100644 --- a/examples/sqlalchemy/app/db.py +++ b/examples/sqlalchemy/app/db.py @@ -1,18 +1,16 @@ from typing import AsyncGenerator from fastapi import Depends -from fastapi_users.db import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase +from fastapi_users.db import SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from sqlalchemy.orm import sessionmaker -from app.models import UserDB - DATABASE_URL = "sqlite+aiosqlite:///./test.db" Base: DeclarativeMeta = declarative_base() -class UserTable(Base, SQLAlchemyBaseUserTable): +class User(SQLAlchemyBaseUserTableUUID, Base): pass @@ -31,4 +29,4 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(UserDB, session, UserTable) + yield SQLAlchemyUserDatabase(session, User) diff --git a/examples/sqlalchemy/app/models.py b/examples/sqlalchemy/app/models.py deleted file mode 100644 index e2392e25..00000000 --- a/examples/sqlalchemy/app/models.py +++ /dev/null @@ -1,17 +0,0 @@ -from fastapi_users import models - - -class User(models.BaseUser): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(User, models.BaseUserDB): - pass diff --git a/examples/sqlalchemy/app/schemas.py b/examples/sqlalchemy/app/schemas.py new file mode 100644 index 00000000..de1169e4 --- /dev/null +++ b/examples/sqlalchemy/app/schemas.py @@ -0,0 +1,15 @@ +import uuid + +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[uuid.UUID]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass diff --git a/examples/sqlalchemy/app/users.py b/examples/sqlalchemy/app/users.py index 04678ac0..479c49e2 100644 --- a/examples/sqlalchemy/app/users.py +++ b/examples/sqlalchemy/app/users.py @@ -1,7 +1,8 @@ +import uuid from typing import Optional from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers +from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, @@ -9,27 +10,25 @@ from fastapi_users.authentication import ( ) from fastapi_users.db import SQLAlchemyUserDatabase -from app.db import get_user_db -from app.models import User, UserCreate, UserDB, UserUpdate +from app.db import User, get_user_db SECRET = "SECRET" -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): + async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None + self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") @@ -50,13 +49,7 @@ auth_backend = AuthenticationBackend( transport=bearer_transport, get_strategy=get_jwt_strategy, ) -fastapi_users = FastAPIUsers( - get_user_manager, - [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, -) + +fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) current_active_user = fastapi_users.current_user(active=True) diff --git a/examples/sqlalchemy/requirements.txt b/examples/sqlalchemy/requirements.txt index 9b18bbc8..6407e81e 100644 --- a/examples/sqlalchemy/requirements.txt +++ b/examples/sqlalchemy/requirements.txt @@ -1,4 +1,4 @@ fastapi -fastapi-users[sqlalchemy2] +fastapi-users[sqlalchemy] uvicorn[standard] aiosqlite diff --git a/examples/tortoise-oauth/app/__init__.py b/examples/tortoise-oauth/app/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/tortoise-oauth/app/db.py b/examples/tortoise-oauth/app/db.py deleted file mode 100644 index 83b1571f..00000000 --- a/examples/tortoise-oauth/app/db.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastapi_users.db import TortoiseUserDatabase - -from app.models import OAuthAccount, UserDB, UserModel - -DATABASE_URL = "sqlite://./test.db" - - -async def get_user_db(): - yield TortoiseUserDatabase(UserDB, UserModel, OAuthAccount) diff --git a/examples/tortoise-oauth/app/models.py b/examples/tortoise-oauth/app/models.py deleted file mode 100644 index 9d470192..00000000 --- a/examples/tortoise-oauth/app/models.py +++ /dev/null @@ -1,30 +0,0 @@ -from fastapi_users import models -from fastapi_users.db import TortoiseBaseOAuthAccountModel, TortoiseBaseUserModel -from tortoise import fields -from tortoise.contrib.pydantic import PydanticModel - - -class User(models.BaseUser, models.BaseOAuthAccountMixin): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserModel(TortoiseBaseUserModel): - pass - - -class UserDB(User, models.BaseUserDB, PydanticModel): - class Config: - orm_mode = True - orig_model = UserModel - - -class OAuthAccount(TortoiseBaseOAuthAccountModel): - user = fields.ForeignKeyField("models.UserModel", related_name="oauth_accounts") diff --git a/examples/tortoise-oauth/app/users.py b/examples/tortoise-oauth/app/users.py deleted file mode 100644 index 49f672f7..00000000 --- a/examples/tortoise-oauth/app/users.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -from typing import Optional - -from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers -from fastapi_users.authentication import ( - AuthenticationBackend, - BearerTransport, - JWTStrategy, -) -from fastapi_users.db import TortoiseUserDatabase -from httpx_oauth.clients.google import GoogleOAuth2 - -from app.db import get_user_db -from app.models import User, UserCreate, UserDB, UserUpdate - -SECRET = "SECRET" - - -google_oauth_client = GoogleOAuth2( - os.environ["GOOGLE_OAUTH_CLIENT_ID"], - os.environ["GOOGLE_OAUTH_CLIENT_SECRET"], -) - - -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB - reset_password_token_secret = SECRET - verification_token_secret = SECRET - - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): - print(f"User {user.id} has registered.") - - async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None - ): - print(f"User {user.id} has forgot their password. Reset token: {token}") - - async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None - ): - print(f"Verification requested for user {user.id}. Verification token: {token}") - - -async def get_user_manager(user_db: TortoiseUserDatabase = Depends(get_user_db)): - yield UserManager(user_db) - - -bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") - - -def get_jwt_strategy() -> JWTStrategy: - return JWTStrategy(secret=SECRET, lifetime_seconds=3600) - - -auth_backend = AuthenticationBackend( - name="jwt", - transport=bearer_transport, - get_strategy=get_jwt_strategy, -) -fastapi_users = FastAPIUsers( - get_user_manager, - [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, -) - -current_active_user = fastapi_users.current_user(active=True) diff --git a/examples/tortoise-oauth/main.py b/examples/tortoise-oauth/main.py deleted file mode 100644 index 1b47137b..00000000 --- a/examples/tortoise-oauth/main.py +++ /dev/null @@ -1,4 +0,0 @@ -import uvicorn - -if __name__ == "__main__": - uvicorn.run("app.app:app", host="0.0.0.0", port=5000, log_level="info") diff --git a/examples/tortoise-oauth/requirements.txt b/examples/tortoise-oauth/requirements.txt deleted file mode 100644 index d5642ec1..00000000 --- a/examples/tortoise-oauth/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -fastapi -fastapi-users[tortoise-orm,oauth] -uvicorn[standard] diff --git a/examples/tortoise/app/__init__.py b/examples/tortoise/app/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/examples/tortoise/app/app.py b/examples/tortoise/app/app.py deleted file mode 100644 index aca846a2..00000000 --- a/examples/tortoise/app/app.py +++ /dev/null @@ -1,37 +0,0 @@ -from fastapi import Depends, FastAPI -from tortoise.contrib.fastapi import register_tortoise - -from app.db import DATABASE_URL -from app.models import UserDB -from app.users import auth_backend, current_active_user, fastapi_users - -app = FastAPI() - -app.include_router( - fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"] -) -app.include_router(fastapi_users.get_register_router(), prefix="/auth", tags=["auth"]) -app.include_router( - fastapi_users.get_reset_password_router(), - prefix="/auth", - tags=["auth"], -) -app.include_router( - fastapi_users.get_verify_router(), - prefix="/auth", - tags=["auth"], -) -app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) - - -@app.get("/authenticated-route") -async def authenticated_route(user: UserDB = Depends(current_active_user)): - return {"message": f"Hello {user.email}!"} - - -register_tortoise( - app, - db_url=DATABASE_URL, - modules={"models": ["app.models"]}, - generate_schemas=True, -) diff --git a/examples/tortoise/app/db.py b/examples/tortoise/app/db.py deleted file mode 100644 index d6fd9349..00000000 --- a/examples/tortoise/app/db.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastapi_users.db import TortoiseUserDatabase - -from app.models import UserDB, UserModel - -DATABASE_URL = "sqlite://./test.db" - - -async def get_user_db(): - yield TortoiseUserDatabase(UserDB, UserModel) diff --git a/examples/tortoise/app/models.py b/examples/tortoise/app/models.py deleted file mode 100644 index e55ecc42..00000000 --- a/examples/tortoise/app/models.py +++ /dev/null @@ -1,25 +0,0 @@ -from fastapi_users import models -from fastapi_users.db import TortoiseBaseUserModel -from tortoise.contrib.pydantic import PydanticModel - - -class User(models.BaseUser): - pass - - -class UserCreate(models.BaseUserCreate): - pass - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserModel(TortoiseBaseUserModel): - pass - - -class UserDB(User, models.BaseUserDB, PydanticModel): - class Config: - orm_mode = True - orig_model = UserModel diff --git a/examples/tortoise/app/users.py b/examples/tortoise/app/users.py deleted file mode 100644 index 20f4c565..00000000 --- a/examples/tortoise/app/users.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Optional - -from fastapi import Depends, Request -from fastapi_users import BaseUserManager, FastAPIUsers -from fastapi_users.authentication import ( - AuthenticationBackend, - BearerTransport, - JWTStrategy, -) -from fastapi_users.db import TortoiseUserDatabase - -from app.db import get_user_db -from app.models import User, UserCreate, UserDB, UserUpdate - -SECRET = "SECRET" - - -class UserManager(BaseUserManager[UserCreate, UserDB]): - user_db_model = UserDB - reset_password_token_secret = SECRET - verification_token_secret = SECRET - - async def on_after_register(self, user: UserDB, request: Optional[Request] = None): - print(f"User {user.id} has registered.") - - async def on_after_forgot_password( - self, user: UserDB, token: str, request: Optional[Request] = None - ): - print(f"User {user.id} has forgot their password. Reset token: {token}") - - async def on_after_request_verify( - self, user: UserDB, token: str, request: Optional[Request] = None - ): - print(f"Verification requested for user {user.id}. Verification token: {token}") - - -async def get_user_manager(user_db: TortoiseUserDatabase = Depends(get_user_db)): - yield UserManager(user_db) - - -bearer_transport = BearerTransport(tokenUrl="auth/jwt/login") - - -def get_jwt_strategy() -> JWTStrategy: - return JWTStrategy(secret=SECRET, lifetime_seconds=3600) - - -auth_backend = AuthenticationBackend( - name="jwt", - transport=bearer_transport, - get_strategy=get_jwt_strategy, -) -fastapi_users = FastAPIUsers( - get_user_manager, - [auth_backend], - User, - UserCreate, - UserUpdate, - UserDB, -) - -current_active_user = fastapi_users.current_user(active=True) diff --git a/examples/tortoise/main.py b/examples/tortoise/main.py deleted file mode 100644 index 1b47137b..00000000 --- a/examples/tortoise/main.py +++ /dev/null @@ -1,4 +0,0 @@ -import uvicorn - -if __name__ == "__main__": - uvicorn.run("app.app:app", host="0.0.0.0", port=5000, log_level="info") diff --git a/examples/tortoise/requirements.txt b/examples/tortoise/requirements.txt deleted file mode 100644 index 7aec70ac..00000000 --- a/examples/tortoise/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -fastapi -fastapi-users[tortoise-orm] -uvicorn[standard] diff --git a/fastapi_users/__init__.py b/fastapi_users/__init__.py index 9ba4d40c..8583f648 100644 --- a/fastapi_users/__init__.py +++ b/fastapi_users/__init__.py @@ -2,16 +2,22 @@ __version__ = "9.3.2" -from fastapi_users import models # noqa: F401 +from fastapi_users import models, schemas # noqa: F401 from fastapi_users.fastapi_users import FastAPIUsers # noqa: F401 from fastapi_users.manager import ( # noqa: F401 BaseUserManager, + IntegerIDMixin, + InvalidID, InvalidPasswordException, + UUIDIDMixin, ) __all__ = [ - "models", + "schemas", "FastAPIUsers", "BaseUserManager", "InvalidPasswordException", + "InvalidID", + "UUIDIDMixin", + "IntegerIDMixin", ] diff --git a/fastapi_users/authentication/authenticator.py b/fastapi_users/authentication/authenticator.py index 4d876713..7fab4b78 100644 --- a/fastapi_users/authentication/authenticator.py +++ b/fastapi_users/authentication/authenticator.py @@ -51,7 +51,7 @@ class Authenticator: def __init__( self, backends: Sequence[AuthenticationBackend], - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP, models.ID], ): self.backends = backends self.get_user_manager = get_user_manager @@ -148,14 +148,14 @@ class Authenticator: async def _authenticate( self, *args, - user_manager: BaseUserManager[models.UC, models.UD], + user_manager: BaseUserManager[models.UP, models.ID], optional: bool = False, active: bool = False, verified: bool = False, superuser: bool = False, **kwargs, - ) -> Tuple[Optional[models.UD], Optional[str]]: - user: Optional[models.UD] = None + ) -> Tuple[Optional[models.UP], Optional[str]]: + user: Optional[models.UP] = None token: Optional[str] = None enabled_backends: Sequence[AuthenticationBackend] = kwargs.get( "enabled_backends", self.backends @@ -163,7 +163,7 @@ class Authenticator: for backend in self.backends: if backend in enabled_backends: token = kwargs[name_to_variable_name(backend.name)] - strategy: Strategy[models.UC, models.UD] = kwargs[ + strategy: Strategy[models.UP, models.ID] = kwargs[ name_to_strategy_variable_name(backend.name) ] if token is not None: diff --git a/fastapi_users/authentication/backend.py b/fastapi_users/authentication/backend.py index 95f44345..09861210 100644 --- a/fastapi_users/authentication/backend.py +++ b/fastapi_users/authentication/backend.py @@ -14,7 +14,7 @@ from fastapi_users.authentication.transport import ( from fastapi_users.types import DependencyCallable -class AuthenticationBackend(Generic[models.UC, models.UD]): +class AuthenticationBackend(Generic[models.UP]): """ Combination of an authentication transport and strategy. @@ -33,7 +33,7 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): self, name: str, transport: Transport, - get_strategy: DependencyCallable[Strategy[models.UC, models.UD]], + get_strategy: DependencyCallable[Strategy[models.UP, models.ID]], ): self.name = name self.transport = transport @@ -41,8 +41,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): async def login( self, - strategy: Strategy[models.UC, models.UD], - user: models.UD, + strategy: Strategy[models.UP, models.ID], + user: models.UP, response: Response, ) -> Any: token = await strategy.write_token(user) @@ -50,8 +50,8 @@ class AuthenticationBackend(Generic[models.UC, models.UD]): async def logout( self, - strategy: Strategy[models.UC, models.UD], - user: models.UD, + strategy: Strategy[models.UP, models.ID], + user: models.UP, token: str, response: Response, ) -> Any: diff --git a/fastapi_users/authentication/strategy/__init__.py b/fastapi_users/authentication/strategy/__init__.py index 51796fdb..ba7526f0 100644 --- a/fastapi_users/authentication/strategy/__init__.py +++ b/fastapi_users/authentication/strategy/__init__.py @@ -3,9 +3,9 @@ from fastapi_users.authentication.strategy.base import ( StrategyDestroyNotSupportedError, ) from fastapi_users.authentication.strategy.db import ( - A, + AP, AccessTokenDatabase, - BaseAccessToken, + AccessTokenProtocol, DatabaseStrategy, ) from fastapi_users.authentication.strategy.jwt import JWTStrategy @@ -16,9 +16,9 @@ except ImportError: # pragma: no cover pass __all__ = [ - "A", + "AP", "AccessTokenDatabase", - "BaseAccessToken", + "AccessTokenProtocol", "DatabaseStrategy", "JWTStrategy", "Strategy", diff --git a/fastapi_users/authentication/strategy/base.py b/fastapi_users/authentication/strategy/base.py index 367b5204..ce60db13 100644 --- a/fastapi_users/authentication/strategy/base.py +++ b/fastapi_users/authentication/strategy/base.py @@ -14,14 +14,14 @@ class StrategyDestroyNotSupportedError(Exception): pass -class Strategy(Protocol, Generic[models.UC, models.UD]): +class Strategy(Protocol, Generic[models.UP, models.ID]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: ... # pragma: no cover - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: ... # pragma: no cover - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: ... # pragma: no cover diff --git a/fastapi_users/authentication/strategy/db/__init__.py b/fastapi_users/authentication/strategy/db/__init__.py index d0b9dc3a..0f55616d 100644 --- a/fastapi_users/authentication/strategy/db/__init__.py +++ b/fastapi_users/authentication/strategy/db/__init__.py @@ -1,5 +1,5 @@ from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase -from fastapi_users.authentication.strategy.db.models import A, BaseAccessToken +from fastapi_users.authentication.strategy.db.models import AP, AccessTokenProtocol from fastapi_users.authentication.strategy.db.strategy import DatabaseStrategy -__all__ = ["A", "AccessTokenDatabase", "BaseAccessToken", "DatabaseStrategy"] +__all__ = ["AP", "AccessTokenDatabase", "AccessTokenProtocol", "DatabaseStrategy"] diff --git a/fastapi_users/authentication/strategy/db/adapter.py b/fastapi_users/authentication/strategy/db/adapter.py index eb06984d..18fac269 100644 --- a/fastapi_users/authentication/strategy/db/adapter.py +++ b/fastapi_users/authentication/strategy/db/adapter.py @@ -1,38 +1,32 @@ import sys from datetime import datetime -from typing import Generic, Optional, Type +from typing import Any, Dict, Generic, Optional if sys.version_info < (3, 8): from typing_extensions import Protocol # pragma: no cover else: from typing import Protocol # pragma: no cover -from fastapi_users.authentication.strategy.db.models import A +from fastapi_users.authentication.strategy.db.models import AP -class AccessTokenDatabase(Protocol, Generic[A]): - """ - Protocol for retrieving, creating and updating access tokens from a database. - - :param access_token_model: Pydantic model of an access token. - """ - - access_token_model: Type[A] +class AccessTokenDatabase(Protocol, Generic[AP]): + """Protocol for retrieving, creating and updating access tokens from a database.""" async def get_by_token( self, token: str, max_age: Optional[datetime] = None - ) -> Optional[A]: + ) -> Optional[AP]: """Get a single access token by token.""" ... # pragma: no cover - async def create(self, access_token: A) -> A: + async def create(self, create_dict: Dict[str, Any]) -> AP: """Create an access token.""" ... # pragma: no cover - async def update(self, access_token: A) -> A: + async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: """Update an access token.""" ... # pragma: no cover - async def delete(self, access_token: A) -> None: + async def delete(self, access_token: AP) -> None: """Delete an access token.""" ... # pragma: no cover diff --git a/fastapi_users/authentication/strategy/db/models.py b/fastapi_users/authentication/strategy/db/models.py index 94b8864a..6c3b58be 100644 --- a/fastapi_users/authentication/strategy/db/models.py +++ b/fastapi_users/authentication/strategy/db/models.py @@ -1,22 +1,24 @@ -from datetime import datetime, timezone +import sys +from datetime import datetime from typing import TypeVar -from pydantic import UUID4, BaseModel, Field +if sys.version_info < (3, 8): + from typing_extensions import Protocol # pragma: no cover +else: + from typing import Protocol # pragma: no cover + +from fastapi_users import models -def now_utc(): - return datetime.now(timezone.utc) - - -class BaseAccessToken(BaseModel): - """Base access token model.""" +class AccessTokenProtocol(Protocol[models.ID]): + """Access token protocol that ORM model should follow.""" token: str - user_id: UUID4 - created_at: datetime = Field(default_factory=now_utc) + user_id: models.ID + created_at: datetime - class Config: - orm_mode = True + def __init__(self, *args, **kwargs) -> None: + ... # pragma: no cover -A = TypeVar("A", bound=BaseAccessToken) +AP = TypeVar("AP", bound=AccessTokenProtocol) diff --git a/fastapi_users/authentication/strategy/db/strategy.py b/fastapi_users/authentication/strategy/db/strategy.py index d7aef0e9..c9e40ebc 100644 --- a/fastapi_users/authentication/strategy/db/strategy.py +++ b/fastapi_users/authentication/strategy/db/strategy.py @@ -1,24 +1,26 @@ import secrets from datetime import datetime, timedelta, timezone -from typing import Generic, Optional +from typing import Any, Dict, Generic, Optional from fastapi_users import models from fastapi_users.authentication.strategy.base import Strategy from fastapi_users.authentication.strategy.db.adapter import AccessTokenDatabase -from fastapi_users.authentication.strategy.db.models import A -from fastapi_users.manager import BaseUserManager, UserNotExists +from fastapi_users.authentication.strategy.db.models import AP +from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists -class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): +class DatabaseStrategy( + Strategy[models.UP, models.ID], Generic[models.UP, models.ID, AP] +): def __init__( - self, database: AccessTokenDatabase[A], lifetime_seconds: Optional[int] = None + self, database: AccessTokenDatabase[AP], lifetime_seconds: Optional[int] = None ): self.database = database self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: if token is None: return None @@ -33,21 +35,21 @@ class DatabaseStrategy(Strategy, Generic[models.UC, models.UD, A]): return None try: - user_id = access_token.user_id - return await user_manager.get(user_id) - except UserNotExists: + parsed_id = user_manager.parse_id(access_token.user_id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID): return None - async def write_token(self, user: models.UD) -> str: - access_token = self._create_access_token(user) - await self.database.create(access_token) + async def write_token(self, user: models.UP) -> str: + access_token_dict = self._create_access_token_dict(user) + access_token = await self.database.create(access_token_dict) return access_token.token - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: access_token = await self.database.get_by_token(token) if access_token is not None: await self.database.delete(access_token) - def _create_access_token(self, user: models.UD) -> A: + def _create_access_token_dict(self, user: models.UP) -> Dict[str, Any]: token = secrets.token_urlsafe() - return self.database.access_token_model(token=token, user_id=user.id) + return {"token": token, "user_id": user.id} diff --git a/fastapi_users/authentication/strategy/jwt.py b/fastapi_users/authentication/strategy/jwt.py index 8d761ad7..4bad548d 100644 --- a/fastapi_users/authentication/strategy/jwt.py +++ b/fastapi_users/authentication/strategy/jwt.py @@ -1,7 +1,6 @@ from typing import Generic, List, Optional import jwt -from pydantic import UUID4 from fastapi_users import models from fastapi_users.authentication.strategy.base import ( @@ -9,10 +8,10 @@ from fastapi_users.authentication.strategy.base import ( StrategyDestroyNotSupportedError, ) from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt -from fastapi_users.manager import BaseUserManager, UserNotExists +from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists -class JWTStrategy(Strategy, Generic[models.UC, models.UD]): +class JWTStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]): def __init__( self, secret: SecretType, @@ -36,8 +35,8 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]): return self.public_key or self.secret async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: if token is None: return None @@ -52,20 +51,18 @@ class JWTStrategy(Strategy, Generic[models.UC, models.UD]): return None try: - user_uiid = UUID4(user_id) - return await user_manager.get(user_uiid) - except ValueError: - return None - except UserNotExists: + parsed_id = user_manager.parse_id(user_id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID): return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: data = {"user_id": str(user.id), "aud": self.token_audience} return generate_jwt( data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm ) - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: raise StrategyDestroyNotSupportedError( "A JWT can't be invalidated: it's valid until it expires." ) diff --git a/fastapi_users/authentication/strategy/redis.py b/fastapi_users/authentication/strategy/redis.py index 3f32e22e..616ac7e5 100644 --- a/fastapi_users/authentication/strategy/redis.py +++ b/fastapi_users/authentication/strategy/redis.py @@ -2,21 +2,20 @@ import secrets from typing import Generic, Optional import aioredis -from pydantic import UUID4 from fastapi_users import models from fastapi_users.authentication.strategy.base import Strategy -from fastapi_users.manager import BaseUserManager, UserNotExists +from fastapi_users.manager import BaseUserManager, InvalidID, UserNotExists -class RedisStrategy(Strategy, Generic[models.UC, models.UD]): +class RedisStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID]): def __init__(self, redis: aioredis.Redis, lifetime_seconds: Optional[int] = None): self.redis = redis self.lifetime_seconds = lifetime_seconds async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: if token is None: return None @@ -25,17 +24,15 @@ class RedisStrategy(Strategy, Generic[models.UC, models.UD]): return None try: - user_uiid = UUID4(user_id) - return await user_manager.get(user_uiid) - except ValueError: - return None - except UserNotExists: + parsed_id = user_manager.parse_id(user_id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID): return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: token = secrets.token_urlsafe() await self.redis.set(token, str(user.id), ex=self.lifetime_seconds) return token - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: await self.redis.delete(token) diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 1856e059..5e3382f7 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -1,52 +1,36 @@ from fastapi_users.db.base import BaseUserDatabase, UserDatabaseDependency -__all__ = [ - "BaseUserDatabase", - "UserDatabaseDependency", -] +__all__ = ["BaseUserDatabase", "UserDatabaseDependency"] -try: # pragma: no cover - from fastapi_users_db_mongodb import MongoDBUserDatabase # noqa: F401 - - __all__.append("MongoDBUserDatabase") -except ImportError: # pragma: no cover - pass try: # pragma: no cover from fastapi_users_db_sqlalchemy import ( # noqa: F401 SQLAlchemyBaseOAuthAccountTable, + SQLAlchemyBaseOAuthAccountTableUUID, SQLAlchemyBaseUserTable, + SQLAlchemyBaseUserTableUUID, SQLAlchemyUserDatabase, ) - __all__.append("SQLAlchemyBaseOAuthAccountTable") __all__.append("SQLAlchemyBaseUserTable") + __all__.append("SQLAlchemyBaseUserTableUUID") + __all__.append("SQLAlchemyBaseOAuthAccountTable") + __all__.append("SQLAlchemyBaseOAuthAccountTableUUID") __all__.append("SQLAlchemyUserDatabase") except ImportError: # pragma: no cover pass try: # pragma: no cover - from fastapi_users_db_tortoise import ( # noqa: F401 - TortoiseBaseOAuthAccountModel, - TortoiseBaseUserModel, - TortoiseUserDatabase, + from fastapi_users_db_beanie import ( # noqa: F401 + BaseOAuthAccount, + BeanieBaseUser, + BeanieUserDatabase, + ObjectIDIDMixin, ) - __all__.append("TortoiseBaseOAuthAccountModel") - __all__.append("TortoiseBaseUserModel") - __all__.append("TortoiseUserDatabase") -except ImportError: # pragma: no cover - pass - -try: # pragma: no cover - from fastapi_users_db_ormar import ( # noqa: F401 - OrmarBaseOAuthAccountModel, - OrmarBaseUserModel, - OrmarUserDatabase, - ) - - __all__.append("OrmarBaseOAuthAccountModel") - __all__.append("OrmarBaseUserModel") - __all__.append("OrmarUserDatabase") + __all__.append("BeanieBaseUser") + __all__.append("BaseOAuthAccount") + __all__.append("BeanieUserDatabase") + __all__.append("ObjectIDIDMixin") except ImportError: # pragma: no cover pass diff --git a/fastapi_users/db/base.py b/fastapi_users/db/base.py index b6e92006..5c0b675c 100644 --- a/fastapi_users/db/base.py +++ b/fastapi_users/db/base.py @@ -1,46 +1,50 @@ -from typing import Generic, Optional, Type +from typing import Any, Dict, Generic, Optional -from pydantic import UUID4 - -from fastapi_users.models import UD +from fastapi_users.models import ID, OAP, UOAP, UP from fastapi_users.types import DependencyCallable -class BaseUserDatabase(Generic[UD]): - """ - Base adapter for retrieving, creating and updating users from a database. +class BaseUserDatabase(Generic[UP, ID]): + """Base adapter for retrieving, creating and updating users from a database.""" - :param user_db_model: Pydantic model of a DB representation of a user. - """ - - user_db_model: Type[UD] - - def __init__(self, user_db_model: Type[UD]): - self.user_db_model = user_db_model - - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: ID) -> Optional[UP]: """Get a single user by id.""" raise NotImplementedError() - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" raise NotImplementedError() - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: """Get a single user by OAuth account id.""" raise NotImplementedError() - async def create(self, user: UD) -> UD: + async def create(self, create_dict: Dict[str, Any]) -> UP: """Create a user.""" raise NotImplementedError() - async def update(self, user: UD) -> UD: + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: """Update a user.""" raise NotImplementedError() - async def delete(self, user: UD) -> None: + async def delete(self, user: UP) -> None: """Delete a user.""" raise NotImplementedError() + async def add_oauth_account( + self: "BaseUserDatabase[UOAP, ID]", user: UOAP, create_dict: Dict[str, Any] + ) -> UOAP: + """Create an OAuth account and add it to the user.""" + raise NotImplementedError() -UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UD]] + async def update_oauth_account( + self: "BaseUserDatabase[UOAP, ID]", + user: UOAP, + oauth_account: OAP, + update_dict: Dict[str, Any], + ) -> UOAP: + """Update an OAuth account on a user.""" + raise NotImplementedError() + + +UserDatabaseDependency = DependencyCallable[BaseUserDatabase[UP, ID]] diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 674b5062..a0d45243 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -2,7 +2,7 @@ from typing import Generic, Sequence, Type from fastapi import APIRouter -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, Authenticator from fastapi_users.jwt import SecretType from fastapi_users.manager import UserManagerDependency @@ -22,58 +22,49 @@ except ModuleNotFoundError: # pragma: no cover BaseOAuth2 = Type # type: ignore -class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): +class FastAPIUsers(Generic[models.UP, models.ID]): """ Main object that ties together the component for users authentication. :param get_user_manager: Dependency callable getter to inject the user manager class instance. :param auth_backends: List of authentication backends. - :param user_model: Pydantic model of a user. - :param user_create_model: Pydantic model for creating a user. - :param user_update_model: Pydantic model for updating a user. - :param user_db_model: Pydantic model of a DB representation of a user. :attribute current_user: Dependency callable getter to inject authenticated user with a specific set of parameters. """ authenticator: Authenticator - _user_model: Type[models.U] - _user_create_model: Type[models.UC] - _user_update_model: Type[models.UU] - _user_db_model: Type[models.UD] def __init__( self, - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP, models.ID], auth_backends: Sequence[AuthenticationBackend], - user_model: Type[models.U], - user_create_model: Type[models.UC], - user_update_model: Type[models.UU], - user_db_model: Type[models.UD], ): self.authenticator = Authenticator(auth_backends, get_user_manager) - - self._user_model = user_model - self._user_db_model = user_db_model - self._user_create_model = user_create_model - self._user_update_model = user_update_model - self.get_user_manager = get_user_manager self.current_user = self.authenticator.current_user - def get_register_router(self) -> APIRouter: - """Return a router with a register route.""" + def get_register_router( + self, user_schema: Type[schemas.U], user_create_schema: Type[schemas.UC] + ) -> APIRouter: + """ + Return a router with a register route. + + :param user_schema: Pydantic schema of a public user. + :param user_create_schema: Pydantic schema for creating a user. + """ return get_register_router( - self.get_user_manager, - self._user_model, - self._user_create_model, + self.get_user_manager, user_schema, user_create_schema ) - def get_verify_router(self) -> APIRouter: - """Return a router with e-mail verification routes.""" - return get_verify_router(self.get_user_manager, self._user_model) + def get_verify_router(self, user_schema: Type[schemas.U]) -> APIRouter: + """ + Return a router with e-mail verification routes. + + :param user_schema: Pydantic schema of a public user. + """ + return get_verify_router(self.get_user_manager, user_schema) def get_reset_password_router(self) -> APIRouter: """Return a reset password process router.""" @@ -122,19 +113,22 @@ class FastAPIUsers(Generic[models.U, models.UC, models.UU, models.UD]): def get_users_router( self, + user_schema: Type[schemas.U], + user_update_schema: Type[schemas.UU], requires_verification: bool = False, ) -> APIRouter: """ Return a router with routes to manage users. + :param user_schema: Pydantic schema of a public user. + :param user_update_schema: Pydantic schema for updating a user. :param requires_verification: Whether the endpoints require the users to be verified or not. """ return get_users_router( self.get_user_manager, - self._user_model, - self._user_update_model, - self._user_db_model, + user_schema, + user_update_schema, self.authenticator, requires_verification, ) diff --git a/fastapi_users/manager.py b/fastapi_users/manager.py index 3693ece3..a4d1c1cd 100644 --- a/fastapi_users/manager.py +++ b/fastapi_users/manager.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, Generic, Optional, Type, Union +import uuid +from typing import Any, Dict, Generic, Optional, Union import jwt from fastapi import Request from fastapi.security import OAuth2PasswordRequestForm -from pydantic import UUID4 -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt from fastapi_users.password import PasswordHelper, PasswordHelperProtocol @@ -19,6 +19,10 @@ class FastAPIUsersException(Exception): pass +class InvalidID(FastAPIUsersException): + pass + + class UserAlreadyExists(FastAPIUsersException): pass @@ -48,11 +52,10 @@ class InvalidPasswordException(FastAPIUsersException): self.reason = reason -class BaseUserManager(Generic[models.UC, models.UD]): +class BaseUserManager(Generic[models.UP, models.ID]): """ User management logic. - :attribute user_db_model: Pydantic model of a DB representation of a user. :attribute reset_password_token_secret: Secret to encode reset password token. :attribute reset_password_token_lifetime_seconds: Lifetime of reset password token. :attribute reset_password_token_audience: JWT audience of reset password token. @@ -63,7 +66,6 @@ class BaseUserManager(Generic[models.UC, models.UD]): :param user_db: Database adapter instance. """ - user_db_model: Type[models.UD] reset_password_token_secret: SecretType reset_password_token_lifetime_seconds: int = 3600 reset_password_token_audience: str = RESET_PASSWORD_TOKEN_AUDIENCE @@ -72,12 +74,12 @@ class BaseUserManager(Generic[models.UC, models.UD]): verification_token_lifetime_seconds: int = 3600 verification_token_audience: str = VERIFY_USER_TOKEN_AUDIENCE - user_db: BaseUserDatabase[models.UD] + user_db: BaseUserDatabase[models.UP, models.ID] password_helper: PasswordHelperProtocol def __init__( self, - user_db: BaseUserDatabase[models.UD], + user_db: BaseUserDatabase[models.UP, models.ID], password_helper: Optional[PasswordHelperProtocol] = None, ): self.user_db = user_db @@ -86,7 +88,17 @@ class BaseUserManager(Generic[models.UC, models.UD]): else: self.password_helper = password_helper # pragma: no cover - async def get(self, id: UUID4) -> models.UD: + def parse_id(self, value: Any) -> models.ID: + """ + Parse a value into a correct models.ID instance. + + :param value: The value to parse. + :raises InvalidID: The models.ID value is invalid. + :return: An models.ID object. + """ + raise NotImplementedError() # pragma: no cover + + async def get(self, id: models.ID) -> models.UP: """ Get a user by id. @@ -101,7 +113,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user - async def get_by_email(self, user_email: str) -> models.UD: + async def get_by_email(self, user_email: str) -> models.UP: """ Get a user by e-mail. @@ -116,7 +128,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user - async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UD: + async def get_by_oauth_account(self, oauth: str, account_id: str) -> models.UP: """ Get a user by OAuth account. @@ -133,14 +145,17 @@ class BaseUserManager(Generic[models.UC, models.UD]): return user async def create( - self, user: models.UC, safe: bool = False, request: Optional[Request] = None - ) -> models.UD: + self, + user_create: schemas.UC, + safe: bool = False, + request: Optional[Request] = None, + ) -> models.UP: """ Create a user in database. Triggers the on_after_register handler on success. - :param user: The UserCreate model to create. + :param user_create: The UserCreate model to create. :param safe: If True, sensitive values like is_superuser or is_verified will be ignored during the creation, defaults to False. :param request: Optional FastAPI request that @@ -148,27 +163,36 @@ class BaseUserManager(Generic[models.UC, models.UD]): :raises UserAlreadyExists: A user already exists with the same e-mail. :return: A new user. """ - await self.validate_password(user.password, user) + await self.validate_password(user_create.password, user_create) - existing_user = await self.user_db.get_by_email(user.email) + existing_user = await self.user_db.get_by_email(user_create.email) if existing_user is not None: raise UserAlreadyExists() - hashed_password = self.password_helper.hash(user.password) user_dict = ( - user.create_update_dict() if safe else user.create_update_dict_superuser() + user_create.create_update_dict() + if safe + else user_create.create_update_dict_superuser() ) - db_user = self.user_db_model(**user_dict, hashed_password=hashed_password) + password = user_dict.pop("password") + user_dict["hashed_password"] = self.password_helper.hash(password) - created_user = await self.user_db.create(db_user) + created_user = await self.user_db.create(user_dict) await self.on_after_register(created_user, request) return created_user async def oauth_callback( - self, oauth_account: models.BaseOAuthAccount, request: Optional[Request] = None - ) -> models.UD: + self: "BaseUserManager[models.UOAP, models.ID]", + oauth_name: str, + access_token: str, + account_id: str, + account_email: str, + expires_at: Optional[int] = None, + refresh_token: Optional[str] = None, + request: Optional[Request] = None, + ) -> models.UOAP: """ Handle the callback after a successful OAuth authentication. @@ -180,50 +204,58 @@ class BaseUserManager(Generic[models.UC, models.UD]): If the user does not exist, it is created and the on_after_register handler is triggered. - :param oauth_account: The new OAuth account to create. + :param oauth_name: Name of the OAuth client. + :param access_token: Valid access token for the service provider. + :param account_id: models.ID of the user on the service provider. + :param account_email: E-mail of the user on the service provider. + :param expires_at: Optional timestamp at which the access token expires. + :param refresh_token: Optional refresh token to get a + fresh access token from the service provider. :param request: Optional FastAPI request that triggered the operation, defaults to None :return: A user. """ + oauth_account_dict = { + "oauth_name": oauth_name, + "access_token": access_token, + "account_id": account_id, + "account_email": account_email, + "expires_at": expires_at, + "refresh_token": refresh_token, + } + try: - user = await self.get_by_oauth_account( - oauth_account.oauth_name, oauth_account.account_id - ) + user = await self.get_by_oauth_account(oauth_name, account_id) except UserNotExists: try: # Link account - user = await self.get_by_email(oauth_account.account_email) - user.oauth_accounts.append(oauth_account) # type: ignore - await self.user_db.update(user) + user = await self.get_by_email(account_email) + user = await self.user_db.add_oauth_account(user, oauth_account_dict) except UserNotExists: # Create account password = self.password_helper.generate() - user = self.user_db_model( - email=oauth_account.account_email, - hashed_password=self.password_helper.hash(password), - oauth_accounts=[oauth_account], - ) - await self.user_db.create(user) + user_dict = { + "email": account_email, + "hashed_password": self.password_helper.hash(password), + } + user = await self.user_db.create(user_dict) + user = await self.user_db.add_oauth_account(user, oauth_account_dict) await self.on_after_register(user, request) else: # Update oauth - updated_oauth_accounts = [] - for existing_oauth_account in user.oauth_accounts: # type: ignore + for existing_oauth_account in user.oauth_accounts: if ( - existing_oauth_account.account_id == oauth_account.account_id - and existing_oauth_account.oauth_name == oauth_account.oauth_name + existing_oauth_account.account_id == account_id + and existing_oauth_account.oauth_name == oauth_name ): - oauth_account.id = existing_oauth_account.id - updated_oauth_accounts.append(oauth_account) - else: - updated_oauth_accounts.append(existing_oauth_account) - user.oauth_accounts = updated_oauth_accounts # type: ignore - await self.user_db.update(user) + user = await self.user_db.update_oauth_account( + user, existing_oauth_account, oauth_account_dict + ) return user async def request_verify( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Start a verification request. @@ -253,7 +285,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): ) await self.on_after_request_verify(user, token, request) - async def verify(self, token: str, request: Optional[Request] = None) -> models.UD: + async def verify(self, token: str, request: Optional[Request] = None) -> models.UP: """ Validate a verification request. @@ -289,11 +321,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): raise InvalidVerifyToken() try: - user_uuid = UUID4(user_id) - except ValueError: + parsed_id = self.parse_id(user_id) + except InvalidID: raise InvalidVerifyToken() - if user_uuid != user.id: + if parsed_id != user.id: raise InvalidVerifyToken() if user.is_verified: @@ -306,7 +338,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return verified_user async def forgot_password( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Start a forgot password request. @@ -334,7 +366,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def reset_password( self, token: str, password: str, request: Optional[Request] = None - ) -> models.UD: + ) -> models.UP: """ Reset the password of a user. @@ -364,11 +396,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): raise InvalidResetPasswordToken() try: - user_uuid = UUID4(user_id) - except ValueError: + parsed_id = self.parse_id(user_id) + except InvalidID: raise InvalidResetPasswordToken() - user = await self.get(user_uuid) + user = await self.get(parsed_id) if not user.is_active: raise UserInactive() @@ -381,11 +413,11 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def update( self, - user_update: models.UU, - user: models.UD, + user_update: schemas.UU, + user: models.UP, safe: bool = False, request: Optional[Request] = None, - ) -> models.UD: + ) -> models.UP: """ Update a user. @@ -408,7 +440,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): await self.on_after_update(updated_user, updated_user_data, request) return updated_user - async def delete(self, user: models.UD) -> None: + async def delete(self, user: models.UP) -> None: """ Delete a user. @@ -417,7 +449,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): await self.user_db.delete(user) async def validate_password( - self, password: str, user: Union[models.UC, models.UD] + self, password: str, user: Union[schemas.UC, models.UP] ) -> None: """ Validate a password. @@ -432,7 +464,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_register( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Perform logic after successful user registration. @@ -447,7 +479,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def on_after_update( self, - user: models.UD, + user: models.UP, update_dict: Dict[str, Any], request: Optional[Request] = None, ) -> None: @@ -464,7 +496,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_request_verify( - self, user: models.UD, token: str, request: Optional[Request] = None + self, user: models.UP, token: str, request: Optional[Request] = None ) -> None: """ Perform logic after successful verification request. @@ -479,7 +511,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_verify( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Perform logic after successful user verification. @@ -493,7 +525,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_forgot_password( - self, user: models.UD, token: str, request: Optional[Request] = None + self, user: models.UP, token: str, request: Optional[Request] = None ) -> None: """ Perform logic after successful forgot password request. @@ -508,7 +540,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): return # pragma: no cover async def on_after_reset_password( - self, user: models.UD, request: Optional[Request] = None + self, user: models.UP, request: Optional[Request] = None ) -> None: """ Perform logic after successful password reset. @@ -523,7 +555,7 @@ class BaseUserManager(Generic[models.UC, models.UD]): async def authenticate( self, credentials: OAuth2PasswordRequestForm - ) -> Optional[models.UD]: + ) -> Optional[models.UP]: """ Authenticate and return a user following an email and a password. @@ -546,27 +578,48 @@ class BaseUserManager(Generic[models.UC, models.UD]): return None # Update password hash to a more robust one if needed if updated_password_hash is not None: - user.hashed_password = updated_password_hash - await self.user_db.update(user) + await self.user_db.update(user, {"hashed_password": updated_password_hash}) return user - async def _update(self, user: models.UD, update_dict: Dict[str, Any]) -> models.UD: + async def _update(self, user: models.UP, update_dict: Dict[str, Any]) -> models.UP: + validated_update_dict = {} for field, value in update_dict.items(): if field == "email" and value != user.email: try: await self.get_by_email(value) raise UserAlreadyExists() except UserNotExists: - user.email = value - user.is_verified = False + validated_update_dict["email"] = value + validated_update_dict["is_verified"] = False elif field == "password": await self.validate_password(value, user) - hashed_password = self.password_helper.hash(value) - user.hashed_password = hashed_password + validated_update_dict["hashed_password"] = self.password_helper.hash( + value + ) else: - setattr(user, field, value) - return await self.user_db.update(user) + validated_update_dict[field] = value + return await self.user_db.update(user, validated_update_dict) -UserManagerDependency = DependencyCallable[BaseUserManager[models.UC, models.UD]] +class UUIDIDMixin: + def parse_id(self, value: Any) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + try: + return uuid.UUID(value) + except ValueError as e: + raise InvalidID() from e + + +class IntegerIDMixin: + def parse_id(self, value: Any) -> int: + if isinstance(value, float): + raise InvalidID() + try: + return int(value) + except ValueError as e: + raise InvalidID() from e + + +UserManagerDependency = DependencyCallable[BaseUserManager[models.UP, models.ID]] diff --git a/fastapi_users/models.py b/fastapi_users/models.py index 6529a3f5..f6bbe1ab 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -1,81 +1,51 @@ -import uuid -from typing import List, Optional, TypeVar +import sys +from typing import Generic, List, Optional, TypeVar -from pydantic import UUID4, BaseModel, EmailStr, Field +if sys.version_info < (3, 8): + from typing_extensions import Protocol # pragma: no cover +else: + from typing import Protocol # pragma: no cover + +ID = TypeVar("ID") -class CreateUpdateDictModel(BaseModel): - def create_update_dict(self): - return self.dict( - exclude_unset=True, - exclude={ - "id", - "is_superuser", - "is_active", - "is_verified", - "oauth_accounts", - }, - ) +class UserProtocol(Protocol[ID]): + """User protocol that ORM model should follow.""" - def create_update_dict_superuser(self): - return self.dict(exclude_unset=True, exclude={"id"}) - - -class BaseUser(CreateUpdateDictModel): - """Base User model.""" - - id: UUID4 = Field(default_factory=uuid.uuid4) - email: EmailStr - is_active: bool = True - is_superuser: bool = False - is_verified: bool = False - - -class BaseUserCreate(CreateUpdateDictModel): - email: EmailStr - password: str - is_active: Optional[bool] = True - is_superuser: Optional[bool] = False - is_verified: Optional[bool] = False - - -class BaseUserUpdate(CreateUpdateDictModel): - password: Optional[str] - email: Optional[EmailStr] - is_active: Optional[bool] - is_superuser: Optional[bool] - is_verified: Optional[bool] - - -class BaseUserDB(BaseUser): + id: ID + email: str hashed_password: str + is_active: bool + is_superuser: bool + is_verified: bool - class Config: - orm_mode = True + def __init__(self, *args, **kwargs) -> None: + ... # pragma: no cover -U = TypeVar("U", bound=BaseUser) -UC = TypeVar("UC", bound=BaseUserCreate) -UU = TypeVar("UU", bound=BaseUserUpdate) -UD = TypeVar("UD", bound=BaseUserDB) +class OAuthAccountProtocol(Protocol[ID]): + """OAuth account protocol that ORM model should follow.""" - -class BaseOAuthAccount(BaseModel): - """Base OAuth account model.""" - - id: UUID4 = Field(default_factory=uuid.uuid4) + id: ID oauth_name: str access_token: str - expires_at: Optional[int] = None - refresh_token: Optional[str] = None + expires_at: Optional[int] + refresh_token: Optional[str] account_id: str account_email: str - class Config: - orm_mode = True + def __init__(self, *args, **kwargs) -> None: + ... # pragma: no cover -class BaseOAuthAccountMixin(BaseModel): - """Adds OAuth accounts list to a User model.""" +UP = TypeVar("UP", bound=UserProtocol) +OAP = TypeVar("OAP", bound=OAuthAccountProtocol) - oauth_accounts: List[BaseOAuthAccount] = [] + +class UserOAuthProtocol(UserProtocol[ID], Generic[ID, OAP]): + """User protocol including a list of OAuth accounts.""" + + oauth_accounts: List[OAP] + + +UOAP = TypeVar("UOAP", bound=UserOAuthProtocol) diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index 16ef5e07..cda0fa3e 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -12,7 +12,7 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_auth_router( backend: AuthenticationBackend, - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP, models.ID], authenticator: Authenticator, requires_verification: bool = False, ) -> APIRouter: @@ -51,8 +51,8 @@ def get_auth_router( async def login( response: Response, credentials: OAuth2PasswordRequestForm = Depends(), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), - strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): user = await user_manager.authenticate(credentials) @@ -82,8 +82,8 @@ def get_auth_router( ) async def logout( response: Response, - user_token: Tuple[models.UD, str] = Depends(get_current_user_token), - strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), + user_token: Tuple[models.UP, str] = Depends(get_current_user_token), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): user, token = user_token return await backend.logout(strategy, user, token, response) diff --git a/fastapi_users/router/oauth.py b/fastapi_users/router/oauth.py index 2232be55..9ec1ac26 100644 --- a/fastapi_users/router/oauth.py +++ b/fastapi_users/router/oauth.py @@ -29,7 +29,7 @@ def generate_state_token( def get_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, - get_user_manager: UserManagerDependency[models.UC, models.UD], + get_user_manager: UserManagerDependency[models.UP, models.ID], state_secret: SecretType, redirect_url: str = None, ) -> APIRouter: @@ -101,8 +101,8 @@ def get_oauth_router( access_token_state: Tuple[OAuth2Token, str] = Depends( oauth2_authorize_callback ), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), - strategy: Strategy[models.UC, models.UD] = Depends(backend.get_strategy), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ): token, state = access_token_state account_id, account_email = await oauth_client.get_id_email( @@ -114,17 +114,16 @@ def get_oauth_router( except jwt.DecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - new_oauth_account = models.BaseOAuthAccount( - oauth_name=oauth_client.name, - access_token=token["access_token"], - expires_at=token.get("expires_at"), - refresh_token=token.get("refresh_token"), - account_id=account_id, - account_email=account_email, + user = await user_manager.oauth_callback( + oauth_client.name, + token["access_token"], + account_id, + account_email, + token.get("expires_at"), + token.get("refresh_token"), + request, ) - user = await user_manager.oauth_callback(new_oauth_account, request) - if not user.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/fastapi_users/router/register.py b/fastapi_users/router/register.py index 567d1a3c..3eab139e 100644 --- a/fastapi_users/router/register.py +++ b/fastapi_users/router/register.py @@ -2,7 +2,7 @@ from typing import Type from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.manager import ( BaseUserManager, InvalidPasswordException, @@ -13,16 +13,16 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_register_router( - get_user_manager: UserManagerDependency[models.UC, models.UD], - user_model: Type[models.U], - user_create_model: Type[models.UC], + get_user_manager: UserManagerDependency[models.UP, models.ID], + user_schema: Type[schemas.U], + user_create_schema: Type[schemas.UC], ) -> APIRouter: """Generate a router with the register route.""" router = APIRouter() @router.post( "/register", - response_model=user_model, + response_model=user_schema, status_code=status.HTTP_201_CREATED, name="register:register", responses={ @@ -55,11 +55,13 @@ def get_register_router( ) async def register( request: Request, - user: user_create_model, # type: ignore - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_create: user_create_schema, # type: ignore + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: - created_user = await user_manager.create(user, safe=True, request=request) + created_user = await user_manager.create( + user_create, safe=True, request=request + ) except UserAlreadyExists: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/fastapi_users/router/reset.py b/fastapi_users/router/reset.py index 46d6e924..95e2e93d 100644 --- a/fastapi_users/router/reset.py +++ b/fastapi_users/router/reset.py @@ -40,7 +40,7 @@ RESET_PASSWORD_RESPONSES: OpenAPIResponseType = { def get_reset_password_router( - get_user_manager: UserManagerDependency[models.UC, models.UD] + get_user_manager: UserManagerDependency[models.UP, models.ID], ) -> APIRouter: """Generate a router with the reset password routes.""" router = APIRouter() @@ -53,7 +53,7 @@ def get_reset_password_router( async def forgot_password( request: Request, email: EmailStr = Body(..., embed=True), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: user = await user_manager.get_by_email(email) @@ -76,7 +76,7 @@ def get_reset_password_router( request: Request, token: str = Body(...), password: str = Body(...), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: await user_manager.reset_password(token, password, request) diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index aed2eaad..b6a48caf 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -1,12 +1,12 @@ -from typing import Type +from typing import Any, Type from fastapi import APIRouter, Depends, HTTPException, Request, Response, status -from pydantic import UUID4 -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import Authenticator from fastapi_users.manager import ( BaseUserManager, + InvalidID, InvalidPasswordException, UserAlreadyExists, UserManagerDependency, @@ -16,10 +16,9 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_users_router( - get_user_manager: UserManagerDependency[models.UC, models.UD], - user_model: Type[models.U], - user_update_model: Type[models.UU], - user_db_model: Type[models.UD], + get_user_manager: UserManagerDependency[models.UP, models.ID], + user_schema: Type[schemas.U], + user_update_schema: Type[schemas.UU], authenticator: Authenticator, requires_verification: bool = False, ) -> APIRouter: @@ -34,17 +33,18 @@ def get_users_router( ) async def get_user_or_404( - id: UUID4, - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), - ) -> models.UD: + id: Any, + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), + ) -> models.UP: try: - return await user_manager.get(id) - except UserNotExists: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + parsed_id = user_manager.parse_id(id) + return await user_manager.get(parsed_id) + except (UserNotExists, InvalidID) as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) from e @router.get( "/me", - response_model=user_model, + response_model=user_schema, name="users:current_user", responses={ status.HTTP_401_UNAUTHORIZED: { @@ -53,13 +53,13 @@ def get_users_router( }, ) async def me( - user: user_db_model = Depends(get_current_active_user), # type: ignore + user: models.UP = Depends(get_current_active_user), ): return user @router.patch( "/me", - response_model=user_model, + response_model=user_schema, dependencies=[Depends(get_current_active_user)], name="users:patch_current_user", responses={ @@ -95,9 +95,9 @@ def get_users_router( ) async def update_me( request: Request, - user_update: user_update_model, # type: ignore - user: user_db_model = Depends(get_current_active_user), # type: ignore - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_update: user_update_schema, # type: ignore + user: models.UP = Depends(get_current_active_user), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: return await user_manager.update( @@ -118,8 +118,8 @@ def get_users_router( ) @router.get( - "/{id:uuid}", - response_model=user_model, + "/{id}", + response_model=user_schema, dependencies=[Depends(get_current_superuser)], name="users:user", responses={ @@ -138,8 +138,8 @@ def get_users_router( return user @router.patch( - "/{id:uuid}", - response_model=user_model, + "/{id}", + response_model=user_schema, dependencies=[Depends(get_current_superuser)], name="users:patch_user", responses={ @@ -180,10 +180,10 @@ def get_users_router( }, ) async def update_user( - user_update: user_update_model, # type: ignore + user_update: user_update_schema, # type: ignore request: Request, user=Depends(get_user_or_404), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: return await user_manager.update( @@ -204,7 +204,7 @@ def get_users_router( ) @router.delete( - "/{id:uuid}", + "/{id}", status_code=status.HTTP_204_NO_CONTENT, response_class=Response, dependencies=[Depends(get_current_superuser)], @@ -223,7 +223,7 @@ def get_users_router( ) async def delete_user( user=Depends(get_user_or_404), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): await user_manager.delete(user) return None diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py index d4046040..af6578f5 100644 --- a/fastapi_users/router/verify.py +++ b/fastapi_users/router/verify.py @@ -3,7 +3,7 @@ from typing import Type from fastapi import APIRouter, Body, Depends, HTTPException, Request, status from pydantic import EmailStr -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.manager import ( BaseUserManager, InvalidVerifyToken, @@ -16,8 +16,8 @@ from fastapi_users.router.common import ErrorCode, ErrorModel def get_verify_router( - get_user_manager: UserManagerDependency[models.UC, models.UD], - user_model: Type[models.U], + get_user_manager: UserManagerDependency[models.UP, models.ID], + user_schema: Type[schemas.U], ): router = APIRouter() @@ -29,7 +29,7 @@ def get_verify_router( async def request_verify_token( request: Request, email: EmailStr = Body(..., embed=True), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: user = await user_manager.get_by_email(email) @@ -41,7 +41,7 @@ def get_verify_router( @router.post( "/verify", - response_model=user_model, + response_model=user_schema, name="verify:verify", responses={ status.HTTP_400_BAD_REQUEST: { @@ -69,7 +69,7 @@ def get_verify_router( async def verify( request: Request, token: str = Body(..., embed=True), - user_manager: BaseUserManager[models.UC, models.UD] = Depends(get_user_manager), + user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), ): try: return await user_manager.verify(token, request) diff --git a/fastapi_users/schemas.py b/fastapi_users/schemas.py new file mode 100644 index 00000000..6ab2994d --- /dev/null +++ b/fastapi_users/schemas.py @@ -0,0 +1,74 @@ +from typing import Generic, List, Optional, TypeVar + +from pydantic import BaseModel, EmailStr + +from fastapi_users import models + + +class CreateUpdateDictModel(BaseModel): + def create_update_dict(self): + return self.dict( + exclude_unset=True, + exclude={ + "id", + "is_superuser", + "is_active", + "is_verified", + "oauth_accounts", + }, + ) + + def create_update_dict_superuser(self): + return self.dict(exclude_unset=True, exclude={"id"}) + + +class BaseUser(Generic[models.ID], CreateUpdateDictModel): + """Base User model.""" + + id: models.ID + email: EmailStr + is_active: bool = True + is_superuser: bool = False + is_verified: bool = False + + +class BaseUserCreate(CreateUpdateDictModel): + email: EmailStr + password: str + is_active: Optional[bool] = True + is_superuser: Optional[bool] = False + is_verified: Optional[bool] = False + + +class BaseUserUpdate(CreateUpdateDictModel): + password: Optional[str] + email: Optional[EmailStr] + is_active: Optional[bool] + is_superuser: Optional[bool] + is_verified: Optional[bool] + + +U = TypeVar("U", bound=BaseUser) +UC = TypeVar("UC", bound=BaseUserCreate) +UU = TypeVar("UU", bound=BaseUserUpdate) + + +class BaseOAuthAccount(Generic[models.ID], BaseModel): + """Base OAuth account model.""" + + id: models.ID + oauth_name: str + access_token: str + expires_at: Optional[int] = None + refresh_token: Optional[str] = None + account_id: str + account_email: str + + class Config: + orm_mode = True + + +class BaseOAuthAccountMixin(BaseModel): + """Adds OAuth accounts list to a User model.""" + + oauth_accounts: List[BaseOAuthAccount] = [] diff --git a/mkdocs.yml b/mkdocs.yml index 81fe714d..56344532 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,7 +24,6 @@ theme: - navigation.instant - navigation.top - navigation.sections - - navigation.indexes - search.suggest - search.highlight - content.code.annotate @@ -71,12 +70,9 @@ nav: - installation.md - Configuration: - configuration/overview.md - - configuration/models.md - - Database adapters: + - User model and databases: - configuration/databases/sqlalchemy.md - - configuration/databases/mongodb.md - - configuration/databases/tortoise.md - - configuration/databases/ormar.md + - configuration/databases/beanie.md - Authentication backends: - Introduction: configuration/authentication/index.md - Transports: @@ -88,6 +84,7 @@ nav: - configuration/authentication/strategies/redis.md - configuration/authentication/backend.md - configuration/user-manager.md + - configuration/schemas.md - Routers: - Introduction: configuration/routers/index.md - configuration/routers/auth.md @@ -113,3 +110,4 @@ nav: - migration/6x_to_7x.md - migration/7x_to_8x.md - migration/8x_to_9x.md + - migration/9x_to_10x.md diff --git a/pyproject.toml b/pyproject.toml index e8d0f28a..29726349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,21 +10,13 @@ module = "passlib.*" ignore_missing_imports = true [[tool.mypy.overrides]] -module = "fastapi_users_db_mongodb.*" +module = "fastapi_users_db_beanie.*" ignore_missing_imports = true [[tool.mypy.overrides]] module = "fastapi_users_db_sqlalchemy.*" ignore_missing_imports = true -[[tool.mypy.overrides]] -module = "fastapi_users_db_tortoise.*" -ignore_missing_imports = true - -[[tool.mypy.overrides]] -module = "fastapi_users_db_ormar.*" -ignore_missing_imports = true - [tool.pytest.ini_options] asyncio_mode = "auto" addopts = "--ignore=test_build.py" @@ -106,19 +98,10 @@ dev = [ "uvicorn", ] sqlalchemy = [ - "fastapi-users-db-sqlalchemy >=1.1.0,<2.0.0", + "fastapi-users-db-sqlalchemy >=4.0.0", ] -sqlalchemy2 = [ - "fastapi-users-db-sqlalchemy >=2.0.0,<4.0.0", -] -mongodb = [ - "fastapi-users-db-mongodb >=1.1.0,<2.0.0", -] -tortoise-orm = [ - "fastapi-users-db-tortoise >=1.1.0,<2.0.0", -] -ormar = [ - "fastapi-users-db-ormar >=1.0.0,<2.0.0", +beanie = [ + "fastapi-users-db-beanie >=1.0.0", ] oauth = [ "httpx-oauth >=0.4,<0.7" diff --git a/tests/conftest.py b/tests/conftest.py index d94bc94e..3a6ddd48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,17 @@ import asyncio -from typing import Any, AsyncGenerator, Callable, Generic, Optional, Type, Union +import dataclasses +import uuid +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Generic, + List, + Optional, + Type, + Union, +) from unittest.mock import MagicMock import httpx @@ -10,17 +22,18 @@ from httpx_oauth.oauth2 import OAuth2 from pydantic import UUID4, SecretStr from pytest_mock import MockerFixture -from fastapi_users import models +from fastapi_users import models, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport from fastapi_users.authentication.strategy import Strategy from fastapi_users.db import BaseUserDatabase from fastapi_users.jwt import SecretType from fastapi_users.manager import ( BaseUserManager, + InvalidID, InvalidPasswordException, UserNotExists, + UUIDIDMixin, ) -from fastapi_users.models import BaseOAuthAccount, BaseOAuthAccountMixin from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.password import PasswordHelper @@ -32,38 +45,60 @@ lancelot_password_hash = password_helper.hash("lancelot") excalibur_password_hash = password_helper.hash("excalibur") -class User(models.BaseUser): +IDType = uuid.UUID + + +@dataclasses.dataclass +class UserModel(models.UserProtocol[IDType]): + email: str + hashed_password: str + id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) + is_active: bool = True + is_superuser: bool = False + is_verified: bool = False + first_name: Optional[str] = None + + +@dataclasses.dataclass +class OAuthAccountModel(models.OAuthAccountProtocol[IDType]): + oauth_name: str + access_token: str + account_id: str + account_email: str + id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) + expires_at: Optional[int] = None + refresh_token: Optional[str] = None + + +@dataclasses.dataclass +class UserOAuthModel(UserModel): + oauth_accounts: List[OAuthAccountModel] = dataclasses.field(default_factory=list) + + +class User(schemas.BaseUser[IDType]): first_name: Optional[str] -class UserCreate(models.BaseUserCreate): +class UserCreate(schemas.BaseUserCreate): first_name: Optional[str] -class UserUpdate(models.BaseUserUpdate): +class UserUpdate(schemas.BaseUserUpdate): first_name: Optional[str] -class UserDB(User, models.BaseUserDB): - pass - - -class UserOAuth(User, BaseOAuthAccountMixin): - pass - - -class UserDBOAuth(UserOAuth, UserDB): +class UserOAuth(User, schemas.BaseOAuthAccountMixin): pass class BaseTestUserManager( - Generic[models.UC, models.UD], BaseUserManager[models.UC, models.UD] + Generic[models.UP], UUIDIDMixin, BaseUserManager[models.UP, IDType] ): reset_password_token_secret = "SECRET" verification_token_secret = "SECRET" async def validate_password( - self, password: str, user: Union[models.UC, models.UD] + self, password: str, user: Union[schemas.UC, models.UP] ) -> None: if len(password) < 3: raise InvalidPasswordException( @@ -71,15 +106,15 @@ class BaseTestUserManager( ) -class UserManager(BaseTestUserManager[UserCreate, UserDB]): - user_db_model = UserDB +class UserManager(BaseTestUserManager[UserModel]): + pass -class UserManagerOAuth(BaseTestUserManager[UserCreate, UserDBOAuth]): - user_db_model = UserDBOAuth +class UserManagerOAuth(BaseTestUserManager[UserOAuthModel]): + pass -class UserManagerMock(UserManager): +class UserManagerMock(BaseTestUserManager[models.UP]): get_by_email: MagicMock request_verify: MagicMock verify: MagicMock @@ -131,16 +166,18 @@ def secret(request) -> SecretType: @pytest.fixture -def user() -> UserDB: - return UserDB( +def user() -> UserModel: + return UserModel( email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash, ) @pytest.fixture -def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth: - return UserDBOAuth( +def user_oauth( + oauth_account1: OAuthAccountModel, oauth_account2: OAuthAccountModel +) -> UserOAuthModel: + return UserOAuthModel( email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash, oauth_accounts=[oauth_account1, oauth_account2], @@ -148,8 +185,8 @@ def user_oauth(oauth_account1, oauth_account2) -> UserDBOAuth: @pytest.fixture -def inactive_user() -> UserDB: - return UserDB( +def inactive_user() -> UserModel: + return UserModel( email="percival@camelot.bt", hashed_password=angharad_password_hash, is_active=False, @@ -157,8 +194,8 @@ def inactive_user() -> UserDB: @pytest.fixture -def inactive_user_oauth(oauth_account3) -> UserDBOAuth: - return UserDBOAuth( +def inactive_user_oauth(oauth_account3: OAuthAccountModel) -> UserOAuthModel: + return UserOAuthModel( email="percival@camelot.bt", hashed_password=angharad_password_hash, is_active=False, @@ -167,8 +204,8 @@ def inactive_user_oauth(oauth_account3) -> UserDBOAuth: @pytest.fixture -def verified_user() -> UserDB: - return UserDB( +def verified_user() -> UserModel: + return UserModel( email="lake.lady@camelot.bt", hashed_password=excalibur_password_hash, is_active=True, @@ -177,8 +214,8 @@ def verified_user() -> UserDB: @pytest.fixture -def verified_user_oauth(oauth_account4) -> UserDBOAuth: - return UserDBOAuth( +def verified_user_oauth(oauth_account4: OAuthAccountModel) -> UserOAuthModel: + return UserOAuthModel( email="lake.lady@camelot.bt", hashed_password=excalibur_password_hash, is_active=False, @@ -187,8 +224,8 @@ def verified_user_oauth(oauth_account4) -> UserDBOAuth: @pytest.fixture -def superuser() -> UserDB: - return UserDB( +def superuser() -> UserModel: + return UserModel( email="merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -196,8 +233,8 @@ def superuser() -> UserDB: @pytest.fixture -def superuser_oauth() -> UserDBOAuth: - return UserDBOAuth( +def superuser_oauth() -> UserOAuthModel: + return UserOAuthModel( email="merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -206,8 +243,8 @@ def superuser_oauth() -> UserDBOAuth: @pytest.fixture -def verified_superuser() -> UserDB: - return UserDB( +def verified_superuser() -> UserModel: + return UserModel( email="the.real.merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -216,8 +253,8 @@ def verified_superuser() -> UserDB: @pytest.fixture -def verified_superuser_oauth() -> UserDBOAuth: - return UserDBOAuth( +def verified_superuser_oauth() -> UserOAuthModel: + return UserOAuthModel( email="the.real.merlin@camelot.bt", hashed_password=viviane_password_hash, is_superuser=True, @@ -227,8 +264,8 @@ def verified_superuser_oauth() -> UserDBOAuth: @pytest.fixture -def oauth_account1() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account1() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service1", access_token="TOKEN", expires_at=1579000751, @@ -238,8 +275,8 @@ def oauth_account1() -> BaseOAuthAccount: @pytest.fixture -def oauth_account2() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account2() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service2", access_token="TOKEN", expires_at=1579000751, @@ -249,8 +286,8 @@ def oauth_account2() -> BaseOAuthAccount: @pytest.fixture -def oauth_account3() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account3() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service3", access_token="TOKEN", expires_at=1579000751, @@ -260,8 +297,8 @@ def oauth_account3() -> BaseOAuthAccount: @pytest.fixture -def oauth_account4() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account4() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service4", access_token="TOKEN", expires_at=1579000751, @@ -271,8 +308,8 @@ def oauth_account4() -> BaseOAuthAccount: @pytest.fixture -def oauth_account5() -> BaseOAuthAccount: - return BaseOAuthAccount( +def oauth_account5() -> OAuthAccountModel: + return OAuthAccountModel( oauth_name="service5", access_token="TOKEN", expires_at=1579000751, @@ -283,10 +320,14 @@ def oauth_account5() -> BaseOAuthAccount: @pytest.fixture def mock_user_db( - user, verified_user, inactive_user, superuser, verified_superuser -) -> BaseUserDatabase: - class MockUserDatabase(BaseUserDatabase[UserDB]): - async def get(self, id: UUID4) -> Optional[UserDB]: + user: UserModel, + verified_user: UserModel, + inactive_user: UserModel, + superuser: UserModel, + verified_superuser: UserModel, +) -> BaseUserDatabase[UserModel, IDType]: + class MockUserDatabase(BaseUserDatabase[UserModel, IDType]): + async def get(self, id: UUID4) -> Optional[UserModel]: if id == user.id: return user if id == verified_user.id: @@ -299,7 +340,7 @@ def mock_user_db( return verified_superuser return None - async def get_by_email(self, email: str) -> Optional[UserDB]: + async def get_by_email(self, email: str) -> Optional[UserModel]: lower_email = email.lower() if lower_email == user.email.lower(): return user @@ -313,28 +354,32 @@ def mock_user_db( return verified_superuser return None - async def create(self, user: UserDB) -> UserDB: + async def create(self, create_dict: Dict[str, Any]) -> UserModel: + return UserModel(**create_dict) + + async def update( + self, user: UserModel, update_dict: Dict[str, Any] + ) -> UserModel: + for field, value in update_dict.items(): + setattr(user, field, value) return user - async def update(self, user: UserDB) -> UserDB: - return user - - async def delete(self, user: UserDB) -> None: + async def delete(self, user: UserModel) -> None: pass - return MockUserDatabase(UserDB) + return MockUserDatabase() @pytest.fixture def mock_user_db_oauth( - user_oauth, - verified_user_oauth, - inactive_user_oauth, - superuser_oauth, - verified_superuser_oauth, -) -> BaseUserDatabase: - class MockUserDatabase(BaseUserDatabase[UserDBOAuth]): - async def get(self, id: UUID4) -> Optional[UserDBOAuth]: + user_oauth: UserOAuthModel, + verified_user_oauth: UserOAuthModel, + inactive_user_oauth: UserOAuthModel, + superuser_oauth: UserOAuthModel, + verified_superuser_oauth: UserOAuthModel, +) -> BaseUserDatabase[UserOAuthModel, IDType]: + class MockUserDatabase(BaseUserDatabase[UserOAuthModel, IDType]): + async def get(self, id: UUID4) -> Optional[UserOAuthModel]: if id == user_oauth.id: return user_oauth if id == verified_user_oauth.id: @@ -347,7 +392,7 @@ def mock_user_db_oauth( return verified_superuser_oauth return None - async def get_by_email(self, email: str) -> Optional[UserDBOAuth]: + async def get_by_email(self, email: str) -> Optional[UserOAuthModel]: lower_email = email.lower() if lower_email == user_oauth.email.lower(): return user_oauth @@ -363,7 +408,7 @@ def mock_user_db_oauth( async def get_by_oauth_account( self, oauth: str, account_id: str - ) -> Optional[UserDBOAuth]: + ) -> Optional[UserOAuthModel]: user_oauth_account = user_oauth.oauth_accounts[0] if ( user_oauth_account.oauth_name == oauth @@ -379,16 +424,46 @@ def mock_user_db_oauth( return inactive_user_oauth return None - async def create(self, user: UserDBOAuth) -> UserDBOAuth: - return user_oauth + async def create(self, create_dict: Dict[str, Any]) -> UserOAuthModel: + return UserOAuthModel(**create_dict) - async def update(self, user: UserDBOAuth) -> UserDBOAuth: - return user_oauth + async def update( + self, user: UserOAuthModel, update_dict: Dict[str, Any] + ) -> UserOAuthModel: + for field, value in update_dict.items(): + setattr(user, field, value) + return user - async def delete(self, user: UserDBOAuth) -> None: + async def delete(self, user: UserOAuthModel) -> None: pass - return MockUserDatabase(UserDBOAuth) + async def add_oauth_account( + self, user: UserOAuthModel, create_dict: Dict[str, Any] + ) -> UserOAuthModel: + oauth_account = OAuthAccountModel(**create_dict) + user.oauth_accounts.append(oauth_account) + return user + + async def update_oauth_account( # type: ignore + self, + user: UserOAuthModel, + oauth_account: OAuthAccountModel, + update_dict: Dict[str, Any], + ) -> UserOAuthModel: + for field, value in update_dict.items(): + setattr(oauth_account, field, value) + updated_oauth_accounts = [] + for existing_oauth_account in user.oauth_accounts: + if ( + existing_oauth_account.account_id == oauth_account.account_id + and existing_oauth_account.oauth_name == oauth_account.oauth_name + ): + updated_oauth_accounts.append(oauth_account) + else: + updated_oauth_accounts.append(existing_oauth_account) + return user + + return MockUserDatabase() @pytest.fixture @@ -450,24 +525,22 @@ class MockTransport(BearerTransport): return {} -class MockStrategy(Strategy[UserCreate, UserDB]): +class MockStrategy(Strategy[UserModel, IDType]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[UserCreate, UserDB] - ) -> Optional[UserDB]: + self, token: Optional[str], user_manager: BaseUserManager[UserModel, IDType] + ) -> Optional[UserModel]: if token is not None: try: - token_uuid = UUID4(token) - return await user_manager.get(token_uuid) - except ValueError: - return None - except UserNotExists: + parsed_id = user_manager.parse_id(token) + return await user_manager.get(parsed_id) + except (InvalidID, UserNotExists): return None return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: UserModel) -> str: return str(user.id) - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: UserModel) -> None: return None diff --git a/tests/test_authentication_authenticator.py b/tests/test_authentication_authenticator.py index 39982ead..d2ee48e2 100644 --- a/tests/test_authentication_authenticator.py +++ b/tests/test_authentication_authenticator.py @@ -12,7 +12,7 @@ from fastapi_users.authentication.strategy import Strategy from fastapi_users.authentication.transport import Transport from fastapi_users.manager import BaseUserManager from fastapi_users.types import DependencyCallable -from tests.conftest import UserDB +from tests.conftest import User, UserModel class MockSecurityScheme(SecurityBase): @@ -29,18 +29,18 @@ class MockTransport(Transport): class NoneStrategy(Strategy): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: return None -class UserStrategy(Strategy, Generic[models.UC, models.UD]): - def __init__(self, user: models.UD): +class UserStrategy(Strategy, Generic[models.UP]): + def __init__(self, user: models.UP): self.user = user async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: return self.user @@ -55,7 +55,7 @@ def get_backend_none(): @pytest.fixture -def get_backend_user(user: UserDB): +def get_backend_user(user: UserModel): def _get_backend_user(name: str = "user"): return AuthenticationBackend( name=name, @@ -78,17 +78,17 @@ def get_test_auth_client(get_user_manager, get_test_client): app = FastAPI() authenticator = Authenticator(backends, get_user_manager) - @app.get("/test-current-user") + @app.get("/test-current-user", response_model=User) def test_current_user( - user: UserDB = Depends( + user: UserModel = Depends( authenticator.current_user(get_enabled_backends=get_enabled_backends) ), ): return user - @app.get("/test-current-active-user") + @app.get("/test-current-active-user", response_model=User) def test_current_active_user( - user: UserDB = Depends( + user: UserModel = Depends( authenticator.current_user( active=True, get_enabled_backends=get_enabled_backends ) @@ -96,9 +96,9 @@ def get_test_auth_client(get_user_manager, get_test_client): ): return user - @app.get("/test-current-superuser") + @app.get("/test-current-superuser", response_model=User) def test_current_superuser( - user: UserDB = Depends( + user: UserModel = Depends( authenticator.current_user( active=True, superuser=True, diff --git a/tests/test_authentication_backend.py b/tests/test_authentication_backend.py index ddd4cec9..b7d99795 100644 --- a/tests/test_authentication_backend.py +++ b/tests/test_authentication_backend.py @@ -12,23 +12,23 @@ from fastapi_users.authentication import ( from fastapi_users.authentication.strategy import StrategyDestroyNotSupportedError from fastapi_users.authentication.transport.base import Transport from fastapi_users.manager import BaseUserManager -from tests.conftest import MockStrategy, MockTransport, UserDB +from tests.conftest import MockStrategy, MockTransport, UserModel class MockTransportLogoutNotSupported(BearerTransport): pass -class MockStrategyDestroyNotSupported(Strategy, Generic[models.UC, models.UD]): +class MockStrategyDestroyNotSupported(Strategy, Generic[models.UP]): async def read_token( - self, token: Optional[str], user_manager: BaseUserManager[models.UC, models.UD] - ) -> Optional[models.UD]: + self, token: Optional[str], user_manager: BaseUserManager[models.UP, models.ID] + ) -> Optional[models.UP]: return None - async def write_token(self, user: models.UD) -> str: + async def write_token(self, user: models.UP) -> str: return "TOKEN" - async def destroy_token(self, token: str, user: models.UD) -> None: + async def destroy_token(self, token: str, user: models.UP) -> None: raise StrategyDestroyNotSupportedError @@ -55,7 +55,7 @@ def backend( @pytest.mark.asyncio @pytest.mark.authentication -async def test_logout(backend: AuthenticationBackend, user: UserDB): +async def test_logout(backend: AuthenticationBackend, user: UserModel): strategy = cast(Strategy, backend.get_strategy()) result = await backend.logout(strategy, user, "TOKEN", Response()) assert result is None diff --git a/tests/test_authentication_strategy_db.py b/tests/test_authentication_strategy_db.py index 8e539eae..201f8949 100644 --- a/tests/test_authentication_strategy_db.py +++ b/tests/test_authentication_strategy_db.py @@ -1,30 +1,37 @@ +import dataclasses import uuid -from datetime import datetime -from typing import Dict, Optional +from datetime import datetime, timezone +from typing import Any, Dict, Optional import pytest from fastapi_users.authentication.strategy import ( AccessTokenDatabase, - BaseAccessToken, + AccessTokenProtocol, DatabaseStrategy, ) +from tests.conftest import IDType, UserModel -class AccessToken(BaseAccessToken): - pass +@dataclasses.dataclass +class AccessTokenModel(AccessTokenProtocol[IDType]): + token: str + user_id: uuid.UUID + id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4) + created_at: datetime = dataclasses.field( + default_factory=lambda: datetime.now(timezone.utc) + ) -class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]): - store: Dict[str, AccessToken] +class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]): + store: Dict[str, AccessTokenModel] def __init__(self): - self.access_token_model = AccessToken self.store = {} async def get_by_token( self, token: str, max_age: Optional[datetime] = None - ) -> Optional[AccessToken]: + ) -> Optional[AccessTokenModel]: try: access_token = self.store[token] if max_age is not None and access_token.created_at < max_age: @@ -33,15 +40,20 @@ class AccessTokenDatabaseMock(AccessTokenDatabase[AccessToken]): except KeyError: return None - async def create(self, access_token: AccessToken) -> AccessToken: + async def create(self, create_dict: Dict[str, Any]) -> AccessTokenModel: + access_token = AccessTokenModel(**create_dict) self.store[access_token.token] = access_token return access_token - async def update(self, access_token: AccessToken) -> AccessToken: + async def update( + self, access_token: AccessTokenModel, update_dict: Dict[str, Any] + ) -> AccessTokenModel: + for field, value in update_dict.items(): + setattr(access_token, field, value) self.store[access_token.token] = access_token return access_token - async def delete(self, access_token: AccessToken) -> None: + async def delete(self, access_token: AccessTokenModel) -> None: try: del self.store[access_token.token] except KeyError: @@ -62,14 +74,18 @@ def database_strategy(access_token_database: AccessTokenDatabaseMock): class TestReadToken: @pytest.mark.asyncio async def test_missing_token( - self, database_strategy: DatabaseStrategy, user_manager + self, + database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], + user_manager, ): authenticated_user = await database_strategy.read_token(None, user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_invalid_token( - self, database_strategy: DatabaseStrategy, user_manager + self, + database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], + user_manager, ): authenticated_user = await database_strategy.read_token("TOKEN", user_manager) assert authenticated_user is None @@ -77,14 +93,15 @@ class TestReadToken: @pytest.mark.asyncio async def test_valid_token_not_existing_user( self, - database_strategy: DatabaseStrategy, + database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], access_token_database: AccessTokenDatabaseMock, user_manager, ): await access_token_database.create( - AccessToken( - token="TOKEN", user_id=uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f") - ) + { + "token": "TOKEN", + "user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"), + } ) authenticated_user = await database_strategy.read_token("TOKEN", user_manager) assert authenticated_user is None @@ -92,12 +109,12 @@ class TestReadToken: @pytest.mark.asyncio async def test_valid_token( self, - database_strategy: DatabaseStrategy, + database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], access_token_database: AccessTokenDatabaseMock, user_manager, - user, + user: UserModel, ): - await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id)) + await access_token_database.create({"token": "TOKEN", "user_id": user.id}) authenticated_user = await database_strategy.read_token("TOKEN", user_manager) assert authenticated_user is not None assert authenticated_user.id == user.id @@ -106,9 +123,9 @@ class TestReadToken: @pytest.mark.authentication @pytest.mark.asyncio async def test_write_token( - database_strategy: DatabaseStrategy, + database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], access_token_database: AccessTokenDatabaseMock, - user, + user: UserModel, ): token = await database_strategy.write_token(user) @@ -120,11 +137,11 @@ async def test_write_token( @pytest.mark.authentication @pytest.mark.asyncio async def test_destroy_token( - database_strategy: DatabaseStrategy, + database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel], access_token_database: AccessTokenDatabaseMock, - user, + user: UserModel, ): - await access_token_database.create(AccessToken(token="TOKEN", user_id=user.id)) + await access_token_database.create({"token": "TOKEN", "user_id": user.id}) await database_strategy.destroy_token("TOKEN", user) diff --git a/tests/test_authentication_strategy_jwt.py b/tests/test_authentication_strategy_jwt.py index 08e5fd7a..45aea5a5 100644 --- a/tests/test_authentication_strategy_jwt.py +++ b/tests/test_authentication_strategy_jwt.py @@ -5,6 +5,7 @@ from fastapi_users.authentication.strategy import ( StrategyDestroyNotSupportedError, ) from fastapi_users.jwt import SecretType, decode_jwt, generate_jwt +from tests.conftest import IDType, UserModel LIFETIME = 3600 @@ -74,7 +75,7 @@ def jwt_strategy(request, secret: SecretType): @pytest.fixture -def token(jwt_strategy: JWTStrategy): +def token(jwt_strategy: JWTStrategy[UserModel, IDType]): def _token(user_id=None, lifetime=LIFETIME): data = {"aud": "fastapi-users:auth"} if user_id is not None: @@ -90,32 +91,36 @@ def token(jwt_strategy: JWTStrategy): @pytest.mark.authentication class TestReadToken: @pytest.mark.asyncio - async def test_missing_token(self, jwt_strategy: JWTStrategy, user_manager): + async def test_missing_token( + self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager + ): authenticated_user = await jwt_strategy.read_token(None, user_manager) assert authenticated_user is None @pytest.mark.asyncio - async def test_invalid_token(self, jwt_strategy: JWTStrategy, user_manager): + async def test_invalid_token( + self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager + ): authenticated_user = await jwt_strategy.read_token("foo", user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_valid_token_missing_user_payload( - self, jwt_strategy: JWTStrategy, user_manager, token + self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token ): authenticated_user = await jwt_strategy.read_token(token(), user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_valid_token_invalid_uuid( - self, jwt_strategy: JWTStrategy, user_manager, token + self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token ): authenticated_user = await jwt_strategy.read_token(token("foo"), user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_valid_token_not_existing_user( - self, jwt_strategy: JWTStrategy, user_manager, token + self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token ): authenticated_user = await jwt_strategy.read_token( token("d35d213e-f3d8-4f08-954a-7e0d1bea286f"), user_manager @@ -124,7 +129,7 @@ class TestReadToken: @pytest.mark.asyncio async def test_valid_token( - self, jwt_strategy: JWTStrategy, user_manager, token, user + self, jwt_strategy: JWTStrategy[UserModel, IDType], user_manager, token, user ): authenticated_user = await jwt_strategy.read_token(token(user.id), user_manager) assert authenticated_user is not None @@ -134,7 +139,7 @@ class TestReadToken: @pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) @pytest.mark.authentication @pytest.mark.asyncio -async def test_write_token(jwt_strategy: JWTStrategy, user): +async def test_write_token(jwt_strategy: JWTStrategy[UserModel, IDType], user): token = await jwt_strategy.write_token(user) decoded = decode_jwt( @@ -149,6 +154,6 @@ async def test_write_token(jwt_strategy: JWTStrategy, user): @pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) @pytest.mark.authentication @pytest.mark.asyncio -async def test_destroy_token(jwt_strategy: JWTStrategy, user): +async def test_destroy_token(jwt_strategy: JWTStrategy[UserModel, IDType], user): with pytest.raises(StrategyDestroyNotSupportedError): await jwt_strategy.destroy_token("TOKEN", user) diff --git a/tests/test_authentication_strategy_redis.py b/tests/test_authentication_strategy_redis.py index cdddde1a..61fe250f 100644 --- a/tests/test_authentication_strategy_redis.py +++ b/tests/test_authentication_strategy_redis.py @@ -4,6 +4,7 @@ from typing import Dict, Optional, Tuple import pytest from fastapi_users.authentication.strategy import RedisStrategy +from tests.conftest import IDType, UserModel class RedisMock: @@ -47,19 +48,23 @@ def redis_strategy(redis): @pytest.mark.authentication class TestReadToken: @pytest.mark.asyncio - async def test_missing_token(self, redis_strategy: RedisStrategy, user_manager): + async def test_missing_token( + self, redis_strategy: RedisStrategy[UserModel, IDType], user_manager + ): authenticated_user = await redis_strategy.read_token(None, user_manager) assert authenticated_user is None @pytest.mark.asyncio - async def test_invalid_token(self, redis_strategy: RedisStrategy, user_manager): + async def test_invalid_token( + self, redis_strategy: RedisStrategy[UserModel, IDType], user_manager + ): authenticated_user = await redis_strategy.read_token("TOKEN", user_manager) assert authenticated_user is None @pytest.mark.asyncio async def test_valid_token_invalid_uuid( self, - redis_strategy: RedisStrategy, + redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user_manager, ): @@ -70,7 +75,7 @@ class TestReadToken: @pytest.mark.asyncio async def test_valid_token_not_existing_user( self, - redis_strategy: RedisStrategy, + redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user_manager, ): @@ -81,7 +86,7 @@ class TestReadToken: @pytest.mark.asyncio async def test_valid_token( self, - redis_strategy: RedisStrategy, + redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user_manager, user, @@ -94,7 +99,9 @@ class TestReadToken: @pytest.mark.authentication @pytest.mark.asyncio -async def test_write_token(redis_strategy: RedisStrategy, redis: RedisMock, user): +async def test_write_token( + redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user +): token = await redis_strategy.write_token(user) value = await redis.get(token) @@ -103,7 +110,9 @@ async def test_write_token(redis_strategy: RedisStrategy, redis: RedisMock, user @pytest.mark.authentication @pytest.mark.asyncio -async def test_destroy_token(redis_strategy: RedisStrategy, redis: RedisMock, user): +async def test_destroy_token( + redis_strategy: RedisStrategy[UserModel, IDType], redis: RedisMock, user +): await redis.set("TOKEN", str(user.id)) await redis_strategy.destroy_token("TOKEN", user) diff --git a/tests/test_db_base.py b/tests/test_db_base.py index 9fab9fa8..23e9e44b 100644 --- a/tests/test_db_base.py +++ b/tests/test_db_base.py @@ -1,16 +1,20 @@ +import uuid + import pytest from fastapi_users.db import BaseUserDatabase -from tests.conftest import UserDB +from tests.conftest import IDType, OAuthAccountModel, UserModel @pytest.mark.asyncio @pytest.mark.db -async def test_not_implemented_methods(user): - base_user_db = BaseUserDatabase(UserDB) +async def test_not_implemented_methods( + user: UserModel, oauth_account1: OAuthAccountModel +): + base_user_db = BaseUserDatabase[UserModel, IDType]() with pytest.raises(NotImplementedError): - await base_user_db.get("aaa") + await base_user_db.get(uuid.uuid4()) with pytest.raises(NotImplementedError): await base_user_db.get_by_email("lancelot@camelot.bt") @@ -19,10 +23,16 @@ async def test_not_implemented_methods(user): await base_user_db.get_by_oauth_account("google", "user_oauth1") with pytest.raises(NotImplementedError): - await base_user_db.create(user) + await base_user_db.create({}) with pytest.raises(NotImplementedError): - await base_user_db.update(user) + await base_user_db.update(user, {}) with pytest.raises(NotImplementedError): await base_user_db.delete(user) + + with pytest.raises(NotImplementedError): + await base_user_db.add_oauth_account(user, {}) + + with pytest.raises(NotImplementedError): + await base_user_db.update_oauth_account(user, oauth_account1, {}) diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index 179ee7d7..f4c60858 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -5,7 +5,7 @@ import pytest from fastapi import Depends, FastAPI, status from fastapi_users import FastAPIUsers -from tests.conftest import User, UserCreate, UserDB, UserUpdate +from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate @pytest.fixture @@ -17,82 +17,92 @@ async def test_app_client( oauth_client, get_test_client, ) -> AsyncGenerator[httpx.AsyncClient, None]: - fastapi_users = FastAPIUsers[User, UserCreate, UserUpdate, UserDB]( - get_user_manager, - [mock_authentication], - User, - UserCreate, - UserUpdate, - UserDB, + fastapi_users = FastAPIUsers[UserModel, IDType]( + get_user_manager, [mock_authentication] ) app = FastAPI() - app.include_router(fastapi_users.get_register_router()) + app.include_router(fastapi_users.get_register_router(User, UserCreate)) app.include_router(fastapi_users.get_reset_password_router()) app.include_router(fastapi_users.get_auth_router(mock_authentication)) app.include_router( fastapi_users.get_oauth_router(oauth_client, mock_authentication, secret) ) - app.include_router(fastapi_users.get_users_router(), prefix="/users") - app.include_router(fastapi_users.get_verify_router()) @app.delete("/users/me") def custom_users_route(): return None - @app.get("/current-user") - def current_user(user=Depends(fastapi_users.current_user())): + app.include_router( + fastapi_users.get_users_router(User, UserUpdate), prefix="/users" + ) + app.include_router(fastapi_users.get_verify_router(User)) + + @app.get("/current-user", response_model=User) + def current_user(user: UserModel = Depends(fastapi_users.current_user())): return user - @app.get("/current-active-user") - def current_active_user(user=Depends(fastapi_users.current_user(active=True))): - return user - - @app.get("/current-verified-user") - def current_verified_user(user=Depends(fastapi_users.current_user(verified=True))): - return user - - @app.get("/current-superuser") - def current_superuser( - user=Depends(fastapi_users.current_user(active=True, superuser=True)) + @app.get("/current-active-user", response_model=User) + def current_active_user( + user: UserModel = Depends(fastapi_users.current_user(active=True)), ): return user - @app.get("/current-verified-superuser") + @app.get("/current-verified-user", response_model=User) + def current_verified_user( + user: UserModel = Depends(fastapi_users.current_user(verified=True)), + ): + return user + + @app.get("/current-superuser", response_model=User) + def current_superuser( + user: UserModel = Depends( + fastapi_users.current_user(active=True, superuser=True) + ) + ): + return user + + @app.get("/current-verified-superuser", response_model=User) def current_verified_superuser( - user=Depends( + user: UserModel = Depends( fastapi_users.current_user(active=True, verified=True, superuser=True) ), ): return user - @app.get("/optional-current-user") - def optional_current_user(user=Depends(fastapi_users.current_user(optional=True))): + @app.get("/optional-current-user", response_model=User) + def optional_current_user( + user: UserModel = Depends(fastapi_users.current_user(optional=True)), + ): return user - @app.get("/optional-current-active-user") + @app.get("/optional-current-active-user", response_model=User) def optional_current_active_user( - user=Depends(fastapi_users.current_user(optional=True, active=True)), + user: UserModel = Depends( + fastapi_users.current_user(optional=True, active=True) + ), ): return user - @app.get("/optional-current-verified-user") + @app.get("/optional-current-verified-user", response_model=User) def optional_current_verified_user( - user=Depends(fastapi_users.current_user(optional=True, verified=True)), + user: UserModel = Depends( + fastapi_users.current_user(optional=True, verified=True) + ), ): return user - @app.get("/optional-current-superuser") + @app.get("/optional-current-superuser", response_model=User) def optional_current_superuser( - user=Depends( + user: UserModel = Depends( fastapi_users.current_user(optional=True, active=True, superuser=True) ), ): return user - @app.get("/optional-current-verified-superuser") + @app.get("/optional-current-verified-superuser", response_model=User) def optional_current_verified_superuser( - user=Depends( + user: UserModel = Depends( fastapi_users.current_user( optional=True, active=True, verified=True, superuser=True ) @@ -150,7 +160,9 @@ class TestGetCurrentUser: ) assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/current-user", headers={"Authorization": f"Bearer {user.id}"} ) @@ -171,7 +183,7 @@ class TestGetCurrentActiveUser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_inactive_user( - self, test_app_client: httpx.AsyncClient, inactive_user: UserDB + self, test_app_client: httpx.AsyncClient, inactive_user: UserModel ): response = await test_app_client.get( "/current-active-user", @@ -179,7 +191,9 @@ class TestGetCurrentActiveUser: ) assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/current-active-user", headers={"Authorization": f"Bearer {user.id}"} ) @@ -200,7 +214,7 @@ class TestGetCurrentVerifiedUser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_unverified_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/current-verified-user", @@ -209,7 +223,7 @@ class TestGetCurrentVerifiedUser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/current-verified-user", @@ -232,7 +246,7 @@ class TestGetCurrentSuperuser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/current-superuser", headers={"Authorization": f"Bearer {user.id}"} @@ -240,7 +254,7 @@ class TestGetCurrentSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/current-superuser", headers={"Authorization": f"Bearer {superuser.id}"} @@ -262,7 +276,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -271,7 +285,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -280,7 +294,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -289,7 +303,7 @@ class TestGetCurrentVerifiedSuperuser: assert response.status_code == status.HTTP_403_FORBIDDEN async def test_valid_token_verified_superuser( - self, test_app_client: httpx.AsyncClient, verified_superuser: UserDB + self, test_app_client: httpx.AsyncClient, verified_superuser: UserModel ): response = await test_app_client.get( "/current-verified-superuser", @@ -313,7 +327,9 @@ class TestOptionalGetCurrentUser: assert response.status_code == status.HTTP_200_OK assert response.json() is None - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/optional-current-user", headers={"Authorization": f"Bearer {user.id}"} ) @@ -337,7 +353,7 @@ class TestOptionalGetCurrentVerifiedUser: assert response.json() is None async def test_valid_token_unverified_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/optional-current-verified-user", @@ -347,7 +363,7 @@ class TestOptionalGetCurrentVerifiedUser: assert response.json() is None async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/optional-current-verified-user", @@ -373,7 +389,7 @@ class TestOptionalGetCurrentActiveUser: assert response.json() is None async def test_valid_token_inactive_user( - self, test_app_client: httpx.AsyncClient, inactive_user: UserDB + self, test_app_client: httpx.AsyncClient, inactive_user: UserModel ): response = await test_app_client.get( "/optional-current-active-user", @@ -382,7 +398,9 @@ class TestOptionalGetCurrentActiveUser: assert response.status_code == status.HTTP_200_OK assert response.json() is None - async def test_valid_token(self, test_app_client: httpx.AsyncClient, user: UserDB): + async def test_valid_token( + self, test_app_client: httpx.AsyncClient, user: UserModel + ): response = await test_app_client.get( "/optional-current-active-user", headers={"Authorization": f"Bearer {user.id}"}, @@ -407,7 +425,7 @@ class TestOptionalGetCurrentSuperuser: assert response.json() is None async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/optional-current-superuser", @@ -417,7 +435,7 @@ class TestOptionalGetCurrentSuperuser: assert response.json() is None async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/optional-current-superuser", @@ -444,7 +462,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_regular_user( - self, test_app_client: httpx.AsyncClient, user: UserDB + self, test_app_client: httpx.AsyncClient, user: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", @@ -454,7 +472,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_verified_user( - self, test_app_client: httpx.AsyncClient, verified_user: UserDB + self, test_app_client: httpx.AsyncClient, verified_user: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", @@ -464,7 +482,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_superuser( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + self, test_app_client: httpx.AsyncClient, superuser: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", @@ -474,7 +492,7 @@ class TestOptionalGetCurrentVerifiedSuperuser: assert response.json() is None async def test_valid_token_verified_superuser( - self, test_app_client: httpx.AsyncClient, verified_superuser: UserDB + self, test_app_client: httpx.AsyncClient, verified_superuser: UserModel ): response = await test_app_client.get( "/optional-current-verified-superuser", diff --git a/tests/test_manager.py b/tests/test_manager.py index 97f8cd1b..634668cf 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,13 +1,14 @@ -from typing import Callable, cast +from typing import Callable import pytest from fastapi.security import OAuth2PasswordRequestForm from pydantic import UUID4 from pytest_mock import MockerFixture -from fastapi_users import models from fastapi_users.jwt import decode_jwt, generate_jwt from fastapi_users.manager import ( + IntegerIDMixin, + InvalidID, InvalidPasswordException, InvalidResetPasswordToken, InvalidVerifyToken, @@ -16,11 +17,17 @@ from fastapi_users.manager import ( UserInactive, UserNotExists, ) -from tests.conftest import UserCreate, UserDB, UserDBOAuth, UserManagerMock, UserUpdate +from tests.conftest import ( + UserCreate, + UserManagerMock, + UserModel, + UserOAuthModel, + UserUpdate, +) @pytest.fixture -def verify_token(user_manager: UserManagerMock): +def verify_token(user_manager: UserManagerMock[UserModel]): def _verify_token( user_id=None, email=None, @@ -37,7 +44,7 @@ def verify_token(user_manager: UserManagerMock): @pytest.fixture -def forgot_password_token(user_manager: UserManagerMock): +def forgot_password_token(user_manager: UserManagerMock[UserModel]): def _forgot_password_token( user_id=None, lifetime=user_manager.reset_password_token_lifetime_seconds ): @@ -62,11 +69,13 @@ def create_oauth2_password_request_form() -> Callable[ @pytest.mark.asyncio @pytest.mark.manager class TestGet: - async def test_not_existing_user(self, user_manager: UserManagerMock): + async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(UserNotExists): await user_manager.get(UUID4("d35d213e-f3d8-4f08-954a-7e0d1bea286f")) - async def test_existing_user(self, user_manager: UserManagerMock, user: UserDB): + async def test_existing_user( + self, user_manager: UserManagerMock[UserModel], user: UserModel + ): retrieved_user = await user_manager.get(user.id) assert retrieved_user.id == user.id @@ -74,11 +83,13 @@ class TestGet: @pytest.mark.asyncio @pytest.mark.manager class TestGetByEmail: - async def test_not_existing_user(self, user_manager: UserManagerMock): + async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(UserNotExists): await user_manager.get_by_email("lancelot@camelot.bt") - async def test_existing_user(self, user_manager: UserManagerMock, user: UserDB): + async def test_existing_user( + self, user_manager: UserManagerMock[UserModel], user: UserModel + ): retrieved_user = await user_manager.get_by_email(user.email) assert retrieved_user.id == user.id @@ -86,12 +97,14 @@ class TestGetByEmail: @pytest.mark.asyncio @pytest.mark.manager class TestGetByOAuthAccount: - async def test_not_existing_user(self, user_manager_oauth: UserManagerMock): + async def test_not_existing_user( + self, user_manager_oauth: UserManagerMock[UserModel] + ): with pytest.raises(UserNotExists): await user_manager_oauth.get_by_oauth_account("service1", "foo") async def test_existing_user( - self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth + self, user_manager_oauth: UserManagerMock[UserModel], user_oauth: UserOAuthModel ): oauth_account = user_oauth.oauth_accounts[0] retrieved_user = await user_manager_oauth.get_by_oauth_account( @@ -106,42 +119,46 @@ class TestCreateUser: @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] ) - async def test_existing_user(self, email: str, user_manager: UserManagerMock): + async def test_existing_user( + self, email: str, user_manager: UserManagerMock[UserModel] + ): user = UserCreate(email=email, password="guinevere") with pytest.raises(UserAlreadyExists): await user_manager.create(user) assert user_manager.on_after_register.called is False @pytest.mark.parametrize("email", ["lancelot@camelot.bt", "Lancelot@camelot.bt"]) - async def test_regular_user(self, email: str, user_manager: UserManagerMock): + async def test_regular_user( + self, email: str, user_manager: UserManagerMock[UserModel] + ): user = UserCreate(email=email, password="guinevere") created_user = await user_manager.create(user) - assert type(created_user) == UserDB + assert type(created_user) == UserModel assert user_manager.on_after_register.called is True @pytest.mark.parametrize("safe,result", [(True, False), (False, True)]) async def test_superuser( - self, user_manager: UserManagerMock, safe: bool, result: bool + self, user_manager: UserManagerMock[UserModel], safe: bool, result: bool ): user = UserCreate( email="lancelot@camelot.b", password="guinevere", is_superuser=True ) created_user = await user_manager.create(user, safe) - assert type(created_user) == UserDB + assert type(created_user) == UserModel assert created_user.is_superuser is result assert user_manager.on_after_register.called is True @pytest.mark.parametrize("safe,result", [(True, True), (False, False)]) async def test_is_active( - self, user_manager: UserManagerMock, safe: bool, result: bool + self, user_manager: UserManagerMock[UserModel], safe: bool, result: bool ): user = UserCreate( email="lancelot@camelot.b", password="guinevere", is_active=False ) created_user = await user_manager.create(user, safe) - assert type(created_user) == UserDB + assert type(created_user) == UserModel assert created_user.is_active is result assert user_manager.on_after_register.called is True @@ -151,16 +168,22 @@ class TestCreateUser: @pytest.mark.manager class TestOAuthCallback: async def test_existing_user_with_oauth( - self, user_manager_oauth: UserManagerMock, user_oauth: UserDBOAuth + self, + user_manager_oauth: UserManagerMock[UserOAuthModel], + user_oauth: UserOAuthModel, ): - oauth_account = models.BaseOAuthAccount( - **user_oauth.oauth_accounts[0].dict(exclude={"id", "access_token"}), - access_token="UPDATED_TOKEN" + oauth_account = user_oauth.oauth_accounts[0] + + user = await user_manager_oauth.oauth_callback( + oauth_account.oauth_name, + "UPDATED_TOKEN", + oauth_account.account_id, + oauth_account.account_email, ) - user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) assert user.id == user_oauth.id assert len(user.oauth_accounts) == 2 + assert user.oauth_accounts[0].id == oauth_account.id assert user.oauth_accounts[0].oauth_name == "service1" assert user.oauth_accounts[0].access_token == "UPDATED_TOKEN" assert user.oauth_accounts[1].access_token == "TOKEN" @@ -169,36 +192,28 @@ class TestOAuthCallback: assert user_manager_oauth.on_after_register.called is False async def test_existing_user_without_oauth( - self, user_manager_oauth: UserManagerMock, superuser_oauth: UserDBOAuth + self, + user_manager_oauth: UserManagerMock[UserOAuthModel], + superuser_oauth: UserOAuthModel, ): - oauth_account = models.BaseOAuthAccount( - oauth_name="service1", - access_token="TOKEN", - expires_at=1579000751, - account_id="superuser_oauth1", - account_email=superuser_oauth.email, + user = await user_manager_oauth.oauth_callback( + "service1", "TOKEN", "superuser_oauth1", superuser_oauth.email, 1579000751 ) - user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) assert user.id == superuser_oauth.id assert len(user.oauth_accounts) == 1 - assert user.oauth_accounts[0].id == oauth_account.id + assert user.oauth_accounts[0].id is not None assert user_manager_oauth.on_after_register.called is False - async def test_new_user(self, user_manager_oauth: UserManagerMock): - oauth_account = models.BaseOAuthAccount( - oauth_name="service1", - access_token="TOKEN", - expires_at=1579000751, - account_id="new_user_oauth1", - account_email="galahad@camelot.bt", + async def test_new_user(self, user_manager_oauth: UserManagerMock[UserOAuthModel]): + user = await user_manager_oauth.oauth_callback( + "service1", "TOKEN", "new_user_oauth1", "galahad@camelot.bt", 1579000751 ) - user = cast(UserDBOAuth, await user_manager_oauth.oauth_callback(oauth_account)) assert user.email == "galahad@camelot.bt" assert len(user.oauth_accounts) == 1 - assert user.oauth_accounts[0].id == oauth_account.id + assert user.oauth_accounts[0].id is not None assert user_manager_oauth.on_after_register.called is True @@ -207,19 +222,19 @@ class TestOAuthCallback: @pytest.mark.manager class TestRequestVerifyUser: async def test_user_inactive( - self, user_manager: UserManagerMock, inactive_user: UserDB + self, user_manager: UserManagerMock[UserModel], inactive_user: UserModel ): with pytest.raises(UserInactive): await user_manager.request_verify(inactive_user) async def test_user_verified( - self, user_manager: UserManagerMock, verified_user: UserDB + self, user_manager: UserManagerMock[UserModel], verified_user: UserModel ): with pytest.raises(UserAlreadyVerified): await user_manager.request_verify(verified_user) async def test_user_active_not_verified( - self, user_manager: UserManagerMock, user: UserDB + self, user_manager: UserManagerMock[UserModel], user: UserModel ): await user_manager.request_verify(user) assert user_manager.on_after_request_verify.called is True @@ -240,40 +255,40 @@ class TestRequestVerifyUser: @pytest.mark.asyncio @pytest.mark.manager class TestVerifyUser: - async def test_invalid_token(self, user_manager: UserManagerMock): + async def test_invalid_token(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(InvalidVerifyToken): await user_manager.verify("foo") async def test_token_expired( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id=user.id, email=user.email, lifetime=-1) await user_manager.verify(token) async def test_missing_user_id( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(email=user.email) await user_manager.verify(token) async def test_missing_user_email( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id=user.id) await user_manager.verify(token) async def test_invalid_user_id( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id="foo", email=user.email) await user_manager.verify(token) async def test_invalid_email( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): with pytest.raises(InvalidVerifyToken): token = verify_token(user_id=user.id, email="foo") @@ -281,9 +296,9 @@ class TestVerifyUser: async def test_email_id_mismatch( self, - user_manager: UserManagerMock, - user: UserDB, - inactive_user: UserDB, + user_manager: UserManagerMock[UserModel], + user: UserModel, + inactive_user: UserModel, verify_token, ): with pytest.raises(InvalidVerifyToken): @@ -291,14 +306,20 @@ class TestVerifyUser: await user_manager.verify(token) async def test_verified_user( - self, user_manager: UserManagerMock, verified_user: UserDB, verify_token + self, + user_manager: UserManagerMock[UserModel], + verified_user: UserModel, + verify_token, ): with pytest.raises(UserAlreadyVerified): token = verify_token(user_id=verified_user.id, email=verified_user.email) await user_manager.verify(token) async def test_inactive_user( - self, user_manager: UserManagerMock, inactive_user: UserDB, verify_token + self, + user_manager: UserManagerMock[UserModel], + inactive_user: UserModel, + verify_token, ): token = verify_token(user_id=inactive_user.id, email=inactive_user.email) verified_user = await user_manager.verify(token) @@ -307,7 +328,7 @@ class TestVerifyUser: assert verified_user.is_active is False async def test_active_user( - self, user_manager: UserManagerMock, user: UserDB, verify_token + self, user_manager: UserManagerMock[UserModel], user: UserModel, verify_token ): token = verify_token(user_id=user.id, email=user.email) verified_user = await user_manager.verify(token) @@ -320,13 +341,15 @@ class TestVerifyUser: @pytest.mark.manager class TestForgotPassword: async def test_user_inactive( - self, user_manager: UserManagerMock, inactive_user: UserDB + self, user_manager: UserManagerMock[UserModel], inactive_user: UserModel ): with pytest.raises(UserInactive): await user_manager.forgot_password(inactive_user) assert user_manager.on_after_forgot_password.called is False - async def test_user_active(self, user_manager: UserManagerMock, user: UserDB): + async def test_user_active( + self, user_manager: UserManagerMock[UserModel], user: UserModel + ): await user_manager.forgot_password(user) assert user_manager.on_after_forgot_password.called is True @@ -345,14 +368,17 @@ class TestForgotPassword: @pytest.mark.asyncio @pytest.mark.manager class TestResetPassword: - async def test_invalid_token(self, user_manager: UserManagerMock): + async def test_invalid_token(self, user_manager: UserManagerMock[UserModel]): with pytest.raises(InvalidResetPasswordToken): await user_manager.reset_password("foo", "guinevere") assert user_manager._update.called is False assert user_manager.on_after_reset_password.called is False async def test_token_expired( - self, user_manager: UserManagerMock, user: UserDB, forgot_password_token + self, + user_manager: UserManagerMock[UserModel], + user: UserModel, + forgot_password_token, ): with pytest.raises(InvalidResetPasswordToken): await user_manager.reset_password( @@ -363,7 +389,10 @@ class TestResetPassword: @pytest.mark.parametrize("user_id", [None, "foo"]) async def test_valid_token_bad_payload( - self, user_id: str, user_manager: UserManagerMock, forgot_password_token + self, + user_id: str, + user_manager: UserManagerMock[UserModel], + forgot_password_token, ): with pytest.raises(InvalidResetPasswordToken): await user_manager.reset_password( @@ -373,7 +402,7 @@ class TestResetPassword: assert user_manager.on_after_reset_password.called is False async def test_not_existing_user( - self, user_manager: UserManagerMock, forgot_password_token + self, user_manager: UserManagerMock[UserModel], forgot_password_token ): with pytest.raises(UserNotExists): await user_manager.reset_password( @@ -385,8 +414,8 @@ class TestResetPassword: async def test_inactive_user( self, - inactive_user: UserDB, - user_manager: UserManagerMock, + inactive_user: UserModel, + user_manager: UserManagerMock[UserModel], forgot_password_token, ): with pytest.raises(UserInactive): @@ -398,7 +427,10 @@ class TestResetPassword: assert user_manager.on_after_reset_password.called is False async def test_invalid_password( - self, user: UserDB, user_manager: UserManagerMock, forgot_password_token + self, + user: UserModel, + user_manager: UserManagerMock[UserModel], + forgot_password_token, ): with pytest.raises(InvalidPasswordException): await user_manager.reset_password( @@ -408,7 +440,10 @@ class TestResetPassword: assert user_manager.on_after_reset_password.called is False async def test_valid_user_password( - self, user: UserDB, user_manager: UserManagerMock, forgot_password_token + self, + user: UserModel, + user_manager: UserManagerMock[UserModel], + forgot_password_token, ): await user_manager.reset_password(forgot_password_token(user.id), "holygrail") @@ -424,7 +459,9 @@ class TestResetPassword: @pytest.mark.asyncio @pytest.mark.manager class TestUpdateUser: - async def test_safe_update(self, user: UserDB, user_manager: UserManagerMock): + async def test_safe_update( + self, user: UserModel, user_manager: UserManagerMock[UserModel] + ): user_update = UserUpdate(first_name="Arthur", is_superuser=True) updated_user = await user_manager.update(user_update, user, safe=True) @@ -433,7 +470,9 @@ class TestUpdateUser: assert user_manager.on_after_update.called is True - async def test_unsafe_update(self, user: UserDB, user_manager: UserManagerMock): + async def test_unsafe_update( + self, user: UserModel, user_manager: UserManagerMock[UserModel] + ): user_update = UserUpdate(first_name="Arthur", is_superuser=True) updated_user = await user_manager.update(user_update, user, safe=False) @@ -443,7 +482,7 @@ class TestUpdateUser: assert user_manager.on_after_update.called is True async def test_password_update_invalid( - self, user: UserDB, user_manager: UserManagerMock + self, user: UserModel, user_manager: UserManagerMock[UserModel] ): user_update = UserUpdate(password="h") with pytest.raises(InvalidPasswordException): @@ -452,7 +491,7 @@ class TestUpdateUser: assert user_manager.on_after_update.called is False async def test_password_update_valid( - self, user: UserDB, user_manager: UserManagerMock + self, user: UserModel, user_manager: UserManagerMock[UserModel] ): old_hashed_password = user.hashed_password user_update = UserUpdate(password="holygrail") @@ -463,7 +502,10 @@ class TestUpdateUser: assert user_manager.on_after_update.called is True async def test_email_update_already_existing( - self, user: UserDB, superuser: UserDB, user_manager: UserManagerMock + self, + user: UserModel, + superuser: UserModel, + user_manager: UserManagerMock[UserModel], ): user_update = UserUpdate(email=superuser.email) with pytest.raises(UserAlreadyExists): @@ -472,7 +514,7 @@ class TestUpdateUser: assert user_manager.on_after_update.called is False async def test_email_update_with_same_email( - self, user: UserDB, user_manager: UserManagerMock + self, user: UserModel, user_manager: UserManagerMock[UserModel] ): user_update = UserUpdate(email=user.email) updated_user = await user_manager.update(user_update, user, safe=True) @@ -485,7 +527,9 @@ class TestUpdateUser: @pytest.mark.asyncio @pytest.mark.manager class TestDelete: - async def test_delete(self, user: UserDB, user_manager: UserManagerMock): + async def test_delete( + self, user: UserModel, user_manager: UserManagerMock[UserModel] + ): await user_manager.delete(user) @@ -497,7 +541,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): form = create_oauth2_password_request_form("lancelot@camelot.bt", "guinevere") user = await user_manager.authenticate(form) @@ -508,7 +552,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): form = create_oauth2_password_request_form("king.arthur@camelot.bt", "percival") user = await user_manager.authenticate(form) @@ -519,7 +563,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): form = create_oauth2_password_request_form( "king.arthur@camelot.bt", "guinevere" @@ -534,7 +578,7 @@ class TestAuthenticate: create_oauth2_password_request_form: Callable[ [str, str], OAuth2PasswordRequestForm ], - user_manager: UserManagerMock, + user_manager: UserManagerMock[UserModel], ): verify_and_update_password_patch = mocker.patch.object( user_manager.password_helper, "verify_and_update" @@ -549,3 +593,19 @@ class TestAuthenticate: assert user is not None assert user.email == "king.arthur@camelot.bt" assert update_spy.called is True + + +def test_integer_id_mixin(): + integer_id_mixin = IntegerIDMixin() + + assert integer_id_mixin.parse_id("123") == 123 + assert integer_id_mixin.parse_id(123) == 123 + + with pytest.raises(InvalidID): + integer_id_mixin.parse_id("123.42") + + with pytest.raises(InvalidID): + integer_id_mixin.parse_id(123.42) + + with pytest.raises(InvalidID): + integer_id_mixin.parse_id("abc") diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 205cf746..439da25d 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -3,19 +3,12 @@ import pytest from fastapi import FastAPI, status from fastapi_users.fastapi_users import FastAPIUsers -from tests.conftest import User, UserCreate, UserDB, UserUpdate +from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate @pytest.fixture def fastapi_users(get_user_manager, mock_authentication) -> FastAPIUsers: - return FastAPIUsers[User, UserCreate, UserUpdate, UserDB]( - get_user_manager, - [mock_authentication], - User, - UserCreate, - UserUpdate, - UserDB, - ) + return FastAPIUsers[UserModel, IDType](get_user_manager, [mock_authentication]) @pytest.fixture @@ -23,14 +16,14 @@ def test_app( fastapi_users: FastAPIUsers, secret, mock_authentication, oauth_client ) -> FastAPI: app = FastAPI() - app.include_router(fastapi_users.get_register_router()) + app.include_router(fastapi_users.get_register_router(User, UserCreate)) app.include_router(fastapi_users.get_reset_password_router()) app.include_router(fastapi_users.get_auth_router(mock_authentication)) app.include_router( fastapi_users.get_oauth_router(oauth_client, mock_authentication, secret) ) - app.include_router(fastapi_users.get_users_router()) - app.include_router(fastapi_users.get_verify_router()) + app.include_router(fastapi_users.get_users_router(User, UserUpdate)) + app.include_router(fastapi_users.get_verify_router(User)) return app diff --git a/tests/test_router_auth.py b/tests/test_router_auth.py index 067c3743..1bbc8026 100644 --- a/tests/test_router_auth.py +++ b/tests/test_router_auth.py @@ -6,7 +6,7 @@ from fastapi import FastAPI, status from fastapi_users.authentication import Authenticator from fastapi_users.router import ErrorCode, get_auth_router -from tests.conftest import UserDB, get_mock_authentication +from tests.conftest import UserModel, get_mock_authentication @pytest.fixture @@ -118,7 +118,7 @@ class TestLogin: path, email, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client data = {"username": email, "password": "guinevere"} @@ -140,7 +140,7 @@ class TestLogin: path, email, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client data = {"username": email, "password": "excalibur"} @@ -182,7 +182,7 @@ class TestLogout: mocker, path, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.post( @@ -198,7 +198,7 @@ class TestLogout: mocker, path, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.post( diff --git a/tests/test_router_oauth.py b/tests/test_router_oauth.py index 89194b3f..bc30648b 100644 --- a/tests/test_router_oauth.py +++ b/tests/test_router_oauth.py @@ -7,7 +7,7 @@ from httpx_oauth.oauth2 import BaseOAuth2, OAuth2 from fastapi_users.authentication import AuthenticationBackend from fastapi_users.router.oauth import generate_state_token, get_oauth_router -from tests.conftest import AsyncMethodMocker, UserDB, UserManagerMock +from tests.conftest import AsyncMethodMocker, UserManagerMock, UserOAuthModel @pytest.fixture @@ -112,7 +112,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, oauth_client: BaseOAuth2, - user_oauth: UserDB, + user_oauth: UserOAuthModel, access_token: str, ): async_method_mocker(oauth_client, "get_access_token", return_value=access_token) @@ -133,7 +133,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, oauth_client: BaseOAuth2, - user_oauth: UserDB, + user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, ): @@ -161,7 +161,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, oauth_client: BaseOAuth2, - inactive_user_oauth: UserDB, + inactive_user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, ): @@ -188,7 +188,7 @@ class TestCallback: async_method_mocker: AsyncMethodMocker, test_app_client_redirect_url: httpx.AsyncClient, oauth_client: BaseOAuth2, - user_oauth: UserDB, + user_oauth: UserOAuthModel, user_manager_oauth: UserManagerMock, access_token: str, ): diff --git a/tests/test_router_users.py b/tests/test_router_users.py index 1d4752c9..cc4364a9 100644 --- a/tests/test_router_users.py +++ b/tests/test_router_users.py @@ -6,7 +6,7 @@ from fastapi import FastAPI, status from fastapi_users.authentication import Authenticator from fastapi_users.router import ErrorCode, get_users_router -from tests.conftest import User, UserDB, UserUpdate, get_mock_authentication +from tests.conftest import User, UserModel, UserUpdate, get_mock_authentication @pytest.fixture @@ -21,7 +21,6 @@ def app_factory(get_user_manager, mock_authentication): get_user_manager, User, UserUpdate, - UserDB, authenticator, requires_verification=requires_verification, ) @@ -59,7 +58,7 @@ class TestMe: async def test_inactive_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - inactive_user: UserDB, + inactive_user: UserModel, ): client, _ = test_app_client response = await client.get( @@ -70,7 +69,7 @@ class TestMe: async def test_active_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -87,7 +86,7 @@ class TestMe: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.get( @@ -116,7 +115,7 @@ class TestUpdateMe: async def test_inactive_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - inactive_user: UserDB, + inactive_user: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -127,8 +126,8 @@ class TestUpdateMe: async def test_existing_email( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_user: UserDB, + user: UserModel, + verified_user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -146,7 +145,7 @@ class TestUpdateMe: async def test_invalid_password( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -167,7 +166,7 @@ class TestUpdateMe: async def test_empty_body( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -184,7 +183,7 @@ class TestUpdateMe: async def test_valid_body( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -202,7 +201,7 @@ class TestUpdateMe: async def test_unverified_after_email_change( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -217,7 +216,7 @@ class TestUpdateMe: async def test_valid_body_is_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"is_superuser": True} @@ -235,7 +234,7 @@ class TestUpdateMe: async def test_valid_body_is_active( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"is_active": False} @@ -253,7 +252,7 @@ class TestUpdateMe: async def test_valid_body_is_verified( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client json = {"is_verified": True} @@ -273,7 +272,7 @@ class TestUpdateMe: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client mocker.spy(mock_user_db, "update") @@ -295,7 +294,7 @@ class TestUpdateMe: async def test_empty_body_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -309,7 +308,7 @@ class TestUpdateMe: async def test_valid_body_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -324,7 +323,7 @@ class TestUpdateMe: async def test_valid_body_is_superuser_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"is_superuser": True} @@ -339,7 +338,7 @@ class TestUpdateMe: async def test_valid_body_is_active_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"is_active": False} @@ -354,7 +353,7 @@ class TestUpdateMe: async def test_valid_body_is_verified_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client json = {"is_verified": False} @@ -371,7 +370,7 @@ class TestUpdateMe: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client mocker.spy(mock_user_db, "update") @@ -399,7 +398,7 @@ class TestGetUser: async def test_regular_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -412,7 +411,7 @@ class TestGetUser: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.get( @@ -424,7 +423,7 @@ class TestGetUser: async def test_not_existing_user_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - superuser: UserDB, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -439,7 +438,7 @@ class TestGetUser: async def test_not_existing_user_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_superuser: UserDB, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.get( @@ -451,8 +450,8 @@ class TestGetUser: async def test_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.get( @@ -470,8 +469,8 @@ class TestGetUser: async def test_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.get( @@ -483,7 +482,7 @@ class TestGetUser: assert data["id"] == str(user.id) assert "hashed_password" not in data - async def test_get_user_namespace(self, app_factory, user: UserDB): + async def test_get_user_namespace(self, app_factory, user: UserModel): assert app_factory(True).url_path_for("users:user", id=user.id) == f"/{user.id}" @@ -498,7 +497,7 @@ class TestUpdateUser: async def test_regular_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -511,7 +510,7 @@ class TestUpdateUser: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -523,7 +522,7 @@ class TestUpdateUser: async def test_not_existing_user_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - superuser: UserDB, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -539,7 +538,7 @@ class TestUpdateUser: async def test_not_existing_user_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_superuser: UserDB, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -552,8 +551,8 @@ class TestUpdateUser: async def test_empty_body_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.patch( @@ -570,8 +569,8 @@ class TestUpdateUser: async def test_empty_body_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -587,8 +586,8 @@ class TestUpdateUser: async def test_valid_body_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -608,9 +607,9 @@ class TestUpdateUser: async def test_existing_email_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -625,8 +624,8 @@ class TestUpdateUser: async def test_invalid_password_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.patch( @@ -644,8 +643,8 @@ class TestUpdateUser: async def test_valid_body_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} @@ -662,8 +661,8 @@ class TestUpdateUser: async def test_valid_body_is_superuser_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"is_superuser": True} @@ -683,8 +682,8 @@ class TestUpdateUser: async def test_valid_body_is_superuser_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"is_superuser": True} @@ -701,8 +700,8 @@ class TestUpdateUser: async def test_valid_body_is_active_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"is_active": False} @@ -722,8 +721,8 @@ class TestUpdateUser: async def test_valid_body_is_active_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"is_active": False} @@ -740,8 +739,8 @@ class TestUpdateUser: async def test_valid_body_is_verified_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client json = {"is_verified": True} @@ -761,8 +760,8 @@ class TestUpdateUser: async def test_valid_body_is_verified_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client json = {"is_verified": True} @@ -781,8 +780,8 @@ class TestUpdateUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client mocker.spy(mock_user_db, "update") @@ -808,8 +807,8 @@ class TestUpdateUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client mocker.spy(mock_user_db, "update") @@ -839,7 +838,7 @@ class TestDeleteUser: async def test_regular_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, + user: UserModel, ): client, requires_verification = test_app_client response = await client.delete( @@ -852,7 +851,7 @@ class TestDeleteUser: async def test_verified_user( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_user: UserDB, + verified_user: UserModel, ): client, _ = test_app_client response = await client.delete( @@ -864,7 +863,7 @@ class TestDeleteUser: async def test_not_existing_user_unverified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - superuser: UserDB, + superuser: UserModel, ): client, requires_verification = test_app_client response = await client.delete( @@ -879,7 +878,7 @@ class TestDeleteUser: async def test_not_existing_user_verified_superuser( self, test_app_client: Tuple[httpx.AsyncClient, bool], - verified_superuser: UserDB, + verified_superuser: UserModel, ): client, _ = test_app_client response = await client.delete( @@ -893,8 +892,8 @@ class TestDeleteUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - superuser: UserDB, + user: UserModel, + superuser: UserModel, ): client, requires_verification = test_app_client mocker.spy(mock_user_db, "delete") @@ -917,8 +916,8 @@ class TestDeleteUser: mocker, mock_user_db, test_app_client: Tuple[httpx.AsyncClient, bool], - user: UserDB, - verified_superuser: UserDB, + user: UserModel, + verified_superuser: UserModel, ): client, _ = test_app_client mocker.spy(mock_user_db, "delete") diff --git a/tests/test_router_verify.py b/tests/test_router_verify.py index c954a0fd..f7619c45 100644 --- a/tests/test_router_verify.py +++ b/tests/test_router_verify.py @@ -11,7 +11,7 @@ from fastapi_users.manager import ( UserNotExists, ) from fastapi_users.router import ErrorCode, get_verify_router -from tests.conftest import AsyncMethodMocker, User, UserDB, UserManagerMock +from tests.conftest import AsyncMethodMocker, User, UserManagerMock, UserModel @pytest.fixture @@ -20,10 +20,7 @@ async def test_app_client( get_user_manager, get_test_client, ) -> AsyncGenerator[httpx.AsyncClient, None]: - verify_router = get_verify_router( - get_user_manager, - User, - ) + verify_router = get_verify_router(get_user_manager, User) app = FastAPI() app.include_router(verify_router) @@ -70,7 +67,7 @@ class TestVerifyTokenRequest: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "get_by_email", return_value=user) user_manager.request_verify.side_effect = UserInactive() @@ -83,7 +80,7 @@ class TestVerifyTokenRequest: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "get_by_email", return_value=user) user_manager.request_verify.side_effect = UserAlreadyVerified() @@ -96,7 +93,7 @@ class TestVerifyTokenRequest: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "get_by_email", return_value=user) async_method_mocker(user_manager, "request_verify", return_value=None) @@ -171,7 +168,7 @@ class TestVerify: async_method_mocker: AsyncMethodMocker, test_app_client: httpx.AsyncClient, user_manager: UserManagerMock, - user: UserDB, + user: UserModel, ): async_method_mocker(user_manager, "verify", return_value=user) response = await test_app_client.post("/verify", json={"token": "foo"})