Use "sub" claim instead of "user_id" for JWT, verify and reset password tokens

This commit is contained in:
octicon-git-branch(16/)
octicon-tag(16/)
François Voron
2023-01-16 11:44:42 +01:00
gitea-unlock(16/)
parent 794133c4fe
commit b18389439a
octicon-diff(16/tw-mr-1) 4 changed files with 12 additions and 12 deletions

4
fastapi_users/authentication/strategy/jwt.py
View File

@@ -44,7 +44,7 @@ class JWTStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID])
data = decode_jwt( data = decode_jwt(
token, self.decode_key, self.token_audience, algorithms=[self.algorithm] token, self.decode_key, self.token_audience, algorithms=[self.algorithm]
) )
user_id = data.get("user_id") user_id = data.get("sub")
if user_id is None: if user_id is None:
return None return None
except jwt.PyJWTError: except jwt.PyJWTError:
@@ -57,7 +57,7 @@ class JWTStrategy(Strategy[models.UP, models.ID], Generic[models.UP, models.ID])
return None return None
async def write_token(self, user: models.UP) -> str: async def write_token(self, user: models.UP) -> str:
data = {"user_id": str(user.id), "aud": self.token_audience} data = {"sub": str(user.id), "aud": self.token_audience}
return generate_jwt( return generate_jwt(
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
) )

8
fastapi_users/manager.py
View File

@@ -286,7 +286,7 @@ class BaseUserManager(Generic[models.UP, models.ID]):
raise exceptions.UserAlreadyVerified() raise exceptions.UserAlreadyVerified()
token_data = { token_data = {
"user_id": str(user.id), "sub": str(user.id),
"email": user.email, "email": user.email,
"aud": self.verification_token_audience, "aud": self.verification_token_audience,
} }
@@ -322,7 +322,7 @@ class BaseUserManager(Generic[models.UP, models.ID]):
raise exceptions.InvalidVerifyToken() raise exceptions.InvalidVerifyToken()
try: try:
user_id = data["user_id"] user_id = data["sub"]
email = data["email"] email = data["email"]
except KeyError: except KeyError:
raise exceptions.InvalidVerifyToken() raise exceptions.InvalidVerifyToken()
@@ -366,7 +366,7 @@ class BaseUserManager(Generic[models.UP, models.ID]):
raise exceptions.UserInactive() raise exceptions.UserInactive()
token_data = { token_data = {
"user_id": str(user.id), "sub": str(user.id),
"password_fgpt": self.password_helper.hash(user.hashed_password), "password_fgpt": self.password_helper.hash(user.hashed_password),
"aud": self.reset_password_token_audience, "aud": self.reset_password_token_audience,
} }
@@ -404,7 +404,7 @@ class BaseUserManager(Generic[models.UP, models.ID]):
raise exceptions.InvalidResetPasswordToken() raise exceptions.InvalidResetPasswordToken()
try: try:
user_id = data["user_id"] user_id = data["sub"]
password_fingerprint = data["password_fgpt"] password_fingerprint = data["password_fgpt"]
except KeyError: except KeyError:
raise exceptions.InvalidResetPasswordToken() raise exceptions.InvalidResetPasswordToken()

4
tests/test_authentication_strategy_jwt.py
View File

@@ -79,7 +79,7 @@ def token(jwt_strategy: JWTStrategy[UserModel, IDType]):
def _token(user_id=None, lifetime=LIFETIME): def _token(user_id=None, lifetime=LIFETIME):
data = {"aud": "fastapi-users:auth"} data = {"aud": "fastapi-users:auth"}
if user_id is not None: if user_id is not None:
data["user_id"] = str(user_id) data["sub"] = str(user_id)
return generate_jwt( return generate_jwt(
data, jwt_strategy.encode_key, lifetime, algorithm=jwt_strategy.algorithm data, jwt_strategy.encode_key, lifetime, algorithm=jwt_strategy.algorithm
) )
@@ -148,7 +148,7 @@ async def test_write_token(jwt_strategy: JWTStrategy[UserModel, IDType], user):
audience=jwt_strategy.token_audience, audience=jwt_strategy.token_audience,
algorithms=[jwt_strategy.algorithm], algorithms=[jwt_strategy.algorithm],
) )
assert decoded["user_id"] == str(user.id) assert decoded["sub"] == str(user.id)
@pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True) @pytest.mark.parametrize("jwt_strategy", ["HS256", "RS256", "ES256"], indirect=True)

8
tests/test_manager.py
View File

@@ -35,7 +35,7 @@ def verify_token(user_manager: UserManagerMock[UserModel]):
): ):
data = {"aud": user_manager.verification_token_audience} data = {"aud": user_manager.verification_token_audience}
if user_id is not None: if user_id is not None:
data["user_id"] = str(user_id) data["sub"] = str(user_id)
if email is not None: if email is not None:
data["email"] = email data["email"] = email
return generate_jwt(data, user_manager.verification_token_secret, lifetime) return generate_jwt(data, user_manager.verification_token_secret, lifetime)
@@ -52,7 +52,7 @@ def forgot_password_token(user_manager: UserManagerMock[UserModel]):
): ):
data = {"aud": user_manager.reset_password_token_audience} data = {"aud": user_manager.reset_password_token_audience}
if user_id is not None: if user_id is not None:
data["user_id"] = str(user_id) data["sub"] = str(user_id)
if current_password_hash is not None: if current_password_hash is not None:
data["password_fgpt"] = user_manager.password_helper.hash( data["password_fgpt"] = user_manager.password_helper.hash(
current_password_hash current_password_hash
@@ -299,7 +299,7 @@ class TestRequestVerifyUser:
user_manager.verification_token_secret, user_manager.verification_token_secret,
audience=[user_manager.verification_token_audience], audience=[user_manager.verification_token_audience],
) )
assert decoded_token["user_id"] == str(user.id) assert decoded_token["sub"] == str(user.id)
assert decoded_token["email"] == str(user.email) assert decoded_token["email"] == str(user.email)
@@ -413,7 +413,7 @@ class TestForgotPassword:
user_manager.reset_password_token_secret, user_manager.reset_password_token_secret,
audience=[user_manager.reset_password_token_audience], audience=[user_manager.reset_password_token_audience],
) )
assert decoded_token["user_id"] == str(user.id) assert decoded_token["sub"] == str(user.id)
valid_fingerprint, _ = user_manager.password_helper.verify_and_update( valid_fingerprint, _ = user_manager.password_helper.verify_and_update(
user.hashed_password, decoded_token["password_fgpt"] user.hashed_password, decoded_token["password_fgpt"]