mirror of
https://github.com/fastapi-users/fastapi-users.git
synced 2026-03-13 07:49:55 +08:00
added samesite option for cookie authentication (#337)
* added samesite option for cookie authentication * formatted with black and added documentation (grabbed from starlette.io documentation)
This commit is contained in:
@@ -25,6 +25,7 @@ You can also define the parameters for the generated cookie:
|
||||
* `cookie_domain` (`None`): Cookie domain.
|
||||
* `cookie_secure` (`True`): Whether to only send the cookie to the server via SSL request.
|
||||
* `cookie_httponly` (`True`): Whether to prevent access to the cookie via JavaScript.
|
||||
* `cookie_samesite` (`lax`): A string that specifies the samesite strategy for the cookie. Valid values are 'lax', 'strict' and 'none'. Defaults to 'lax'.
|
||||
|
||||
!!! tip
|
||||
You can also optionally define the `name`. It's useful in the case you wish to have several backends of the same class. Each backend should have a unique name. **Defaults to `cookie`**.
|
||||
|
||||
@@ -46,7 +46,12 @@ jwt_authentication = JWTAuthentication(
|
||||
|
||||
app = FastAPI()
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]
|
||||
|
||||
@@ -57,7 +57,12 @@ jwt_authentication = JWTAuthentication(
|
||||
|
||||
app = FastAPI()
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]
|
||||
|
||||
@@ -46,7 +46,12 @@ jwt_authentication = JWTAuthentication(
|
||||
)
|
||||
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]
|
||||
|
||||
@@ -50,7 +50,12 @@ jwt_authentication = JWTAuthentication(
|
||||
|
||||
app = FastAPI()
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]
|
||||
|
||||
@@ -70,7 +70,12 @@ jwt_authentication = JWTAuthentication(
|
||||
|
||||
app = FastAPI()
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]
|
||||
|
||||
@@ -59,7 +59,12 @@ jwt_authentication = JWTAuthentication(
|
||||
)
|
||||
|
||||
fastapi_users = FastAPIUsers(
|
||||
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
user_db,
|
||||
[jwt_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]
|
||||
|
||||
@@ -36,6 +36,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
cookie_domain: Optional[str]
|
||||
cookie_secure: bool
|
||||
cookie_httponly: bool
|
||||
cookie_samesite: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,6 +47,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
cookie_domain: str = None,
|
||||
cookie_secure: bool = True,
|
||||
cookie_httponly: bool = True,
|
||||
cookie_samesite: str = "lax",
|
||||
name: str = "cookie",
|
||||
):
|
||||
super().__init__(name, logout=True)
|
||||
@@ -56,10 +58,13 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
self.cookie_domain = cookie_domain
|
||||
self.cookie_secure = cookie_secure
|
||||
self.cookie_httponly = cookie_httponly
|
||||
self.cookie_samesite = cookie_samesite
|
||||
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
|
||||
|
||||
async def __call__(
|
||||
self, credentials: Optional[str], user_db: BaseUserDatabase,
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
user_db: BaseUserDatabase,
|
||||
) -> Optional[BaseUserDB]:
|
||||
if credentials is None:
|
||||
return None
|
||||
@@ -93,6 +98,7 @@ class CookieAuthentication(BaseAuthentication[str]):
|
||||
domain=self.cookie_domain,
|
||||
secure=self.cookie_secure,
|
||||
httponly=self.cookie_httponly,
|
||||
samesite=self.cookie_samesite,
|
||||
)
|
||||
|
||||
# We shouldn't return directly the response
|
||||
|
||||
@@ -39,7 +39,9 @@ class JWTAuthentication(BaseAuthentication[str]):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
|
||||
async def __call__(
|
||||
self, credentials: Optional[str], user_db: BaseUserDatabase,
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
user_db: BaseUserDatabase,
|
||||
) -> Optional[BaseUserDB]:
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
@@ -73,7 +73,8 @@ class FastAPIUsers:
|
||||
)
|
||||
|
||||
def get_register_router(
|
||||
self, after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
self,
|
||||
after_register: Optional[Callable[[models.UD, Request], None]] = None,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Return a router with a register route.
|
||||
|
||||
@@ -24,7 +24,10 @@ def generate_state_token(
|
||||
|
||||
def decode_state_token(token: str, secret: str) -> Dict[str, str]:
|
||||
return jwt.decode(
|
||||
token, secret, audience=STATE_TOKEN_AUDIENCE, algorithms=[JWT_ALGORITHM],
|
||||
token,
|
||||
secret,
|
||||
audience=STATE_TOKEN_AUDIENCE,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
)
|
||||
|
||||
|
||||
@@ -43,16 +46,20 @@ def get_oauth_router(
|
||||
|
||||
if redirect_url is not None:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client, redirect_url=redirect_url,
|
||||
oauth_client,
|
||||
redirect_url=redirect_url,
|
||||
)
|
||||
else:
|
||||
oauth2_authorize_callback = OAuth2AuthorizeCallback(
|
||||
oauth_client, route_name=callback_route_name,
|
||||
oauth_client,
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
@router.get("/authorize")
|
||||
async def authorize(
|
||||
request: Request, authentication_backend: str, scopes: List[str] = Query(None),
|
||||
request: Request,
|
||||
authentication_backend: str,
|
||||
scopes: List[str] = Query(None),
|
||||
):
|
||||
# Check that authentication_backend exists
|
||||
backend_exists = False
|
||||
@@ -73,7 +80,9 @@ def get_oauth_router(
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url, state, scopes,
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
return {"authorization_url": authorization_url}
|
||||
|
||||
@@ -57,7 +57,8 @@ def get_users_router(
|
||||
user: user_db_model = Depends(get_current_active_user), # type: ignore
|
||||
):
|
||||
updated_user = cast(
|
||||
models.BaseUserUpdate, updated_user,
|
||||
models.BaseUserUpdate,
|
||||
updated_user,
|
||||
) # Prevent mypy complain
|
||||
updated_user_data = updated_user.create_update_dict()
|
||||
updated_user = await _update_user(user, updated_user_data, request)
|
||||
@@ -81,7 +82,8 @@ def get_users_router(
|
||||
id: UUID4, updated_user: user_update_model, request: Request # type: ignore
|
||||
):
|
||||
updated_user = cast(
|
||||
models.BaseUserUpdate, updated_user,
|
||||
models.BaseUserUpdate,
|
||||
updated_user,
|
||||
) # Prevent mypy complain
|
||||
user = await _get_or_404(id)
|
||||
updated_user_data = updated_user.create_update_dict_superuser()
|
||||
|
||||
@@ -56,7 +56,8 @@ def event_loop():
|
||||
@pytest.fixture
|
||||
def user() -> UserDB:
|
||||
return UserDB(
|
||||
email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash,
|
||||
email="king.arthur@camelot.bt",
|
||||
hashed_password=guinevere_password_hash,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ async def mongodb_user_db_oauth(get_mongodb_user_db):
|
||||
@pytest.mark.db
|
||||
async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
||||
user = UserDB(
|
||||
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||
email="lancelot@camelot.bt",
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
|
||||
# Create
|
||||
@@ -116,7 +117,10 @@ async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
|
||||
async def test_email_query(
|
||||
mongodb_user_db: MongoDBUserDatabase[UserDB], email: str, query: str, found: bool
|
||||
):
|
||||
user = UserDB(email=email, hashed_password=get_password_hash("guinevere"),)
|
||||
user = UserDB(
|
||||
email=email,
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
await mongodb_user_db.create(user)
|
||||
|
||||
email_user = await mongodb_user_db.get_by_email(query)
|
||||
|
||||
@@ -71,7 +71,8 @@ async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, N
|
||||
@pytest.mark.db
|
||||
async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
|
||||
user = UserDB(
|
||||
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||
email="lancelot@camelot.bt",
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
|
||||
# Create
|
||||
@@ -121,7 +122,8 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
|
||||
|
||||
# Exception when creating/updating a OAuth user
|
||||
user_oauth = UserDBOAuth(
|
||||
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||
email="lancelot@camelot.bt",
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
with pytest.raises(NotSetOAuthAccountTableError):
|
||||
await sqlalchemy_user_db.create(user_oauth)
|
||||
|
||||
@@ -55,7 +55,8 @@ async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None]
|
||||
@pytest.mark.db
|
||||
async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
|
||||
user = UserDB(
|
||||
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
|
||||
email="lancelot@camelot.bt",
|
||||
hashed_password=get_password_hash("guinevere"),
|
||||
)
|
||||
|
||||
# Create
|
||||
|
||||
@@ -12,7 +12,12 @@ async def test_app_client(
|
||||
mock_user_db, mock_authentication, oauth_client, get_test_client
|
||||
) -> httpx.AsyncClient:
|
||||
fastapi_users = FastAPIUsers(
|
||||
mock_user_db, [mock_authentication], User, UserCreate, UserUpdate, UserDB,
|
||||
mock_user_db,
|
||||
[mock_authentication],
|
||||
User,
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -82,7 +82,8 @@ class TestAuthorize:
|
||||
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
|
||||
mock.return_value = "AUTHORIZATION_URL"
|
||||
response = await test_app_client.get(
|
||||
"/authorize", params={"scopes": ["scope1", "scope2"]},
|
||||
"/authorize",
|
||||
params={"scopes": ["scope1", "scope2"]},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
@@ -162,7 +163,8 @@ class TestCallback:
|
||||
) as get_id_email_mock:
|
||||
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
|
||||
response = await test_app_client.get(
|
||||
"/callback", params={"code": "CODE", "state": "STATE"},
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": "STATE"},
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
@@ -194,7 +196,8 @@ class TestCallback:
|
||||
) as user_update_mock:
|
||||
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
|
||||
response = await test_app_client.get(
|
||||
"/callback", params={"code": "CODE", "state": state_jwt},
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
@@ -232,7 +235,8 @@ class TestCallback:
|
||||
superuser_oauth.email,
|
||||
)
|
||||
response = await test_app_client.get(
|
||||
"/callback", params={"code": "CODE", "state": state_jwt},
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
@@ -269,7 +273,8 @@ class TestCallback:
|
||||
"galahad@camelot.bt",
|
||||
)
|
||||
response = await test_app_client.get(
|
||||
"/callback", params={"code": "CODE", "state": state_jwt},
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
get_id_email_mock.assert_awaited_once_with("TOKEN")
|
||||
@@ -308,7 +313,8 @@ class TestCallback:
|
||||
inactive_user_oauth.email,
|
||||
)
|
||||
response = await test_app_client.get(
|
||||
"/callback", params={"code": "CODE", "state": state_jwt},
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||
@@ -337,7 +343,8 @@ class TestCallback:
|
||||
) as get_id_email_mock:
|
||||
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
|
||||
response = await test_app_client_redirect_url.get(
|
||||
"/callback", params={"code": "CODE", "state": state_jwt},
|
||||
"/callback",
|
||||
params={"code": "CODE", "state": state_jwt},
|
||||
)
|
||||
|
||||
get_access_token_mock.assert_awaited_once_with(
|
||||
|
||||
@@ -32,7 +32,11 @@ async def test_app_client(
|
||||
mock_user_db, mock_authentication, after_register, get_test_client
|
||||
) -> httpx.AsyncClient:
|
||||
register_router = get_register_router(
|
||||
mock_user_db, User, UserCreate, UserDB, after_register,
|
||||
mock_user_db,
|
||||
User,
|
||||
UserCreate,
|
||||
UserDB,
|
||||
after_register,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -38,7 +38,12 @@ async def test_app_client(
|
||||
)
|
||||
|
||||
user_router = get_users_router(
|
||||
mock_user_db, User, UserUpdate, UserDB, authenticator, after_update,
|
||||
mock_user_db,
|
||||
User,
|
||||
UserUpdate,
|
||||
UserDB,
|
||||
authenticator,
|
||||
after_update,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
Reference in New Issue
Block a user