diff --git a/fastapi_users/__init__.py b/fastapi_users/__init__.py index 7ee793b4..76966abd 100644 --- a/fastapi_users/__init__.py +++ b/fastapi_users/__init__.py @@ -1,36 +1,2 @@ -from typing import Callable, Type - -from fastapi import APIRouter - -from fastapi_users.authentication import BaseAuthentication -from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import BaseUser, BaseUserDB -from fastapi_users.router import get_user_router - - -class FastAPIUsers: - """ - Main object that ties together the component for users authentication. - - :param db: Database adapter instance. - :param auth: Authentication logic instance. - :param user_model: Pydantic model of a user. - - :attribute router: FastAPI router exposing authentication routes. - :attribute get_current_user: Dependency callable to inject authenticated user. - """ - - db: BaseUserDatabase - auth: BaseAuthentication - router: APIRouter - get_current_user: Callable[..., BaseUserDB] - - def __init__( - self, db: BaseUserDatabase, auth: BaseAuthentication, user_model: Type[BaseUser] - ): - self.db = db - self.auth = auth - self.router = get_user_router(self.db, user_model, self.auth) - - get_current_user = self.auth.get_authentication_method(self.db) - self.get_current_user = get_current_user # type: ignore +from fastapi_users.fastapi_users import FastAPIUsers # noqa: F401 +from fastapi_users.models import BaseUser # noqa: F401 diff --git a/fastapi_users/authentication/__init__.py b/fastapi_users/authentication/__init__.py index a152b145..87ac17dc 100644 --- a/fastapi_users/authentication/__init__.py +++ b/fastapi_users/authentication/__init__.py @@ -1,18 +1,2 @@ -from typing import Callable - -from starlette.responses import Response - -from fastapi_users.db import BaseUserDatabase -from fastapi_users.models import BaseUserDB - - -class BaseAuthentication: - """Base adapter for generating and decoding authentication tokens.""" - - async def get_login_response(self, user: BaseUserDB, response: Response): - raise NotImplementedError() - - def get_authentication_method( - self, user_db: BaseUserDatabase - ) -> Callable[..., BaseUserDB]: - raise NotImplementedError() +from fastapi_users.authentication.base import BaseAuthentication # noqa: F401 +from fastapi_users.authentication.jwt import JWTAuthentication # noqa: F401 diff --git a/fastapi_users/authentication/base.py b/fastapi_users/authentication/base.py new file mode 100644 index 00000000..a152b145 --- /dev/null +++ b/fastapi_users/authentication/base.py @@ -0,0 +1,18 @@ +from typing import Callable + +from starlette.responses import Response + +from fastapi_users.db import BaseUserDatabase +from fastapi_users.models import BaseUserDB + + +class BaseAuthentication: + """Base adapter for generating and decoding authentication tokens.""" + + async def get_login_response(self, user: BaseUserDB, response: Response): + raise NotImplementedError() + + def get_authentication_method( + self, user_db: BaseUserDatabase + ) -> Callable[..., BaseUserDB]: + raise NotImplementedError() diff --git a/fastapi_users/authentication/jwt.py b/fastapi_users/authentication/jwt.py index 60f1e20f..73110ffd 100644 --- a/fastapi_users/authentication/jwt.py +++ b/fastapi_users/authentication/jwt.py @@ -6,8 +6,8 @@ from fastapi.security import OAuth2PasswordBearer from starlette import status from starlette.responses import Response -from fastapi_users.authentication import BaseAuthentication -from fastapi_users.db import BaseUserDatabase +from fastapi_users.authentication.base import BaseAuthentication +from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import BaseUserDB oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login") diff --git a/fastapi_users/db/__init__.py b/fastapi_users/db/__init__.py index 42204c8b..af515040 100644 --- a/fastapi_users/db/__init__.py +++ b/fastapi_users/db/__init__.py @@ -1,59 +1,5 @@ -from typing import List, Optional - -from fastapi.security import OAuth2PasswordRequestForm - -from fastapi_users.models import BaseUserDB -from fastapi_users.password import get_password_hash, verify_and_update_password - - -class BaseUserDatabase: - """Base adapter for retrieving, creating and updating users from a database.""" - - async def list(self) -> List[BaseUserDB]: - """List all users.""" - raise NotImplementedError() - - async def get(self, id: str) -> Optional[BaseUserDB]: - """Get a single user by id.""" - raise NotImplementedError() - - async def get_by_email(self, email: str) -> Optional[BaseUserDB]: - """Get a single user by email.""" - raise NotImplementedError() - - async def create(self, user: BaseUserDB) -> BaseUserDB: - """Create a user.""" - raise NotImplementedError() - - async def update(self, user: BaseUserDB) -> BaseUserDB: - """Update a user.""" - raise NotImplementedError() - - async def authenticate( - self, credentials: OAuth2PasswordRequestForm - ) -> Optional[BaseUserDB]: - """ - Authenticate and return a user following an email and a password. - - Will automatically upgrade password hash if necessary. - """ - user = await self.get_by_email(credentials.username) - - # Always run the hasher to mitigate timing attack - # Inspired from Django: https://code.djangoproject.com/ticket/20760 - get_password_hash(credentials.password) - - if user is None: - return None - else: - verified, updated_password_hash = verify_and_update_password( - credentials.password, user.hashed_password - ) - if not verified: - return None - # Update password hash to a more robust one if needed - if updated_password_hash is not None: - user.hashed_password = updated_password_hash - await self.update(user) - - return user +from fastapi_users.db.base import BaseUserDatabase # noqa: F401 +from fastapi_users.db.sqlalchemy import ( # noqa: F401 + SQLAlchemyBaseUserTable, + SQLAlchemyUserDatabase, +) diff --git a/fastapi_users/db/base.py b/fastapi_users/db/base.py new file mode 100644 index 00000000..42204c8b --- /dev/null +++ b/fastapi_users/db/base.py @@ -0,0 +1,59 @@ +from typing import List, Optional + +from fastapi.security import OAuth2PasswordRequestForm + +from fastapi_users.models import BaseUserDB +from fastapi_users.password import get_password_hash, verify_and_update_password + + +class BaseUserDatabase: + """Base adapter for retrieving, creating and updating users from a database.""" + + async def list(self) -> List[BaseUserDB]: + """List all users.""" + raise NotImplementedError() + + async def get(self, id: str) -> Optional[BaseUserDB]: + """Get a single user by id.""" + raise NotImplementedError() + + async def get_by_email(self, email: str) -> Optional[BaseUserDB]: + """Get a single user by email.""" + raise NotImplementedError() + + async def create(self, user: BaseUserDB) -> BaseUserDB: + """Create a user.""" + raise NotImplementedError() + + async def update(self, user: BaseUserDB) -> BaseUserDB: + """Update a user.""" + raise NotImplementedError() + + async def authenticate( + self, credentials: OAuth2PasswordRequestForm + ) -> Optional[BaseUserDB]: + """ + Authenticate and return a user following an email and a password. + + Will automatically upgrade password hash if necessary. + """ + user = await self.get_by_email(credentials.username) + + # Always run the hasher to mitigate timing attack + # Inspired from Django: https://code.djangoproject.com/ticket/20760 + get_password_hash(credentials.password) + + if user is None: + return None + else: + verified, updated_password_hash = verify_and_update_password( + credentials.password, user.hashed_password + ) + if not verified: + return None + # Update password hash to a more robust one if needed + if updated_password_hash is not None: + user.hashed_password = updated_password_hash + await self.update(user) + + return user diff --git a/fastapi_users/db/sqlalchemy.py b/fastapi_users/db/sqlalchemy.py index 8efaec0e..4d376ef6 100644 --- a/fastapi_users/db/sqlalchemy.py +++ b/fastapi_users/db/sqlalchemy.py @@ -3,11 +3,11 @@ from typing import List, cast from databases import Database from sqlalchemy import Boolean, Column, String, Table -from fastapi_users.db import BaseUserDatabase +from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import BaseUserDB -class BaseUserTable: +class SQLAlchemyBaseUserTable: """Base SQLAlchemy users table definition.""" __tablename__ = "user" diff --git a/fastapi_users/fastapi_users.py b/fastapi_users/fastapi_users.py new file mode 100644 index 00000000..a1df00b4 --- /dev/null +++ b/fastapi_users/fastapi_users.py @@ -0,0 +1,38 @@ +"""Ready-to-use and customizable users management for FastAPI.""" + +from typing import Callable, Type + +from fastapi import APIRouter + +from fastapi_users.authentication import BaseAuthentication +from fastapi_users.db import BaseUserDatabase +from fastapi_users.models import BaseUser, BaseUserDB +from fastapi_users.router import get_user_router + + +class FastAPIUsers: + """ + Main object that ties together the component for users authentication. + + :param db: Database adapter instance. + :param auth: Authentication logic instance. + :param user_model: Pydantic model of a user. + + :attribute router: FastAPI router exposing authentication routes. + :attribute get_current_user: Dependency callable to inject authenticated user. + """ + + db: BaseUserDatabase + auth: BaseAuthentication + router: APIRouter + get_current_user: Callable[..., BaseUserDB] + + def __init__( + self, db: BaseUserDatabase, auth: BaseAuthentication, user_model: Type[BaseUser] + ): + self.db = db + self.auth = auth + self.router = get_user_router(self.db, user_model, self.auth) + + get_current_user = self.auth.get_authentication_method(self.db) + self.get_current_user = get_current_user # type: ignore diff --git a/tests/test_db_sqlalchemy.py b/tests/test_db_sqlalchemy.py index aaf5641f..51da0c3f 100644 --- a/tests/test_db_sqlalchemy.py +++ b/tests/test_db_sqlalchemy.py @@ -6,14 +6,14 @@ import sqlalchemy from databases import Database from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base -from fastapi_users.db.sqlalchemy import BaseUserTable, SQLAlchemyUserDatabase +from fastapi_users.db.sqlalchemy import SQLAlchemyBaseUserTable, SQLAlchemyUserDatabase @pytest.fixture async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: Base: DeclarativeMeta = declarative_base() - class User(BaseUserTable, Base): + class User(SQLAlchemyBaseUserTable, Base): pass DATABASE_URL = "sqlite:///./test.db"