Files
hanko/backend/handler/thirdparty_test.go
2023-03-02 10:11:05 +01:00

207 lines
7.6 KiB
Go

package handler
import (
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/teamhanko/hanko/backend/config"
"github.com/teamhanko/hanko/backend/dto"
"github.com/teamhanko/hanko/backend/session"
"github.com/teamhanko/hanko/backend/test"
"github.com/teamhanko/hanko/backend/thirdparty"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestThirdPartyHandler_Auth(t *testing.T) {
tests := []struct {
name string
referer string
enabledProviders []string
allowedRedirectURLs []string
requestedProvider string
requestedRedirectTo string
expectedBaseURL string
expectedError string
expectedErrorDescription string // can be a partial message
}{
{
name: "successful redirect to google",
referer: "https://login.test.example",
enabledProviders: []string{"google"},
allowedRedirectURLs: []string{"https://*.test.example"},
requestedProvider: "google",
requestedRedirectTo: "https://app.test.example",
expectedBaseURL: "https://" + thirdparty.GoogleAuthBase + thirdparty.GoogleOauthAuthEndpoint,
},
{
name: "successful redirect to github",
referer: "https://login.test.example",
enabledProviders: []string{"github"},
allowedRedirectURLs: []string{"https://*.test.example"},
requestedProvider: "github",
requestedRedirectTo: "https://app.test.example",
expectedBaseURL: "https://" + thirdparty.GithubAuthBase + thirdparty.GithubOauthAuthEndpoint,
},
{
name: "error redirect on missing provider",
referer: "https://login.test.example",
requestedRedirectTo: "https://app.test.example",
expectedBaseURL: "https://login.test.example",
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
expectedErrorDescription: "is a required field",
},
{
name: "error redirect on missing redirectTo",
referer: "https://login.test.example",
requestedProvider: "google",
expectedBaseURL: "https://login.test.example",
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
expectedErrorDescription: "is a required field",
},
{
name: "error redirect when requested provider is disabled",
referer: "https://login.test.example",
enabledProviders: []string{"github"},
allowedRedirectURLs: []string{"https://*.test.example"},
requestedProvider: "google",
requestedRedirectTo: "https://app.test.example",
expectedBaseURL: "https://login.test.example",
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
expectedErrorDescription: "provider is disabled",
},
{
name: "error redirect when requesting an unknown provider",
referer: "https://login.test.example",
allowedRedirectURLs: []string{"https://*.test.example"},
requestedProvider: "unknownProvider",
requestedRedirectTo: "https://app.test.example",
expectedBaseURL: "https://login.test.example",
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
expectedErrorDescription: "is not supported",
},
{
name: "error redirect when requesting a redirectTo that is not allowed",
referer: "https://login.test.example",
enabledProviders: []string{"google"},
allowedRedirectURLs: []string{"https://*.test.example"},
requestedProvider: "google",
requestedRedirectTo: "https://app.test.wrong",
expectedBaseURL: "https://login.test.example",
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
expectedErrorDescription: "redirect to 'https://app.test.wrong' not allowed",
},
{
name: "error redirect with redirect to error redirect url if referer not present",
allowedRedirectURLs: []string{"https://*.test.example"},
requestedProvider: "unknownProvider",
requestedRedirectTo: "https://app.test.example",
expectedBaseURL: "https://error.test.example",
expectedError: thirdparty.ThirdPartyErrorCodeInvalidRequest,
expectedErrorDescription: "is not supported",
},
}
for _, testData := range tests {
t.Run(testData.name, func(t *testing.T) {
cfg := setUpConfig(t, testData.enabledProviders, testData.allowedRedirectURLs)
e := echo.New()
e.Validator = dto.NewCustomValidator()
req := httptest.NewRequest(http.MethodGet, "/thirdparty/auth", nil)
params := url.Values{}
if testData.requestedProvider != "" {
params.Add("provider", testData.requestedProvider)
}
if testData.requestedRedirectTo != "" {
params.Add("redirect_to", testData.requestedRedirectTo)
}
req.URL.RawQuery = params.Encode()
req.Header.Set("Referer", testData.referer)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
p := test.NewPersister(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwkManager := test.JwkManager{}
sessionMgr, err := session.NewManager(jwkManager, cfg.Session)
require.NoError(t, err)
handler := NewThirdPartyHandler(cfg, p, sessionMgr, test.NewAuditLogger(), jwkManager)
err = handler.Auth(c)
require.NoError(t, err)
assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
u, err := url.Parse(rec.Header().Get("Location"))
assert.NoError(t, err, "redirect url parse failed")
assert.Equal(t, testData.expectedBaseURL, u.Scheme+"://"+u.Host+u.Path)
q := u.Query()
if testData.expectedError != "" {
assert.Equal(t, testData.expectedError, q.Get("error"))
errorDescription := q.Get("error_description")
isCorrectErrorDescription := strings.Contains(errorDescription, testData.expectedErrorDescription)
assert.Truef(t, isCorrectErrorDescription, "error description '%s' does not contain '%s'", errorDescription, testData.expectedErrorDescription)
} else {
assert.Equal(t, cfg.ThirdParty.RedirectURL, q.Get("redirect_uri"))
assert.Equal(t, cfg.ThirdParty.Providers.Get(testData.requestedProvider).ClientID, q.Get("client_id"))
assert.Equal(t, "code", q.Get("response_type"))
state, err := thirdparty.VerifyState(sessionMgr, q.Get("state"))
require.NoError(t, err)
assert.Equal(t, strings.ToLower(testData.requestedProvider), state.Provider)
if testData.requestedRedirectTo == "" {
assert.Equal(t, cfg.ThirdParty.ErrorRedirectURL, state.RedirectTo)
} else {
assert.Equal(t, testData.requestedRedirectTo, state.RedirectTo)
}
}
})
}
}
func setUpConfig(t *testing.T, enabledProviders []string, allowedRedirectURLs []string) *config.Config {
cfg := &config.Config{ThirdParty: config.ThirdParty{
Providers: config.ThirdPartyProviders{
Google: config.ThirdPartyProvider{
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
}, GitHub: config.ThirdPartyProvider{
Enabled: false,
ClientID: "fakeClientID",
Secret: "fakeClientSecret",
}},
ErrorRedirectURL: "https://error.test.example",
RedirectURL: "https://api.test.example/callback",
AllowedRedirectURLS: allowedRedirectURLs,
}}
for _, provider := range enabledProviders {
switch provider {
case "google":
cfg.ThirdParty.Providers.Google.Enabled = true
case "github":
cfg.ThirdParty.Providers.GitHub.Enabled = true
}
}
err := cfg.PostProcess()
require.NoError(t, err)
return cfg
}