Plugins: Use plugins SDK constants for OAuth header names (#90892)

use SDK consts
This commit is contained in:
Will Browne
2024-08-20 13:29:41 +01:00
committed by GitHub
parent e788df921c
commit d35e9264bb
2 changed files with 14 additions and 19 deletions

View File

@ -27,11 +27,6 @@ func NewOAuthTokenMiddleware(oAuthTokenService oauthtoken.OAuthTokenService) plu
}) })
} }
const (
tokenHeaderName = "Authorization"
idTokenHeaderName = "X-ID-Token"
)
type OAuthTokenMiddleware struct { type OAuthTokenMiddleware struct {
baseMiddleware baseMiddleware
oAuthTokenService oauthtoken.OAuthTokenService oAuthTokenService oauthtoken.OAuthTokenService
@ -69,19 +64,19 @@ func (m *OAuthTokenMiddleware) applyToken(ctx context.Context, pCtx backend.Plug
switch t := req.(type) { switch t := req.(type) {
case *backend.QueryDataRequest: case *backend.QueryDataRequest:
t.Headers[tokenHeaderName] = authorizationHeader t.Headers[backend.OAuthIdentityTokenHeaderName] = authorizationHeader
if idTokenHeader != "" { if idTokenHeader != "" {
t.Headers[idTokenHeaderName] = idTokenHeader t.Headers[backend.OAuthIdentityIDTokenHeaderName] = idTokenHeader
} }
case *backend.CheckHealthRequest: case *backend.CheckHealthRequest:
t.Headers[tokenHeaderName] = authorizationHeader t.Headers[backend.OAuthIdentityTokenHeaderName] = authorizationHeader
if idTokenHeader != "" { if idTokenHeader != "" {
t.Headers[idTokenHeaderName] = idTokenHeader t.Headers[backend.OAuthIdentityIDTokenHeaderName] = idTokenHeader
} }
case *backend.CallResourceRequest: case *backend.CallResourceRequest:
t.Headers[tokenHeaderName] = []string{authorizationHeader} t.Headers[backend.OAuthIdentityTokenHeaderName] = []string{authorizationHeader}
if idTokenHeader != "" { if idTokenHeader != "" {
t.Headers[idTokenHeaderName] = []string{idTokenHeader} t.Headers[backend.OAuthIdentityIDTokenHeaderName] = []string{idTokenHeader}
} }
} }
} }

View File

@ -112,8 +112,8 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.NotNil(t, cdt.QueryDataReq) require.NotNil(t, cdt.QueryDataReq)
require.Len(t, cdt.QueryDataReq.Headers, 3) require.Len(t, cdt.QueryDataReq.Headers, 3)
require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader]) require.Equal(t, "test", cdt.QueryDataReq.Headers[otherHeader])
require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[tokenHeaderName]) require.Equal(t, "Bearer access-token", cdt.QueryDataReq.Headers[backend.OAuthIdentityTokenHeaderName])
require.Equal(t, "id-token", cdt.QueryDataReq.Headers[idTokenHeaderName]) require.Equal(t, "id-token", cdt.QueryDataReq.Headers[backend.OAuthIdentityIDTokenHeaderName])
}) })
t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) { t.Run("Should forward OAuth Identity when calling CallResource", func(t *testing.T) {
@ -125,10 +125,10 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.NotNil(t, cdt.CallResourceReq) require.NotNil(t, cdt.CallResourceReq)
require.Len(t, cdt.CallResourceReq.Headers, 3) require.Len(t, cdt.CallResourceReq.Headers, 3)
require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0]) require.Equal(t, "test", cdt.CallResourceReq.Headers[otherHeader][0])
require.Len(t, cdt.CallResourceReq.Headers[tokenHeaderName], 1) require.Len(t, cdt.CallResourceReq.Headers[backend.OAuthIdentityTokenHeaderName], 1)
require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[tokenHeaderName][0]) require.Equal(t, "Bearer access-token", cdt.CallResourceReq.Headers[backend.OAuthIdentityTokenHeaderName][0])
require.Len(t, cdt.CallResourceReq.Headers[idTokenHeaderName], 1) require.Len(t, cdt.CallResourceReq.Headers[backend.OAuthIdentityIDTokenHeaderName], 1)
require.Equal(t, "id-token", cdt.CallResourceReq.Headers[idTokenHeaderName][0]) require.Equal(t, "id-token", cdt.CallResourceReq.Headers[backend.OAuthIdentityIDTokenHeaderName][0])
}) })
t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) { t.Run("Should forward OAuth Identity when calling CheckHealth", func(t *testing.T) {
@ -140,8 +140,8 @@ func TestOAuthTokenMiddleware(t *testing.T) {
require.NotNil(t, cdt.CheckHealthReq) require.NotNil(t, cdt.CheckHealthReq)
require.Len(t, cdt.CheckHealthReq.Headers, 3) require.Len(t, cdt.CheckHealthReq.Headers, 3)
require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader]) require.Equal(t, "test", cdt.CheckHealthReq.Headers[otherHeader])
require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[tokenHeaderName]) require.Equal(t, "Bearer access-token", cdt.CheckHealthReq.Headers[backend.OAuthIdentityTokenHeaderName])
require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[idTokenHeaderName]) require.Equal(t, "id-token", cdt.CheckHealthReq.Headers[backend.OAuthIdentityIDTokenHeaderName])
}) })
}) })
} }