mirror of
https://github.com/teamhanko/hanko.git
synced 2025-10-27 06:06:54 +08:00
feat: custom jwt claims
This commit is contained in:
committed by
GitHub
parent
bc9eece531
commit
efeca4a76f
@ -23,6 +23,7 @@ easily integrated into any web app with as little as two lines of code.
|
||||
- [Account linking](#account-linking)
|
||||
- [User import](#user-import)
|
||||
- [Webhooks](#webhooks)
|
||||
- [Session JWT templates](#session-jwt-templates)
|
||||
- [API specification](#api-specification)
|
||||
- [Configuration reference](#configuration-reference)
|
||||
- [License](#license)
|
||||
@ -576,6 +577,65 @@ webhooks:
|
||||
- user
|
||||
```
|
||||
|
||||
### Session JWT templates
|
||||
|
||||
You can define custom claims that will be added to session JWTs through the `session.jwt_template.claims`
|
||||
configuration option.
|
||||
|
||||
These claims are processed at JWT generation time and can include static values,
|
||||
templated strings using Go's text/template syntax, or nested structures (maps and slices).
|
||||
|
||||
The template has access to user data via the `.User` field, which includes:
|
||||
- `.User.UserID`: The user's unique ID (string)
|
||||
- `.User.Email`: Email details (optional, with `.Address`, `.IsPrimary`, `.IsVerified`)
|
||||
- `.User.Username`: The user's username (string, optional)
|
||||
|
||||
Claims that fail to process (e.g., due to invalid templates) are logged and skipped,
|
||||
ensuring JWT generation continues without interruption.
|
||||
|
||||
|
||||
Example usage in YAML configuration:
|
||||
```yaml
|
||||
session:
|
||||
lifespan: 24h
|
||||
jwt_template:
|
||||
claims:
|
||||
role: "user" # Static value
|
||||
user_email: "{{.User.Email.Address}}" # Templated string
|
||||
is_verified: "{{.User.Email.IsVerified}}" # Boolean from user data
|
||||
metadata: # Nested map
|
||||
source: "hanko"
|
||||
greeting: "Hello {{.User.Username}}"
|
||||
scopes: # Slice with templated value
|
||||
- "read"
|
||||
- "write"
|
||||
- "{{if .User.Email.IsVerified}}admin{{else}}basic{{end}}"
|
||||
```
|
||||
|
||||
In this example:
|
||||
- `role` is a static string ("user").
|
||||
- `user_email` dynamically inserts the user's email address.
|
||||
- `is_verified` inserts a boolean indicating email verification status.
|
||||
- `metadata` is a nested map with a static `source` and a templated `greeting`.
|
||||
- `scopes` is a slice combining static values and a conditional template.
|
||||
|
||||
Notes:
|
||||
- Claims with the following keys will be ignored because they are currently added to the JWT
|
||||
by default:
|
||||
- sub
|
||||
- iat
|
||||
- exp
|
||||
- aud
|
||||
- iss
|
||||
- email
|
||||
- username
|
||||
- session_id
|
||||
- Templates must be valid Go `text/template` syntax. Invalid templates are logged and ignored.
|
||||
- Boolean strings ("true" or "false") from templates are automatically converted to actual booleans.
|
||||
- Use conditionals (e.g., `{{if .User.Email}}`) to handle optional fields safely.
|
||||
|
||||
For more details on template syntax, see: https://pkg.go.dev/text/template
|
||||
|
||||
## API specification
|
||||
|
||||
- [Hanko Public API](https://docs.hanko.io/api-reference/public/introduction)
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
@ -11,12 +15,12 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/persistence"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"log"
|
||||
)
|
||||
|
||||
func NewCreateCommand() *cobra.Command {
|
||||
var (
|
||||
configFile string
|
||||
pretty bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
@ -56,18 +60,13 @@ func NewCreateCommand() *cobra.Command {
|
||||
|
||||
userId := uuid.FromStringOrNil(args[0])
|
||||
|
||||
emails, err := persister.GetEmailPersister().FindByUserId(userId)
|
||||
userModel, err := persister.GetUserPersister().Get(userId)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to get emails from db: %s", err)
|
||||
fmt.Printf("failed to get user from db: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var emailJwt *dto.EmailJwt
|
||||
if e := emails.GetPrimary(); e != nil {
|
||||
emailJwt = dto.JwtFromEmailModel(e)
|
||||
}
|
||||
|
||||
token, rawToken, err := sessionManager.GenerateJWT(userId, emailJwt)
|
||||
token, rawToken, err := sessionManager.GenerateJWT(dto.UserJWTFromUserModel(userModel))
|
||||
if err != nil {
|
||||
fmt.Printf("failed to generate token: %s", err)
|
||||
return
|
||||
@ -91,11 +90,25 @@ func NewCreateCommand() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("token: %s", token)
|
||||
fmt.Printf("Token: %s\n", token)
|
||||
|
||||
if pretty {
|
||||
rawTokenMap, err := rawToken.AsMap(context.Background())
|
||||
if err != nil {
|
||||
fmt.Println("failed to get JWT payload as map:", err)
|
||||
return
|
||||
}
|
||||
payloadJSON, err := json.MarshalIndent(rawTokenMap, "", " ")
|
||||
if err != nil {
|
||||
fmt.Println("failed to marshal JWT payload as JSON:", err)
|
||||
}
|
||||
fmt.Printf("JWT payload: %s\n", string(payloadJSON))
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&configFile, "config", "", "config file")
|
||||
cmd.Flags().BoolVar(&pretty, "pretty", true, "pretty print the JWT payload")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@ -2,8 +2,9 @@ package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/invopop/jsonschema"
|
||||
"time"
|
||||
|
||||
"github.com/invopop/jsonschema"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
@ -38,6 +39,62 @@ type Session struct {
|
||||
// Deprecated. Use settings in parent object.
|
||||
//`server_side` contains configuration for server-side sessions.
|
||||
ServerSide *ServerSide `yaml:"server_side" json:"server_side" koanf:"server_side"`
|
||||
// `jwt_template` defines a template for adding custom `claims` to session JWTs.
|
||||
//
|
||||
// These claims are processed at JWT generation time and can include static values,
|
||||
// templated strings using Go's text/template syntax, or nested structures (maps and slices).
|
||||
//
|
||||
// The template has access to user data via the `.User` field, which includes:
|
||||
// - `.User.UserID`: The user's unique ID (string)
|
||||
// - `.User.Email`: Email details (optional, with `.Address`, `.IsPrimary`, `.IsVerified`)
|
||||
// - `.User.Username`: The user's username (string, optional)
|
||||
//
|
||||
// Claims that fail to process (e.g., due to invalid templates) are logged and skipped,
|
||||
// ensuring JWT generation continues without interruption.
|
||||
//
|
||||
//
|
||||
// Example usage in YAML configuration:
|
||||
// ```yaml
|
||||
// session:
|
||||
// lifespan: 24h
|
||||
// jwt_template:
|
||||
// claims:
|
||||
// role: "user" # Static value
|
||||
// user_email: "{{.User.Email.Address}}" # Templated string
|
||||
// is_verified: "{{.User.Email.IsVerified}}" # Boolean from user data
|
||||
// metadata: # Nested map
|
||||
// source: "hanko"
|
||||
// greeting: "Hello {{.User.Username}}"
|
||||
// scopes: # Slice with templated value
|
||||
// - "read"
|
||||
// - "write"
|
||||
// - "{{if .User.Email.IsVerified}}admin{{else}}basic{{end}}"
|
||||
// ```
|
||||
//
|
||||
// In this example:
|
||||
// - `role` is a static string ("user").
|
||||
// - `user_email` dynamically inserts the user's email address.
|
||||
// - `is_verified` inserts a boolean indicating email verification status.
|
||||
// - `metadata` is a nested map with a static `source` and a templated `greeting`.
|
||||
// - `scopes` is a slice combining static values and a conditional template.
|
||||
//
|
||||
// Notes:
|
||||
// - Claims with the following keys will be ignored because they are currently added to the JWT
|
||||
// by default:
|
||||
// - sub
|
||||
// - iat
|
||||
// - exp
|
||||
// - aud
|
||||
// - iss
|
||||
// - email
|
||||
// - username
|
||||
// - session_id
|
||||
// - Templates must be valid Go `text/template` syntax. Invalid templates are logged and ignored.
|
||||
// - Boolean strings ("true" or "false") from templates are automatically converted to actual booleans.
|
||||
// - Use conditionals (e.g., `{{if .User.Email}}`) to handle optional fields safely.
|
||||
//
|
||||
// For more details on template syntax, see: https://pkg.go.dev/text/template
|
||||
JWTTemplate *JWTTemplate `yaml:"jwt_template" json:"jwt_template,omitempty" koanf:"jwt_template"`
|
||||
}
|
||||
|
||||
func (s *Session) Validate() error {
|
||||
@ -45,6 +102,7 @@ func (s *Session) Validate() error {
|
||||
if err != nil {
|
||||
return errors.New("failed to parse lifespan")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -97,3 +155,7 @@ type ServerSide struct {
|
||||
// older sessions are invalidated.
|
||||
Limit int `yaml:"limit" json:"limit,omitempty" koanf:"limit" jsonschema:"default=100"`
|
||||
}
|
||||
|
||||
type JWTTemplate struct {
|
||||
Claims map[string]interface{} `yaml:"claims" json:"claims,omitempty" koanf:"claims"`
|
||||
}
|
||||
|
||||
@ -41,18 +41,18 @@ func FromEmailModel(email *models.Email, cfg *config.Config) *EmailResponse {
|
||||
return emailResponse
|
||||
}
|
||||
|
||||
type EmailJwt struct {
|
||||
type EmailJWT struct {
|
||||
Address string `json:"address"`
|
||||
IsPrimary bool `json:"is_primary"`
|
||||
IsVerified bool `json:"is_verified"`
|
||||
}
|
||||
|
||||
func JwtFromEmailModel(email *models.Email) *EmailJwt {
|
||||
func EmailJWTFromEmailModel(email *models.Email) *EmailJWT {
|
||||
if email == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &EmailJwt{
|
||||
return &EmailJWT{
|
||||
Address: email.Address,
|
||||
IsPrimary: email.IsPrimary(),
|
||||
IsVerified: email.Verified,
|
||||
|
||||
@ -3,11 +3,12 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/mileusna/useragent"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionData struct {
|
||||
@ -47,27 +48,55 @@ func FromSessionModel(model models.Session, current bool) SessionData {
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Subject uuid.UUID `json:"subject"`
|
||||
IssuedAt *time.Time `json:"issued_at,omitempty"`
|
||||
Expiration time.Time `json:"expiration"`
|
||||
Audience []string `json:"audience,omitempty"`
|
||||
Issuer *string `json:"issuer,omitempty"`
|
||||
Email *EmailJwt `json:"email,omitempty"`
|
||||
Username *string `json:"username,omitempty"`
|
||||
SessionID uuid.UUID `json:"session_id"`
|
||||
Subject uuid.UUID `json:"subject"`
|
||||
IssuedAt *time.Time `json:"issued_at,omitempty"`
|
||||
Expiration time.Time `json:"expiration"`
|
||||
Audience []string `json:"audience,omitempty"`
|
||||
Issuer *string `json:"issuer,omitempty"`
|
||||
Email *EmailJWT `json:"email,omitempty"`
|
||||
Username *string `json:"username,omitempty"`
|
||||
SessionID uuid.UUID `json:"session_id"`
|
||||
CustomClaims map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
type ValidateSessionResponse struct {
|
||||
IsValid bool `json:"is_valid"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
// deprecated
|
||||
ExpirationTime *time.Time `json:"expiration_time,omitempty"`
|
||||
// deprecated
|
||||
UserID *uuid.UUID `json:"user_id,omitempty"`
|
||||
// Custom MarshalJSON to flatten CustomClaims into the top level
|
||||
func (c Claims) MarshalJSON() ([]byte, error) {
|
||||
// Create a map to hold the flattened structure
|
||||
flattened := make(map[string]interface{})
|
||||
|
||||
// Marshal basic fields into the flattened map
|
||||
flattened["subject"] = c.Subject
|
||||
flattened["expiration"] = c.Expiration
|
||||
flattened["session_id"] = c.SessionID
|
||||
|
||||
if c.IssuedAt != nil {
|
||||
flattened["issued_at"] = c.IssuedAt
|
||||
}
|
||||
if len(c.Audience) > 0 {
|
||||
flattened["audience"] = c.Audience
|
||||
}
|
||||
if c.Issuer != nil {
|
||||
flattened["issuer"] = c.Issuer
|
||||
}
|
||||
if c.Email != nil {
|
||||
flattened["email"] = c.Email
|
||||
}
|
||||
if c.Username != nil {
|
||||
flattened["username"] = c.Username
|
||||
}
|
||||
|
||||
// Flatten CustomClaims into the top level
|
||||
for key, value := range c.CustomClaims {
|
||||
flattened[key] = value
|
||||
}
|
||||
|
||||
return json.Marshal(flattened)
|
||||
}
|
||||
|
||||
func GetClaimsFromToken(token jwt.Token) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
claims := &Claims{
|
||||
CustomClaims: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
if subject := token.Subject(); len(subject) > 0 {
|
||||
s, err := uuid.FromString(subject)
|
||||
@ -118,9 +147,30 @@ func GetClaimsFromToken(token jwt.Token) (*Claims, error) {
|
||||
|
||||
claims.Expiration = token.Expiration()
|
||||
|
||||
hankoClaims := map[string]bool{
|
||||
"email": true,
|
||||
"username": true,
|
||||
"session_id": true,
|
||||
}
|
||||
|
||||
for key, value := range token.PrivateClaims() {
|
||||
if !hankoClaims[key] {
|
||||
claims.CustomClaims[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
type ValidateSessionResponse struct {
|
||||
IsValid bool `json:"is_valid"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
// deprecated
|
||||
ExpirationTime *time.Time `json:"expiration_time,omitempty"`
|
||||
// deprecated
|
||||
UserID *uuid.UUID `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type ValidateSessionRequest struct {
|
||||
SessionToken string `json:"session_token" validate:"required"`
|
||||
}
|
||||
|
||||
222
backend/dto/session_test.go
Normal file
222
backend/dto/session_test.go
Normal file
@ -0,0 +1,222 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetClaimsFromToken(t *testing.T) {
|
||||
subject := uuid.Must(uuid.NewV4())
|
||||
sessionID := uuid.Must(uuid.NewV4())
|
||||
now := time.Now()
|
||||
expiration := now.Add(1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token jwt.Token
|
||||
expected *Claims
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "valid token with all claims",
|
||||
token: func() jwt.Token {
|
||||
token, _ := jwt.NewBuilder().
|
||||
Subject(subject.String()).
|
||||
IssuedAt(now).
|
||||
Audience([]string{"test-audience"}).
|
||||
Issuer("test-issuer").
|
||||
Expiration(expiration).
|
||||
Claim("session_id", sessionID.String()).
|
||||
Claim("email", map[string]interface{}{
|
||||
"address": "test@example.com",
|
||||
"is_verified": true,
|
||||
"is_primary": true,
|
||||
}).
|
||||
Claim("username", "testuser").
|
||||
Claim("custom", "value").
|
||||
Build()
|
||||
return token
|
||||
}(),
|
||||
expected: &Claims{
|
||||
Subject: subject,
|
||||
SessionID: sessionID,
|
||||
IssuedAt: &now,
|
||||
Audience: []string{"test-audience"},
|
||||
Issuer: stringPtr("test-issuer"),
|
||||
Email: &EmailJWT{
|
||||
Address: "test@example.com",
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
},
|
||||
Username: stringPtr("testuser"),
|
||||
Expiration: expiration,
|
||||
CustomClaims: map[string]interface{}{
|
||||
"custom": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := GetClaimsFromToken(tt.token)
|
||||
if tt.expectedError != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, claims)
|
||||
|
||||
// Compare the claims
|
||||
if tt.expected != nil {
|
||||
assert.Equal(t, tt.expected.Subject, claims.Subject)
|
||||
assert.Equal(t, tt.expected.SessionID, claims.SessionID)
|
||||
assert.Equal(t, tt.expected.Audience, claims.Audience)
|
||||
assert.Equal(t, tt.expected.Issuer, claims.Issuer)
|
||||
assert.Equal(t, tt.expected.Username, claims.Username)
|
||||
assert.Equal(t, tt.expected.CustomClaims, claims.CustomClaims)
|
||||
|
||||
if tt.expected.Email != nil {
|
||||
assert.Equal(t, tt.expected.Email.Address, claims.Email.Address)
|
||||
assert.Equal(t, tt.expected.Email.IsVerified, claims.Email.IsVerified)
|
||||
assert.Equal(t, tt.expected.Email.IsPrimary, claims.Email.IsPrimary)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaims_MarshalJSON(t *testing.T) {
|
||||
subject := uuid.Must(uuid.NewV4())
|
||||
sessionID := uuid.Must(uuid.NewV4())
|
||||
now := time.Now().Truncate(time.Second)
|
||||
expiration := now.Add(1 * time.Hour)
|
||||
username := "testuser"
|
||||
issuer := "test-issuer"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims Claims
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "all fields populated",
|
||||
claims: Claims{
|
||||
Subject: subject,
|
||||
SessionID: sessionID,
|
||||
IssuedAt: &now,
|
||||
Audience: []string{"test-audience"},
|
||||
Issuer: &issuer,
|
||||
Email: &EmailJWT{
|
||||
Address: "test@example.com",
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
},
|
||||
Username: &username,
|
||||
Expiration: expiration,
|
||||
CustomClaims: map[string]interface{}{
|
||||
"custom": "value",
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"subject": subject.String(),
|
||||
"session_id": sessionID.String(),
|
||||
"issued_at": now,
|
||||
"audience": []interface{}{"test-audience"},
|
||||
"issuer": issuer,
|
||||
"email": map[string]interface{}{
|
||||
"address": "test@example.com",
|
||||
"is_verified": true,
|
||||
"is_primary": true,
|
||||
},
|
||||
"username": username,
|
||||
"expiration": expiration,
|
||||
"custom": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal fields",
|
||||
claims: Claims{
|
||||
Subject: subject,
|
||||
SessionID: sessionID,
|
||||
Expiration: expiration,
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"subject": subject.String(),
|
||||
"session_id": sessionID.String(),
|
||||
"expiration": expiration,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with custom claims only",
|
||||
claims: Claims{
|
||||
Subject: subject,
|
||||
SessionID: sessionID,
|
||||
Expiration: expiration,
|
||||
CustomClaims: map[string]interface{}{
|
||||
"custom1": "value1",
|
||||
"custom2": "value2",
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"subject": subject.String(),
|
||||
"session_id": sessionID.String(),
|
||||
"expiration": expiration,
|
||||
"custom1": "value1",
|
||||
"custom2": "value2",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal the claims to JSON
|
||||
jsonData, err := json.Marshal(tt.claims)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, jsonData)
|
||||
|
||||
// Unmarshal the JSON back to a map for comparison
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(jsonData, &result)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Compare the expected and actual results
|
||||
for key, expectedValue := range tt.expected {
|
||||
actualValue := result[key]
|
||||
switch v := expectedValue.(type) {
|
||||
case time.Time:
|
||||
// For time values, compare the string representation after truncating to seconds
|
||||
expectedTime := v.Truncate(time.Second).UTC()
|
||||
actualTime, err := time.Parse(time.RFC3339, actualValue.(string))
|
||||
assert.NoError(t, err)
|
||||
actualTime = actualTime.Truncate(time.Second).UTC()
|
||||
assert.Equal(t, expectedTime, actualTime, "time mismatch for key: %s", key)
|
||||
case *time.Time:
|
||||
// For pointer to time values, compare the string representation after truncating to seconds
|
||||
expectedTime := v.Truncate(time.Second).UTC()
|
||||
actualTime, err := time.Parse(time.RFC3339, actualValue.(string))
|
||||
assert.NoError(t, err)
|
||||
actualTime = actualTime.Truncate(time.Second).UTC()
|
||||
assert.Equal(t, expectedTime, actualTime, "time mismatch for key: %s", key)
|
||||
case uuid.UUID:
|
||||
// For UUID values, compare the string representation
|
||||
assert.Equal(t, v.String(), actualValue, "UUID mismatch for key: %s", key)
|
||||
default:
|
||||
assert.Equal(t, expectedValue, actualValue, "mismatch for key: %s", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a string pointer
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@ -1,9 +1,10 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CreateUserResponse struct {
|
||||
@ -27,3 +28,26 @@ type UserInfoResponse struct {
|
||||
Verified bool `json:"verified"`
|
||||
HasWebauthnCredential bool `json:"has_webauthn_credential"`
|
||||
}
|
||||
|
||||
// UserJWT represents an abstracted user model for session management
|
||||
type UserJWT struct {
|
||||
UserID string
|
||||
Email *EmailJWT
|
||||
Username string
|
||||
}
|
||||
|
||||
func UserJWTFromUserModel(userModel *models.User) UserJWT {
|
||||
userJWT := UserJWT{
|
||||
UserID: userModel.ID.String(),
|
||||
}
|
||||
|
||||
if primaryEmail := userModel.Emails.GetPrimary(); primaryEmail != nil {
|
||||
userJWT.Email = EmailJWTFromEmailModel(primaryEmail)
|
||||
}
|
||||
|
||||
if userModel.Username != nil {
|
||||
userJWT.Username = userModel.Username.Username
|
||||
}
|
||||
|
||||
return userJWT
|
||||
}
|
||||
|
||||
@ -3,14 +3,14 @@ package shared
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gobuffalo/nulls"
|
||||
"github.com/gofrs/uuid"
|
||||
auditlog "github.com/teamhanko/hanko/backend/audit_log"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"github.com/teamhanko/hanko/backend/flowpilot"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IssueSession struct {
|
||||
@ -36,17 +36,9 @@ func (h IssueSession) Execute(c flowpilot.HookExecutionContext) error {
|
||||
return fmt.Errorf("failed to fetch user from db: %w", err)
|
||||
}
|
||||
|
||||
var emailDTO *dto.EmailJwt
|
||||
if email := userModel.Emails.GetPrimary(); email != nil {
|
||||
emailDTO = dto.JwtFromEmailModel(email)
|
||||
}
|
||||
userJWT := dto.UserJWTFromUserModel(userModel)
|
||||
|
||||
var generateJWTOptions []session.JWTOptions
|
||||
if userModel.Username != nil {
|
||||
generateJWTOptions = append(generateJWTOptions, session.WithValue("username", userModel.Username.Username))
|
||||
}
|
||||
|
||||
signedSessionToken, rawToken, err := deps.SessionManager.GenerateJWT(userId, emailDTO, generateJWTOptions...)
|
||||
signedSessionToken, rawToken, err := deps.SessionManager.GenerateJWT(userJWT)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate JWT: %w", err)
|
||||
}
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/teamhanko/hanko/backend/crypto/jwk"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"github.com/teamhanko/hanko/backend/persistence"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"github.com/teamhanko/hanko/backend/test"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getDefaultSessionManager(storage persistence.Persister) session.Manager {
|
||||
@ -19,7 +21,9 @@ func getDefaultSessionManager(storage persistence.Persister) session.Manager {
|
||||
|
||||
func generateSessionCookie(storage persistence.Persister, userId uuid.UUID) (*http.Cookie, error) {
|
||||
manager := getDefaultSessionManager(storage)
|
||||
token, rawToken, err := manager.GenerateJWT(userId, nil)
|
||||
token, rawToken, err := manager.GenerateJWT(dto.UserJWT{
|
||||
UserID: userId.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -404,12 +404,15 @@ func (h *PasscodeHandler) Finish(c echo.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
var emailJwt *dto.EmailJwt
|
||||
var emailJwt *dto.EmailJWT
|
||||
if e := userModel.Emails.GetPrimary(); e != nil {
|
||||
emailJwt = dto.JwtFromEmailModel(e)
|
||||
emailJwt = dto.EmailJWTFromEmailModel(e)
|
||||
}
|
||||
|
||||
token, rawToken, err := h.sessionManager.GenerateJWT(*passcode.UserId, emailJwt)
|
||||
token, rawToken, err := h.sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: passcode.UserId.String(),
|
||||
Email: emailJwt,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt: %w", err)
|
||||
}
|
||||
|
||||
@ -4,6 +4,11 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
@ -13,10 +18,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"github.com/teamhanko/hanko/backend/test"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPasscodeSuite(t *testing.T) {
|
||||
@ -301,13 +302,17 @@ func (s *passcodeSuite) TestPasscodeHandler_Finish() {
|
||||
req := httptest.NewRequest(http.MethodPost, "/passcode/login/finalize", bytes.NewReader(bodyJson))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if currentTest.sendSessionTokenInAuthHeader {
|
||||
sessionToken, _, err := sessionManager.GenerateJWT(uuid.FromStringOrNil(currentTest.userId), nil)
|
||||
sessionToken, _, err := sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: currentTest.userId,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", sessionToken))
|
||||
}
|
||||
|
||||
if currentTest.sendSessionTokenInCookie {
|
||||
sessionToken, _, err := sessionManager.GenerateJWT(uuid.FromStringOrNil(currentTest.userId), nil)
|
||||
sessionToken, _, err := sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: currentTest.userId,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
sessionCookie, err := sessionManager.GenerateCookie(sessionToken)
|
||||
|
||||
@ -3,12 +3,15 @@ package handler
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/sethvargo/go-limiter"
|
||||
"github.com/teamhanko/hanko/backend/audit_log"
|
||||
auditlog "github.com/teamhanko/hanko/backend/audit_log"
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"github.com/teamhanko/hanko/backend/persistence"
|
||||
@ -16,8 +19,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/rate_limiter"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"net/http"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type PasswordHandler struct {
|
||||
@ -218,12 +219,15 @@ func (h *PasswordHandler) Login(c echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusUnauthorized).SetInternal(err)
|
||||
}
|
||||
|
||||
var emailJwt *dto.EmailJwt
|
||||
var emailJwt *dto.EmailJWT
|
||||
if e := user.Emails.GetPrimary(); e != nil {
|
||||
emailJwt = dto.JwtFromEmailModel(e)
|
||||
emailJwt = dto.EmailJWTFromEmailModel(e)
|
||||
}
|
||||
|
||||
token, rawToken, err := h.sessionManager.GenerateJWT(pw.UserId, emailJwt)
|
||||
token, rawToken, err := h.sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: pw.UserId.String(),
|
||||
Email: emailJwt,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt: %w", err)
|
||||
}
|
||||
|
||||
@ -2,6 +2,8 @@ package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gobuffalo/nulls"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
@ -13,7 +15,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/persistence"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type SessionAdminHandler struct {
|
||||
@ -56,12 +57,15 @@ func (h *SessionAdminHandler) Generate(ctx echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusNotFound, "user not found")
|
||||
}
|
||||
|
||||
var emailDTO *dto.EmailJwt
|
||||
var emailDTO *dto.EmailJWT
|
||||
if email := user.Emails.GetPrimary(); email != nil {
|
||||
emailDTO = dto.JwtFromEmailModel(email)
|
||||
emailDTO = dto.EmailJWTFromEmailModel(email)
|
||||
}
|
||||
|
||||
encodedToken, rawToken, err := h.sessionManger.GenerateJWT(userID, emailDTO)
|
||||
encodedToken, rawToken, err := h.sessionManger.GenerateJWT(dto.UserJWT{
|
||||
UserID: userID.String(),
|
||||
Email: emailDTO,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate JWT: %w", err)
|
||||
}
|
||||
|
||||
@ -3,6 +3,9 @@ package handler
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
@ -14,8 +17,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
rateLimit "github.com/teamhanko/hanko/backend/rate_limiter"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TokenHandler struct {
|
||||
@ -87,12 +88,15 @@ func (h TokenHandler) Validate(c echo.Context) error {
|
||||
return fmt.Errorf("failed to get emails from db: %w", err)
|
||||
}
|
||||
|
||||
var emailJwt *dto.EmailJwt
|
||||
var emailJwt *dto.EmailJWT
|
||||
if e := emails.GetPrimary(); e != nil {
|
||||
emailJwt = dto.JwtFromEmailModel(e)
|
||||
emailJwt = dto.EmailJWTFromEmailModel(e)
|
||||
}
|
||||
|
||||
jwtToken, rawToken, err := h.sessionManager.GenerateJWT(token.UserID, emailJwt)
|
||||
jwtToken, rawToken, err := h.sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: token.UserID.String(),
|
||||
Email: emailJwt,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt: %w", err)
|
||||
}
|
||||
|
||||
@ -3,11 +3,14 @@ package handler
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/teamhanko/hanko/backend/audit_log"
|
||||
auditlog "github.com/teamhanko/hanko/backend/audit_log"
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"github.com/teamhanko/hanko/backend/dto/admin"
|
||||
@ -16,8 +19,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"github.com/teamhanko/hanko/backend/webhooks/events"
|
||||
"github.com/teamhanko/hanko/backend/webhooks/utils"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
@ -110,12 +111,15 @@ func (h *UserHandler) Create(c echo.Context) error {
|
||||
return fmt.Errorf("failed to get email from db: %w", err)
|
||||
}
|
||||
|
||||
var emailJwt *dto.EmailJwt
|
||||
var emailJwt *dto.EmailJWT
|
||||
if e := emails.GetPrimary(); e != nil {
|
||||
emailJwt = dto.JwtFromEmailModel(e)
|
||||
emailJwt = dto.EmailJWTFromEmailModel(e)
|
||||
}
|
||||
|
||||
token, _, err := h.sessionManager.GenerateJWT(newUser.ID, emailJwt)
|
||||
token, _, err := h.sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: newUser.ID.String(),
|
||||
Email: emailJwt,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt: %w", err)
|
||||
|
||||
@ -4,13 +4,17 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"github.com/gobuffalo/pop/v6"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/teamhanko/hanko/backend/audit_log"
|
||||
auditlog "github.com/teamhanko/hanko/backend/audit_log"
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"github.com/teamhanko/hanko/backend/dto/intern"
|
||||
@ -18,9 +22,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/persistence"
|
||||
"github.com/teamhanko/hanko/backend/persistence/models"
|
||||
"github.com/teamhanko/hanko/backend/session"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WebauthnHandler struct {
|
||||
@ -418,12 +419,15 @@ func (h *WebauthnHandler) FinishAuthentication(c echo.Context) error {
|
||||
return fmt.Errorf("failed to delete assertion session data: %w", err)
|
||||
}
|
||||
|
||||
var emailJwt *dto.EmailJwt
|
||||
var emailJwt *dto.EmailJWT
|
||||
if e := user.Emails.GetPrimary(); e != nil {
|
||||
emailJwt = dto.JwtFromEmailModel(e)
|
||||
emailJwt = dto.EmailJWTFromEmailModel(e)
|
||||
}
|
||||
|
||||
token, rawToken, err := h.sessionManager.GenerateJWT(webauthnUser.UserId, emailJwt)
|
||||
token, rawToken, err := h.sessionManager.GenerateJWT(dto.UserJWT{
|
||||
UserID: user.ID.String(),
|
||||
Email: emailJwt,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jwt: %w", err)
|
||||
}
|
||||
|
||||
@ -2,18 +2,19 @@ package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
hankoJwk "github.com/teamhanko/hanko/backend/crypto/jwk"
|
||||
hankoJwt "github.com/teamhanko/hanko/backend/crypto/jwt"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error)
|
||||
GenerateJWT(user dto.UserJWT, opts ...JWTOptions) (string, jwt.Token, error)
|
||||
Verify(string) (jwt.Token, error)
|
||||
GenerateCookie(token string) (*http.Cookie, error)
|
||||
DeleteCookie() (*http.Cookie, error)
|
||||
@ -26,6 +27,7 @@ type manager struct {
|
||||
cookieConfig cookieConfig
|
||||
issuer string
|
||||
audience []string
|
||||
jwtTemplate *config.JWTTemplate
|
||||
}
|
||||
|
||||
type cookieConfig struct {
|
||||
@ -85,17 +87,26 @@ func NewManager(jwkManager hankoJwk.Manager, config config.Config) (Manager, err
|
||||
SameSite: sameSite,
|
||||
Secure: config.Session.Cookie.Secure,
|
||||
},
|
||||
audience: audience,
|
||||
audience: audience,
|
||||
jwtTemplate: config.Session.JWTTemplate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateJWT creates a new session JWT for the given user
|
||||
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error) {
|
||||
func (m *manager) GenerateJWT(user dto.UserJWT, opts ...JWTOptions) (string, jwt.Token, error) {
|
||||
token := jwt.New()
|
||||
|
||||
// Process the claim template if found
|
||||
if m.jwtTemplate != nil {
|
||||
if err := ProcessJWTTemplate(token, m.jwtTemplate.Claims, user); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
issuedAt := time.Now()
|
||||
expiration := issuedAt.Add(m.sessionLength)
|
||||
|
||||
token := jwt.New()
|
||||
_ = token.Set(jwt.SubjectKey, userId.String())
|
||||
_ = token.Set(jwt.SubjectKey, user.UserID)
|
||||
_ = token.Set(jwt.IssuedAtKey, issuedAt)
|
||||
_ = token.Set(jwt.ExpirationKey, expiration)
|
||||
_ = token.Set(jwt.AudienceKey, m.audience)
|
||||
@ -106,8 +117,12 @@ func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt, opts ...JWT
|
||||
}
|
||||
_ = token.Set("session_id", sessionID.String())
|
||||
|
||||
if email != nil {
|
||||
_ = token.Set("email", &email)
|
||||
if user.Email != nil {
|
||||
_ = token.Set("email", user.Email)
|
||||
}
|
||||
|
||||
if user.Username != "" {
|
||||
_ = token.Set("username", user.Username)
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
|
||||
@ -2,6 +2,9 @@ package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -9,8 +12,6 @@ import (
|
||||
"github.com/teamhanko/hanko/backend/config"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
"github.com/teamhanko/hanko/backend/test"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewGenerator(t *testing.T) {
|
||||
@ -31,7 +32,9 @@ func TestGenerator_Generate(t *testing.T) {
|
||||
userId, err := uuid.NewV4()
|
||||
assert.NoError(t, err)
|
||||
|
||||
session, _, err := sessionGenerator.GenerateJWT(userId, nil)
|
||||
session, _, err := sessionGenerator.GenerateJWT(dto.UserJWT{
|
||||
UserID: userId.String(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
require.NotEmpty(t, session)
|
||||
}
|
||||
@ -51,13 +54,16 @@ func TestGenerator_Verify(t *testing.T) {
|
||||
|
||||
testEmail := "lorem@ipsum.local"
|
||||
|
||||
emailDto := &dto.EmailJwt{
|
||||
emailDto := &dto.EmailJWT{
|
||||
Address: testEmail,
|
||||
IsPrimary: true,
|
||||
IsVerified: false,
|
||||
}
|
||||
|
||||
session, _, err := sessionGenerator.GenerateJWT(userId, emailDto)
|
||||
session, _, err := sessionGenerator.GenerateJWT(dto.UserJWT{
|
||||
UserID: userId.String(),
|
||||
Email: emailDto,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
require.NotEmpty(t, session)
|
||||
|
||||
@ -72,9 +78,9 @@ func TestGenerator_Verify(t *testing.T) {
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, emailClaim)
|
||||
|
||||
// Workaround as .(EmailJwt) interface conversion is not possible
|
||||
// Workaround as .(EmailJWT) interface conversion is not possible
|
||||
emailJson, _ := json.Marshal(emailClaim)
|
||||
var tokenEmail dto.EmailJwt
|
||||
var tokenEmail dto.EmailJWT
|
||||
_ = json.Unmarshal(emailJson, &tokenEmail)
|
||||
|
||||
assert.Equal(t, testEmail, tokenEmail.Address)
|
||||
@ -103,7 +109,9 @@ func TestManager_GenerateJWT_IssAndAud(t *testing.T) {
|
||||
require.NotEmpty(t, sessionGenerator)
|
||||
|
||||
userId, _ := uuid.NewV4()
|
||||
j, _, err := sessionGenerator.GenerateJWT(userId, nil)
|
||||
j, _, err := sessionGenerator.GenerateJWT(dto.UserJWT{
|
||||
UserID: userId.String(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
token, err := jwt.ParseString(j, jwt.WithVerify(false))
|
||||
@ -134,7 +142,9 @@ func TestManager_GenerateJWT_AdditionalAudiences(t *testing.T) {
|
||||
require.NotEmpty(t, sessionGenerator)
|
||||
|
||||
userId, _ := uuid.NewV4()
|
||||
j, _, err := sessionGenerator.GenerateJWT(userId, nil)
|
||||
j, _, err := sessionGenerator.GenerateJWT(dto.UserJWT{
|
||||
UserID: userId.String(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
token, err := jwt.ParseString(j, jwt.WithVerify(false))
|
||||
@ -168,7 +178,9 @@ func Test_GenerateJWT_SessionID(t *testing.T) {
|
||||
require.NotEmpty(t, sessionGenerator)
|
||||
|
||||
userId, _ := uuid.NewV4()
|
||||
tokenString, _, err := sessionGenerator.GenerateJWT(userId, nil)
|
||||
tokenString, _, err := sessionGenerator.GenerateJWT(dto.UserJWT{
|
||||
UserID: userId.String(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
token, err := jwt.ParseString(tokenString, jwt.WithVerify(false))
|
||||
|
||||
95
backend/session/template.go
Normal file
95
backend/session/template.go
Normal file
@ -0,0 +1,95 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"text/template"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
)
|
||||
|
||||
// JWTTemplateData holds the data available for template processing
|
||||
type JWTTemplateData struct {
|
||||
User *dto.UserJWT
|
||||
}
|
||||
|
||||
// ProcessJWTTemplate processes a map of claims using the provided user data and sets them on the token
|
||||
func ProcessJWTTemplate(token jwt.Token, claims map[string]interface{}, user dto.UserJWT) error {
|
||||
claimTemplateData := JWTTemplateData{
|
||||
User: &user,
|
||||
}
|
||||
for key, value := range claims {
|
||||
processedValue, err := processClaimTemplate(value, claimTemplateData)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("session", key).Msgf("failed to process custom JWT claim template: %+v", value)
|
||||
continue
|
||||
}
|
||||
err = token.Set(key, processedValue)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Str("session", key).Msgf("failed to set processed JWT claim %+v to token", value)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// processClaimTemplate processes a claim value, handling both string templates and nested structures
|
||||
func processClaimTemplate(value interface{}, data JWTTemplateData) (interface{}, error) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return parseClaimTemplateValue(v, data)
|
||||
case map[string]interface{}:
|
||||
result := make(map[string]interface{})
|
||||
for key, val := range v {
|
||||
processed, err := processClaimTemplate(val, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[key] = processed
|
||||
}
|
||||
return result, nil
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, val := range v {
|
||||
processed, err := processClaimTemplate(val, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[i] = processed
|
||||
}
|
||||
return result, nil
|
||||
default:
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// parseClaimTemplateValue parses and executes a template string using the provided data
|
||||
func parseClaimTemplateValue(tmplStr string, data JWTTemplateData) (interface{}, error) {
|
||||
tmpl, err := template.New("").Parse(tmplStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse template: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err = tmpl.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("failed to execute template: %w", err)
|
||||
}
|
||||
|
||||
// "Workaround"/"hack" for when the template expression evaluates to a boolean string, i.e. "true"
|
||||
// or "false". This converts it to a bool for consistency's sake (i.e. to prevent that both boolean
|
||||
// values and boolean strings are eventually set in the JWT).
|
||||
resultString := buf.String()
|
||||
if resultString == "true" || resultString == "false" {
|
||||
b, err := strconv.ParseBool(buf.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse string as bool: %w", err)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
return resultString, nil
|
||||
}
|
||||
261
backend/session/template_test.go
Normal file
261
backend/session/template_test.go
Normal file
@ -0,0 +1,261 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/teamhanko/hanko/backend/dto"
|
||||
)
|
||||
|
||||
func TestProcessTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
data JWTTemplateData
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple template",
|
||||
template: "Hello {{.User.Email.Address}}",
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "Hello test@example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "template with pipelining",
|
||||
template: "Hello {{.User.Email.Address | printf \"%s\" }}",
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "Hello test@example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "template with conditional",
|
||||
template: "{{if .User.Email.IsVerified}}Verified{{else}}Unverified{{end}} user {{.User.Email.Address}}",
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
IsVerified: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "Verified user test@example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid template",
|
||||
template: "Hello {{.InvalidField}}",
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{},
|
||||
},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: "",
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{},
|
||||
},
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseClaimTemplateValue(tt.template, tt.data)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClaimValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
data JWTTemplateData
|
||||
want interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "string template",
|
||||
value: "Hello {{.User.Email.Address}}",
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: "Hello test@example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nested map with templates",
|
||||
value: map[string]interface{}{
|
||||
"greeting": "Hello {{.User.Email.Address}}",
|
||||
"nested": map[string]interface{}{
|
||||
"message": "Welcome {{.User.Email.Address}}",
|
||||
},
|
||||
},
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: map[string]interface{}{
|
||||
"greeting": "Hello test@example.com",
|
||||
"nested": map[string]interface{}{
|
||||
"message": "Welcome test@example.com",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "slice with templates",
|
||||
value: []interface{}{
|
||||
"Hello {{.User.Email.Address}}",
|
||||
map[string]interface{}{
|
||||
"message": "Welcome {{.User.Email.Address}}",
|
||||
},
|
||||
},
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []interface{}{
|
||||
"Hello test@example.com",
|
||||
map[string]interface{}{
|
||||
"message": "Welcome test@example.com",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-string primitive",
|
||||
value: 42,
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{},
|
||||
},
|
||||
want: 42,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid template in map",
|
||||
value: map[string]interface{}{
|
||||
"message": "Hello {{.InvalidField}}",
|
||||
},
|
||||
data: JWTTemplateData{
|
||||
User: &dto.UserJWT{},
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := processClaimTemplate(tt.value, tt.data)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClaimTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims map[string]interface{}
|
||||
user dto.UserJWT
|
||||
expectedClaims map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "successful claim processing",
|
||||
claims: map[string]interface{}{
|
||||
"email": "{{.User.Email.Address}}",
|
||||
"verified": "{{.User.Email.IsVerified}}",
|
||||
"static_string": "static-value",
|
||||
"static_bool": false,
|
||||
},
|
||||
user: dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
IsVerified: true,
|
||||
},
|
||||
},
|
||||
expectedClaims: map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"verified": true,
|
||||
"static_string": "static-value",
|
||||
"static_bool": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "partial claim processing with errors",
|
||||
claims: map[string]interface{}{
|
||||
"valid": "{{.User.Email.Address}}",
|
||||
"invalid": "{{.InvalidField}}",
|
||||
"static": "static-value",
|
||||
},
|
||||
user: dto.UserJWT{
|
||||
Email: &dto.EmailJWT{
|
||||
Address: "test@example.com",
|
||||
},
|
||||
},
|
||||
expectedClaims: map[string]interface{}{
|
||||
"valid": "test@example.com",
|
||||
"static": "static-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := jwt.New()
|
||||
err := ProcessJWTTemplate(token, tt.claims, tt.user)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify each expected claim
|
||||
for key, expectedValue := range tt.expectedClaims {
|
||||
value, exists := token.Get(key)
|
||||
assert.True(t, exists, "claim %s should exist", key)
|
||||
assert.Equal(t, expectedValue, value, "claim %s should have correct value", key)
|
||||
}
|
||||
|
||||
// For the error case, verify the invalid claim was not set
|
||||
if tt.name == "partial claim processing with errors" {
|
||||
_, exists := token.Get("invalid")
|
||||
assert.False(t, exists, "invalid claim should not be set")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -216,6 +216,52 @@ language to use for outgoing emails. If you have disabled email delivery through
|
||||
`email.send` event, the value for the `lang` parameter is reflected in the JWT payload of the token contained in the
|
||||
webhook request in the "Language" claim.
|
||||
|
||||
### Custom session claim type safety
|
||||
|
||||
The Hanko backend allows you to define custom claims that are added to issued session JWTs
|
||||
(see [here](https://github.com/teamhanko/hanko/blob/main/backend/README.md#session-jwt-templates) for more info).
|
||||
|
||||
To allow for IDE autocompletion and to maintain type safety for your custom claims:
|
||||
|
||||
1. Create a TypeScript definition file (`*.d.ts`) in your project (alternatively, modify an existing one).
|
||||
2. Import the `Claims` type from the frontend SDK.
|
||||
3. Declare a custom type that extends the `Claims` type.
|
||||
4. Add your custom claims to your custom type.
|
||||
|
||||
```ts
|
||||
import type { Claims } from "@teamhanko/hanko-frontend-sdk" // 2.
|
||||
// import type { Claims } from "@teamhanko/elements" // alternatively, if you use Hanko Elements, which
|
||||
// re-exports most SDK types
|
||||
|
||||
|
||||
type CustomClaims = Claims<{ // 3.
|
||||
custom_claim?: string // 4.
|
||||
}>;
|
||||
```
|
||||
|
||||
5. Use your custom type when accessing claims, e.g. in session details received in event callbacks or when accessing
|
||||
claims in responses from session validation
|
||||
[endpoints](https://docs.hanko.io/api-reference/public/session-management/validate-a-session):
|
||||
|
||||
```ts
|
||||
import type { CustomClaims } from "..."; // path to your type declaration file
|
||||
|
||||
hanko.onSessionCreated((sessionDetail) => {
|
||||
const claims = sessionDetail.claims as CustomClaims;
|
||||
console.info("My custom claim:", claims.custom_claim);
|
||||
});
|
||||
```
|
||||
|
||||
```ts
|
||||
import type { CustomClaims } from "..."; // path to your type declaration file
|
||||
|
||||
async function session() {
|
||||
const session = await hanko.sessionClient.validate();
|
||||
const claims = session.claims as CustomClaims;
|
||||
console.info("My custom claim:", claims.custom_claim);
|
||||
};
|
||||
```
|
||||
|
||||
## Bugs
|
||||
|
||||
Found a bug? Please report on our [GitHub](https://github.com/teamhanko/hanko/issues) page.
|
||||
|
||||
@ -51,6 +51,7 @@ import {
|
||||
WebauthnCredentials,
|
||||
Identity,
|
||||
SessionCheckResponse,
|
||||
Claims,
|
||||
} from "./lib/Dto";
|
||||
|
||||
export type {
|
||||
@ -74,6 +75,7 @@ export type {
|
||||
WebauthnCredentials,
|
||||
Identity,
|
||||
SessionCheckResponse,
|
||||
Claims,
|
||||
};
|
||||
|
||||
// Errors
|
||||
|
||||
@ -245,7 +245,12 @@ export interface Identity {
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the claims associated with a session or token.
|
||||
* Represents the claims associated with a session or token. Includes standard claims such as `subject`, `issued_at`,
|
||||
* `expiration`, and others, as well as custom claims defined by the user.
|
||||
*
|
||||
* @template TCustomClaims - An optional generic parameter that represents custom claims.
|
||||
* It extends a record with string keys and unknown values.
|
||||
* Defaults to `Record<string, unknown>` if not provided.
|
||||
*
|
||||
* @interface
|
||||
* @category SDK
|
||||
@ -258,8 +263,13 @@ export interface Identity {
|
||||
* @property {Pick<Email, "address" | "is_primary" | "is_verified">} [email] - Email information associated with the subject (optional).
|
||||
* @property {string} [username] - The subject's username (optional).
|
||||
* @property {string} session_id - The session identifier linked to the claims.
|
||||
*
|
||||
* @description Custom claims can be added via the `TCustomClaims` generic parameter, which will be merged
|
||||
* with the standard claims properties. These custom claims must follow the `Record<string, unknown>` pattern.
|
||||
*/
|
||||
export interface Claims {
|
||||
export type Claims<
|
||||
TCustomClaims extends Record<string, unknown> = Record<string, unknown>,
|
||||
> = {
|
||||
subject: string;
|
||||
issued_at?: string;
|
||||
expiration: string;
|
||||
@ -268,7 +278,7 @@ export interface Claims {
|
||||
email?: Pick<Email, "address" | "is_primary" | "is_verified">;
|
||||
username?: string;
|
||||
session_id: string;
|
||||
}
|
||||
} & TCustomClaims;
|
||||
|
||||
/**
|
||||
* Represents the response from a session validation or retrieval operation.
|
||||
|
||||
Reference in New Issue
Block a user