mirror of
https://github.com/teamhanko/hanko.git
synced 2025-10-27 06:06:54 +08:00
207 lines
7.6 KiB
Go
207 lines
7.6 KiB
Go
package handler
|
|
|
|
import (
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/teamhanko/hanko/backend/config"
|
|
"github.com/teamhanko/hanko/backend/dto"
|
|
"github.com/teamhanko/hanko/backend/session"
|
|
"github.com/teamhanko/hanko/backend/test"
|
|
"github.com/teamhanko/hanko/backend/thirdparty"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestThirdPartyHandler_Auth(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
name string
|
|
referer string
|
|
enabledProviders []string
|
|
allowedRedirectURLs []string
|
|
requestedProvider string
|
|
requestedRedirectTo string
|
|
expectedBaseURL string
|
|
expectedError string
|
|
expectedErrorDescription string // can be a partial message
|
|
}{
|
|
{
|
|
name: "successful redirect to google",
|
|
referer: "https://login.test.example",
|
|
enabledProviders: []string{"google"},
|
|
allowedRedirectURLs: []string{"https://*.test.example"},
|
|
requestedProvider: "google",
|
|
requestedRedirectTo: "https://app.test.example",
|
|
expectedBaseURL: "https://" + thirdparty.GoogleAuthBase + thirdparty.GoogleOauthAuthEndpoint,
|
|
},
|
|
{
|
|
name: "successful redirect to github",
|
|
referer: "https://login.test.example",
|
|
enabledProviders: []string{"github"},
|
|
allowedRedirectURLs: []string{"https://*.test.example"},
|
|
requestedProvider: "github",
|
|
requestedRedirectTo: "https://app.test.example",
|
|
expectedBaseURL: "https://" + thirdparty.GithubAuthBase + thirdparty.GithubOauthAuthEndpoint,
|
|
},
|
|
{
|
|
name: "error redirect on missing provider",
|
|
referer: "https://login.test.example",
|
|
requestedRedirectTo: "https://app.test.example",
|
|
expectedBaseURL: "https://login.test.example",
|
|
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
|
|
expectedErrorDescription: "is a required field",
|
|
},
|
|
{
|
|
name: "error redirect on missing redirectTo",
|
|
referer: "https://login.test.example",
|
|
requestedProvider: "google",
|
|
expectedBaseURL: "https://login.test.example",
|
|
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
|
|
expectedErrorDescription: "is a required field",
|
|
},
|
|
{
|
|
name: "error redirect when requested provider is disabled",
|
|
referer: "https://login.test.example",
|
|
enabledProviders: []string{"github"},
|
|
allowedRedirectURLs: []string{"https://*.test.example"},
|
|
requestedProvider: "google",
|
|
requestedRedirectTo: "https://app.test.example",
|
|
expectedBaseURL: "https://login.test.example",
|
|
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
|
|
expectedErrorDescription: "provider is disabled",
|
|
},
|
|
{
|
|
name: "error redirect when requesting an unknown provider",
|
|
referer: "https://login.test.example",
|
|
allowedRedirectURLs: []string{"https://*.test.example"},
|
|
requestedProvider: "unknownProvider",
|
|
requestedRedirectTo: "https://app.test.example",
|
|
expectedBaseURL: "https://login.test.example",
|
|
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
|
|
expectedErrorDescription: "is not supported",
|
|
},
|
|
{
|
|
name: "error redirect when requesting a redirectTo that is not allowed",
|
|
referer: "https://login.test.example",
|
|
enabledProviders: []string{"google"},
|
|
allowedRedirectURLs: []string{"https://*.test.example"},
|
|
requestedProvider: "google",
|
|
requestedRedirectTo: "https://app.test.wrong",
|
|
expectedBaseURL: "https://login.test.example",
|
|
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
|
|
expectedErrorDescription: "redirect to 'https://app.test.wrong' not allowed",
|
|
},
|
|
{
|
|
name: "error redirect with redirect to error redirect url if referer not present",
|
|
allowedRedirectURLs: []string{"https://*.test.example"},
|
|
requestedProvider: "unknownProvider",
|
|
requestedRedirectTo: "https://app.test.example",
|
|
expectedBaseURL: "https://error.test.example",
|
|
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
|
|
expectedErrorDescription: "is not supported",
|
|
},
|
|
}
|
|
|
|
for _, testData := range tests {
|
|
t.Run(testData.name, func(t *testing.T) {
|
|
cfg := setUpConfig(t, testData.enabledProviders, testData.allowedRedirectURLs)
|
|
e := echo.New()
|
|
e.Validator = dto.NewCustomValidator()
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/thirdparty/auth", nil)
|
|
|
|
params := url.Values{}
|
|
if testData.requestedProvider != "" {
|
|
params.Add("provider", testData.requestedProvider)
|
|
}
|
|
if testData.requestedRedirectTo != "" {
|
|
params.Add("redirect_to", testData.requestedRedirectTo)
|
|
}
|
|
req.URL.RawQuery = params.Encode()
|
|
|
|
req.Header.Set("Referer", testData.referer)
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
c := e.NewContext(req, rec)
|
|
p := test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
|
|
|
jwkManager := test.JwkManager{}
|
|
sessionMgr, err := session.NewManager(jwkManager, cfg.Session)
|
|
require.NoError(t, err)
|
|
|
|
handler := NewThirdPartyHandler(cfg, p, sessionMgr, test.NewAuditLogger(), jwkManager)
|
|
|
|
err = handler.Auth(c)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
|
|
|
|
u, err := url.Parse(rec.Header().Get("Location"))
|
|
assert.NoError(t, err, "redirect url parse failed")
|
|
|
|
assert.Equal(t, testData.expectedBaseURL, u.Scheme+"://"+u.Host+u.Path)
|
|
|
|
q := u.Query()
|
|
|
|
if testData.expectedError != "" {
|
|
assert.Equal(t, testData.expectedError, q.Get("error"))
|
|
errorDescription := q.Get("error_description")
|
|
isCorrectErrorDescription := strings.Contains(errorDescription, testData.expectedErrorDescription)
|
|
assert.Truef(t, isCorrectErrorDescription, "error description '%s' does not contain '%s'", errorDescription, testData.expectedErrorDescription)
|
|
} else {
|
|
assert.Equal(t, cfg.ThirdParty.RedirectURL, q.Get("redirect_uri"))
|
|
assert.Equal(t, cfg.ThirdParty.Providers.Get(testData.requestedProvider).ClientID, q.Get("client_id"))
|
|
assert.Equal(t, "code", q.Get("response_type"))
|
|
|
|
state, err := thirdparty.VerifyState(sessionMgr, q.Get("state"))
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, strings.ToLower(testData.requestedProvider), state.Provider)
|
|
|
|
if testData.requestedRedirectTo == "" {
|
|
assert.Equal(t, cfg.ThirdParty.ErrorRedirectURL, state.RedirectTo)
|
|
} else {
|
|
assert.Equal(t, testData.requestedRedirectTo, state.RedirectTo)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setUpConfig(t *testing.T, enabledProviders []string, allowedRedirectURLs []string) *config.Config {
|
|
cfg := &config.Config{ThirdParty: config.ThirdParty{
|
|
Providers: config.ThirdPartyProviders{
|
|
Google: config.ThirdPartyProvider{
|
|
Enabled: false,
|
|
ClientID: "fakeClientID",
|
|
Secret: "fakeClientSecret",
|
|
}, GitHub: config.ThirdPartyProvider{
|
|
Enabled: false,
|
|
ClientID: "fakeClientID",
|
|
Secret: "fakeClientSecret",
|
|
}},
|
|
ErrorRedirectURL: "https://error.test.example",
|
|
RedirectURL: "https://api.test.example/callback",
|
|
AllowedRedirectURLS: allowedRedirectURLs,
|
|
}}
|
|
|
|
for _, provider := range enabledProviders {
|
|
switch provider {
|
|
case "google":
|
|
cfg.ThirdParty.Providers.Google.Enabled = true
|
|
case "github":
|
|
cfg.ThirdParty.Providers.GitHub.Enabled = true
|
|
}
|
|
}
|
|
|
|
err := cfg.PostProcess()
|
|
require.NoError(t, err)
|
|
|
|
return cfg
|
|
}
|