mirror of
https://github.com/grafana/grafana.git
synced 2026-03-13 15:29:48 +08:00
Auth: Fix token refresh when using Entra ID OAuth with workload_identity (federated credentials) (#114172)
This commit is contained in:
@@ -8,16 +8,17 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/apimachinery/identity"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/ssosettings/validation"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -215,14 +217,119 @@ func (s *SocialAzureAD) Exchange(ctx context.Context, code string, authOptions .
|
||||
return s.Config.Exchange(ctx, code, authOptions...)
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
|
||||
s.reloadMutex.RLock()
|
||||
defer s.reloadMutex.RUnlock()
|
||||
|
||||
if s.info.ClientAuthentication == social.WorkloadIdentity {
|
||||
return &azureADTokenSource{
|
||||
log: s.log,
|
||||
ctx: ctx,
|
||||
conf: s.Config,
|
||||
token: t,
|
||||
clientId: s.info.ClientId,
|
||||
workloadIdentityTokenFile: s.info.WorkloadIdentityTokenFile,
|
||||
}
|
||||
}
|
||||
|
||||
return s.Config.TokenSource(ctx, t)
|
||||
}
|
||||
|
||||
type azureADTokenSource struct {
|
||||
log log.Logger
|
||||
ctx context.Context
|
||||
conf *oauth2.Config
|
||||
token *oauth2.Token
|
||||
clientId string
|
||||
workloadIdentityTokenFile string
|
||||
}
|
||||
|
||||
func (s *azureADTokenSource) Token() (*oauth2.Token, error) {
|
||||
s.log.Debug("Fetching Token with AzureAD Token Source and Workload Identity")
|
||||
if s.token.Valid() {
|
||||
return s.token, nil
|
||||
}
|
||||
|
||||
if s.token.RefreshToken == "" {
|
||||
s.log.Warn("AzureADToken fetchToken failed: no refresh token available")
|
||||
return nil, fmt.Errorf("no refresh token available to refresh the access token")
|
||||
}
|
||||
|
||||
// refresh the expired token using the refresh token
|
||||
federatedToken, err := os.ReadFile(s.workloadIdentityTokenFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read workload identity token file: %w", err)
|
||||
}
|
||||
|
||||
v := url.Values{}
|
||||
v.Set("client_id", s.clientId)
|
||||
v.Set("grant_type", "refresh_token")
|
||||
v.Set("refresh_token", s.token.RefreshToken)
|
||||
v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
|
||||
v.Set("client_assertion", strings.TrimSpace(string(federatedToken)))
|
||||
|
||||
return s.fetchToken(v)
|
||||
}
|
||||
|
||||
func (s *azureADTokenSource) fetchToken(params url.Values) (*oauth2.Token, error) {
|
||||
req, err := http.NewRequestWithContext(s.ctx, "POST", s.conf.Endpoint.TokenURL, strings.NewReader(params.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
// Correct way to get HTTP client from context
|
||||
httpClient, ok := s.ctx.Value(oauth2.HTTPClient).(*http.Client)
|
||||
if !ok {
|
||||
httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
s.log.Error("Failed to close response body", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
s.log.Debug("oauth2: cannot fetch token", "status", resp.Status, "body", body)
|
||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", resp.Status)
|
||||
}
|
||||
|
||||
var rawResponse interface{}
|
||||
if err := json.Unmarshal(body, &rawResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal raw response body: %w", err)
|
||||
}
|
||||
|
||||
var token *oauth2.Token
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal token response body: %w", err)
|
||||
}
|
||||
|
||||
if token.ExpiresIn > 0 {
|
||||
token.Expiry = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
|
||||
s.log.Debug("AzureADToken fetchToken completed", "expiry", token.Expiry)
|
||||
return token.WithExtra(rawResponse), nil
|
||||
}
|
||||
|
||||
// ManagedIdentityCallback retrieves a token using the managed identity credential of the Azure service.
|
||||
func (s *SocialAzureAD) managedIdentityCallback(ctx context.Context) (string, error) {
|
||||
// Validate required fields for Managed Identity authentication
|
||||
if s.info.ManagedIdentityClientID == "" {
|
||||
return "", fmt.Errorf("ManagedIdentityClientID is required for Managed Identity authentication")
|
||||
return "", fmt.Errorf("ManagedIdentityClientID is required for Managed Identity or Workload Identity authentication")
|
||||
}
|
||||
if s.info.FederatedCredentialAudience == "" {
|
||||
return "", fmt.Errorf("FederatedCredentialAudience is required for Managed Identity authentication")
|
||||
return "", fmt.Errorf("FederatedCredentialAudience is required for Managed Identity or Workload Identity authentication")
|
||||
}
|
||||
|
||||
// Prepare Managed Identity Credential
|
||||
@@ -230,7 +337,7 @@ func (s *SocialAzureAD) managedIdentityCallback(ctx context.Context) (string, er
|
||||
ID: azidentity.ClientID(s.info.ManagedIdentityClientID),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error constructing managed identity credential: %w", err)
|
||||
return "", fmt.Errorf("error constructing managed/workload identity credential: %w", err)
|
||||
}
|
||||
|
||||
// Request token and return
|
||||
@@ -238,7 +345,7 @@ func (s *SocialAzureAD) managedIdentityCallback(ctx context.Context) (string, er
|
||||
Scopes: []string{fmt.Sprintf("%s/.default", s.info.FederatedCredentialAudience)},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting managed identity token: %w", err)
|
||||
return "", fmt.Errorf("error getting managed/workload identity token: %w", err)
|
||||
}
|
||||
|
||||
return tk.Token, nil
|
||||
|
||||
@@ -7,12 +7,15 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/go-jose/go-jose/v4/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
@@ -1449,3 +1452,164 @@ func TestSocialAzureAD_Reload_ExtraFields(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAzureAD_TokenSource_WorkloadIdentity(t *testing.T) {
|
||||
info := &social.OAuthInfo{
|
||||
ClientId: "some-client-id",
|
||||
ClientAuthentication: social.WorkloadIdentity,
|
||||
FederatedCredentialAudience: "api://AzureADTokenExchange",
|
||||
TokenUrl: "https://login.microsoftonline.com/token",
|
||||
}
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
workloadFile := path.Join(t.TempDir(), "workload.json")
|
||||
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
|
||||
s.info.WorkloadIdentityTokenFile = workloadFile
|
||||
|
||||
// Mock the token endpoint
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// Verify that client_assertion is present in the request
|
||||
if r.FormValue("client_assertion") != "mock-client-assertion" {
|
||||
t.Errorf("expected client_assertion to be 'mock-client-assertion', got '%s'", r.FormValue("client_assertion"))
|
||||
}
|
||||
if r.FormValue("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
|
||||
t.Errorf("expected client_assertion_type to be 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', got '%s'", r.FormValue("client_assertion_type"))
|
||||
}
|
||||
if r.FormValue("grant_type") != "refresh_token" {
|
||||
t.Errorf("expected grant_type to be 'refresh_token', got '%s'", r.FormValue("grant_type"))
|
||||
}
|
||||
if r.FormValue("client_id") != "some-client-id" {
|
||||
t.Errorf("expected client_id to be 'client-id', got '%s'", r.FormValue("client_id"))
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Update TokenURL to point to the mock server
|
||||
s.Endpoint.TokenURL = server.URL
|
||||
|
||||
// Create a token source with an expired token
|
||||
now := time.Now()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "old-access-token",
|
||||
RefreshToken: "old-refresh-token",
|
||||
Expiry: now.Add(-time.Hour),
|
||||
}
|
||||
|
||||
// Create a context with the mock client
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
|
||||
|
||||
// Get a new token (this should trigger a refresh)
|
||||
ts := s.TokenSource(ctx, token)
|
||||
newToken, err := ts.Token()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new-access-token", newToken.AccessToken)
|
||||
assert.Equal(t, "new-refresh-token", newToken.RefreshToken)
|
||||
assert.WithinDuration(t, now.Add(time.Hour), newToken.Expiry, time.Second)
|
||||
})
|
||||
|
||||
t.Run("error when workload token file does not exist", func(t *testing.T) {
|
||||
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
|
||||
s.info.WorkloadIdentityTokenFile = "/non/existent/file"
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "old-access-token",
|
||||
RefreshToken: "old-refresh-token",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
|
||||
ts := s.TokenSource(context.Background(), token)
|
||||
_, err := ts.Token()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to read workload identity token file")
|
||||
})
|
||||
|
||||
t.Run("error when token endpoint returns error", func(t *testing.T) {
|
||||
workloadFile := path.Join(t.TempDir(), "workload.json")
|
||||
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
|
||||
s.info.WorkloadIdentityTokenFile = workloadFile
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
s.Endpoint.TokenURL = server.URL
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "old-access-token",
|
||||
RefreshToken: "old-refresh-token",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
|
||||
|
||||
ts := s.TokenSource(ctx, token)
|
||||
_, err = ts.Token()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "oauth2: cannot fetch token: 500 Internal Server Error")
|
||||
})
|
||||
|
||||
t.Run("error when missing refresh token", func(t *testing.T) {
|
||||
workloadFile := path.Join(t.TempDir(), "workload.json")
|
||||
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
|
||||
s.info.WorkloadIdentityTokenFile = workloadFile
|
||||
|
||||
// No RefreshToken
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "old-access-token",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
|
||||
ts := s.TokenSource(context.Background(), token)
|
||||
_, err = ts.Token()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no refresh token available to refresh the access token")
|
||||
})
|
||||
|
||||
t.Run("error when invalid token response", func(t *testing.T) {
|
||||
workloadFile := path.Join(t.TempDir(), "workload.json")
|
||||
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
|
||||
s.info.WorkloadIdentityTokenFile = workloadFile
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("invalid-json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
s.Endpoint.TokenURL = server.URL
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "old-access-token",
|
||||
RefreshToken: "old-refresh-token",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
|
||||
|
||||
ts := s.TokenSource(ctx, token)
|
||||
_, err = ts.Token()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "unable to unmarshal raw response body")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user