Files
Misi 1f3dc0533c Auth: Add tracing to auth clients and AuthToken service (#107878)
* Add tracing to auth clients + authtoken svc

* Fix span names

* Fix ext_jwt.go

* Fix idimpl/service

* Update wire_gen.go

* Add tracing to JWT client

* Lint
2025-07-10 15:41:00 +02:00

201 lines
5.5 KiB
Go

package idimpl
import (
"context"
"errors"
"fmt"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/singleflight"
authnlib "github.com/grafana/authlib/authn"
claims "github.com/grafana/authlib/types"
"github.com/grafana/grafana/pkg/apimachinery/identity"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/services/apiserver/endpoints/request"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/setting"
)
const (
cachePrefix = "id-token"
tokenTTL = 10 * time.Minute
cacheLeeway = 30 * time.Second
)
var _ auth.IDService = (*Service)(nil)
func ProvideService(
cfg *setting.Cfg, signer auth.IDSigner,
cache remotecache.CacheStorage, authnService authn.Service,
reg prometheus.Registerer, tracer trace.Tracer,
) *Service {
s := &Service{
cfg: cfg, logger: log.New("id-service"),
signer: signer, cache: cache,
metrics: newMetrics(reg),
nsMapper: request.GetNamespaceMapper(cfg),
tracer: tracer,
}
authnService.RegisterPostAuthHook(s.SyncIDToken, 140)
return s
}
type Service struct {
cfg *setting.Cfg
logger log.Logger
signer auth.IDSigner
cache remotecache.CacheStorage
si singleflight.Group
metrics *metrics
tracer trace.Tracer
nsMapper request.NamespaceMapper
}
func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (string, *authnlib.Claims[authnlib.IDTokenClaims], error) {
ctx, span := s.tracer.Start(ctx, "user.sync.SignIdentity")
defer span.End()
defer func(t time.Time) {
s.metrics.tokenSigningDurationHistogram.Observe(time.Since(t).Seconds())
}(time.Now())
cacheKey := getCacheKey(id)
type resultType struct {
token string
idClaims *auth.IDClaims
}
result, err, _ := s.si.Do(cacheKey, func() (any, error) {
cachedToken, err := s.cache.Get(ctx, cacheKey)
if err == nil {
s.metrics.tokenSigningFromCacheCounter.Inc()
s.logger.FromContext(ctx).Debug("Cached token found", "id", id.GetID())
tokenClaims, err := s.extractTokenClaims(string(cachedToken))
if err != nil {
return resultType{}, err
}
return resultType{token: string(cachedToken), idClaims: tokenClaims}, nil
}
s.metrics.tokenSigningCounter.Inc()
s.logger.FromContext(ctx).Debug("Sign new id token", "id", id.GetID())
now := time.Now()
idClaims := &auth.IDClaims{
Claims: jwt.Claims{
Issuer: s.cfg.AppURL,
Audience: getAudience(id.GetOrgID()),
Subject: id.GetID(),
Expiry: jwt.NewNumericDate(now.Add(tokenTTL)),
IssuedAt: jwt.NewNumericDate(now),
},
Rest: authnlib.IDTokenClaims{
Namespace: s.nsMapper(id.GetOrgID()),
Identifier: id.GetRawIdentifier(),
Type: id.GetIdentityType(),
},
}
if id.IsIdentityType(claims.TypeUser) {
idClaims.Rest.Email = id.GetEmail()
idClaims.Rest.EmailVerified = id.GetEmailVerified()
idClaims.Rest.AuthenticatedBy = id.GetAuthenticatedBy()
idClaims.Rest.Username = id.GetLogin()
idClaims.Rest.DisplayName = id.GetName()
}
if id.GetOrgRole().IsValid() {
idClaims.Rest.Role = string(id.GetOrgRole())
}
token, err := s.signer.SignIDToken(ctx, idClaims)
if err != nil {
s.metrics.failedTokenSigningCounter.Inc()
return resultType{}, nil
}
extracted, err := s.extractTokenClaims(token)
if err != nil {
return resultType{}, err
}
expires := time.Until(extracted.Expiry.Time())
if err := s.cache.Set(ctx, cacheKey, []byte(token), expires-cacheLeeway); err != nil {
s.logger.FromContext(ctx).Error("Failed to add id token to cache", "error", err)
}
return resultType{token: token, idClaims: idClaims}, nil
})
if err != nil {
return "", nil, err
}
return result.(resultType).token, result.(resultType).idClaims, nil
}
func (s *Service) RemoveIDToken(ctx context.Context, id identity.Requester) error {
ctx, span := s.tracer.Start(ctx, "user.sync.RemoveIDToken")
defer span.End()
return s.cache.Delete(ctx, getCacheKey(id))
}
func (s *Service) SyncIDToken(ctx context.Context, identity *authn.Identity, _ *authn.Request) error {
ctx, span := s.tracer.Start(ctx, "user.sync.SyncIDToken")
defer span.End()
// FIXME(kalleep): we should probably lazy load this
token, idClaims, err := s.SignIdentity(ctx, identity)
if err != nil {
if shouldLogErr(err) {
s.logger.FromContext(ctx).Error("Failed to sign id token", "err", err, "id", identity.GetID())
}
// for now don't return error so we don't break authentication from this hook
return nil
}
identity.IDToken = token
identity.IDTokenClaims = idClaims
return nil
}
func (s *Service) extractTokenClaims(token string) (*authnlib.Claims[authnlib.IDTokenClaims], error) {
parsed, err := jwt.ParseSigned(token)
if err != nil {
s.metrics.failedTokenSigningCounter.Inc()
return nil, err
}
extracted := authnlib.Claims[authnlib.IDTokenClaims]{}
// We don't need to verify the signature here, we are only interested in checking
// when the token expires.
if err := parsed.UnsafeClaimsWithoutVerification(&extracted); err != nil {
s.metrics.failedTokenSigningCounter.Inc()
return nil, err
}
return &extracted, nil
}
func getAudience(orgID int64) jwt.Audience {
return jwt.Audience{fmt.Sprintf("org:%d", orgID)}
}
func getCacheKey(ident identity.Requester) string {
return cachePrefix + ident.GetCacheKey() + string(ident.GetOrgRole())
}
func shouldLogErr(err error) bool {
return !errors.Is(err, context.Canceled)
}