mirror of
https://github.com/grafana/loki.git
synced 2026-03-13 09:33:58 +08:00
feat: Resolve ingestion policy via a header (#19548)
This commit is contained in:
@@ -601,7 +601,7 @@ func (d *Distributor) PushWithResolver(ctx context.Context, req *logproto.PushRe
|
||||
|
||||
var lbs labels.Labels
|
||||
var retentionHours, policy string
|
||||
lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver, format)
|
||||
lbs, stream.Labels, stream.Hash, retentionHours, policy, err = d.parseStreamLabels(ctx, validationContext, stream.Labels, stream, streamResolver, format)
|
||||
if err != nil {
|
||||
d.writeFailuresManager.Log(tenantID, err)
|
||||
validationErrors.Add(err)
|
||||
@@ -915,7 +915,7 @@ func (d *Distributor) trackDiscardedData(
|
||||
|
||||
if d.usageTracker != nil {
|
||||
for _, stream := range req.Streams {
|
||||
lbs, _, _, _, _, err := d.parseStreamLabels(validationContext, stream.Labels, stream, streamResolver, format)
|
||||
lbs, _, _, _, _, err := d.parseStreamLabels(ctx, validationContext, stream.Labels, stream, streamResolver, format)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -1309,10 +1309,10 @@ type labelData struct {
|
||||
}
|
||||
|
||||
// parseStreamLabels parses stream labels using a request-scoped policy resolver
|
||||
func (d *Distributor) parseStreamLabels(vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver, format string) (labels.Labels, string, uint64, string, string, error) {
|
||||
func (d *Distributor) parseStreamLabels(ctx context.Context, vContext validationContext, key string, stream logproto.Stream, streamResolver push.StreamResolver, format string) (labels.Labels, string, uint64, string, string, error) {
|
||||
if val, ok := d.labelCache.Get(key); ok {
|
||||
retentionHours := streamResolver.RetentionHoursFor(val.ls)
|
||||
policy := streamResolver.PolicyFor(val.ls)
|
||||
policy := streamResolver.PolicyFor(ctx, val.ls)
|
||||
return val.ls, val.ls.String(), val.hash, retentionHours, policy, nil
|
||||
}
|
||||
|
||||
@@ -1323,7 +1323,7 @@ func (d *Distributor) parseStreamLabels(vContext validationContext, key string,
|
||||
return labels.EmptyLabels(), "", 0, retentionHours, "", fmt.Errorf(validation.InvalidLabelsErrorMsg, key, err)
|
||||
}
|
||||
|
||||
policy := streamResolver.PolicyFor(ls)
|
||||
policy := streamResolver.PolicyFor(ctx, ls)
|
||||
retentionHours := d.tenantsRetention.RetentionHoursFor(vContext.userID, ls)
|
||||
|
||||
if err := d.validator.ValidateLabels(vContext, ls, stream, retentionHours, policy, format); err != nil {
|
||||
@@ -1454,8 +1454,8 @@ func (r requestScopedStreamResolver) RetentionHoursFor(lbs labels.Labels) string
|
||||
return r.retention.RetentionHoursFor(lbs)
|
||||
}
|
||||
|
||||
func (r requestScopedStreamResolver) PolicyFor(lbs labels.Labels) string {
|
||||
policies := r.policyStreamMappings.PolicyFor(lbs)
|
||||
func (r requestScopedStreamResolver) PolicyFor(ctx context.Context, lbs labels.Labels) string {
|
||||
policies := r.policyStreamMappings.PolicyFor(ctx, lbs)
|
||||
|
||||
var policy string
|
||||
if len(policies) > 0 {
|
||||
|
||||
@@ -1291,7 +1291,7 @@ func Benchmark_SortLabelsOnPush(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
stream := request.Streams[0]
|
||||
stream.Labels = `{buzz="f", a="b"}`
|
||||
_, _, _, _, _, err := d.parseStreamLabels(vCtx, stream.Labels, stream, streamResolver, constants.Loki)
|
||||
_, _, _, _, _, err := d.parseStreamLabels(context.Background(), vCtx, stream.Labels, stream, streamResolver, constants.Loki)
|
||||
if err != nil {
|
||||
panic("parseStreamLabels fail,err:" + err.Error())
|
||||
}
|
||||
@@ -1331,7 +1331,7 @@ func TestParseStreamLabels(t *testing.T) {
|
||||
vCtx := d.validator.getValidationContextForTime(testTime, "123")
|
||||
streamResolver := newRequestScopedStreamResolver("123", d.validator.Limits, nil)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
lbs, lbsString, hash, _, _, err := d.parseStreamLabels(vCtx, tc.origLabels, logproto.Stream{
|
||||
lbs, lbsString, hash, _, _, err := d.parseStreamLabels(context.Background(), vCtx, tc.origLabels, logproto.Stream{
|
||||
Labels: tc.origLabels,
|
||||
}, streamResolver, constants.Loki)
|
||||
if tc.expectedErr != nil {
|
||||
@@ -2336,10 +2336,10 @@ func TestRequestScopedStreamResolver(t *testing.T) {
|
||||
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
|
||||
require.Equal(t, 24*time.Hour, retentionPeriod)
|
||||
|
||||
policy := resolver.PolicyFor(labels.FromStrings("env", "prod"))
|
||||
policy := resolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod"))
|
||||
require.Equal(t, "policy0", policy)
|
||||
|
||||
policy = resolver.PolicyFor(labels.FromStrings("env", "dev"))
|
||||
policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev"))
|
||||
require.Empty(t, policy)
|
||||
|
||||
// We now modify the underlying limits to test that the resolver is not affected by changes to the limits
|
||||
@@ -2378,10 +2378,10 @@ func TestRequestScopedStreamResolver(t *testing.T) {
|
||||
retentionPeriod = resolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
|
||||
require.Equal(t, 24*time.Hour, retentionPeriod)
|
||||
|
||||
policy = resolver.PolicyFor(labels.FromStrings("env", "prod"))
|
||||
policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod"))
|
||||
require.Equal(t, "policy0", policy)
|
||||
|
||||
policy = resolver.PolicyFor(labels.FromStrings("env", "dev"))
|
||||
policy = resolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev"))
|
||||
require.Empty(t, policy)
|
||||
|
||||
// But a new resolver should return the new values
|
||||
@@ -2397,10 +2397,10 @@ func TestRequestScopedStreamResolver(t *testing.T) {
|
||||
retentionPeriod = newResolver.RetentionPeriodFor(labels.FromStrings("env", "dev"))
|
||||
require.Equal(t, 72*time.Hour, retentionPeriod)
|
||||
|
||||
policy = newResolver.PolicyFor(labels.FromStrings("env", "prod"))
|
||||
policy = newResolver.PolicyFor(t.Context(), labels.FromStrings("env", "prod"))
|
||||
require.Empty(t, policy)
|
||||
|
||||
policy = newResolver.PolicyFor(labels.FromStrings("env", "dev"))
|
||||
policy = newResolver.PolicyFor(t.Context(), labels.FromStrings("env", "dev"))
|
||||
require.Equal(t, "policy1", policy)
|
||||
}
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe
|
||||
"stream", s.Labels,
|
||||
"streamLabelsHash", util.HashedQuery(s.Labels), // this is to make it easier to do searching and grouping
|
||||
"streamSizeBytes", humanize.Bytes(uint64(pushStats.StreamSizeBytes[s.Labels])),
|
||||
"policy", streamResolver.PolicyFor(lbs),
|
||||
"policy", streamResolver.PolicyFor(r.Context(), lbs),
|
||||
}
|
||||
if timestamp, ok := pushStats.MostRecentEntryTimestampPerStream[s.Labels]; ok {
|
||||
logValues = append(logValues, "mostRecentLagMs", time.Since(timestamp).Milliseconds())
|
||||
|
||||
@@ -201,7 +201,7 @@ func (i *instance) consumeChunk(ctx context.Context, ls labels.Labels, chunk *lo
|
||||
|
||||
s, _, _ := i.streams.LoadOrStoreNewByFP(fp,
|
||||
func() (*stream, error) {
|
||||
s, err := i.createStreamByFP(ls, fp)
|
||||
s, err := i.createStreamByFP(ctx, ls, fp)
|
||||
s.chunkMtx.Lock() // Lock before return, because we have defer that unlocks it.
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -299,7 +299,7 @@ func (i *instance) createStream(ctx context.Context, pushReqStream logproto.Stre
|
||||
}
|
||||
|
||||
retentionHours := util.RetentionHours(i.tenantsRetention.RetentionPeriodFor(i.instanceID, labels))
|
||||
policy := i.resolvePolicyForStream(labels)
|
||||
policy := i.resolvePolicyForStream(ctx, labels)
|
||||
|
||||
if record != nil {
|
||||
err = i.streamCountLimiter.AssertNewStreamAllowed(i.instanceID, policy)
|
||||
@@ -336,9 +336,9 @@ func (i *instance) createStream(ctx context.Context, pushReqStream logproto.Stre
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (i *instance) resolvePolicyForStream(labels labels.Labels) string {
|
||||
func (i *instance) resolvePolicyForStream(ctx context.Context, labels labels.Labels) string {
|
||||
mapping := i.limiter.limits.PoliciesStreamMapping(i.instanceID)
|
||||
policies := mapping.PolicyFor(labels)
|
||||
policies := mapping.PolicyFor(ctx, labels)
|
||||
// NOTE: We previously resolved the policy on distributors and logged when multiple policies were matched.
|
||||
// As on distributors, we use the first policy by alphabetical order.
|
||||
var policy string
|
||||
@@ -400,7 +400,7 @@ func (i *instance) onStreamCreated(s *stream) {
|
||||
}
|
||||
}
|
||||
|
||||
func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*stream, error) {
|
||||
func (i *instance) createStreamByFP(ctx context.Context, ls labels.Labels, fp model.Fingerprint) (*stream, error) {
|
||||
sortedLabels := i.index.Add(logproto.FromLabelsToLabelAdapters(ls), fp)
|
||||
|
||||
chunkfmt, headfmt, err := i.chunkFormatAt(model.Now())
|
||||
@@ -409,7 +409,7 @@ func (i *instance) createStreamByFP(ls labels.Labels, fp model.Fingerprint) (*st
|
||||
}
|
||||
|
||||
retentionHours := util.RetentionHours(i.tenantsRetention.RetentionPeriodFor(i.instanceID, ls))
|
||||
policy := i.resolvePolicyForStream(ls)
|
||||
policy := i.resolvePolicyForStream(ctx, ls)
|
||||
|
||||
s := newStream(chunkfmt, headfmt, i.cfg, i.limiter.rateLimitStrategy, i.instanceID, fp, sortedLabels, i.limiter.UnorderedWrites(i.instanceID), i.streamRateCalculator, i.metrics, i.writeFailures, i.configs, retentionHours, policy)
|
||||
|
||||
|
||||
@@ -581,7 +581,7 @@ func Benchmark_instance_addNewTailer(b *testing.B) {
|
||||
chunkfmt, headfmt, err := inst.chunkFormatAt(model.Now())
|
||||
require.NoError(b, err)
|
||||
retentionHours := util.RetentionHours(tenantsRetention.RetentionPeriodFor("test", lbs))
|
||||
policy := inst.resolvePolicyForStream(lbs)
|
||||
policy := inst.resolvePolicyForStream(context.Background(), lbs)
|
||||
|
||||
b.Run("addTailersToNewStream", func(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
|
||||
@@ -243,7 +243,7 @@ func createStream(t *testing.T, inst *instance, fingerprint int) *stream {
|
||||
lbls := labels.FromStrings("mock", strconv.Itoa(fingerprint))
|
||||
|
||||
stream, _, err := inst.streams.LoadOrStoreNew(lbls.String(), func() (*stream, error) {
|
||||
return inst.createStreamByFP(lbls, model.Fingerprint(fingerprint))
|
||||
return inst.createStreamByFP(context.Background(), lbls, model.Fingerprint(fingerprint))
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
return stream
|
||||
|
||||
@@ -242,7 +242,7 @@ func otlpToLokiPushRequest(ctx context.Context, ld plog.Logs, userID string, otl
|
||||
// Calculate resource attributes metadata size for stats
|
||||
resourceAttributesAsStructuredMetadataSize := loki_util.StructuredMetadataSize(resourceAttributesAsStructuredMetadata)
|
||||
retentionPeriodForUser := streamResolver.RetentionPeriodFor(lbs)
|
||||
policy := streamResolver.PolicyFor(lbs)
|
||||
policy := streamResolver.PolicyFor(ctx, lbs)
|
||||
|
||||
// Check if the stream has the exporter=OTLP label; set flag instead of incrementing per stream
|
||||
if value, ok := streamLabels[model.LabelName("exporter")]; ok && value == "OTLP" {
|
||||
@@ -386,7 +386,7 @@ func otlpToLokiPushRequest(ctx context.Context, ld plog.Logs, userID string, otl
|
||||
pushRequestsByStream[entryLabelsStr] = stream
|
||||
|
||||
entryRetentionPeriod := streamResolver.RetentionPeriodFor(entryLbs)
|
||||
entryPolicy := streamResolver.PolicyFor(entryLbs)
|
||||
entryPolicy := streamResolver.PolicyFor(ctx, entryLbs)
|
||||
|
||||
if _, ok := stats.StructuredMetadataBytes[entryPolicy]; !ok {
|
||||
stats.StructuredMetadataBytes[entryPolicy] = make(map[time.Duration]int64)
|
||||
|
||||
@@ -583,7 +583,7 @@ func TestOTLPToLokiPushRequest(t *testing.T) {
|
||||
stats := NewPushStats()
|
||||
tracker := NewMockTracker()
|
||||
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
|
||||
streamResolver.policyForOverride = func(lbs labels.Labels) string {
|
||||
streamResolver.policyForOverride = func(_ context.Context, lbs labels.Labels) string {
|
||||
if lbs.Get("service_name") == "service-1" {
|
||||
return "service-1-policy"
|
||||
}
|
||||
@@ -926,7 +926,7 @@ func TestOTLPLogAttributesAsIndexLabels(t *testing.T) {
|
||||
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
|
||||
|
||||
// All logs will use the same policy for simplicity
|
||||
streamResolver.policyForOverride = func(_ labels.Labels) string {
|
||||
streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string {
|
||||
return "test-policy"
|
||||
}
|
||||
|
||||
@@ -1029,7 +1029,7 @@ func TestOTLPStructuredMetadataCalculation(t *testing.T) {
|
||||
tracker := NewMockTracker()
|
||||
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
|
||||
|
||||
streamResolver.policyForOverride = func(_ labels.Labels) string {
|
||||
streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string {
|
||||
return "test-policy"
|
||||
}
|
||||
|
||||
@@ -1215,7 +1215,7 @@ func TestOTLPSeverityTextAsLabel(t *testing.T) {
|
||||
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
|
||||
|
||||
// All logs will use the same policy for simplicity
|
||||
streamResolver.policyForOverride = func(_ labels.Labels) string {
|
||||
streamResolver.policyForOverride = func(_ context.Context, _ labels.Labels) string {
|
||||
return "test-policy"
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package push
|
||||
import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
@@ -113,7 +114,7 @@ func (EmptyLimits) PolicyFor(_ string, _ labels.Labels) string {
|
||||
type StreamResolver interface {
|
||||
RetentionPeriodFor(lbs labels.Labels) time.Duration
|
||||
RetentionHoursFor(lbs labels.Labels) string
|
||||
PolicyFor(lbs labels.Labels) string
|
||||
PolicyFor(ctx context.Context, lbs labels.Labels) string
|
||||
}
|
||||
|
||||
type (
|
||||
@@ -443,7 +444,7 @@ func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfi
|
||||
req.Streams[i] = s
|
||||
}
|
||||
|
||||
err = CalculateStreamsStats(userID, req, streamResolver, tenantConfigs, pushStats)
|
||||
err = CalculateStreamsStats(r.Context(), userID, req, streamResolver, tenantConfigs, pushStats)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -452,7 +453,7 @@ func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfi
|
||||
}
|
||||
|
||||
// CalculateStreamsStats modifies pushStats with statistics about all the streams from req.
|
||||
func CalculateStreamsStats(userID string, req *logproto.PushRequest, streamResolver StreamResolver, tenantConfigs *runtime.TenantConfigs, pushStats *Stats) error {
|
||||
func CalculateStreamsStats(ctx context.Context, userID string, req *logproto.PushRequest, streamResolver StreamResolver, tenantConfigs *runtime.TenantConfigs, pushStats *Stats) error {
|
||||
logPushRequestStreams := false
|
||||
if tenantConfigs != nil {
|
||||
logPushRequestStreams = tenantConfigs.LogPushRequestStreams(userID)
|
||||
@@ -471,7 +472,7 @@ func CalculateStreamsStats(userID string, req *logproto.PushRequest, streamResol
|
||||
var policy string
|
||||
if streamResolver != nil {
|
||||
retentionPeriod = streamResolver.RetentionPeriodFor(lbs)
|
||||
policy = streamResolver.PolicyFor(lbs)
|
||||
policy = streamResolver.PolicyFor(ctx, lbs)
|
||||
}
|
||||
|
||||
if _, ok := pushStats.LogLinesBytes[policy]; !ok {
|
||||
|
||||
@@ -787,7 +787,7 @@ type mockStreamResolver struct {
|
||||
tenant string
|
||||
limits *fakeLimits
|
||||
|
||||
policyForOverride func(lbs labels.Labels) string
|
||||
policyForOverride func(ctx context.Context, lbs labels.Labels) string
|
||||
}
|
||||
|
||||
func newMockStreamResolver(tenant string, limits *fakeLimits) *mockStreamResolver {
|
||||
@@ -805,9 +805,9 @@ func (m mockStreamResolver) RetentionHoursFor(lbs labels.Labels) string {
|
||||
return m.limits.RetentionHoursFor(m.tenant, lbs)
|
||||
}
|
||||
|
||||
func (m mockStreamResolver) PolicyFor(lbs labels.Labels) string {
|
||||
func (m mockStreamResolver) PolicyFor(ctx context.Context, lbs labels.Labels) string {
|
||||
if m.policyForOverride != nil {
|
||||
return m.policyForOverride(lbs)
|
||||
return m.policyForOverride(ctx, lbs)
|
||||
}
|
||||
|
||||
return m.limits.PolicyFor(m.tenant, lbs)
|
||||
|
||||
@@ -391,6 +391,7 @@ func (t *Loki) initDistributor() (services.Service, error) {
|
||||
httpPushHandlerMiddleware := middleware.Merge(
|
||||
serverutil.RecoveryHTTPMiddleware,
|
||||
t.HTTPAuthMiddleware,
|
||||
validation.NewIngestionPolicyMiddleware(util_log.Logger),
|
||||
)
|
||||
|
||||
lokiPushHandler := httpPushHandlerMiddleware.Wrap(http.HandlerFunc(t.distributor.PushHandler))
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/go-kit/log"
|
||||
"github.com/grafana/dskit/middleware"
|
||||
"github.com/prometheus/prometheus/model/labels"
|
||||
|
||||
"github.com/grafana/loki/v3/pkg/logql/syntax"
|
||||
@@ -11,6 +15,8 @@ import (
|
||||
|
||||
const (
|
||||
GlobalPolicy = "*"
|
||||
|
||||
HTTPHeaderIngestionPolicyKey = "X-Loki-Ingestion-Policy"
|
||||
)
|
||||
|
||||
type PriorityStream struct {
|
||||
@@ -54,7 +60,15 @@ func (p *PolicyStreamMapping) Validate() error {
|
||||
// with the same priority.
|
||||
// Returned policies are sorted alphabetically.
|
||||
// If no policies match, it returns an empty slice.
|
||||
func (p *PolicyStreamMapping) PolicyFor(lbs labels.Labels) []string {
|
||||
// If a policy is set via the X-Loki-Ingestion-Policy header (passed through context), it overrides
|
||||
// all stream-to-policy mappings and returns that policy.
|
||||
func (p *PolicyStreamMapping) PolicyFor(ctx context.Context, lbs labels.Labels) []string {
|
||||
// Check if a policy was set via the HTTP header (X-Loki-Ingestion-Policy)
|
||||
// This overrides any stream-to-policy mappings
|
||||
if headerPolicy := ExtractIngestionPolicyContext(ctx); headerPolicy != "" {
|
||||
return []string{headerPolicy}
|
||||
}
|
||||
|
||||
var (
|
||||
found bool
|
||||
highestPriority int
|
||||
@@ -143,3 +157,54 @@ func (p *PolicyStreamMapping) ApplyDefaultPolicyStreamMappings(defaults PolicySt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// policyContextKey is used as a key for context values to avoid collisions
|
||||
type policyContextKey int
|
||||
|
||||
const (
|
||||
ingestionPolicyContextKey policyContextKey = 1
|
||||
)
|
||||
|
||||
// ExtractIngestionPolicyHTTP retrieves the ingestion policy from the HTTP header and returns it.
|
||||
// If no policy is found, it returns an empty string.
|
||||
func ExtractIngestionPolicyHTTP(r *http.Request) string {
|
||||
return r.Header.Get(HTTPHeaderIngestionPolicyKey)
|
||||
}
|
||||
|
||||
// InjectIngestionPolicyContext returns a derived context containing the provided ingestion policy.
|
||||
func InjectIngestionPolicyContext(ctx context.Context, policy string) context.Context {
|
||||
return context.WithValue(ctx, ingestionPolicyContextKey, policy)
|
||||
}
|
||||
|
||||
// ExtractIngestionPolicyContext gets the embedded ingestion policy from the context.
|
||||
// If no policy is found, it returns an empty string.
|
||||
func ExtractIngestionPolicyContext(ctx context.Context) string {
|
||||
policy, ok := ctx.Value(ingestionPolicyContextKey).(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return policy
|
||||
}
|
||||
|
||||
type ingestionPolicyMiddleware struct {
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// NewIngestionPolicyMiddleware creates a middleware that extracts the ingestion policy
|
||||
// from the HTTP header and injects it into the context of the request.
|
||||
func NewIngestionPolicyMiddleware(logger log.Logger) middleware.Interface {
|
||||
return &ingestionPolicyMiddleware{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap implements the middleware interface
|
||||
func (m *ingestionPolicyMiddleware) Wrap(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if policy := ExtractIngestionPolicyHTTP(r); policy != "" {
|
||||
r = r.Clone(InjectIngestionPolicyContext(r.Context(), policy))
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/prometheus/prometheus/model/labels"
|
||||
@@ -102,19 +105,20 @@ func Test_PolicyStreamMapping_PolicyFor(t *testing.T) {
|
||||
|
||||
require.NoError(t, mapping.Validate())
|
||||
|
||||
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(labels.FromStrings("foo", "bar")))
|
||||
ctx := t.Context()
|
||||
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
|
||||
// matches both policy2 and policy1 but policy1 has higher priority.
|
||||
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(labels.FromStrings("foo", "bar", "daz", "baz")))
|
||||
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar", "daz", "baz")))
|
||||
// matches policy3 and policy4 but policy3 has higher priority..
|
||||
require.Equal(t, []string{"policy3"}, mapping.PolicyFor(labels.FromStrings("qyx", "qzx", "qox", "qox")))
|
||||
require.Equal(t, []string{"policy3"}, mapping.PolicyFor(ctx, labels.FromStrings("qyx", "qzx", "qox", "qox")))
|
||||
// matches no policy.
|
||||
require.Empty(t, mapping.PolicyFor(labels.FromStrings("foo", "fooz", "daz", "qux", "quux", "corge")))
|
||||
require.Empty(t, mapping.PolicyFor(ctx, labels.FromStrings("foo", "fooz", "daz", "qux", "quux", "corge")))
|
||||
// matches policy5 through regex.
|
||||
require.Equal(t, []string{"policy5"}, mapping.PolicyFor(labels.FromStrings("qab", "qzxqox")))
|
||||
require.Equal(t, []string{"policy5"}, mapping.PolicyFor(ctx, labels.FromStrings("qab", "qzxqox")))
|
||||
|
||||
require.Equal(t, []string{"policy6"}, mapping.PolicyFor(labels.FromStrings("env", "prod", "team", "finance")))
|
||||
require.Equal(t, []string{"policy6"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod", "team", "finance")))
|
||||
// Matches policy7 and policy8 which have the same priority.
|
||||
require.Equal(t, []string{"policy7", "policy8"}, mapping.PolicyFor(labels.FromStrings("env", "prod")))
|
||||
require.Equal(t, []string{"policy7", "policy8"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod")))
|
||||
}
|
||||
|
||||
func TestPolicyStreamMapping_ApplyDefaultPolicyStreamMappings(t *testing.T) {
|
||||
@@ -284,3 +288,187 @@ func TestPolicyStreamMapping_ApplyDefaultPolicyStreamMappings_Validation(t *test
|
||||
// Verify the result is valid
|
||||
require.NoError(t, existing.Validate())
|
||||
}
|
||||
|
||||
func Test_PolicyStreamMapping_PolicyFor_WithHeaderOverride(t *testing.T) {
|
||||
mapping := PolicyStreamMapping{
|
||||
"policy1": []*PriorityStream{
|
||||
{
|
||||
Selector: `{foo="bar"}`,
|
||||
Priority: 2,
|
||||
Matchers: []*labels.Matcher{
|
||||
labels.MustNewMatcher(labels.MatchEqual, "foo", "bar"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"policy2": []*PriorityStream{
|
||||
{
|
||||
Selector: `{env="prod"}`,
|
||||
Priority: 1,
|
||||
Matchers: []*labels.Matcher{
|
||||
labels.MustNewMatcher(labels.MatchEqual, "env", "prod"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, mapping.Validate())
|
||||
|
||||
t.Run("without header context, uses normal mapping", func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
// Should match policy1 based on labels
|
||||
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
|
||||
// Should match policy2 based on labels
|
||||
require.Equal(t, []string{"policy2"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod")))
|
||||
// Should match no policy
|
||||
require.Empty(t, mapping.PolicyFor(ctx, labels.FromStrings("unknown", "label")))
|
||||
})
|
||||
|
||||
t.Run("with header context, overrides all mappings", func(t *testing.T) {
|
||||
ctx := InjectIngestionPolicyContext(t.Context(), "override-policy")
|
||||
|
||||
// Even though labels match policy1, header policy overrides
|
||||
require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
|
||||
|
||||
// Even though labels match policy2, header policy overrides
|
||||
require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("env", "prod")))
|
||||
|
||||
// Even though labels don't match anything, header policy is used
|
||||
require.Equal(t, []string{"override-policy"}, mapping.PolicyFor(ctx, labels.FromStrings("unknown", "label")))
|
||||
})
|
||||
|
||||
t.Run("empty header context is ignored", func(t *testing.T) {
|
||||
// Inject empty string - should be treated as not set
|
||||
ctx := InjectIngestionPolicyContext(t.Context(), "")
|
||||
|
||||
// Should fall back to normal mapping behavior
|
||||
require.Equal(t, []string{"policy1"}, mapping.PolicyFor(ctx, labels.FromStrings("foo", "bar")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractInjectIngestionPolicyContext(t *testing.T) {
|
||||
t.Run("inject and extract policy", func(t *testing.T) {
|
||||
policy := "test-policy"
|
||||
|
||||
ctx := InjectIngestionPolicyContext(t.Context(), policy)
|
||||
extracted := ExtractIngestionPolicyContext(ctx)
|
||||
require.Equal(t, policy, extracted)
|
||||
})
|
||||
|
||||
t.Run("extract from empty context", func(t *testing.T) {
|
||||
extracted := ExtractIngestionPolicyContext(t.Context())
|
||||
require.Empty(t, extracted)
|
||||
})
|
||||
|
||||
t.Run("inject empty string", func(t *testing.T) {
|
||||
ctx := InjectIngestionPolicyContext(t.Context(), "")
|
||||
extracted := ExtractIngestionPolicyContext(ctx)
|
||||
require.Empty(t, extracted)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractIngestionPolicyHTTP(t *testing.T) {
|
||||
t.Run("extract policy from header", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(HTTPHeaderIngestionPolicyKey, "my-policy")
|
||||
|
||||
policy := ExtractIngestionPolicyHTTP(req)
|
||||
require.Equal(t, "my-policy", policy)
|
||||
})
|
||||
|
||||
t.Run("no header present", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := ExtractIngestionPolicyHTTP(req)
|
||||
require.Empty(t, policy)
|
||||
})
|
||||
|
||||
t.Run("empty header value", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(HTTPHeaderIngestionPolicyKey, "")
|
||||
|
||||
policy := ExtractIngestionPolicyHTTP(req)
|
||||
require.Empty(t, policy)
|
||||
})
|
||||
|
||||
t.Run("header with whitespace", func(t *testing.T) {
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(HTTPHeaderIngestionPolicyKey, " policy-with-spaces ")
|
||||
|
||||
policy := ExtractIngestionPolicyHTTP(req)
|
||||
require.Equal(t, " policy-with-spaces ", policy)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIngestionPolicyMiddleware(t *testing.T) {
|
||||
t.Run("middleware injects policy into context", func(t *testing.T) {
|
||||
var capturedCtx context.Context
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCtx = r.Context()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := NewIngestionPolicyMiddleware(nil)
|
||||
wrappedHandler := middleware.Wrap(handler)
|
||||
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(HTTPHeaderIngestionPolicyKey, "test-policy")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
policy := ExtractIngestionPolicyContext(capturedCtx)
|
||||
require.Equal(t, "test-policy", policy)
|
||||
})
|
||||
|
||||
t.Run("middleware does not modify context when no header", func(t *testing.T) {
|
||||
var capturedCtx context.Context
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCtx = r.Context()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := NewIngestionPolicyMiddleware(nil)
|
||||
wrappedHandler := middleware.Wrap(handler)
|
||||
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
policy := ExtractIngestionPolicyContext(capturedCtx)
|
||||
require.Empty(t, policy)
|
||||
})
|
||||
|
||||
t.Run("middleware does not inject empty header value", func(t *testing.T) {
|
||||
var capturedCtx context.Context
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedCtx = r.Context()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
middleware := NewIngestionPolicyMiddleware(nil)
|
||||
wrappedHandler := middleware.Wrap(handler)
|
||||
|
||||
req, err := http.NewRequest("POST", "/loki/api/v1/push", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(HTTPHeaderIngestionPolicyKey, "")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
policy := ExtractIngestionPolicyContext(capturedCtx)
|
||||
require.Empty(t, policy)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user