mirror of
				https://github.com/fastapi-users/fastapi-users.git
				synced 2025-10-31 17:38:30 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			150 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			150 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import dataclasses
 | |
| import uuid
 | |
| from datetime import datetime, timezone
 | |
| from typing import Any
 | |
| 
 | |
| import pytest
 | |
| 
 | |
| from fastapi_users.authentication.strategy import (
 | |
|     AccessTokenDatabase,
 | |
|     AccessTokenProtocol,
 | |
|     DatabaseStrategy,
 | |
| )
 | |
| from tests.conftest import IDType, UserModel
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class AccessTokenModel(AccessTokenProtocol[IDType]):
 | |
|     token: str
 | |
|     user_id: uuid.UUID
 | |
|     id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
 | |
|     created_at: datetime = dataclasses.field(
 | |
|         default_factory=lambda: datetime.now(timezone.utc)
 | |
|     )
 | |
| 
 | |
| 
 | |
| class AccessTokenDatabaseMock(AccessTokenDatabase[AccessTokenModel]):
 | |
|     store: dict[str, AccessTokenModel]
 | |
| 
 | |
|     def __init__(self):
 | |
|         self.store = {}
 | |
| 
 | |
|     async def get_by_token(
 | |
|         self, token: str, max_age: datetime | None = None
 | |
|     ) -> AccessTokenModel | None:
 | |
|         try:
 | |
|             access_token = self.store[token]
 | |
|             if max_age is not None and access_token.created_at < max_age:
 | |
|                 return None
 | |
|         except KeyError:
 | |
|             return None
 | |
|         else:
 | |
|             return access_token
 | |
| 
 | |
|     async def create(self, create_dict: dict[str, Any]) -> AccessTokenModel:
 | |
|         access_token = AccessTokenModel(**create_dict)
 | |
|         self.store[access_token.token] = access_token
 | |
|         return access_token
 | |
| 
 | |
|     async def update(
 | |
|         self, access_token: AccessTokenModel, update_dict: dict[str, Any]
 | |
|     ) -> AccessTokenModel:
 | |
|         for field, value in update_dict.items():
 | |
|             setattr(access_token, field, value)
 | |
|         self.store[access_token.token] = access_token
 | |
|         return access_token
 | |
| 
 | |
|     async def delete(self, access_token: AccessTokenModel) -> None:
 | |
|         try:
 | |
|             del self.store[access_token.token]
 | |
|         except KeyError:
 | |
|             pass
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def access_token_database() -> AccessTokenDatabaseMock:
 | |
|     return AccessTokenDatabaseMock()
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def database_strategy(access_token_database: AccessTokenDatabaseMock):
 | |
|     return DatabaseStrategy(access_token_database, 3600)
 | |
| 
 | |
| 
 | |
| @pytest.mark.authentication
 | |
| class TestReadToken:
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_missing_token(
 | |
|         self,
 | |
|         database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
 | |
|         user_manager,
 | |
|     ):
 | |
|         authenticated_user = await database_strategy.read_token(None, user_manager)
 | |
|         assert authenticated_user is None
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_invalid_token(
 | |
|         self,
 | |
|         database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
 | |
|         user_manager,
 | |
|     ):
 | |
|         authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
 | |
|         assert authenticated_user is None
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_valid_token_not_existing_user(
 | |
|         self,
 | |
|         database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
 | |
|         access_token_database: AccessTokenDatabaseMock,
 | |
|         user_manager,
 | |
|     ):
 | |
|         await access_token_database.create(
 | |
|             {
 | |
|                 "token": "TOKEN",
 | |
|                 "user_id": uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"),
 | |
|             }
 | |
|         )
 | |
|         authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
 | |
|         assert authenticated_user is None
 | |
| 
 | |
|     @pytest.mark.asyncio
 | |
|     async def test_valid_token(
 | |
|         self,
 | |
|         database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
 | |
|         access_token_database: AccessTokenDatabaseMock,
 | |
|         user_manager,
 | |
|         user: UserModel,
 | |
|     ):
 | |
|         await access_token_database.create({"token": "TOKEN", "user_id": user.id})
 | |
|         authenticated_user = await database_strategy.read_token("TOKEN", user_manager)
 | |
|         assert authenticated_user is not None
 | |
|         assert authenticated_user.id == user.id
 | |
| 
 | |
| 
 | |
| @pytest.mark.authentication
 | |
| @pytest.mark.asyncio
 | |
| async def test_write_token(
 | |
|     database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
 | |
|     access_token_database: AccessTokenDatabaseMock,
 | |
|     user: UserModel,
 | |
| ):
 | |
|     token = await database_strategy.write_token(user)
 | |
| 
 | |
|     access_token = await access_token_database.get_by_token(token)
 | |
|     assert access_token is not None
 | |
|     assert access_token.user_id == user.id
 | |
| 
 | |
| 
 | |
| @pytest.mark.authentication
 | |
| @pytest.mark.asyncio
 | |
| async def test_destroy_token(
 | |
|     database_strategy: DatabaseStrategy[UserModel, IDType, AccessTokenModel],
 | |
|     access_token_database: AccessTokenDatabaseMock,
 | |
|     user: UserModel,
 | |
| ):
 | |
|     await access_token_database.create({"token": "TOKEN", "user_id": user.id})
 | |
| 
 | |
|     await database_strategy.destroy_token("TOKEN", user)
 | |
| 
 | |
|     assert await access_token_database.get_by_token("TOKEN") is None
 | 
