diff --git a/README.md b/README.md index a0b12df6..0394534d 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Add quickly a registration and authentication system to your [FastAPI](https://f ## Features * [X] Extensible base user model -* [X] Ready-to-use register, login, forgot and reset password routes +* [X] Ready-to-use register, login, reset password and verify e-mail routes * [X] Ready-to-use OAuth2 flow * [X] Dependency callables to inject current user in route * [X] Customizable database backend @@ -74,6 +74,12 @@ Alternatively, you can run `pytest` yourself. The MongoDB unit tests will be ski pytest ``` +There are quite a few unit tests, so you might run into ulimit issues where there are too many open file descriptors. You may be able to set a new, higher limit temporarily with: + +```bash +ulimit -n 2048 +``` + ### Format the code Execute the following command to apply `isort` and `black` formatting: diff --git a/docs/configuration/model.md b/docs/configuration/model.md index 9c6a9b03..8aca184e 100644 --- a/docs/configuration/model.md +++ b/docs/configuration/model.md @@ -5,6 +5,7 @@ * `id` (`UUID4`) – Unique identifier of the user. Default to a **UUID4**. * `email` (`str`) – Email of the user. Validated by [`email-validator`](https://github.com/JoshData/python-email-validator). * `is_active` (`bool`) – Whether or not the user is active. If not, login and forgot password requests will be denied. Default to `True`. +* `is_verified` (`bool`) – Whether or not the user is verified. Optional but helpful with the [`verify` router](./routers/verify.md) logic. Default to `False`. * `is_superuser` (`bool`) – Whether or not the user is a superuser. Useful to implement administration logic. Default to `False`. ## Define your models @@ -38,7 +39,7 @@ class UserDB(User, models.BaseUserDB): pass ``` -You can of course add you own properties there to fit to your needs! +You can of course add your own properties there to fit to your needs! ## Password validation @@ -59,7 +60,7 @@ class UserCreate(models.BaseUserCreate): ## Next steps -Depending on your database backend, database configuration will differ a bit. +Depending on your database backend, the database configuration will differ a bit. [I'm using SQLAlchemy](databases/sqlalchemy.md) diff --git a/docs/configuration/routers/auth.md b/docs/configuration/routers/auth.md index 8b10873e..a344f0d0 100644 --- a/docs/configuration/routers/auth.md +++ b/docs/configuration/routers/auth.md @@ -31,3 +31,15 @@ app.include_router( tags=["auth"], ) ``` + +### Optional: user verification + +You can require the user to be **verified** (i.e. `is_verified` property set to `True`) to allow login. You have to set the `requires_validation` parameter to `True` on the router instantiation method: + +```py +app.include_router( + fastapi_users.get_auth_router(jwt_authentication, requires_verification=True), + prefix="/auth/jwt", + tags=["auth"], +) +``` diff --git a/docs/configuration/routers/index.md b/docs/configuration/routers/index.md index 1a90b9a6..9942bc8e 100644 --- a/docs/configuration/routers/index.md +++ b/docs/configuration/routers/index.md @@ -33,6 +33,7 @@ This helper class will let you generate useful routers to setup the authenticati * [Auth router](./auth.md): Provides `/login` and `/logout` routes for a given [authentication backend](../authentication/index.md). * [Register router](./register.md): Provides `/register` routes to allow a user to create a new account. * [Reset password router](./reset.md): Provides `/forgot-password` and `/reset-password` routes to allow a user to reset its password. +* [Verify router](./verify.md): Provides `/request-verify-token` and `/verify` routes to manage user e-mail verification. * [Users router](./users.md): Provides routes to manage users. * [OAuth router](../oauth.md): Provides routes to perform an OAuth authentication against a service provider (like Google or Facebook). diff --git a/docs/configuration/routers/register.md b/docs/configuration/routers/register.md index e3cdfe6c..1c6208e5 100644 --- a/docs/configuration/routers/register.md +++ b/docs/configuration/routers/register.md @@ -51,4 +51,4 @@ app.include_router( prefix="/auth", tags=["auth"], ) -``` +``` \ No newline at end of file diff --git a/docs/configuration/routers/users.md b/docs/configuration/routers/users.md index ce17b39d..359cfffd 100644 --- a/docs/configuration/routers/users.md +++ b/docs/configuration/routers/users.md @@ -25,6 +25,18 @@ app.include_router( ) ``` +### Optional: user verification + +You can require the user to be **verified** (i.e. `is_verified` property set to `True`) to access those routes. You have to set the `requires_validation` parameter to `True` on the router instantiation method: + +```py +app.include_router( + fastapi_users.get_users_router(requires_verification=True), + prefix="/users", + tags=["users"], +) +``` + ## After update You can provide a custom function to be called after a successful update user request. It is called with **three arguments**: diff --git a/docs/configuration/routers/verify.md b/docs/configuration/routers/verify.md new file mode 100644 index 00000000..49c29967 --- /dev/null +++ b/docs/configuration/routers/verify.md @@ -0,0 +1,88 @@ +# Verify router + +!!! warning + This feature is not released yet. + +This router provides routes to manage user email verification. Check the [routes usage](../../usage/routes.md) to learn how to use them. + +!!! success "👏👏👏" + A big thank you to [Edd Salkield](https://github.com/eddsalkield) and [Mark Todd](https://github.com/mark-todd) who worked hard on this feature! + +## Setup + +```py +from fastapi import FastAPI +from fastapi_users import FastAPIUsers + +fastapi_users = FastAPIUsers( + user_db, + auth_backends, + User, + UserCreate, + UserUpdate, + UserDB, +) + +app = FastAPI() +app.include_router( + fastapi_users.get_verify_router("SECRET"), + prefix="/auth", + tags=["auth"], +) +``` + +Parameters: + +* `verification_token_secret`: Secret to encode verify token. +* `verification_token_lifetime_seconds`: Lifetime of verify token. **Defaults to 3600**. +* `after_verification_request`: Optional function called after a successful verify request. See below. +* `after_verification`: Optional function called after a successful verification. See below. + +## After verification request + +You can provide a custom function to be called after a successful verification request. It is called with **three arguments**: + +* The **user** for which the verification has been requested. +* A ready-to-use **JWT token** that will be accepted by the verify route. +* The original **`Request` object**. + +Typically, you'll want to **send an e-mail** with the link (and the token) that allows the user to verify their e-mail. + +You can define it as an `async` or standard method. + +Example: + +```py +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + +app.include_router( + fastapi_users.get_verify_router("SECRET", after_verification_request=after_verification_request), + prefix="/auth", + tags=["auth"], +) +``` + +## After verification + +You can provide a custom function to be called after a successful user verification. It is called with **two arguments**: + +* The **user** that has been verified. +* The original **`Request` object**. + +This may be useful if you wish to send another e-mail or store this information in a data analytics or customer success platform. + +You can define it as an `async` or standard method. + +Example: + +```py +def after_verification(user: UserDB, request: Request): + print(f"{user.id} is now verified.") + +app.include_router( + fastapi_users.get_verify_router("SECRET", after_verification=after_verification), + prefix="/auth", + tags=["auth"], +) +``` diff --git a/docs/src/full_mongodb.py b/docs/src/full_mongodb.py index 6ed38720..b93311e4 100644 --- a/docs/src/full_mongodb.py +++ b/docs/src/full_mongodb.py @@ -40,6 +40,10 @@ def on_after_forgot_password(user: UserDB, token: str, request: Request): print(f"User {user.id} has forgot their password. Reset token: {token}") +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + jwt_authentication = JWTAuthentication( secret=SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login" ) @@ -66,4 +70,11 @@ app.include_router( prefix="/auth", tags=["auth"], ) +app.include_router( + fastapi_users.get_verify_router( + SECRET, after_verification_request=after_verification_request + ), + prefix="/auth", + tags=["auth"], +) app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) diff --git a/docs/src/full_sqlalchemy.py b/docs/src/full_sqlalchemy.py index 2558823e..2b28c5d0 100644 --- a/docs/src/full_sqlalchemy.py +++ b/docs/src/full_sqlalchemy.py @@ -51,6 +51,10 @@ def on_after_forgot_password(user: UserDB, token: str, request: Request): print(f"User {user.id} has forgot their password. Reset token: {token}") +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + jwt_authentication = JWTAuthentication( secret=SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login" ) @@ -77,6 +81,13 @@ app.include_router( prefix="/auth", tags=["auth"], ) +app.include_router( + fastapi_users.get_verify_router( + SECRET, after_verification_request=after_verification_request + ), + prefix="/auth", + tags=["auth"], +) app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) diff --git a/docs/src/full_tortoise.py b/docs/src/full_tortoise.py index 5b2ba8d8..dd60b4cf 100644 --- a/docs/src/full_tortoise.py +++ b/docs/src/full_tortoise.py @@ -46,6 +46,10 @@ def on_after_forgot_password(user: UserDB, token: str, request: Request): print(f"User {user.id} has forgot their password. Reset token: {token}") +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + jwt_authentication = JWTAuthentication( secret=SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login" ) @@ -71,4 +75,11 @@ app.include_router( prefix="/auth", tags=["auth"], ) +app.include_router( + fastapi_users.get_verify_router( + SECRET, after_verification_request=after_verification_request + ), + prefix="/auth", + tags=["auth"], +) app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) diff --git a/docs/src/oauth_full_mongodb.py b/docs/src/oauth_full_mongodb.py index ee675a3b..0159e062 100644 --- a/docs/src/oauth_full_mongodb.py +++ b/docs/src/oauth_full_mongodb.py @@ -44,6 +44,10 @@ def on_after_forgot_password(user: UserDB, token: str, request: Request): print(f"User {user.id} has forgot their password. Reset token: {token}") +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + jwt_authentication = JWTAuthentication( secret=SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login" ) @@ -70,6 +74,13 @@ app.include_router( prefix="/auth", tags=["auth"], ) +app.include_router( + fastapi_users.get_verify_router( + SECRET, after_verification_request=after_verification_request + ), + prefix="/auth", + tags=["auth"], +) app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) google_oauth_router = fastapi_users.get_oauth_router( diff --git a/docs/src/oauth_full_sqlalchemy.py b/docs/src/oauth_full_sqlalchemy.py index 167d5ea1..fdd22540 100644 --- a/docs/src/oauth_full_sqlalchemy.py +++ b/docs/src/oauth_full_sqlalchemy.py @@ -64,6 +64,10 @@ def on_after_forgot_password(user: UserDB, token: str, request: Request): print(f"User {user.id} has forgot their password. Reset token: {token}") +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + jwt_authentication = JWTAuthentication( secret=SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login" ) @@ -90,6 +94,13 @@ app.include_router( prefix="/auth", tags=["auth"], ) +app.include_router( + fastapi_users.get_verify_router( + SECRET, after_verification_request=after_verification_request + ), + prefix="/auth", + tags=["auth"], +) app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) google_oauth_router = fastapi_users.get_oauth_router( diff --git a/docs/src/oauth_full_tortoise.py b/docs/src/oauth_full_tortoise.py index 704468b7..c03a16ee 100644 --- a/docs/src/oauth_full_tortoise.py +++ b/docs/src/oauth_full_tortoise.py @@ -59,6 +59,10 @@ def on_after_forgot_password(user: UserDB, token: str, request: Request): print(f"User {user.id} has forgot their password. Reset token: {token}") +def after_verification_request(user: UserDB, token: str, request: Request): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + jwt_authentication = JWTAuthentication( secret=SECRET, lifetime_seconds=3600, tokenUrl="/auth/jwt/login" ) @@ -84,6 +88,13 @@ app.include_router( prefix="/auth", tags=["auth"], ) +app.include_router( + fastapi_users.get_verify_router( + SECRET, after_verification_request=after_verification_request + ), + prefix="/auth", + tags=["auth"], +) app.include_router(fastapi_users.get_users_router(), prefix="/users", tags=["users"]) google_oauth_router = fastapi_users.get_oauth_router( diff --git a/docs/usage/dependency-callables.md b/docs/usage/dependency-callables.md index 090143ed..94617b23 100644 --- a/docs/usage/dependency-callables.md +++ b/docs/usage/dependency-callables.md @@ -25,6 +25,16 @@ def protected_route(user: User = Depends(fastapi_users.get_current_active_user)) return f"Hello, {user.email}" ``` +## `get_current_verified_user` + +Get the current active and verified user. Will throw a `401 Unauthorized` if missing or wrong credentials or if the user is not active and verified. + +```py +@app.get("/protected-route") +def protected_route(user: User = Depends(fastapi_users.get_current_verified_user)): + return f"Hello, {user.email}" +``` + ## `get_current_superuser` Get the current superuser. Will throw a `401 Unauthorized` if missing or wrong credentials or if the user is not active. Will throw a `403 Forbidden` if the user is not a superuser. @@ -35,6 +45,16 @@ def protected_route(user: User = Depends(fastapi_users.get_current_superuser)): return f"Hello, {user.email}" ``` +## `get_current_verified_superuser` + +Get the current verified superuser. Will throw a `401 Unauthorized` if missing or wrong credentials or if the user is not active and verified. Will throw a `403 Forbidden` if the user is not a superuser. + +```py +@app.get("/protected-route") +def protected_route(user: User = Depends(fastapi_users.get_current_verified_superuser)): + return f"Hello, {user.email}" +``` + ## `get_optional_current_user` Get the current user (**active or not**). Will return `None` if missing or wrong credentials. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. @@ -50,7 +70,7 @@ def optional_user_route(user: Optional[User] = Depends(fastapi_users.get_optiona ## `get_optional_current_active_user` -Get the current active user. Will return `None` if missing or wrong credentials. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. +Get the current active user. Will return `None` if missing or wrong credentials or if the user is not active. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. ```py @app.get("/optional-user-route") @@ -61,9 +81,22 @@ def optional_user_route(user: User = Depends(fastapi_users.get_optional_current_ return "Hello, anonymous" ``` +## `get_optional_current_verified_user` + +Get the current active and verified user. Will return `None` if missing or wrong credentials or if the user is not active and verified. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. + +```py +@app.get("/optional-user-route") +def optional_user_route(user: User = Depends(fastapi_users.get_optional_current_verified_user)): + if user: + return f"Hello, {user.email}" + else: + return "Hello, anonymous" +``` + ## `get_optional_current_superuser` -Get the current superuser. Will return `None` if missing or wrong credentials. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. +Get the current superuser. Will return `None` if missing or wrong credentials or if the user is not active. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. ```py @app.get("/optional-user-route") @@ -74,9 +107,22 @@ def optional_user_route(user: User = Depends(fastapi_users.get_optional_current_ return "Hello, anonymous" ``` +## `get_optional_current_verified_superuser` + +Get the current active and verified superuser. Will return `None` if missing or wrong credentials or if the user is not active and verified. It can be useful if you wish to change the behaviour of your endpoint if a user is logged in or not. + +```py +@app.get("/optional-user-route") +def optional_user_route(user: User = Depends(fastapi_users.get_optional_current_verified_superuser)): + if user: + return f"Hello, {user.email}" + else: + return "Hello, anonymous" +``` + ## In path operation -If you don't need a user, you can use more clear way: +If you don't need the user in the route logic, you can use this syntax: ```py @app.get("/protected-route", dependencies=[Depends(fastapi_users.get_current_superuser)]) diff --git a/docs/usage/helpers.md b/docs/usage/helpers.md index 623dc9e9..3dbda8a1 100644 --- a/docs/usage/helpers.md +++ b/docs/usage/helpers.md @@ -1,8 +1,10 @@ # Helpers +**FastAPI Users** provides some helper functions to perform some actions programmatically. They are available from your `FastAPIUsers` instance. + ## Create user -**FastAPI Users** provides a helper function to easily create a user programmatically. They are available from your `FastAPIUsers` instance. +Create a user. ```py regular_user = await fastapi_users.create_user( @@ -20,3 +22,20 @@ superuser = await fastapi_users.create_user( ) ) ``` + +## Verify user + +Verify a user. + +```py +verified_user = await fastapi_users.verify_user(non_verified_user) +assert verified_user.is_verified is True +``` + +## Get user + +Retrieve a user by e-mail. + +```py +user = await fastapi_users.get_user("king.arthur@camelot.bt") +``` diff --git a/docs/usage/routes.md b/docs/usage/routes.md index bfd40dca..46bd1719 100644 --- a/docs/usage/routes.md +++ b/docs/usage/routes.md @@ -79,9 +79,9 @@ Register a new user. Will call the `after_register` [handler](../configuration/r ### `POST /forgot-password` -Request a reset password procedure. Will generate a temporary token and call the `after_forgot_password` [handlers](../configuration/routers/reset.md#after-forgot-password) if the user exists. +Request a reset password procedure. Will generate a temporary token and call the `after_forgot_password` [handler](../configuration/routers/reset.md#after-forgot-password) if the user exists. -To prevent malicious users from guessing existing users in your databse, the route will always return a `202 Accepted` response, even if the user requested does not exist. +To prevent malicious users from guessing existing users in your database, the route will always return a `202 Accepted` response, even if the user requested does not exist. !!! abstract "Payload" ```json @@ -117,6 +117,68 @@ Reset a password. Requires the token generated by the `/forgot-password` route. } ``` +## Verify router + +!!! warning + This feature is not released yet. + +### `POST /request-verify-token` + +Request a user to verify their e-mail. Will generate a temporary token and call the `after_verification_request` [handler](../configuration/routers/verify.md#after-verification-request) if the user exists. + +To prevent malicious users from guessing existing users in your database, the route will always return a `202 Accepted` response, even if the user requested does not exist. + +!!! abstract "Payload" + ```json + { + "email": "king.arthur@camelot.bt" + } + ``` + +!!! success "`202 Accepted`" + +### `POST /verify` + +Verify a user. Requires the token generated by the `/request-verify-token` route. Will call the call the `after_verification` [handler](../configuration/routers/verify.md#after-verification) on success. + +!!! abstract "Payload" + ```json + { + "token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoiOTIyMWZmYzktNjQwZi00MzcyLTg2ZDMtY2U2NDJjYmE1NjAzIiwiYXVkIjoiZmFzdGFwaS11c2VyczphdXRoIiwiZXhwIjoxNTcxNTA0MTkzfQ.M10bjOe45I5Ncu_uXvOmVV8QxnL-nZfcH96U90JaocI" + } + ``` + +!!! success "`200 OK`" + +!!! fail "`422 Validation Error`" + +!!! fail "`400 Bad Request`" + Expired token. + + ```json + { + "detail": "VERIFY_USER_TOKEN_EXPIRED" + } + ``` + +!!! fail "`400 Bad Request`" + Bad token, not existing user or not the e-mail currently set for the user. + + ```json + { + "detail": "VERIFY_USER_BAD_TOKEN" + } + ``` + +!!! fail "`400 Bad Request`" + The user is already verified. + + ```json + { + "detail": "VERIFY_USER_ALREADY_VERIFIED" + } + ``` + ## OAuth router Each OAuth router you define will expose the two following routes. diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index 7ee8eef1..c7874ccd 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -75,6 +75,13 @@ class Authenticator: return None return user + @with_signature(signature, func_name="get_optional_current_verified_user") + async def get_optional_current_verified_user(*args, **kwargs): + user = await get_optional_current_active_user(*args, **kwargs) + if not user or not user.is_verified: + return None + return user + @with_signature(signature, func_name="get_optional_current_superuser") async def get_optional_current_superuser(*args, **kwargs): user = await get_optional_current_active_user(*args, **kwargs) @@ -82,6 +89,13 @@ class Authenticator: return None return user + @with_signature(signature, func_name="get_optional_current_verified_superuser") + async def get_optional_current_verified_superuser(*args, **kwargs): + user = await get_optional_current_verified_user(*args, **kwargs) + if not user or not user.is_superuser: + return None + return user + @with_signature(signature, func_name="get_current_user") async def get_current_user(*args, **kwargs): user = await get_optional_current_user(*args, **kwargs) @@ -96,6 +110,13 @@ class Authenticator: raise self._get_credentials_exception() return user + @with_signature(signature, func_name="get_current_verified_user") + async def get_current_verified_user(*args, **kwargs): + user = await get_optional_current_verified_user(*args, **kwargs) + if user is None: + raise self._get_credentials_exception() + return user + @with_signature(signature, func_name="get_current_superuser") async def get_current_superuser(*args, **kwargs): user = await get_optional_current_active_user(*args, **kwargs) @@ -105,12 +126,27 @@ class Authenticator: raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN) return user + @with_signature(signature, func_name="get_current_verified_superuser") + async def get_current_verified_superuser(*args, **kwargs): + user = await get_optional_current_verified_user(*args, **kwargs) + if user is None: + raise self._get_credentials_exception() + if not user.is_superuser: + raise self._get_credentials_exception(status.HTTP_403_FORBIDDEN) + return user + self.get_current_user = get_current_user self.get_current_active_user = get_current_active_user + self.get_current_verified_user = get_current_verified_user self.get_current_superuser = get_current_superuser + self.get_current_verified_superuser = get_current_verified_superuser self.get_optional_current_user = get_optional_current_user self.get_optional_current_active_user = get_optional_current_active_user + self.get_optional_current_verified_user = get_optional_current_verified_user self.get_optional_current_superuser = get_optional_current_superuser + self.get_optional_current_verified_superuser = ( + get_optional_current_verified_superuser + ) async def _authenticate(self, *args, **kwargs) -> Optional[BaseUserDB]: for backend in self.backends: diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index 1dfba7a5..afd61796 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -57,6 +57,7 @@ class SQLAlchemyBaseUserTable: hashed_password = Column(String(length=72), nullable=False) is_active = Column(Boolean, default=True, nullable=False) is_superuser = Column(Boolean, default=False, nullable=False) + is_verified = Column(Boolean, default=False, nullable=False) class SQLAlchemyBaseOAuthAccountTable: diff --git a/fastapi_users/db/tortoise.py b/fastapi_users/db/tortoise.py index 09e9151e..2f3021aa 100644 --- a/fastapi_users/db/tortoise.py +++ b/fastapi_users/db/tortoise.py @@ -14,6 +14,7 @@ class TortoiseBaseUserModel(models.Model): hashed_password = fields.CharField(null=False, max_length=255) is_active = fields.BooleanField(default=True, null=False) is_superuser = fields.BooleanField(default=False, null=False) + is_verified = fields.BooleanField(default=False, null=False) async def to_dict(self): d = {} diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py index 51d93c4d..d889866f 100644 --- a/fastapi_users/fastapi_users.py +++ b/fastapi_users/fastapi_users.py @@ -10,8 +10,16 @@ from fastapi_users.router import ( get_register_router, get_reset_password_router, get_users_router, + get_verify_router, +) +from fastapi_users.user import ( + CreateUserProtocol, + GetUserProtocol, + VerifyUserProtocol, + get_create_user, + get_get_user, + get_verify_user, ) -from fastapi_users.user import CreateUserProtocol, get_create_user try: from httpx_oauth.oauth2 import BaseOAuth2 @@ -35,12 +43,16 @@ class FastAPIUsers: :attribute create_user: Helper function to create a user programmatically. :attribute get_current_user: Dependency callable to inject authenticated user. :attribute get_current_active_user: Dependency callable to inject active user. + :attribute get_current_verified_user: Dependency callable to inject verified user. :attribute get_current_superuser: Dependency callable to inject superuser. + :attribute get_current_verified_superuser: Dependency callable to inject verified superuser. """ db: BaseUserDatabase authenticator: Authenticator create_user: CreateUserProtocol + verify_user: VerifyUserProtocol + get_user: GetUserProtocol _user_model: Type[models.BaseUser] _user_create_model: Type[models.BaseUserCreate] _user_update_model: Type[models.BaseUserUpdate] @@ -65,17 +77,29 @@ class FastAPIUsers: self._user_db_model = user_db_model self.create_user = get_create_user(db, user_db_model) + self.verify_user = get_verify_user(db) + self.get_user = get_get_user(db) self.get_current_user = self.authenticator.get_current_user self.get_current_active_user = self.authenticator.get_current_active_user + self.get_current_verified_user = self.authenticator.get_current_verified_user self.get_current_superuser = self.authenticator.get_current_superuser + self.get_current_verified_superuser = ( + self.authenticator.get_current_verified_superuser + ) self.get_optional_current_user = self.authenticator.get_optional_current_user self.get_optional_current_active_user = ( self.authenticator.get_optional_current_active_user ) + self.get_optional_current_verified_user = ( + self.authenticator.get_optional_current_verified_user + ) self.get_optional_current_superuser = ( self.authenticator.get_optional_current_superuser ) + self.get_optional_current_verified_superuser = ( + self.authenticator.get_optional_current_verified_superuser + ) def get_register_router( self, @@ -94,6 +118,31 @@ class FastAPIUsers: after_register, ) + def get_verify_router( + self, + verification_token_secret: str, + verification_token_lifetime_seconds: int = 3600, + after_verification_request: Optional[ + Callable[[models.UD, str, Request], None] + ] = None, + after_verification: Optional[Callable[[models.UD, Request], None]] = None, + ) -> APIRouter: + """ + Return a router with a register route. + + :param after_register: Optional function called + after a successful registration. + """ + return get_verify_router( + self.verify_user, + self.get_user, + self._user_model, + verification_token_secret, + verification_token_lifetime_seconds, + after_verification_request, + after_verification, + ) + def get_reset_password_router( self, reset_password_token_secret: str, @@ -117,13 +166,20 @@ class FastAPIUsers: after_forgot_password, ) - def get_auth_router(self, backend: BaseAuthentication) -> APIRouter: + def get_auth_router( + self, backend: BaseAuthentication, requires_verification: bool = False + ) -> APIRouter: """ Return an auth router for a given authentication backend. :param backend: The authentication backend instance. """ - return get_auth_router(backend, self.db, self.authenticator) + return get_auth_router( + backend, + self.db, + self.authenticator, + requires_verification, + ) def get_oauth_router( self, @@ -157,6 +213,7 @@ class FastAPIUsers: after_update: Optional[ Callable[[models.UD, Dict[str, Any], Request], None] ] = None, + requires_verification: bool = False, ) -> APIRouter: """ Return a router with routes to manage users. @@ -171,4 +228,5 @@ class FastAPIUsers: self._user_db_model, self.authenticator, after_update, + requires_verification, ) diff --git a/fastapi_users/models.py b/fastapi_users/models.py index f23912bf..2e509117 100644 --- a/fastapi_users/models.py +++ b/fastapi_users/models.py @@ -8,7 +8,13 @@ class CreateUpdateDictModel(BaseModel): def create_update_dict(self): return self.dict( exclude_unset=True, - exclude={"id", "is_superuser", "is_active", "oauth_accounts"}, + exclude={ + "id", + "is_superuser", + "is_active", + "is_verified", + "oauth_accounts", + }, ) def create_update_dict_superuser(self): @@ -22,6 +28,7 @@ class BaseUser(CreateUpdateDictModel): email: Optional[EmailStr] = None is_active: Optional[bool] = True is_superuser: Optional[bool] = False + is_verified: Optional[bool] = False @validator("id", pre=True, always=True) def default_id(cls, v): @@ -33,6 +40,7 @@ class BaseUserCreate(CreateUpdateDictModel): password: str is_active: Optional[bool] = True is_superuser: Optional[bool] = False + is_verified: Optional[bool] = False class BaseUserUpdate(BaseUser): diff --git a/fastapi_users/router/__init__.py b/fastapi_users/router/__init__.py index 162641e7..7469b9b6 100644 --- a/fastapi_users/router/__init__.py +++ b/fastapi_users/router/__init__.py @@ -3,6 +3,7 @@ from fastapi_users.router.common import ErrorCode # noqa: F401 from fastapi_users.router.register import get_register_router # noqa: F401 from fastapi_users.router.reset import get_reset_password_router # noqa: F401 from fastapi_users.router.users import get_users_router # noqa: F401 +from fastapi_users.router.verify import get_verify_router # noqa: F401 try: from fastapi_users.router.oauth import get_oauth_router # noqa: F401 diff --git a/fastapi_users/router/auth.py b/fastapi_users/router/auth.py index 83678072..67826386 100644 --- a/fastapi_users/router/auth.py +++ b/fastapi_users/router/auth.py @@ -11,9 +11,14 @@ def get_auth_router( backend: BaseAuthentication, user_db: BaseUserDatabase[models.BaseUserDB], authenticator: Authenticator, + requires_verification: bool = False, ) -> APIRouter: """Generate a router with login/logout routes for an authentication backend.""" router = APIRouter() + if requires_verification: + get_current_user = authenticator.get_current_verified_user + else: + get_current_user = authenticator.get_current_active_user @router.post("/login") async def login( @@ -26,15 +31,17 @@ def get_auth_router( status_code=status.HTTP_400_BAD_REQUEST, detail=ErrorCode.LOGIN_BAD_CREDENTIALS, ) - + if requires_verification and not user.is_verified: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.LOGIN_USER_NOT_VERIFIED, + ) return await backend.get_login_response(user, response) if backend.logout: @router.post("/logout") - async def logout( - response: Response, user=Depends(authenticator.get_current_active_user) - ): + async def logout(response: Response, user=Depends(get_current_user)): return await backend.get_logout_response(user, response) return router diff --git a/fastapi_users/router/common.py b/fastapi_users/router/common.py index fc2394f3..d6a6266d 100644 --- a/fastapi_users/router/common.py +++ b/fastapi_users/router/common.py @@ -5,7 +5,11 @@ from typing import Callable class ErrorCode: REGISTER_USER_ALREADY_EXISTS = "REGISTER_USER_ALREADY_EXISTS" LOGIN_BAD_CREDENTIALS = "LOGIN_BAD_CREDENTIALS" + LOGIN_USER_NOT_VERIFIED = "LOGIN_USER_NOT_VERIFIED" RESET_PASSWORD_BAD_TOKEN = "RESET_PASSWORD_BAD_TOKEN" + VERIFY_USER_BAD_TOKEN = "VERIFY_USER_BAD_TOKEN" + VERIFY_USER_ALREADY_VERIFIED = "VERIFY_USER_ALREADY_VERIFIED" + VERIFY_USER_TOKEN_EXPIRED = "VERIFY_USER_TOKEN_EXPIRED" async def run_handler(handler: Callable, *args, **kwargs): diff --git a/fastapi_users/router/users.py b/fastapi_users/router/users.py index a1107687..753b22b7 100644 --- a/fastapi_users/router/users.py +++ b/fastapi_users/router/users.py @@ -17,12 +17,17 @@ def get_users_router( user_db_model: Type[models.BaseUserDB], authenticator: Authenticator, after_update: Optional[Callable[[models.UD, Dict[str, Any], Request], None]] = None, + requires_verification: bool = False, ) -> APIRouter: """Generate a router with the authentication routes.""" router = APIRouter() - get_current_active_user = authenticator.get_current_active_user - get_current_superuser = authenticator.get_current_superuser + if requires_verification: + get_current_active_user = authenticator.get_current_verified_user + get_current_superuser = authenticator.get_current_verified_superuser + else: + get_current_active_user = authenticator.get_current_active_user + get_current_superuser = authenticator.get_current_superuser async def _get_or_404(id: UUID4) -> models.BaseUserDB: user = await user_db.get(id) diff --git a/fastapi_users/router/verify.py b/fastapi_users/router/verify.py new file mode 100644 index 00000000..7da8cedb --- /dev/null +++ b/fastapi_users/router/verify.py @@ -0,0 +1,127 @@ +from typing import Callable, Optional, Type, cast + +import jwt +from fastapi import APIRouter, Body, HTTPException, Request, status +from pydantic import UUID4, EmailStr + +from fastapi_users import models +from fastapi_users.router.common import ErrorCode, run_handler +from fastapi_users.user import ( + GetUserProtocol, + UserAlreadyVerified, + UserNotExists, + VerifyUserProtocol, +) +from fastapi_users.utils import JWT_ALGORITHM, generate_jwt + +VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify" + + +def get_verify_router( + verify_user: VerifyUserProtocol, + get_user: GetUserProtocol, + user_model: Type[models.BaseUser], + verification_token_secret: str, + verification_token_lifetime_seconds: int = 3600, + after_verification_request: Optional[ + Callable[[models.UD, str, Request], None] + ] = None, + after_verification: Optional[Callable[[models.UD, Request], None]] = None, +): + router = APIRouter() + + @router.post("/request-verify-token", status_code=status.HTTP_202_ACCEPTED) + async def request_verify_token( + request: Request, email: EmailStr = Body(..., embed=True) + ): + try: + user = await get_user(email) + if user.is_verified: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_ALREADY_VERIFIED, + ) + elif user.is_active: + token_data = { + "user_id": str(user.id), + "email": email, + "aud": VERIFY_USER_TOKEN_AUDIENCE, + } + token = generate_jwt( + token_data, + verification_token_lifetime_seconds, + verification_token_secret, + ) + + if after_verification_request: + await run_handler(after_verification_request, user, token, request) + except UserNotExists: + pass + + return None + + @router.post("/verify", response_model=user_model) + async def verify(request: Request, token: str = Body(..., embed=True)): + try: + data = jwt.decode( + token, + verification_token_secret, + audience=VERIFY_USER_TOKEN_AUDIENCE, + algorithms=[JWT_ALGORITHM], + ) + except jwt.exceptions.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_TOKEN_EXPIRED, + ) + except jwt.PyJWTError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_BAD_TOKEN, + ) + + user_id = data.get("user_id") + email = cast(EmailStr, data.get("email")) + + if user_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_BAD_TOKEN, + ) + + try: + user_check = await get_user(email) + except UserNotExists: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_BAD_TOKEN, + ) + + try: + user_uuid = UUID4(user_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_BAD_TOKEN, + ) + + if user_check.id != user_uuid: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_BAD_TOKEN, + ) + + try: + user = await verify_user(user_check) + except UserAlreadyVerified: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ErrorCode.VERIFY_USER_ALREADY_VERIFIED, + ) + + if after_verification: + await run_handler(after_verification, user, request) + + return user + + return router diff --git a/fastapi_users/user.py b/fastapi_users/user.py index 8bec4a8e..161dadb8 100644 --- a/fastapi_users/user.py +++ b/fastapi_users/user.py @@ -5,6 +5,8 @@ try: except ImportError: from typing_extensions import Protocol # type: ignore +from pydantic import EmailStr + from fastapi_users import models from fastapi_users.db import BaseUserDatabase from fastapi_users.password import get_password_hash @@ -14,9 +16,21 @@ class UserAlreadyExists(Exception): pass +class UserNotExists(Exception): + pass + + +class UserAlreadyVerified(Exception): + pass + + class CreateUserProtocol(Protocol): # pragma: no cover def __call__( - self, user: models.BaseUserCreate, safe: bool = False + self, + user: models.BaseUserCreate, + safe: bool = False, + is_active: bool = None, + is_verified: bool = None, ) -> Awaitable[models.BaseUserDB]: pass @@ -26,7 +40,10 @@ def get_create_user( user_db_model: Type[models.BaseUserDB], ) -> CreateUserProtocol: async def create_user( - user: models.BaseUserCreate, safe: bool = False + user: models.BaseUserCreate, + safe: bool = False, + is_active: bool = None, + is_verified: bool = None, ) -> models.BaseUserDB: existing_user = await user_db.get_by_email(user.email) @@ -41,3 +58,43 @@ def get_create_user( return await user_db.create(db_user) return create_user + + +class VerifyUserProtocol(Protocol): # pragma: no cover + def __call__(self, user: models.BaseUserDB) -> Awaitable[models.BaseUserDB]: + pass + + +def get_verify_user( + user_db: BaseUserDatabase[models.BaseUserDB], +) -> VerifyUserProtocol: + async def verify_user(user: models.BaseUserDB) -> models.BaseUserDB: + if user.is_verified: + raise UserAlreadyVerified() + + user.is_verified = True + return await user_db.update(user) + + return verify_user + + +class GetUserProtocol(Protocol): # pragma: no cover + def __call__(self, user_email: EmailStr) -> Awaitable[models.BaseUserDB]: + pass + + +def get_get_user( + user_db: BaseUserDatabase[models.BaseUserDB], +) -> GetUserProtocol: + async def get_user(user_email: EmailStr) -> models.BaseUserDB: + if not (user_email == EmailStr(user_email)): + raise UserNotExists() + + user = await user_db.get_by_email(user_email) + + if user is None: + raise UserNotExists() + + return user + + return get_user diff --git a/mkdocs.yml b/mkdocs.yml index 956f95f7..16931a71 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,6 +44,7 @@ nav: - configuration/routers/register.md - configuration/routers/reset.md - configuration/routers/users.md + - configuration/routers/verify.md - configuration/full_example.md - configuration/oauth.md - Usage: diff --git a/tests/conftest.py b/tests/conftest.py index 6f4e87bb..8bccb78d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ guinevere_password_hash = get_password_hash("guinevere") angharad_password_hash = get_password_hash("angharad") viviane_password_hash = get_password_hash("viviane") lancelot_password_hash = get_password_hash("lancelot") +excalibur_password_hash = get_password_hash("excalibur") class User(models.BaseUser): @@ -89,6 +90,26 @@ def inactive_user_oauth(oauth_account3) -> UserDBOAuth: ) +@pytest.fixture +def verified_user() -> UserDB: + return UserDB( + email="lake.lady@camelot.bt", + hashed_password=excalibur_password_hash, + is_active=True, + is_verified=True, + ) + + +@pytest.fixture +def verified_user_oauth(oauth_account4) -> UserDBOAuth: + return UserDBOAuth( + email="lake.lady@camelot.bt", + hashed_password=excalibur_password_hash, + is_active=False, + oauth_accounts=[oauth_account4], + ) + + @pytest.fixture def superuser() -> UserDB: return UserDB( @@ -108,6 +129,27 @@ def superuser_oauth() -> UserDBOAuth: ) +@pytest.fixture +def verified_superuser() -> UserDB: + return UserDB( + email="the.real.merlin@camelot.bt", + hashed_password=viviane_password_hash, + is_superuser=True, + is_verified=True, + ) + + +@pytest.fixture +def verified_superuser_oauth() -> UserDBOAuth: + return UserDBOAuth( + email="the.real.merlin@camelot.bt", + hashed_password=viviane_password_hash, + is_superuser=True, + is_verified=True, + oauth_accounts=[], + ) + + @pytest.fixture def oauth_account1() -> BaseOAuthAccount: return BaseOAuthAccount( @@ -142,25 +184,57 @@ def oauth_account3() -> BaseOAuthAccount: @pytest.fixture -def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase: +def oauth_account4() -> BaseOAuthAccount: + return BaseOAuthAccount( + oauth_name="service4", + access_token="TOKEN", + expires_at=1579000751, + account_id="verified_user_oauth1", + account_email="lake.lady@camelot.bt", + ) + + +@pytest.fixture +def oauth_account5() -> BaseOAuthAccount: + return BaseOAuthAccount( + oauth_name="service5", + access_token="TOKEN", + expires_at=1579000751, + account_id="verified_superuser_oauth1", + account_email="the.real.merlin@camelot.bt", + ) + + +@pytest.fixture +def mock_user_db( + user, verified_user, inactive_user, superuser, verified_superuser +) -> BaseUserDatabase: class MockUserDatabase(BaseUserDatabase[UserDB]): async def get(self, id: UUID4) -> Optional[UserDB]: if id == user.id: return user + if id == verified_user.id: + return verified_user if id == inactive_user.id: return inactive_user if id == superuser.id: return superuser + if id == verified_superuser.id: + return verified_superuser return None async def get_by_email(self, email: str) -> Optional[UserDB]: lower_email = email.lower() if lower_email == user.email.lower(): return user + if lower_email == verified_user.email.lower(): + return verified_user if lower_email == inactive_user.email.lower(): return inactive_user if lower_email == superuser.email.lower(): return superuser + if lower_email == verified_superuser.email.lower(): + return verified_superuser return None async def create(self, user: UserDB) -> UserDB: @@ -177,26 +251,38 @@ def mock_user_db(user, inactive_user, superuser) -> BaseUserDatabase: @pytest.fixture def mock_user_db_oauth( - user_oauth, inactive_user_oauth, superuser_oauth + user_oauth, + verified_user_oauth, + inactive_user_oauth, + superuser_oauth, + verified_superuser_oauth, ) -> BaseUserDatabase: class MockUserDatabase(BaseUserDatabase[UserDBOAuth]): async def get(self, id: UUID4) -> Optional[UserDBOAuth]: if id == user_oauth.id: return user_oauth + if id == verified_user_oauth.id: + return verified_user_oauth if id == inactive_user_oauth.id: return inactive_user_oauth if id == superuser_oauth.id: return superuser_oauth + if id == verified_superuser_oauth.id: + return verified_superuser_oauth return None async def get_by_email(self, email: str) -> Optional[UserDBOAuth]: lower_email = email.lower() if lower_email == user_oauth.email.lower(): return user_oauth + if lower_email == verified_user_oauth.email.lower(): + return verified_user_oauth if lower_email == inactive_user_oauth.email.lower(): return inactive_user_oauth if lower_email == superuser_oauth.email.lower(): return superuser_oauth + if lower_email == verified_superuser_oauth.email.lower(): + return verified_superuser_oauth return None async def get_by_oauth_account( diff --git a/tests/test_fastapi_users.py b/tests/test_fastapi_users.py index dfe8b913..8bf5cd0d 100644 --- a/tests/test_fastapi_users.py +++ b/tests/test_fastapi_users.py @@ -28,6 +28,7 @@ async def test_app_client( app.include_router(fastapi_users.get_auth_router(mock_authentication)) app.include_router(fastapi_users.get_oauth_router(oauth_client, "SECRET")) app.include_router(fastapi_users.get_users_router(), prefix="/users") + app.include_router(fastapi_users.get_verify_router("SECRET")) @app.get("/current-user") def current_user(user=Depends(fastapi_users.get_current_user)): @@ -37,10 +38,20 @@ async def test_app_client( def current_active_user(user=Depends(fastapi_users.get_current_active_user)): return user + @app.get("/current-verified-user") + def current_verified_user(user=Depends(fastapi_users.get_current_verified_user)): + return user + @app.get("/current-superuser") def current_superuser(user=Depends(fastapi_users.get_current_superuser)): return user + @app.get("/current-verified-superuser") + def current_verified_superuser( + user=Depends(fastapi_users.get_current_verified_superuser), + ): + return user + @app.get("/optional-current-user") def optional_current_user(user=Depends(fastapi_users.get_optional_current_user)): return user @@ -51,12 +62,24 @@ async def test_app_client( ): return user + @app.get("/optional-current-verified-user") + def optional_current_verified_user( + user=Depends(fastapi_users.get_optional_current_verified_user), + ): + return user + @app.get("/optional-current-superuser") def optional_current_superuser( user=Depends(fastapi_users.get_optional_current_superuser), ): return user + @app.get("/optional-current-verified-superuser") + def optional_current_verified_superuser( + user=Depends(fastapi_users.get_optional_current_verified_superuser), + ): + return user + async for client in get_test_client(app): yield client @@ -71,6 +94,18 @@ class TestRoutes: status.HTTP_405_METHOD_NOT_ALLOWED, ) + response = await test_app_client.post("/request-verify-token") + assert response.status_code not in ( + status.HTTP_404_NOT_FOUND, + status.HTTP_405_METHOD_NOT_ALLOWED, + ) + + response = await test_app_client.post("/verify") + assert response.status_code not in ( + status.HTTP_404_NOT_FOUND, + status.HTTP_405_METHOD_NOT_ALLOWED, + ) + response = await test_app_client.post("/forgot-password") assert response.status_code not in ( status.HTTP_404_NOT_FOUND, @@ -157,6 +192,38 @@ class TestGetCurrentActiveUser: assert response.status_code == status.HTTP_200_OK +@pytest.mark.fastapi_users +@pytest.mark.asyncio +class TestGetCurrentVerifiedUser: + async def test_missing_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get("/current-verified-user") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_invalid_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get( + "/current-verified-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_valid_token_unverified_user( + self, test_app_client: httpx.AsyncClient, user: UserDB + ): + response = await test_app_client.get( + "/current-verified-user", + headers={"Authorization": f"Bearer {user.id}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_valid_token_verified_user( + self, test_app_client: httpx.AsyncClient, verified_user: UserDB + ): + response = await test_app_client.get( + "/current-verified-user", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + + @pytest.mark.fastapi_users @pytest.mark.asyncio class TestGetCurrentSuperuser: @@ -187,6 +254,56 @@ class TestGetCurrentSuperuser: assert response.status_code == status.HTTP_200_OK +@pytest.mark.fastapi_users +@pytest.mark.asyncio +class TestGetCurrentVerifiedSuperuser: + async def test_missing_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get("/current-verified-superuser") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_invalid_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get( + "/current-verified-superuser", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_valid_token_regular_user( + self, test_app_client: httpx.AsyncClient, user: UserDB + ): + response = await test_app_client.get( + "/current-verified-superuser", + headers={"Authorization": f"Bearer {user.id}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_valid_token_verified_user( + self, test_app_client: httpx.AsyncClient, verified_user: UserDB + ): + response = await test_app_client.get( + "/current-verified-superuser", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_valid_token_superuser( + self, test_app_client: httpx.AsyncClient, superuser: UserDB + ): + response = await test_app_client.get( + "/current-verified-superuser", + headers={"Authorization": f"Bearer {superuser.id}"}, + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + async def test_valid_token_verified_superuser( + self, test_app_client: httpx.AsyncClient, verified_superuser: UserDB + ): + response = await test_app_client.get( + "/current-verified-superuser", + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + + @pytest.mark.fastapi_users @pytest.mark.asyncio class TestOptionalGetCurrentUser: @@ -210,6 +327,42 @@ class TestOptionalGetCurrentUser: assert response.json() is not None +@pytest.mark.fastapi_users +@pytest.mark.asyncio +class TestOptionalGetCurrentVerifiedUser: + async def test_missing_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get("/optional-current-verified-user") + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_invalid_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get( + "/optional-current-verified-user", headers={"Authorization": "Bearer foo"} + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_valid_token_unverified_user( + self, test_app_client: httpx.AsyncClient, user: UserDB + ): + response = await test_app_client.get( + "/optional-current-verified-user", + headers={"Authorization": f"Bearer {user.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_valid_token_verified_user( + self, test_app_client: httpx.AsyncClient, verified_user: UserDB + ): + response = await test_app_client.get( + "/optional-current-verified-user", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is not None + + @pytest.mark.fastapi_users @pytest.mark.asyncio class TestOptionalGetCurrentActiveUser: @@ -278,3 +431,60 @@ class TestOptionalGetCurrentSuperuser: ) assert response.status_code == status.HTTP_200_OK assert response.json() is not None + + +@pytest.mark.fastapi_users +@pytest.mark.asyncio +class TestOptionalGetCurrentVerifiedSuperuser: + async def test_missing_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get("/optional-current-verified-superuser") + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_invalid_token(self, test_app_client: httpx.AsyncClient): + response = await test_app_client.get( + "/optional-current-verified-superuser", + headers={"Authorization": "Bearer foo"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_valid_token_regular_user( + self, test_app_client: httpx.AsyncClient, user: UserDB + ): + response = await test_app_client.get( + "/optional-current-verified-superuser", + headers={"Authorization": f"Bearer {user.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_valid_token_verified_user( + self, test_app_client: httpx.AsyncClient, verified_user: UserDB + ): + response = await test_app_client.get( + "/optional-current-verified-superuser", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_valid_token_superuser( + self, test_app_client: httpx.AsyncClient, superuser: UserDB + ): + response = await test_app_client.get( + "/optional-current-verified-superuser", + headers={"Authorization": f"Bearer {superuser.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is None + + async def test_valid_token_verified_superuser( + self, test_app_client: httpx.AsyncClient, verified_superuser: UserDB + ): + response = await test_app_client.get( + "/optional-current-verified-superuser", + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + assert response.json() is not None diff --git a/tests/test_router_auth.py b/tests/test_router_auth.py index 9f8c49d5..21a05214 100644 --- a/tests/test_router_auth.py +++ b/tests/test_router_auth.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, cast +from typing import Any, AsyncGenerator, Dict, Tuple, cast import httpx import pytest @@ -10,56 +10,102 @@ from tests.conftest import MockAuthentication, UserDB @pytest.fixture +def app_factory(mock_user_db, mock_authentication): + def _app_factory(requires_verification: bool) -> FastAPI: + mock_authentication_bis = MockAuthentication(name="mock-bis") + authenticator = Authenticator( + [mock_authentication, mock_authentication_bis], mock_user_db + ) + + mock_auth_router = get_auth_router( + mock_authentication, + mock_user_db, + authenticator, + requires_verification=requires_verification, + ) + mock_bis_auth_router = get_auth_router( + mock_authentication_bis, + mock_user_db, + authenticator, + requires_verification=requires_verification, + ) + + app = FastAPI() + app.include_router(mock_auth_router, prefix="/mock") + app.include_router(mock_bis_auth_router, prefix="/mock-bis") + + return app + + return _app_factory + + +@pytest.fixture( + params=[True, False], ids=["required_verification", "not_required_verification"] +) @pytest.mark.asyncio async def test_app_client( - mock_user_db, mock_authentication, get_test_client -) -> AsyncGenerator[httpx.AsyncClient, None]: - mock_authentication_bis = MockAuthentication(name="mock-bis") - authenticator = Authenticator( - [mock_authentication, mock_authentication_bis], mock_user_db - ) - - mock_auth_router = get_auth_router(mock_authentication, mock_user_db, authenticator) - mock_bis_auth_router = get_auth_router( - mock_authentication_bis, mock_user_db, authenticator - ) - - app = FastAPI() - app.include_router(mock_auth_router, prefix="/mock") - app.include_router(mock_bis_auth_router, prefix="/mock-bis") + request, get_test_client, app_factory +) -> AsyncGenerator[Tuple[httpx.AsyncClient, bool], None]: + requires_verification = request.param + app = app_factory(requires_verification) async for client in get_test_client(app): - yield client + yield client, requires_verification @pytest.mark.router @pytest.mark.parametrize("path", ["/mock/login", "/mock-bis/login"]) @pytest.mark.asyncio class TestLogin: - async def test_empty_body(self, path, test_app_client: httpx.AsyncClient): - response = await test_app_client.post(path, data={}) + async def test_empty_body( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client + response = await client.post(path, data={}) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - async def test_missing_username(self, path, test_app_client: httpx.AsyncClient): + async def test_missing_username( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client data = {"password": "guinevere"} - response = await test_app_client.post(path, data=data) + response = await client.post(path, data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - async def test_missing_password(self, path, test_app_client: httpx.AsyncClient): + async def test_missing_password( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client data = {"username": "king.arthur@camelot.bt"} - response = await test_app_client.post(path, data=data) + response = await client.post(path, data=data) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - async def test_not_existing_user(self, path, test_app_client: httpx.AsyncClient): + async def test_not_existing_user( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client data = {"username": "lancelot@camelot.bt", "password": "guinevere"} - response = await test_app_client.post(path, data=data) + response = await client.post(path, data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS - async def test_wrong_password(self, path, test_app_client: httpx.AsyncClient): + async def test_wrong_password( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client data = {"username": "king.arthur@camelot.bt", "password": "percival"} - response = await test_app_client.post(path, data=data) + response = await client.post(path, data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS @@ -67,17 +113,46 @@ class TestLogin: @pytest.mark.parametrize( "email", ["king.arthur@camelot.bt", "King.Arthur@camelot.bt"] ) - async def test_valid_credentials( - self, path, email, test_app_client: httpx.AsyncClient, user: UserDB + async def test_valid_credentials_unverified( + self, + path, + email, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, ): + client, requires_verification = test_app_client data = {"username": email, "password": "guinevere"} - response = await test_app_client.post(path, data=data) - assert response.status_code == status.HTTP_200_OK - assert response.json() == {"token": str(user.id)} + response = await client.post(path, data=data) + if requires_verification: + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.LOGIN_USER_NOT_VERIFIED + else: + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"token": str(user.id)} - async def test_inactive_user(self, path, test_app_client: httpx.AsyncClient): + @pytest.mark.parametrize("email", ["lake.lady@camelot.bt", "Lake.Lady@camelot.bt"]) + async def test_valid_credentials_verified( + self, + path, + email, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + ): + client, _ = test_app_client + data = {"username": email, "password": "excalibur"} + response = await client.post(path, data=data) + assert response.status_code == status.HTTP_200_OK + assert response.json() == {"token": str(verified_user.id)} + + async def test_inactive_user( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client data = {"username": "percival@camelot.bt", "password": "angharad"} - response = await test_app_client.post(path, data=data) + response = await client.post(path, data=data) assert response.status_code == status.HTTP_400_BAD_REQUEST data = cast(Dict[str, Any], response.json()) assert data["detail"] == ErrorCode.LOGIN_BAD_CREDENTIALS @@ -87,14 +162,40 @@ class TestLogin: @pytest.mark.parametrize("path", ["/mock/logout", "/mock-bis/logout"]) @pytest.mark.asyncio class TestLogout: - async def test_missing_token(self, path, test_app_client: httpx.AsyncClient): - response = await test_app_client.post(path) + async def test_missing_token( + self, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + ): + client, _ = test_app_client + response = await client.post(path) assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_valid_credentials( - self, mocker, path, test_app_client: httpx.AsyncClient, user: UserDB + async def test_valid_credentials_unverified( + self, + mocker, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, ): - response = await test_app_client.post( + client, requires_verification = test_app_client + response = await client.post( path, headers={"Authorization": f"Bearer {user.id}"} ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + async def test_valid_credentials_verified( + self, + mocker, + path, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + ): + client, _ = test_app_client + response = await client.post( + path, headers={"Authorization": f"Bearer {verified_user.id}"} + ) assert response.status_code == status.HTTP_200_OK diff --git a/tests/test_router_reset.py b/tests/test_router_reset.py index e533f925..374c9816 100644 --- a/tests/test_router_reset.py +++ b/tests/test_router_reset.py @@ -191,7 +191,7 @@ class TestResetPassword: user: UserDB, ): mocker.spy(mock_user_db, "update") - current_hashed_passord = user.hashed_password + current_hashed_password = user.hashed_password json = {"token": forgot_password_token(user.id), "password": "holygrail"} response = await test_app_client.post("/reset-password", json=json) @@ -199,4 +199,4 @@ class TestResetPassword: assert mock_user_db.update.called is True updated_user = mock_user_db.update.call_args[0][0] - assert updated_user.hashed_password != current_hashed_passord + assert updated_user.hashed_password != current_hashed_password diff --git a/tests/test_router_users.py b/tests/test_router_users.py index ffe1f99c..e273963c 100644 --- a/tests/test_router_users.py +++ b/tests/test_router_users.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, cast +from typing import Any, AsyncGenerator, Dict, Tuple, cast from unittest.mock import MagicMock import asynctest @@ -28,101 +28,328 @@ def after_update(request): @pytest.fixture +def app_factory(mock_user_db, mock_authentication, after_update): + def _app_factory(requires_verification: bool) -> FastAPI: + mock_authentication_bis = MockAuthentication(name="mock-bis") + authenticator = Authenticator( + [mock_authentication, mock_authentication_bis], mock_user_db + ) + + user_router = get_users_router( + mock_user_db, + User, + UserUpdate, + UserDB, + authenticator, + after_update, + requires_verification=requires_verification, + ) + + app = FastAPI() + app.include_router(user_router) + + return app + + return _app_factory + + +@pytest.fixture( + params=[True, False], ids=["required_verification", "not_required_verification"] +) @pytest.mark.asyncio async def test_app_client( - mock_user_db, mock_authentication, after_update, get_test_client -) -> AsyncGenerator[httpx.AsyncClient, None]: - mock_authentication_bis = MockAuthentication(name="mock-bis") - authenticator = Authenticator( - [mock_authentication, mock_authentication_bis], mock_user_db - ) - - user_router = get_users_router( - mock_user_db, - User, - UserUpdate, - UserDB, - authenticator, - after_update, - ) - - app = FastAPI() - app.include_router(user_router) + request, get_test_client, app_factory +) -> AsyncGenerator[Tuple[httpx.AsyncClient, bool], None]: + requires_verification = request.param + app = app_factory(requires_verification) async for client in get_test_client(app): - yield client + yield client, requires_verification @pytest.mark.router @pytest.mark.asyncio class TestMe: - async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.get("/me") + async def test_missing_token(self, test_app_client: Tuple[httpx.AsyncClient, bool]): + client, _ = test_app_client + response = await client.get("/me") assert response.status_code == status.HTTP_401_UNAUTHORIZED async def test_inactive_user( - self, test_app_client: httpx.AsyncClient, inactive_user: UserDB + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + inactive_user: UserDB, ): - response = await test_app_client.get( + client, _ = test_app_client + response = await client.get( "/me", headers={"Authorization": f"Bearer {inactive_user.id}"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_active_user(self, test_app_client: httpx.AsyncClient, user: UserDB): - response = await test_app_client.get( + async def test_active_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + ): + client, requires_verification = test_app_client + response = await client.get( "/me", headers={"Authorization": f"Bearer {user.id}"} ) - assert response.status_code == status.HTTP_200_OK + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + data = cast(Dict[str, Any], response.json()) + assert data["id"] == str(user.id) + assert data["email"] == user.email + async def test_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + ): + client, _ = test_app_client + response = await client.get( + "/me", headers={"Authorization": f"Bearer {verified_user.id}"} + ) + assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) - assert data["id"] == str(user.id) - assert data["email"] == user.email + assert data["id"] == str(verified_user.id) + assert data["email"] == verified_user.email @pytest.mark.router @pytest.mark.asyncio class TestUpdateMe: async def test_missing_token( - self, test_app_client: httpx.AsyncClient, after_update + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + after_update, ): - response = await test_app_client.patch("/me") + client, _ = test_app_client + response = await client.patch("/me") assert response.status_code == status.HTTP_401_UNAUTHORIZED assert after_update.called is False async def test_inactive_user( - self, test_app_client: httpx.AsyncClient, inactive_user: UserDB, after_update + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + inactive_user: UserDB, + after_update, ): - response = await test_app_client.patch( + client, _ = test_app_client + response = await client.patch( "/me", headers={"Authorization": f"Bearer {inactive_user.id}"} ) assert response.status_code == status.HTTP_401_UNAUTHORIZED assert after_update.called is False async def test_empty_body( - self, test_app_client: httpx.AsyncClient, user: UserDB, after_update + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + after_update, ): - response = await test_app_client.patch( + client, requires_verification = test_app_client + response = await client.patch( "/me", json={}, headers={"Authorization": f"Bearer {user.id}"} ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert after_update.called is False + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["email"] == user.email + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_valid_body( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + after_update, + ): + client, requires_verification = test_app_client + json = {"email": "king.arthur@tintagel.bt"} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert after_update.called is False + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["email"] == "king.arthur@tintagel.bt" + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {"email": "king.arthur@tintagel.bt"} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_valid_body_is_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + after_update, + ): + client, requires_verification = test_app_client + json = {"is_superuser": True} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert after_update.called is False + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_superuser"] is False + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_valid_body_is_active( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + after_update, + ): + client, requires_verification = test_app_client + json = {"is_active": False} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert after_update.called is False + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_active"] is True + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_valid_body_is_verified( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + after_update, + ): + client, requires_verification = test_app_client + json = {"is_verified": True} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert after_update.called is False + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_verified"] is False + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_valid_body_password( + self, + mocker, + mock_user_db, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + after_update, + ): + client, requires_verification = test_app_client + mocker.spy(mock_user_db, "update") + current_hashed_password = user.hashed_password + + json = {"password": "merlin"} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert after_update.called is False + else: + assert response.status_code == status.HTTP_200_OK + assert mock_user_db.update.called is True + + updated_user = mock_user_db.update.call_args[0][0] + assert updated_user.hashed_password != current_hashed_password + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {"password": "merlin"} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_empty_body_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + after_update, + ): + client, _ = test_app_client + response = await client.patch( + "/me", json={}, headers={"Authorization": f"Bearer {verified_user.id}"} + ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) - assert data["email"] == user.email + assert data["email"] == verified_user.email assert after_update.called is True actual_user = after_update.call_args[0][0] - assert actual_user.id == user.id + assert actual_user.id == verified_user.id updated_fields = after_update.call_args[0][1] assert updated_fields == {} request = after_update.call_args[0][2] assert isinstance(request, Request) - async def test_valid_body( - self, test_app_client: httpx.AsyncClient, user: UserDB, after_update + async def test_valid_body_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + after_update, ): + client, _ = test_app_client json = {"email": "king.arthur@tintagel.bt"} - response = await test_app_client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} ) assert response.status_code == status.HTTP_200_OK @@ -131,18 +358,22 @@ class TestUpdateMe: assert after_update.called is True actual_user = after_update.call_args[0][0] - assert actual_user.id == user.id + assert actual_user.id == verified_user.id updated_fields = after_update.call_args[0][1] assert updated_fields == {"email": "king.arthur@tintagel.bt"} request = after_update.call_args[0][2] assert isinstance(request, Request) - async def test_valid_body_is_superuser( - self, test_app_client: httpx.AsyncClient, user: UserDB, after_update + async def test_valid_body_is_superuser_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + after_update, ): + client, _ = test_app_client json = {"is_superuser": True} - response = await test_app_client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} ) assert response.status_code == status.HTTP_200_OK @@ -151,18 +382,22 @@ class TestUpdateMe: assert after_update.called is True actual_user = after_update.call_args[0][0] - assert actual_user.id == user.id + assert actual_user.id == verified_user.id updated_fields = after_update.call_args[0][1] assert updated_fields == {} request = after_update.call_args[0][2] assert isinstance(request, Request) - async def test_valid_body_is_active( - self, test_app_client: httpx.AsyncClient, user: UserDB, after_update + async def test_valid_body_is_active_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + after_update, ): + client, _ = test_app_client json = {"is_active": False} - response = await test_app_client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} ) assert response.status_code == status.HTTP_200_OK @@ -171,36 +406,61 @@ class TestUpdateMe: assert after_update.called is True actual_user = after_update.call_args[0][0] - assert actual_user.id == user.id + assert actual_user.id == verified_user.id updated_fields = after_update.call_args[0][1] assert updated_fields == {} request = after_update.call_args[0][2] assert isinstance(request, Request) - async def test_valid_body_password( + async def test_valid_body_is_verified_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + after_update, + ): + client, _ = test_app_client + json = {"is_verified": False} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} + ) + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_verified"] is True + + assert after_update.called is True + actual_user = after_update.call_args[0][0] + assert actual_user.id == verified_user.id + updated_fields = after_update.call_args[0][1] + assert updated_fields == {} + request = after_update.call_args[0][2] + assert isinstance(request, Request) + + async def test_valid_body_password_verified_user( self, mocker, mock_user_db, - test_app_client: httpx.AsyncClient, - user: UserDB, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, after_update, ): + client, _ = test_app_client mocker.spy(mock_user_db, "update") - current_hashed_passord = user.hashed_password + current_hashed_password = verified_user.hashed_password json = {"password": "merlin"} - response = await test_app_client.patch( - "/me", json=json, headers={"Authorization": f"Bearer {user.id}"} + response = await client.patch( + "/me", json=json, headers={"Authorization": f"Bearer {verified_user.id}"} ) assert response.status_code == status.HTTP_200_OK assert mock_user_db.update.called is True updated_user = mock_user_db.update.call_args[0][0] - assert updated_user.hashed_password != current_hashed_passord + assert updated_user.hashed_password != current_hashed_password assert after_update.called is True actual_user = after_update.call_args[0][0] - assert actual_user.id == user.id + assert actual_user.id == verified_user.id updated_fields = after_update.call_args[0][1] assert updated_fields == {"password": "merlin"} request = after_update.call_args[0][2] @@ -210,32 +470,94 @@ class TestUpdateMe: @pytest.mark.router @pytest.mark.asyncio class TestGetUser: - async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.get("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") + async def test_missing_token(self, test_app_client: Tuple[httpx.AsyncClient, bool]): + client, _ = test_app_client + response = await client.get("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB): - response = await test_app_client.get( + async def test_regular_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + ): + client, requires_verification = test_app_client + response = await client.get( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", headers={"Authorization": f"Bearer {user.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + ): + client, _ = test_app_client + response = await client.get( + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) assert response.status_code == status.HTTP_403_FORBIDDEN - async def test_not_existing_user( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + async def test_not_existing_user_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + superuser: UserDB, ): - response = await test_app_client.get( + client, requires_verification = test_app_client + response = await client.get( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_404_NOT_FOUND + + async def test_not_existing_user_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_superuser: UserDB, + ): + client, _ = test_app_client + response = await client.get( + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_404_NOT_FOUND async def test_superuser( - self, test_app_client: httpx.AsyncClient, user: UserDB, superuser: UserDB + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + superuser: UserDB, ): - response = await test_app_client.get( + client, requires_verification = test_app_client + response = await client.get( f"/{user.id}", headers={"Authorization": f"Bearer {superuser.id}"} ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["id"] == str(user.id) + assert "hashed_password" not in data + + async def test_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + response = await client.get( + f"/{user.id}", headers={"Authorization": f"Bearer {verified_superuser.id}"} + ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) @@ -246,140 +568,410 @@ class TestGetUser: @pytest.mark.router @pytest.mark.asyncio class TestUpdateUser: - async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.patch("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") + async def test_missing_token(self, test_app_client: Tuple[httpx.AsyncClient, bool]): + client, _ = test_app_client + response = await client.patch("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB): - response = await test_app_client.patch( + async def test_regular_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + ): + client, requires_verification = test_app_client + response = await client.patch( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", headers={"Authorization": f"Bearer {user.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + ): + client, _ = test_app_client + response = await client.patch( + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) assert response.status_code == status.HTTP_403_FORBIDDEN - async def test_not_existing_user( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + async def test_not_existing_user_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + superuser: UserDB, ): - response = await test_app_client.patch( + client, requires_verification = test_app_client + response = await client.patch( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", json={}, headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_404_NOT_FOUND + + async def test_not_existing_user_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_superuser: UserDB, + ): + client, _ = test_app_client + response = await client.patch( + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + json={}, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_404_NOT_FOUND - async def test_empty_body( - self, test_app_client: httpx.AsyncClient, user: UserDB, superuser: UserDB + async def test_empty_body_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + superuser: UserDB, ): - response = await test_app_client.patch( + client, requires_verification = test_app_client + response = await client.patch( f"/{user.id}", json={}, headers={"Authorization": f"Bearer {superuser.id}"} ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["email"] == user.email + + async def test_empty_body_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + response = await client.patch( + f"/{user.id}", + json={}, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) assert data["email"] == user.email - async def test_valid_body( - self, test_app_client: httpx.AsyncClient, user: UserDB, superuser: UserDB + async def test_valid_body_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + superuser: UserDB, ): + client, requires_verification = test_app_client json = {"email": "king.arthur@tintagel.bt"} - response = await test_app_client.patch( + response = await client.patch( f"/{user.id}", json=json, headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["email"] == "king.arthur@tintagel.bt" + + async def test_valid_body_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + json = {"email": "king.arthur@tintagel.bt"} + response = await client.patch( + f"/{user.id}", + json=json, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) assert data["email"] == "king.arthur@tintagel.bt" - async def test_valid_body_is_superuser( - self, test_app_client: httpx.AsyncClient, user: UserDB, superuser: UserDB + async def test_valid_body_is_superuser_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + superuser: UserDB, ): + client, requires_verification = test_app_client json = {"is_superuser": True} - response = await test_app_client.patch( + response = await client.patch( f"/{user.id}", json=json, headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_superuser"] is True + + async def test_valid_body_is_superuser_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + json = {"is_superuser": True} + response = await client.patch( + f"/{user.id}", + json=json, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) assert data["is_superuser"] is True - async def test_valid_body_is_active( - self, test_app_client: httpx.AsyncClient, user: UserDB, superuser: UserDB + async def test_valid_body_is_active_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + superuser: UserDB, ): + client, requires_verification = test_app_client json = {"is_active": False} - response = await test_app_client.patch( + response = await client.patch( f"/{user.id}", json=json, headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_active"] is False + + async def test_valid_body_is_active_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + json = {"is_active": False} + response = await client.patch( + f"/{user.id}", + json=json, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_200_OK data = cast(Dict[str, Any], response.json()) assert data["is_active"] is False - async def test_valid_body_password( + async def test_valid_body_is_verified_unverified_superuser( self, - mocker, - mock_user_db, - test_app_client: httpx.AsyncClient, + test_app_client: Tuple[httpx.AsyncClient, bool], user: UserDB, superuser: UserDB, ): - mocker.spy(mock_user_db, "update") - current_hashed_passord = user.hashed_password - - json = {"password": "merlin"} - response = await test_app_client.patch( + client, requires_verification = test_app_client + json = {"is_verified": True} + response = await client.patch( f"/{user.id}", json=json, headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_verified"] is True + + async def test_valid_body_is_verified_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + json = {"is_verified": True} + response = await client.patch( + f"/{user.id}", + json=json, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) + assert response.status_code == status.HTTP_200_OK + + data = cast(Dict[str, Any], response.json()) + assert data["is_verified"] is True + + async def test_valid_body_password_unverified_superuser( + self, + mocker, + mock_user_db, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + superuser: UserDB, + ): + client, requires_verification = test_app_client + mocker.spy(mock_user_db, "update") + current_hashed_password = user.hashed_password + + json = {"password": "merlin"} + response = await client.patch( + f"/{user.id}", + json=json, + headers={"Authorization": f"Bearer {superuser.id}"}, + ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_200_OK + assert mock_user_db.update.called is True + + updated_user = mock_user_db.update.call_args[0][0] + assert updated_user.hashed_password != current_hashed_password + + async def test_valid_body_password_verified_superuser( + self, + mocker, + mock_user_db, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + mocker.spy(mock_user_db, "update") + current_hashed_password = user.hashed_password + + json = {"password": "merlin"} + response = await client.patch( + f"/{user.id}", + json=json, + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_200_OK assert mock_user_db.update.called is True updated_user = mock_user_db.update.call_args[0][0] - assert updated_user.hashed_password != current_hashed_passord + assert updated_user.hashed_password != current_hashed_password @pytest.mark.router @pytest.mark.asyncio class TestDeleteUser: - async def test_missing_token(self, test_app_client: httpx.AsyncClient): - response = await test_app_client.delete("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") + async def test_missing_token(self, test_app_client: Tuple[httpx.AsyncClient, bool]): + client, _ = test_app_client + response = await client.delete("/d35d213e-f3d8-4f08-954a-7e0d1bea286f") assert response.status_code == status.HTTP_401_UNAUTHORIZED - async def test_regular_user(self, test_app_client: httpx.AsyncClient, user: UserDB): - response = await test_app_client.delete( + async def test_regular_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + ): + client, requires_verification = test_app_client + response = await client.delete( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", headers={"Authorization": f"Bearer {user.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_403_FORBIDDEN + + async def test_verified_user( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_user: UserDB, + ): + client, _ = test_app_client + response = await client.delete( + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {verified_user.id}"}, + ) assert response.status_code == status.HTTP_403_FORBIDDEN - async def test_not_existing_user( - self, test_app_client: httpx.AsyncClient, superuser: UserDB + async def test_not_existing_user_unverified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + superuser: UserDB, ): - response = await test_app_client.delete( + client, requires_verification = test_app_client + response = await client.delete( "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", headers={"Authorization": f"Bearer {superuser.id}"}, ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_404_NOT_FOUND + + async def test_not_existing_user_verified_superuser( + self, + test_app_client: Tuple[httpx.AsyncClient, bool], + verified_superuser: UserDB, + ): + client, _ = test_app_client + response = await client.delete( + "/d35d213e-f3d8-4f08-954a-7e0d1bea286f", + headers={"Authorization": f"Bearer {verified_superuser.id}"}, + ) assert response.status_code == status.HTTP_404_NOT_FOUND - async def test_superuser( + async def test_unverified_superuser( self, mocker, mock_user_db, - test_app_client: httpx.AsyncClient, + test_app_client: Tuple[httpx.AsyncClient, bool], user: UserDB, superuser: UserDB, ): + client, requires_verification = test_app_client mocker.spy(mock_user_db, "delete") - response = await test_app_client.delete( + response = await client.delete( f"/{user.id}", headers={"Authorization": f"Bearer {superuser.id}"} ) + if requires_verification: + assert response.status_code == status.HTTP_401_UNAUTHORIZED + else: + assert response.status_code == status.HTTP_204_NO_CONTENT + assert response.json() is None + assert mock_user_db.delete.called is True + + deleted_user = mock_user_db.delete.call_args[0][0] + assert deleted_user.id == user.id + + async def test_verified_superuser( + self, + mocker, + mock_user_db, + test_app_client: Tuple[httpx.AsyncClient, bool], + user: UserDB, + verified_superuser: UserDB, + ): + client, _ = test_app_client + mocker.spy(mock_user_db, "delete") + + response = await client.delete( + f"/{user.id}", headers={"Authorization": f"Bearer {verified_superuser.id}"} + ) assert response.status_code == status.HTTP_204_NO_CONTENT assert response.json() is None assert mock_user_db.delete.called is True diff --git a/tests/test_router_verify.py b/tests/test_router_verify.py new file mode 100644 index 00000000..c54f9d8c --- /dev/null +++ b/tests/test_router_verify.py @@ -0,0 +1,335 @@ +from typing import Any, AsyncGenerator, Dict, cast +from unittest.mock import MagicMock + +import asynctest +import httpx +import pytest +from fastapi import FastAPI, status + +from fastapi_users.router import ErrorCode, get_verify_router +from fastapi_users.user import get_get_user, get_verify_user +from fastapi_users.utils import generate_jwt +from tests.conftest import User, UserDB + +SECRET = "SECRET" +LIFETIME = 3600 +VERIFY_USER_TOKEN_AUDIENCE = "fastapi-users:verify" +JWT_ALGORITHM = "HS256" + + +@pytest.fixture +def verify_token(): + def _verify_token(user_id=None, email=None, lifetime=LIFETIME): + data = {"aud": VERIFY_USER_TOKEN_AUDIENCE} + if user_id is not None: + data["user_id"] = str(user_id) + if email is not None: + data["email"] = email + return generate_jwt(data, lifetime, SECRET, JWT_ALGORITHM) + + return _verify_token + + +def after_verification_sync(): + return MagicMock(return_value=None) + + +def after_verification_async(): + return asynctest.CoroutineMock(return_value=None) + + +@pytest.fixture(params=[after_verification_sync, after_verification_async]) +def after_verification(request): + return request.param() + + +def after_verification_request_sync(): + return MagicMock(return_value=None) + + +def after_verification_request_async(): + return asynctest.CoroutineMock(return_value=None) + + +@pytest.fixture( + params=[after_verification_request_sync, after_verification_request_async] +) +def after_verification_request(request): + return request.param() + + +@pytest.fixture +@pytest.mark.asyncio +async def test_app_client( + mock_user_db, + after_verification_request, + after_verification, + get_test_client, +) -> AsyncGenerator[httpx.AsyncClient, None]: + verify_user = get_verify_user(mock_user_db) + get_user = get_get_user(mock_user_db) + verify_router = get_verify_router( + verify_user, + get_user, + User, + SECRET, + LIFETIME, + after_verification_request, + after_verification, + ) + + app = FastAPI() + app.include_router(verify_router) + + async for client in get_test_client(app): + yield client + + +@pytest.mark.router +@pytest.mark.asyncio +class TestVerifyTokenRequest: + async def test_empty_body( + self, + test_app_client: httpx.AsyncClient, + after_verification_request, + ): + response = await test_app_client.post("/request-verify-token", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert after_verification_request.called is False + + async def test_wrong_email( + self, + test_app_client: httpx.AsyncClient, + after_verification_request, + ): + json = {"email": "king.arthur"} + response = await test_app_client.post("/request-verify-token", json=json) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert after_verification_request.called is False + + async def test_user_not_exists( + self, + test_app_client: httpx.AsyncClient, + after_verification_request, + ): + json = {"email": "user@example.com"} + response = await test_app_client.post("/request-verify-token", json=json) + assert response.status_code == status.HTTP_202_ACCEPTED + assert after_verification_request.called is False + + async def test_user_verified_valid_request( + self, + test_app_client: httpx.AsyncClient, + verified_user: UserDB, + after_verification_request, + ): + input_user = verified_user + json = {"email": input_user.email} + response = await test_app_client.post("/request-verify-token", json=json) + assert after_verification_request.called is False + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_ALREADY_VERIFIED + + async def test_user_inactive_valid_request( + self, + test_app_client: httpx.AsyncClient, + inactive_user: UserDB, + after_verification_request, + ): + input_user = inactive_user + json = {"email": input_user.email} + response = await test_app_client.post("/request-verify-token", json=json) + assert after_verification_request.called is False + assert response.status_code == status.HTTP_202_ACCEPTED + + async def test_user_active_valid_request( + self, + test_app_client: httpx.AsyncClient, + user: UserDB, + after_verification_request, + ): + input_user = user + json = {"email": input_user.email} + response = await test_app_client.post("/request-verify-token", json=json) + assert response.status_code == status.HTTP_202_ACCEPTED + assert after_verification_request.called is True + + +@pytest.mark.router +@pytest.mark.asyncio +class TestVerify: + async def test_empty_body( + self, + test_app_client: httpx.AsyncClient, + after_verification_request, + after_verification, + ): + response = await test_app_client.post("/verify", json={}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_invalid_token( + self, + test_app_client: httpx.AsyncClient, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": "foo"} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_BAD_TOKEN + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_valid_token_missing_user_id( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(None, user.email)} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_BAD_TOKEN + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_valid_token_missing_email( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(user.id, None)} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_BAD_TOKEN + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_valid_token_invalid_uuid( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token("foo", user.email)} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_BAD_TOKEN + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_valid_token_invalid_email( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(user.id, "foo")} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_BAD_TOKEN + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_valid_token_email_id_mismatch( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + inactive_user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(user.id, inactive_user.email)} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_BAD_TOKEN + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_expired_token( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(user.id, user.email, -1)} + response = await test_app_client.post("/verify", json=json) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_TOKEN_EXPIRED + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_inactive_user( + self, + test_app_client: httpx.AsyncClient, + verify_token, + inactive_user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(inactive_user.id, inactive_user.email)} + response = await test_app_client.post("/verify", json=json) + + assert response.status_code == status.HTTP_200_OK + assert after_verification.called is True + assert after_verification_request.called is False + data = cast(Dict[str, Any], response.json()) + assert data["is_active"] is False + + async def test_verified_user( + self, + test_app_client: httpx.AsyncClient, + verify_token, + verified_user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(verified_user.id, verified_user.email)} + response = await test_app_client.post("/verify", json=json) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = cast(Dict[str, Any], response.json()) + assert data["detail"] == ErrorCode.VERIFY_USER_ALREADY_VERIFIED + + assert after_verification.called is False + assert after_verification_request.called is False + + async def test_active_user( + self, + test_app_client: httpx.AsyncClient, + verify_token, + user: UserDB, + after_verification_request, + after_verification, + ): + json = {"token": verify_token(user.id, user.email)} + response = await test_app_client.post("/verify", json=json) + + assert response.status_code == status.HTTP_200_OK + assert after_verification.called is True + assert after_verification_request.called is False + data = cast(Dict[str, Any], response.json()) + assert data["is_active"] is True diff --git a/tests/test_user.py b/tests/test_user.py index b21e14f3..16443ffb 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -1,6 +1,13 @@ import pytest -from fastapi_users.user import CreateUserProtocol, UserAlreadyExists, get_create_user +from fastapi_users.user import ( + CreateUserProtocol, + UserAlreadyExists, + UserAlreadyVerified, + VerifyUserProtocol, + get_create_user, + get_verify_user, +) from tests.conftest import UserCreate, UserDB @@ -11,7 +18,6 @@ def create_user( return get_create_user(mock_user_db, UserDB) -@pytest.mark.router @pytest.mark.asyncio class TestCreateUser: @pytest.mark.parametrize( @@ -45,3 +51,21 @@ class TestCreateUser: created_user = await create_user(user, safe) assert type(created_user) == UserDB assert created_user.is_active is result + + +@pytest.fixture +def verify_user( + mock_user_db, +) -> VerifyUserProtocol: + return get_verify_user(mock_user_db) + + +@pytest.mark.asyncio +class TestVerifyUser: + async def test_already_verified_user(self, verify_user, verified_user): + with pytest.raises(UserAlreadyVerified): + await verify_user(verified_user) + + async def test_non_verified_user(self, verify_user, user): + user = await verify_user(user) + assert user.is_verified