Auth: Fix token refresh when using Entra ID OAuth with workload_identity (federated credentials) (#114172)

This commit is contained in:
Richard Hagen
2026-02-23 14:38:58 +01:00
committed by GitHub
parent dbe3a18a6b
commit 3b71ed6c58
2 changed files with 277 additions and 6 deletions

View File

@@ -8,16 +8,17 @@ import (
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
jose "github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/featuremgmt"
@@ -27,6 +28,7 @@ import (
"github.com/grafana/grafana/pkg/services/ssosettings/validation"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"golang.org/x/oauth2"
)
const (
@@ -215,14 +217,119 @@ func (s *SocialAzureAD) Exchange(ctx context.Context, code string, authOptions .
return s.Config.Exchange(ctx, code, authOptions...)
}
func (s *SocialAzureAD) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
s.reloadMutex.RLock()
defer s.reloadMutex.RUnlock()
if s.info.ClientAuthentication == social.WorkloadIdentity {
return &azureADTokenSource{
log: s.log,
ctx: ctx,
conf: s.Config,
token: t,
clientId: s.info.ClientId,
workloadIdentityTokenFile: s.info.WorkloadIdentityTokenFile,
}
}
return s.Config.TokenSource(ctx, t)
}
type azureADTokenSource struct {
log log.Logger
ctx context.Context
conf *oauth2.Config
token *oauth2.Token
clientId string
workloadIdentityTokenFile string
}
func (s *azureADTokenSource) Token() (*oauth2.Token, error) {
s.log.Debug("Fetching Token with AzureAD Token Source and Workload Identity")
if s.token.Valid() {
return s.token, nil
}
if s.token.RefreshToken == "" {
s.log.Warn("AzureADToken fetchToken failed: no refresh token available")
return nil, fmt.Errorf("no refresh token available to refresh the access token")
}
// refresh the expired token using the refresh token
federatedToken, err := os.ReadFile(s.workloadIdentityTokenFile)
if err != nil {
return nil, fmt.Errorf("failed to read workload identity token file: %w", err)
}
v := url.Values{}
v.Set("client_id", s.clientId)
v.Set("grant_type", "refresh_token")
v.Set("refresh_token", s.token.RefreshToken)
v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
v.Set("client_assertion", strings.TrimSpace(string(federatedToken)))
return s.fetchToken(v)
}
func (s *azureADTokenSource) fetchToken(params url.Values) (*oauth2.Token, error) {
req, err := http.NewRequestWithContext(s.ctx, "POST", s.conf.Endpoint.TokenURL, strings.NewReader(params.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Correct way to get HTTP client from context
httpClient, ok := s.ctx.Value(oauth2.HTTPClient).(*http.Client)
if !ok {
httpClient = http.DefaultClient
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.log.Error("Failed to close response body", "error", err)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
s.log.Debug("oauth2: cannot fetch token", "status", resp.Status, "body", body)
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", resp.Status)
}
var rawResponse interface{}
if err := json.Unmarshal(body, &rawResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal raw response body: %w", err)
}
var token *oauth2.Token
if err := json.Unmarshal(body, &token); err != nil {
return nil, fmt.Errorf("unable to unmarshal token response body: %w", err)
}
if token.ExpiresIn > 0 {
token.Expiry = time.Now().Add(time.Duration(token.ExpiresIn) * time.Second)
}
s.log.Debug("AzureADToken fetchToken completed", "expiry", token.Expiry)
return token.WithExtra(rawResponse), nil
}
// ManagedIdentityCallback retrieves a token using the managed identity credential of the Azure service.
func (s *SocialAzureAD) managedIdentityCallback(ctx context.Context) (string, error) {
// Validate required fields for Managed Identity authentication
if s.info.ManagedIdentityClientID == "" {
return "", fmt.Errorf("ManagedIdentityClientID is required for Managed Identity authentication")
return "", fmt.Errorf("ManagedIdentityClientID is required for Managed Identity or Workload Identity authentication")
}
if s.info.FederatedCredentialAudience == "" {
return "", fmt.Errorf("FederatedCredentialAudience is required for Managed Identity authentication")
return "", fmt.Errorf("FederatedCredentialAudience is required for Managed Identity or Workload Identity authentication")
}
// Prepare Managed Identity Credential
@@ -230,7 +337,7 @@ func (s *SocialAzureAD) managedIdentityCallback(ctx context.Context) (string, er
ID: azidentity.ClientID(s.info.ManagedIdentityClientID),
})
if err != nil {
return "", fmt.Errorf("error constructing managed identity credential: %w", err)
return "", fmt.Errorf("error constructing managed/workload identity credential: %w", err)
}
// Request token and return
@@ -238,7 +345,7 @@ func (s *SocialAzureAD) managedIdentityCallback(ctx context.Context) (string, er
Scopes: []string{fmt.Sprintf("%s/.default", s.info.FederatedCredentialAudience)},
})
if err != nil {
return "", fmt.Errorf("error getting managed identity token: %w", err)
return "", fmt.Errorf("error getting managed/workload identity token: %w", err)
}
return tk.Token, nil

View File

@@ -7,12 +7,15 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path"
"strings"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
@@ -1449,3 +1452,164 @@ func TestSocialAzureAD_Reload_ExtraFields(t *testing.T) {
})
}
}
func TestSocialAzureAD_TokenSource_WorkloadIdentity(t *testing.T) {
info := &social.OAuthInfo{
ClientId: "some-client-id",
ClientAuthentication: social.WorkloadIdentity,
FederatedCredentialAudience: "api://AzureADTokenExchange",
TokenUrl: "https://login.microsoftonline.com/token",
}
t.Run("success", func(t *testing.T) {
workloadFile := path.Join(t.TempDir(), "workload.json")
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
require.NoError(t, err)
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
s.info.WorkloadIdentityTokenFile = workloadFile
// Mock the token endpoint
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Error(err)
}
// Verify that client_assertion is present in the request
if r.FormValue("client_assertion") != "mock-client-assertion" {
t.Errorf("expected client_assertion to be 'mock-client-assertion', got '%s'", r.FormValue("client_assertion"))
}
if r.FormValue("client_assertion_type") != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
t.Errorf("expected client_assertion_type to be 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', got '%s'", r.FormValue("client_assertion_type"))
}
if r.FormValue("grant_type") != "refresh_token" {
t.Errorf("expected grant_type to be 'refresh_token', got '%s'", r.FormValue("grant_type"))
}
if r.FormValue("client_id") != "some-client-id" {
t.Errorf("expected client_id to be 'client-id', got '%s'", r.FormValue("client_id"))
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": "new-access-token",
"token_type": "Bearer",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
})
require.NoError(t, err)
}))
defer server.Close()
// Update TokenURL to point to the mock server
s.Endpoint.TokenURL = server.URL
// Create a token source with an expired token
now := time.Now()
token := &oauth2.Token{
AccessToken: "old-access-token",
RefreshToken: "old-refresh-token",
Expiry: now.Add(-time.Hour),
}
// Create a context with the mock client
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
// Get a new token (this should trigger a refresh)
ts := s.TokenSource(ctx, token)
newToken, err := ts.Token()
require.NoError(t, err)
assert.Equal(t, "new-access-token", newToken.AccessToken)
assert.Equal(t, "new-refresh-token", newToken.RefreshToken)
assert.WithinDuration(t, now.Add(time.Hour), newToken.Expiry, time.Second)
})
t.Run("error when workload token file does not exist", func(t *testing.T) {
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
s.info.WorkloadIdentityTokenFile = "/non/existent/file"
token := &oauth2.Token{
AccessToken: "old-access-token",
RefreshToken: "old-refresh-token",
Expiry: time.Now().Add(-time.Hour),
}
ts := s.TokenSource(context.Background(), token)
_, err := ts.Token()
require.Error(t, err)
require.Contains(t, err.Error(), "failed to read workload identity token file")
})
t.Run("error when token endpoint returns error", func(t *testing.T) {
workloadFile := path.Join(t.TempDir(), "workload.json")
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
require.NoError(t, err)
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
s.info.WorkloadIdentityTokenFile = workloadFile
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
s.Endpoint.TokenURL = server.URL
token := &oauth2.Token{
AccessToken: "old-access-token",
RefreshToken: "old-refresh-token",
Expiry: time.Now().Add(-time.Hour),
}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
ts := s.TokenSource(ctx, token)
_, err = ts.Token()
require.Error(t, err)
require.Contains(t, err.Error(), "oauth2: cannot fetch token: 500 Internal Server Error")
})
t.Run("error when missing refresh token", func(t *testing.T) {
workloadFile := path.Join(t.TempDir(), "workload.json")
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
require.NoError(t, err)
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
s.info.WorkloadIdentityTokenFile = workloadFile
// No RefreshToken
token := &oauth2.Token{
AccessToken: "old-access-token",
Expiry: time.Now().Add(-time.Hour),
}
ts := s.TokenSource(context.Background(), token)
_, err = ts.Token()
require.Error(t, err)
require.Contains(t, err.Error(), "no refresh token available to refresh the access token")
})
t.Run("error when invalid token response", func(t *testing.T) {
workloadFile := path.Join(t.TempDir(), "workload.json")
err := os.WriteFile(workloadFile, []byte("mock-client-assertion"), 0600)
require.NoError(t, err)
s := NewAzureADProvider(info, setting.NewCfg(), nil, ssosettingstests.NewFakeService(), featuremgmt.WithFeatures(), remotecache.FakeCacheStorage{})
s.info.WorkloadIdentityTokenFile = workloadFile
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("invalid-json"))
}))
defer server.Close()
s.Endpoint.TokenURL = server.URL
token := &oauth2.Token{
AccessToken: "old-access-token",
RefreshToken: "old-refresh-token",
Expiry: time.Now().Add(-time.Hour),
}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, server.Client())
ts := s.TokenSource(ctx, token)
_, err = ts.Token()
require.Error(t, err)
require.Contains(t, err.Error(), "unable to unmarshal raw response body")
})
}