added samesite option for cookie authentication (#337)

* added samesite option for cookie authentication

* formatted with black and added documentation (grabbed from starlette.io documentation)
This commit is contained in:
SelfhostedPro
2020-09-29 22:27:25 -07:00
committed by GitHub
parent c5f53b73d1
commit 8adce74cd9
20 changed files with 112 additions and 32 deletions

View File

@@ -25,6 +25,7 @@ You can also define the parameters for the generated cookie:
* `cookie_domain` (`None`): Cookie domain.
* `cookie_secure` (`True`): Whether to only send the cookie to the server via SSL request.
* `cookie_httponly` (`True`): Whether to prevent access to the cookie via JavaScript.
* `cookie_samesite` (`lax`): A string that specifies the samesite strategy for the cookie. Valid values are 'lax', 'strict' and 'none'. Defaults to 'lax'.
!!! tip
You can also optionally define the `name`. It's useful in the case you wish to have several backends of the same class. Each backend should have a unique name. **Defaults to `cookie`**.

View File

@@ -46,7 +46,12 @@ jwt_authentication = JWTAuthentication(
app = FastAPI()
fastapi_users = FastAPIUsers(
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
user_db,
[jwt_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app.include_router(
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]

View File

@@ -57,7 +57,12 @@ jwt_authentication = JWTAuthentication(
app = FastAPI()
fastapi_users = FastAPIUsers(
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
user_db,
[jwt_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app.include_router(
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]

View File

@@ -46,7 +46,12 @@ jwt_authentication = JWTAuthentication(
)
fastapi_users = FastAPIUsers(
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
user_db,
[jwt_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app.include_router(
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]

View File

@@ -50,7 +50,12 @@ jwt_authentication = JWTAuthentication(
app = FastAPI()
fastapi_users = FastAPIUsers(
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
user_db,
[jwt_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app.include_router(
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]

View File

@@ -70,7 +70,12 @@ jwt_authentication = JWTAuthentication(
app = FastAPI()
fastapi_users = FastAPIUsers(
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
user_db,
[jwt_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app.include_router(
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]

View File

@@ -59,7 +59,12 @@ jwt_authentication = JWTAuthentication(
)
fastapi_users = FastAPIUsers(
user_db, [jwt_authentication], User, UserCreate, UserUpdate, UserDB,
user_db,
[jwt_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app.include_router(
fastapi_users.get_auth_router(jwt_authentication), prefix="/auth/jwt", tags=["auth"]

View File

@@ -36,6 +36,7 @@ class CookieAuthentication(BaseAuthentication[str]):
cookie_domain: Optional[str]
cookie_secure: bool
cookie_httponly: bool
cookie_samesite: str
def __init__(
self,
@@ -46,6 +47,7 @@ class CookieAuthentication(BaseAuthentication[str]):
cookie_domain: str = None,
cookie_secure: bool = True,
cookie_httponly: bool = True,
cookie_samesite: str = "lax",
name: str = "cookie",
):
super().__init__(name, logout=True)
@@ -56,10 +58,13 @@ class CookieAuthentication(BaseAuthentication[str]):
self.cookie_domain = cookie_domain
self.cookie_secure = cookie_secure
self.cookie_httponly = cookie_httponly
self.cookie_samesite = cookie_samesite
self.scheme = APIKeyCookie(name=self.cookie_name, auto_error=False)
async def __call__(
self, credentials: Optional[str], user_db: BaseUserDatabase,
self,
credentials: Optional[str],
user_db: BaseUserDatabase,
) -> Optional[BaseUserDB]:
if credentials is None:
return None
@@ -93,6 +98,7 @@ class CookieAuthentication(BaseAuthentication[str]):
domain=self.cookie_domain,
secure=self.cookie_secure,
httponly=self.cookie_httponly,
samesite=self.cookie_samesite,
)
# We shouldn't return directly the response

View File

@@ -39,7 +39,9 @@ class JWTAuthentication(BaseAuthentication[str]):
self.lifetime_seconds = lifetime_seconds
async def __call__(
self, credentials: Optional[str], user_db: BaseUserDatabase,
self,
credentials: Optional[str],
user_db: BaseUserDatabase,
) -> Optional[BaseUserDB]:
if credentials is None:
return None

View File

@@ -73,7 +73,8 @@ class FastAPIUsers:
)
def get_register_router(
self, after_register: Optional[Callable[[models.UD, Request], None]] = None,
self,
after_register: Optional[Callable[[models.UD, Request], None]] = None,
) -> APIRouter:
"""
Return a router with a register route.

View File

@@ -24,7 +24,10 @@ def generate_state_token(
def decode_state_token(token: str, secret: str) -> Dict[str, str]:
return jwt.decode(
token, secret, audience=STATE_TOKEN_AUDIENCE, algorithms=[JWT_ALGORITHM],
token,
secret,
audience=STATE_TOKEN_AUDIENCE,
algorithms=[JWT_ALGORITHM],
)
@@ -43,16 +46,20 @@ def get_oauth_router(
if redirect_url is not None:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client, redirect_url=redirect_url,
oauth_client,
redirect_url=redirect_url,
)
else:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client, route_name=callback_route_name,
oauth_client,
route_name=callback_route_name,
)
@router.get("/authorize")
async def authorize(
request: Request, authentication_backend: str, scopes: List[str] = Query(None),
request: Request,
authentication_backend: str,
scopes: List[str] = Query(None),
):
# Check that authentication_backend exists
backend_exists = False
@@ -73,7 +80,9 @@ def get_oauth_router(
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url, state, scopes,
authorize_redirect_url,
state,
scopes,
)
return {"authorization_url": authorization_url}

View File

@@ -57,7 +57,8 @@ def get_users_router(
user: user_db_model = Depends(get_current_active_user), # type: ignore
):
updated_user = cast(
models.BaseUserUpdate, updated_user,
models.BaseUserUpdate,
updated_user,
) # Prevent mypy complain
updated_user_data = updated_user.create_update_dict()
updated_user = await _update_user(user, updated_user_data, request)
@@ -81,7 +82,8 @@ def get_users_router(
id: UUID4, updated_user: user_update_model, request: Request # type: ignore
):
updated_user = cast(
models.BaseUserUpdate, updated_user,
models.BaseUserUpdate,
updated_user,
) # Prevent mypy complain
user = await _get_or_404(id)
updated_user_data = updated_user.create_update_dict_superuser()

View File

@@ -56,7 +56,8 @@ def event_loop():
@pytest.fixture
def user() -> UserDB:
return UserDB(
email="king.arthur@camelot.bt", hashed_password=guinevere_password_hash,
email="king.arthur@camelot.bt",
hashed_password=guinevere_password_hash,
)

View File

@@ -53,7 +53,8 @@ async def mongodb_user_db_oauth(get_mongodb_user_db):
@pytest.mark.db
async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
user = UserDB(
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
# Create
@@ -116,7 +117,10 @@ async def test_queries(mongodb_user_db: MongoDBUserDatabase[UserDB]):
async def test_email_query(
mongodb_user_db: MongoDBUserDatabase[UserDB], email: str, query: str, found: bool
):
user = UserDB(email=email, hashed_password=get_password_hash("guinevere"),)
user = UserDB(
email=email,
hashed_password=get_password_hash("guinevere"),
)
await mongodb_user_db.create(user)
email_user = await mongodb_user_db.get_by_email(query)

View File

@@ -71,7 +71,8 @@ async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, N
@pytest.mark.db
async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
user = UserDB(
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
# Create
@@ -121,7 +122,8 @@ async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[UserDB]):
# Exception when creating/updating a OAuth user
user_oauth = UserDBOAuth(
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
with pytest.raises(NotSetOAuthAccountTableError):
await sqlalchemy_user_db.create(user_oauth)

View File

@@ -55,7 +55,8 @@ async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None]
@pytest.mark.db
async def test_queries(tortoise_user_db: TortoiseUserDatabase[UserDB]):
user = UserDB(
email="lancelot@camelot.bt", hashed_password=get_password_hash("guinevere"),
email="lancelot@camelot.bt",
hashed_password=get_password_hash("guinevere"),
)
# Create

View File

@@ -12,7 +12,12 @@ async def test_app_client(
mock_user_db, mock_authentication, oauth_client, get_test_client
) -> httpx.AsyncClient:
fastapi_users = FastAPIUsers(
mock_user_db, [mock_authentication], User, UserCreate, UserUpdate, UserDB,
mock_user_db,
[mock_authentication],
User,
UserCreate,
UserUpdate,
UserDB,
)
app = FastAPI()

View File

@@ -82,7 +82,8 @@ class TestAuthorize:
with asynctest.patch.object(oauth_client, "get_authorization_url") as mock:
mock.return_value = "AUTHORIZATION_URL"
response = await test_app_client.get(
"/authorize", params={"scopes": ["scope1", "scope2"]},
"/authorize",
params={"scopes": ["scope1", "scope2"]},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@@ -162,7 +163,8 @@ class TestCallback:
) as get_id_email_mock:
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
response = await test_app_client.get(
"/callback", params={"code": "CODE", "state": "STATE"},
"/callback",
params={"code": "CODE", "state": "STATE"},
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
@@ -194,7 +196,8 @@ class TestCallback:
) as user_update_mock:
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
response = await test_app_client.get(
"/callback", params={"code": "CODE", "state": state_jwt},
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
@@ -232,7 +235,8 @@ class TestCallback:
superuser_oauth.email,
)
response = await test_app_client.get(
"/callback", params={"code": "CODE", "state": state_jwt},
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
@@ -269,7 +273,8 @@ class TestCallback:
"galahad@camelot.bt",
)
response = await test_app_client.get(
"/callback", params={"code": "CODE", "state": state_jwt},
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_id_email_mock.assert_awaited_once_with("TOKEN")
@@ -308,7 +313,8 @@ class TestCallback:
inactive_user_oauth.email,
)
response = await test_app_client.get(
"/callback", params={"code": "CODE", "state": state_jwt},
"/callback",
params={"code": "CODE", "state": state_jwt},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
@@ -337,7 +343,8 @@ class TestCallback:
) as get_id_email_mock:
get_id_email_mock.return_value = ("user_oauth1", user_oauth.email)
response = await test_app_client_redirect_url.get(
"/callback", params={"code": "CODE", "state": state_jwt},
"/callback",
params={"code": "CODE", "state": state_jwt},
)
get_access_token_mock.assert_awaited_once_with(

View File

@@ -32,7 +32,11 @@ async def test_app_client(
mock_user_db, mock_authentication, after_register, get_test_client
) -> httpx.AsyncClient:
register_router = get_register_router(
mock_user_db, User, UserCreate, UserDB, after_register,
mock_user_db,
User,
UserCreate,
UserDB,
after_register,
)
app = FastAPI()

View File

@@ -38,7 +38,12 @@ async def test_app_client(
)
user_router = get_users_router(
mock_user_db, User, UserUpdate, UserDB, authenticator, after_update,
mock_user_db,
User,
UserUpdate,
UserDB,
authenticator,
after_update,
)
app = FastAPI()