mirror of
https://github.com/teamhanko/hanko.git
synced 2025-10-28 06:37:57 +08:00
* feat: add auth prompt config option * feat: add pkce oauth flow When the oauth flow is initialized with a `code_verifier` the state cookie is optional and on hanko_token exchange the client must also send the `code_verifier` in addition to the `hanko_token`. * fix: fix runtime errors & tests
124 lines
3.1 KiB
Go
124 lines
3.1 KiB
Go
package thirdparty
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/teamhanko/hanko/backend/v2/config"
|
|
"github.com/teamhanko/hanko/backend/v2/crypto"
|
|
"github.com/teamhanko/hanko/backend/v2/crypto/aes_gcm"
|
|
"time"
|
|
)
|
|
|
|
func GenerateStateForFlowAPI(isFlow bool) func(*State) {
|
|
return func(state *State) {
|
|
state.IsFlow = isFlow
|
|
}
|
|
}
|
|
|
|
func GenerateStateWithPKCECodeVerifier(codeVerifier string) func(state *State) {
|
|
return func(state *State) {
|
|
if codeVerifier != "" {
|
|
state.CodeVerifier = codeVerifier
|
|
}
|
|
}
|
|
}
|
|
|
|
func GenerateState(config *config.Config, provider string, redirectTo string, options ...func(*State)) ([]byte, error) {
|
|
if provider == "" {
|
|
return nil, errors.New("provider must be present")
|
|
}
|
|
|
|
if redirectTo == "" {
|
|
redirectTo = config.ThirdParty.ErrorRedirectURL
|
|
}
|
|
|
|
nonce, err := crypto.GenerateRandomStringURLSafe(32)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not generate nonce: %w", err)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
state := State{
|
|
Provider: provider,
|
|
RedirectTo: redirectTo,
|
|
IssuedAt: now,
|
|
ExpiresAt: now.Add(time.Minute * 5),
|
|
Nonce: nonce,
|
|
}
|
|
|
|
for _, option := range options {
|
|
option(&state)
|
|
}
|
|
|
|
stateJson, err := json.Marshal(state)
|
|
|
|
aes, err := aes_gcm.NewAESGCM(config.Secrets.Keys)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not instantiate aesgcm: %w", err)
|
|
}
|
|
|
|
encryptedState, err := aes.Encrypt(stateJson)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not encrypt state: %w", err)
|
|
}
|
|
|
|
return []byte(encryptedState), nil
|
|
}
|
|
|
|
type State struct {
|
|
Provider string `json:"provider"`
|
|
RedirectTo string `json:"redirect_to"`
|
|
IssuedAt time.Time `json:"issued_at"`
|
|
ExpiresAt time.Time `json:"expires_at"`
|
|
Nonce string `json:"nonce"`
|
|
IsFlow bool `json:"is_flow"`
|
|
CodeVerifier string `json:"code_verifier,omitempty"`
|
|
}
|
|
|
|
func VerifyState(config *config.Config, state string, expectedState string) (*State, error) {
|
|
decodedState, err := decodeState(config, state)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not decode state: %w", err)
|
|
}
|
|
|
|
if decodedState.CodeVerifier == "" {
|
|
if expectedState == "" {
|
|
return nil, errors.New("expected state must not be empty")
|
|
}
|
|
decodedExpectedState, err := decodeState(config, expectedState)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not decode expectedState: %w", err)
|
|
}
|
|
|
|
if decodedState.Nonce != decodedExpectedState.Nonce {
|
|
return nil, errors.New("could not verify state")
|
|
}
|
|
}
|
|
|
|
if time.Now().UTC().After(decodedState.ExpiresAt) {
|
|
return nil, errors.New("state is expired")
|
|
}
|
|
|
|
return decodedState, nil
|
|
}
|
|
|
|
func decodeState(config *config.Config, state string) (*State, error) {
|
|
aes, err := aes_gcm.NewAESGCM(config.Secrets.Keys)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not instantiate aesgcm: %w", err)
|
|
}
|
|
|
|
decryptedState, err := aes.Decrypt(state)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not decrypt state: %w", err)
|
|
}
|
|
|
|
var unmarshalledState State
|
|
err = json.Unmarshal(decryptedState, &unmarshalledState)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not unmarshal state: %w", err)
|
|
}
|
|
return &unmarshalledState, nil
|
|
}
|