package thirdparty import ( "encoding/json" "errors" "fmt" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/crypto" "github.com/teamhanko/hanko/backend/crypto/aes_gcm" "time" ) func GenerateStateForFlowAPI(isFlow bool) func(*State) { return func(state *State) { state.IsFlow = isFlow } } func GenerateState(config *config.Config, provider string, redirectTo string, options ...func(*State)) ([]byte, error) { if provider == "" { return nil, errors.New("provider must be present") } if redirectTo == "" { redirectTo = config.ThirdParty.ErrorRedirectURL } nonce, err := crypto.GenerateRandomStringURLSafe(32) if err != nil { return nil, fmt.Errorf("could not generate nonce: %w", err) } now := time.Now().UTC() state := State{ Provider: provider, RedirectTo: redirectTo, IssuedAt: now, ExpiresAt: now.Add(time.Minute * 5), Nonce: nonce, } for _, option := range options { option(&state) } stateJson, err := json.Marshal(state) aes, err := aes_gcm.NewAESGCM(config.Secrets.Keys) if err != nil { return nil, fmt.Errorf("could not instantiate aesgcm: %w", err) } encryptedState, err := aes.Encrypt(stateJson) if err != nil { return nil, fmt.Errorf("could not encrypt state: %w", err) } return []byte(encryptedState), nil } type State struct { Provider string `json:"provider"` RedirectTo string `json:"redirect_to"` IssuedAt time.Time `json:"issued_at"` ExpiresAt time.Time `json:"expires_at"` Nonce string `json:"nonce"` IsFlow bool `json:"is_flow"` } func VerifyState(config *config.Config, state string, expectedState string) (*State, error) { decodedState, err := decodeState(config, state) if err != nil { return nil, fmt.Errorf("could not decode state: %w", err) } decodedExpectedState, err := decodeState(config, expectedState) if err != nil { return nil, fmt.Errorf("could not decode expectedState: %w", err) } if decodedState.Nonce != decodedExpectedState.Nonce { return nil, errors.New("could not verify state") } if time.Now().UTC().After(decodedState.ExpiresAt) { return nil, errors.New("state is expired") } return decodedState, nil } func decodeState(config *config.Config, state string) (*State, error) { aes, err := aes_gcm.NewAESGCM(config.Secrets.Keys) if err != nil { return nil, fmt.Errorf("could not instantiate aesgcm: %w", err) } decryptedState, err := aes.Decrypt(state) if err != nil { return nil, fmt.Errorf("could not decrypt state: %w", err) } var unmarshalledState State err = json.Unmarshal(decryptedState, &unmarshalledState) if err != nil { return nil, fmt.Errorf("could not unmarshal state: %w", err) } return &unmarshalledState, nil }