mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2025-08-14 18:58:10 +08:00
Implement user-facing API
This commit is contained in:
6
Makefile
6
Makefile
@ -6,3 +6,9 @@ format:
|
||||
|
||||
test:
|
||||
$(PIPENV_RUN) pytest
|
||||
|
||||
docs-serve:
|
||||
$(PIPENV_RUN) mkdocs serve
|
||||
|
||||
docs-publish:
|
||||
$(PIPENV_RUN) mkdocs gh-deploy
|
||||
|
@ -0,0 +1,36 @@
|
||||
from typing import Callable, Type
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUser, BaseUserDB
|
||||
from fastapi_users.router import get_user_router
|
||||
|
||||
|
||||
class FastAPIUsers:
|
||||
"""
|
||||
Main object that ties together the component for users authentication.
|
||||
|
||||
:param db: Database adapter instance.
|
||||
:param auth: Authentication logic instance.
|
||||
:param user_model: Pydantic model of a user.
|
||||
|
||||
:attribute router: FastAPI router exposing authentication routes.
|
||||
:attribute get_current_user: Dependency callable to inject authenticated user.
|
||||
"""
|
||||
|
||||
db: BaseUserDatabase
|
||||
auth: BaseAuthentication
|
||||
router: APIRouter
|
||||
get_current_user: Callable[..., BaseUserDB]
|
||||
|
||||
def __init__(
|
||||
self, db: BaseUserDatabase, auth: BaseAuthentication, user_model: Type[BaseUser]
|
||||
):
|
||||
self.db = db
|
||||
self.auth = auth
|
||||
self.router = get_user_router(self.db, user_model, self.auth)
|
||||
|
||||
get_current_user = self.auth.get_authentication_method(self.db)
|
||||
self.get_current_user = get_current_user # type: ignore
|
||||
|
@ -7,14 +7,10 @@ from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class BaseAuthentication:
|
||||
|
||||
user_db: BaseUserDatabase
|
||||
|
||||
def __init__(self, user_db: BaseUserDatabase):
|
||||
self.user_db = user_db
|
||||
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_authentication_method(self) -> Callable[..., BaseUserDB]:
|
||||
def get_authentication_method(
|
||||
self, user_db: BaseUserDatabase
|
||||
) -> Callable[..., BaseUserDB]:
|
||||
raise NotImplementedError()
|
||||
|
@ -7,6 +7,7 @@ from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
from fastapi_users.authentication import BaseAuthentication
|
||||
from fastapi_users.db import BaseUserDatabase
|
||||
from fastapi_users.models import BaseUserDB
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
@ -25,8 +26,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
secret: str
|
||||
lifetime_seconds: int
|
||||
|
||||
def __init__(self, secret: str, lifetime_seconds: int, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, secret: str, lifetime_seconds: int):
|
||||
self.secret = secret
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
@ -36,7 +36,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
|
||||
return {"token": token}
|
||||
|
||||
def get_authentication_method(self):
|
||||
def get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
@ -50,7 +50,7 @@ class JWTAuthentication(BaseAuthentication):
|
||||
except jwt.PyJWTError:
|
||||
raise credentials_exception
|
||||
|
||||
user = await self.user_db.get(user_id)
|
||||
user = await user_db.get(user_id)
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
|
||||
|
@ -7,29 +7,36 @@ from fastapi_users.password import get_password_hash, verify_and_update_password
|
||||
|
||||
|
||||
class BaseUserDatabase:
|
||||
"""
|
||||
Common interface exposing methods to list, get, create and update users in
|
||||
the database.
|
||||
"""
|
||||
"""Base adapter for retrieving, creating and updating users from a database."""
|
||||
|
||||
async def list(self) -> List[BaseUserDB]:
|
||||
"""List all users."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get(self, id: str) -> Optional[BaseUserDB]:
|
||||
"""Get a single user by id."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[BaseUserDB]:
|
||||
"""Get a single user by email."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create(self, user: BaseUserDB) -> BaseUserDB:
|
||||
"""Create a user."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def update(self, user: BaseUserDB) -> BaseUserDB:
|
||||
"""Update a user."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[BaseUserDB]:
|
||||
"""
|
||||
Authenticate and return a user following an email and a password.
|
||||
|
||||
Will automatically upgrade password hash if necessary.
|
||||
"""
|
||||
user = await self.get_by_email(credentials.username)
|
||||
|
||||
# Always run the hasher to mitigate timing attack
|
||||
|
@ -8,6 +8,8 @@ from fastapi_users.models import BaseUserDB
|
||||
|
||||
|
||||
class BaseUserTable:
|
||||
"""Base SQLAlchemy users table definition."""
|
||||
|
||||
__tablename__ = "user"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
@ -18,6 +20,12 @@ class BaseUserTable:
|
||||
|
||||
|
||||
class SQLAlchemyUserDatabase(BaseUserDatabase):
|
||||
"""
|
||||
Database adapter for SQLAlchemy.
|
||||
|
||||
:param database: `Database` instance from `encode/databases`.
|
||||
:param users: SQLAlchemy users table instance.
|
||||
"""
|
||||
|
||||
database: Database
|
||||
users: Table
|
||||
|
@ -1,7 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from starlette import status
|
||||
from starlette.responses import Response
|
||||
|
||||
@ -62,13 +63,21 @@ class MockAuthentication(BaseAuthentication):
|
||||
async def get_login_response(self, user: BaseUserDB, response: Response):
|
||||
return {"token": user.id}
|
||||
|
||||
async def authenticate(self, token: str) -> BaseUserDB:
|
||||
user = await self.user_db.get(token)
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
|
||||
return user
|
||||
def get_authentication_method(self, user_db: BaseUserDatabase):
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/login")
|
||||
|
||||
async def authentication_method(token: str = Depends(oauth2_scheme)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED
|
||||
)
|
||||
user = await user_db.get(token)
|
||||
if user is None or not user.is_active:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
return authentication_method
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_authentication(mock_user_db) -> MockAuthentication:
|
||||
return MockAuthentication(mock_user_db)
|
||||
def mock_authentication() -> MockAuthentication:
|
||||
return MockAuthentication()
|
||||
|
@ -14,8 +14,8 @@ LIFETIME = 3600
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_authentication(mock_user_db):
|
||||
return JWTAuthentication(SECRET, LIFETIME, mock_user_db)
|
||||
def jwt_authentication():
|
||||
return JWTAuthentication(SECRET, LIFETIME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -28,12 +28,14 @@ def token():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_auth_client(jwt_authentication):
|
||||
def test_auth_client(jwt_authentication, mock_user_db):
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test-auth")
|
||||
def test_auth(
|
||||
user: BaseUserDB = Depends(jwt_authentication.get_authentication_method())
|
||||
user: BaseUserDB = Depends(
|
||||
jwt_authentication.get_authentication_method(mock_user_db)
|
||||
)
|
||||
):
|
||||
return user
|
||||
|
||||
|
54
tests/test_fastapi_users.py
Normal file
54
tests/test_fastapi_users.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from starlette import status
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from fastapi_users import FastAPIUsers
|
||||
from fastapi_users.models import BaseUser, BaseUserDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fastapi_users(mock_user_db, mock_authentication) -> FastAPIUsers:
|
||||
class User(BaseUser):
|
||||
pass
|
||||
|
||||
return FastAPIUsers(mock_user_db, mock_authentication, User)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app_client(fastapi_users: FastAPIUsers) -> TestClient:
|
||||
app = FastAPI()
|
||||
app.include_router(fastapi_users.router)
|
||||
|
||||
@app.get("/authenticated")
|
||||
def authenticated(user=Depends(fastapi_users.get_current_user)):
|
||||
return user
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestRouter:
|
||||
def test_routes_exist(self, test_app_client: TestClient):
|
||||
response = test_app_client.post("/register")
|
||||
assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||
|
||||
response = test_app_client.post("/login")
|
||||
assert response.status_code != status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
def test_missing_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get("/authenticated")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_invalid_token(self, test_app_client: TestClient):
|
||||
response = test_app_client.get(
|
||||
"/authenticated", headers={"Authorization": "Bearer foo"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
def test_valid_token(self, test_app_client: TestClient, user: BaseUserDB):
|
||||
response = test_app_client.get(
|
||||
"/authenticated", headers={"Authorization": f"Bearer {user.id}"}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
Reference in New Issue
Block a user