diff --git a/clients/pkg/promtail/targets/lokipush/pushtarget.go b/clients/pkg/promtail/targets/lokipush/pushtarget.go index 881e7635f2..8ace3a8076 100644 --- a/clients/pkg/promtail/targets/lokipush/pushtarget.go +++ b/clients/pkg/promtail/targets/lokipush/pushtarget.go @@ -115,7 +115,7 @@ func (t *PushTarget) run() error { func (t *PushTarget) handleLoki(w http.ResponseWriter, r *http.Request) { logger := util_log.WithContext(r.Context(), util_log.Logger) userID, _ := tenant.TenantID(r.Context()) - req, _, err := push.ParseRequest(logger, userID, t.config.MaxSendMsgSize, r, push.EmptyLimits{}, nil, push.ParseLokiRequest, nil, nil, "", "loki") + req, _, err := push.ParseRequest(logger, userID, t.config.MaxSendMsgSize, 0, r, push.EmptyLimits{}, nil, push.ParseLokiRequest, nil, nil, "", "loki") if err != nil { level.Warn(t.logger).Log("msg", "failed to parse incoming push request", "err", err.Error()) http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/docs/sources/shared/configuration.md b/docs/sources/shared/configuration.md index 9e9510fe50..b66d6fd297 100644 --- a/docs/sources/shared/configuration.md +++ b/docs/sources/shared/configuration.md @@ -3212,6 +3212,10 @@ ring: # CLI flag: -distributor.max-recv-msg-size [max_recv_msg_size: | default = 104857600] +# The maximum size of a decompressed message. Defaults to 50x max-recv-msg-size. +# CLI flag: -distributor.max-decompressed-size +[max_decompressed_size: | default = 5242880000] + rate_store: # The max number of concurrent requests to make to ingester stream apis # CLI flag: -distributor.rate-store.max-request-parallelism diff --git a/pkg/distributor/distributor.go b/pkg/distributor/distributor.go index b65c159aad..f018bec782 100644 --- a/pkg/distributor/distributor.go +++ b/pkg/distributor/distributor.go @@ -87,7 +87,8 @@ type Config struct { PushWorkerCount int `yaml:"push_worker_count"` // Request parser - MaxRecvMsgSize int `yaml:"max_recv_msg_size"` + MaxRecvMsgSize int `yaml:"max_recv_msg_size"` + MaxDecompressedSize int `yaml:"max_decompressed_size"` // For testing. factory ring_client.PoolFactory `yaml:"-"` @@ -121,6 +122,7 @@ func (cfg *Config) RegisterFlags(fs *flag.FlagSet) { cfg.RateStore.RegisterFlagsWithPrefix("distributor.rate-store", fs) cfg.WriteFailuresLogging.RegisterFlagsWithPrefix("distributor.write-failures-logging", fs) fs.IntVar(&cfg.MaxRecvMsgSize, "distributor.max-recv-msg-size", 100<<20, "The maximum size of a received message.") + fs.IntVar(&cfg.MaxDecompressedSize, "distributor.max-decompressed-size", 5000<<20, "The maximum size of a decompressed message. Defaults to 50x max-recv-msg-size.") fs.IntVar(&cfg.PushWorkerCount, "distributor.push-worker-count", 256, "Number of workers to push batches to ingesters.") fs.BoolVar(&cfg.KafkaEnabled, "distributor.kafka-writes-enabled", false, "Enable writes to Kafka during Push requests.") fs.BoolVar(&cfg.IngesterEnabled, "distributor.ingester-writes-enabled", true, "Enable writes to Ingesters during Push requests. Defaults to true.") @@ -135,6 +137,10 @@ func (cfg *Config) Validate() error { if err := cfg.DataObjTeeConfig.Validate(); err != nil { return err } + // Set default maxDecompressedSize if not configured (50x maxRecvMsgSize) + if cfg.MaxDecompressedSize == 0 && cfg.MaxRecvMsgSize > 0 { + cfg.MaxDecompressedSize = cfg.MaxRecvMsgSize * 50 + } return nil } diff --git a/pkg/distributor/distributor_test.go b/pkg/distributor/distributor_test.go index b27d599b29..efc89b62d9 100644 --- a/pkg/distributor/distributor_test.go +++ b/pkg/distributor/distributor_test.go @@ -2651,3 +2651,64 @@ func TestDistributor_PushIngestLimits(t *testing.T) { }) } } + +func TestConfig_Validate(t *testing.T) { + tests := []struct { + name string + cfg Config + expectedMaxDecompressedSize int + expectedError string + }{ + { + name: "sets default maxDecompressedSize when zero and maxRecvMsgSize is set", + cfg: Config{ + MaxRecvMsgSize: 100 << 20, // 100 MB + MaxDecompressedSize: 0, + KafkaEnabled: false, + IngesterEnabled: true, + }, + expectedMaxDecompressedSize: 5000 << 20, // 5000 MB (50x) + }, + { + name: "does not override explicit maxDecompressedSize", + cfg: Config{ + MaxRecvMsgSize: 100 << 20, // 100 MB + MaxDecompressedSize: 500 << 20, // 500 MB + KafkaEnabled: false, + IngesterEnabled: true, + }, + expectedMaxDecompressedSize: 500 << 20, // 500 MB (unchanged) + }, + { + name: "does not set default when maxRecvMsgSize is zero", + cfg: Config{ + MaxRecvMsgSize: 0, + MaxDecompressedSize: 0, + KafkaEnabled: false, + IngesterEnabled: true, + }, + expectedMaxDecompressedSize: 0, // Should remain 0 + }, + { + name: "validates kafka and ingester enabled", + cfg: Config{ + KafkaEnabled: false, + IngesterEnabled: false, + }, + expectedError: "at least one of kafka and ingestor writes must be enabled", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.Validate() + if tt.expectedError != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedMaxDecompressedSize, tt.cfg.MaxDecompressedSize) + } + }) + } +} diff --git a/pkg/distributor/http.go b/pkg/distributor/http.go index 551c91b8f8..3bbdc1bd60 100644 --- a/pkg/distributor/http.go +++ b/pkg/distributor/http.go @@ -53,7 +53,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe logPushRequestStreams := d.tenantConfigs.LogPushRequestStreams(tenantID) filterPushRequestStreamsIPs := d.tenantConfigs.FilterPushRequestStreamsIPs(tenantID) presumedAgentIP := extractPresumedAgentIP(r) - req, pushStats, err := push.ParseRequest(logger, tenantID, d.cfg.MaxRecvMsgSize, r, d.validator.Limits, d.tenantConfigs, + req, pushStats, err := push.ParseRequest(logger, tenantID, d.cfg.MaxRecvMsgSize, d.cfg.MaxDecompressedSize, r, d.validator.Limits, d.tenantConfigs, pushRequestParser, d.usageTracker, streamResolver, presumedAgentIP, format) if err != nil { switch { diff --git a/pkg/distributor/http_test.go b/pkg/distributor/http_test.go index 180affdde8..035ef0df34 100644 --- a/pkg/distributor/http_test.go +++ b/pkg/distributor/http_test.go @@ -176,6 +176,7 @@ func (p *fakeParser) parseRequest( _ push.Limits, _ *runtime.TenantConfigs, _ int, + _ int, _ push.UsageTracker, _ push.StreamResolver, _ log.Logger, diff --git a/pkg/loghttp/push/otlp.go b/pkg/loghttp/push/otlp.go index 88feceee8b..41e230d89f 100644 --- a/pkg/loghttp/push/otlp.go +++ b/pkg/loghttp/push/otlp.go @@ -44,9 +44,9 @@ const ( messageSizeLargerErrFmt = "%w than max (%d vs %d)" ) -func ParseOTLPRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) { +func ParseOTLPRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize, maxDecompressedSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) { stats := NewPushStats() - otlpLogs, err := extractLogs(r, maxRecvMsgSize, stats) + otlpLogs, err := extractLogs(r, maxRecvMsgSize, maxDecompressedSize, stats) if err != nil { return nil, nil, err } @@ -55,7 +55,7 @@ func ParseOTLPRequest(userID string, r *http.Request, limits Limits, tenantConfi return req, stats, err } -func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Logs, error) { +func extractLogs(r *http.Request, maxRecvMsgSize, maxDecompressedSize int, pushStats *Stats) (plog.Logs, error) { pushStats.ContentEncoding = r.Header.Get(contentEnc) // bodySize should always reflect the compressed size of the request body bodySize := loki_util.NewSizeReader(r.Body) @@ -67,7 +67,7 @@ func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Lo } switch pushStats.ContentEncoding { case gzipContentEncoding: - r, err := gzip.NewReader(bodySize) + r, err := gzip.NewReader(body) if err != nil { return plog.NewLogs(), err } @@ -75,14 +75,24 @@ func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Lo defer func(reader *gzip.Reader) { _ = reader.Close() }(r) + if maxDecompressedSize > 0 { + body = io.LimitReader(body, int64(maxDecompressedSize)+1) + } + case zstdContentEncoding: var err error body, err = zstd.NewReader(body) if err != nil { return plog.NewLogs(), err } + if maxDecompressedSize > 0 { + body = io.LimitReader(body, int64(maxDecompressedSize)+1) + } case lz4ContentEncoding: body = io.NopCloser(lz4.NewReader(body)) + if maxDecompressedSize > 0 { + body = io.LimitReader(body, int64(maxDecompressedSize)+1) + } case "": // no content encoding, use the body as is default: @@ -90,12 +100,18 @@ func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Lo } buf, err := io.ReadAll(body) if err != nil { - if size := bodySize.Size(); size > int64(maxRecvMsgSize) && maxRecvMsgSize > 0 { - return plog.NewLogs(), fmt.Errorf(messageSizeLargerErrFmt, loki_util.ErrMessageSizeTooLarge, size, maxRecvMsgSize) - } return plog.NewLogs(), err } + // Check the size of the compressed body + if size := bodySize.Size(); size > int64(maxRecvMsgSize) && maxRecvMsgSize > 0 { + return plog.NewLogs(), fmt.Errorf(messageSizeLargerErrFmt, loki_util.ErrMessageSizeTooLarge, size, maxRecvMsgSize) + } + // Check the size of the decompressed body + if len(buf) > maxDecompressedSize && maxDecompressedSize > 0 { + return plog.NewLogs(), fmt.Errorf(messageSizeLargerErrFmt, loki_util.ErrMessageDecompressedSizeTooLarge, len(buf), maxDecompressedSize) + } + pushStats.BodySize = bodySize.Size() req := plogotlp.NewExportRequest() diff --git a/pkg/loghttp/push/otlp_test.go b/pkg/loghttp/push/otlp_test.go index 670fdb5e3f..adf1fa82e6 100644 --- a/pkg/loghttp/push/otlp_test.go +++ b/pkg/loghttp/push/otlp_test.go @@ -1,6 +1,7 @@ package push import ( + "compress/gzip" "context" "encoding/base64" "fmt" @@ -1297,6 +1298,59 @@ func simpleOTLPLogs() plog.Logs { return ld } +// largeOTLPLogs creates an OTLP log record which is larger than 1MB +// and will compress to less than 1MB (~3kb depending on the compression algorithm). +func largeOTLPLogs() plog.Logs { + ld := plog.NewLogs() + rl := ld.ResourceLogs().AppendEmpty() + rl.Resource().Attributes().PutStr("service.name", "test-service") + sl := rl.ScopeLogs().AppendEmpty() + for i := 0; i < 1024; i++ { + logRecord := sl.LogRecords().AppendEmpty() + logRecord.Body().SetStr(strings.Repeat(" ", 1024)) + } + return ld +} + +func createJSON(logs plog.Logs) ([]byte, error) { + req := plogotlp.NewExportRequestFromLogs(logs) + jsonBytes, err := req.MarshalJSON() + if err != nil { + return nil, err + } + return jsonBytes, nil +} + +func createGzipCompressedProtobuf(logs plog.Logs) ([]byte, error) { + req := plogotlp.NewExportRequestFromLogs(logs) + protoBytes, err := req.MarshalProto() + if err != nil { + return nil, err + } + return compressWithGzip(protoBytes) +} + +func createGzipCompressedJSON(logs plog.Logs) ([]byte, error) { + req := plogotlp.NewExportRequestFromLogs(logs) + jsonBytes, err := req.MarshalJSON() + if err != nil { + return nil, err + } + return compressWithGzip(jsonBytes) +} + +func compressWithGzip(data []byte) ([]byte, error) { + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + if _, err := writer.Write(data); err != nil { + return nil, err + } + if err := writer.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + func createZstdCompressedProtobuf(logs plog.Logs) ([]byte, error) { req := plogotlp.NewExportRequestFromLogs(logs) protoBytes, err := req.MarshalProto() @@ -1378,15 +1432,109 @@ func createOTLPLogWithNestedAttributes() plog.Logs { return ld } -func TestContentEncoding(t *testing.T) { +func TestContentEncodingAndLength(t *testing.T) { testCases := []struct { - name string - contentType string - contentEncoding string - generateBody func() ([]byte, error) - expectedError bool - expectedLogs plog.Logs + name string + contentType string + contentEncoding string + generateBody func() ([]byte, error) + expectedError bool + expectedErrorMessage string + expectedLogs plog.Logs + maxRecvMsgSize int + maxDecompressedSize int }{ + { + name: "identity_valid_json", + contentType: "application/json", + contentEncoding: "", + generateBody: func() ([]byte, error) { + return createJSON(simpleOTLPLogs()) + }, + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "identity_large_json", + contentType: "application/json", + contentEncoding: "", + generateBody: func() ([]byte, error) { + return createJSON(largeOTLPLogs()) + }, + expectedError: true, + expectedErrorMessage: "message size too large than max", + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 1 << 20, // 1 MB + }, + { + name: "gzip_valid_protobuf", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedProtobuf(simpleOTLPLogs()) + }, + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "gzip_valid_json", + contentType: "application/json", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedJSON(simpleOTLPLogs()) + }, + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "gzip_invalid_data", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return []byte("invalid gzip data"), nil + }, + expectedError: true, + expectedLogs: plog.NewLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "gzip_nested_attributes", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedProtobuf(createOTLPLogWithNestedAttributes()) + }, + expectedError: false, + expectedLogs: createOTLPLogWithNestedAttributes(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "gzip_large_protobuf", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedProtobuf(largeOTLPLogs()) + }, + expectedError: false, + expectedLogs: largeOTLPLogs(), + maxRecvMsgSize: 1 << 20, // 1 MB + }, + { + name: "gzip_too_large_protobuf", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedProtobuf(largeOTLPLogs()) + }, + expectedError: true, + expectedErrorMessage: "message size too large than max (40961 vs 40960)", + expectedLogs: largeOTLPLogs(), + maxRecvMsgSize: 1 << 12, // 4 KB + maxDecompressedSize: 40960, // Explicitly set to trigger error + }, { name: "zstd_valid_protobuf", contentType: "application/x-protobuf", @@ -1394,8 +1542,9 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return createZstdCompressedProtobuf(simpleOTLPLogs()) }, - expectedError: false, - expectedLogs: simpleOTLPLogs(), + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB }, { name: "zstd_valid_json", @@ -1404,8 +1553,9 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return createZstdCompressedJSON(simpleOTLPLogs()) }, - expectedError: false, - expectedLogs: simpleOTLPLogs(), + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB }, { name: "zstd_invalid_data", @@ -1414,8 +1564,9 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return []byte("invalid zstd data"), nil }, - expectedError: true, - expectedLogs: plog.NewLogs(), + expectedError: true, + expectedLogs: plog.NewLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB }, { name: "zstd_nested_attributes", @@ -1424,8 +1575,22 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return createZstdCompressedProtobuf(createOTLPLogWithNestedAttributes()) }, - expectedError: false, - expectedLogs: createOTLPLogWithNestedAttributes(), + expectedError: false, + expectedLogs: createOTLPLogWithNestedAttributes(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "zstd_too_large_protobuf", + contentType: "application/x-protobuf", + contentEncoding: "zstd", + generateBody: func() ([]byte, error) { + return createZstdCompressedProtobuf(largeOTLPLogs()) + }, + expectedError: true, + expectedErrorMessage: "message size too large than max (40961 vs 40960)", + expectedLogs: largeOTLPLogs(), + maxRecvMsgSize: 1 << 12, // 4 KB + maxDecompressedSize: 40960, // Explicitly set to trigger error }, { name: "lz4_valid_protobuf", @@ -1434,8 +1599,9 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return createLz4CompressedProtobuf(simpleOTLPLogs()) }, - expectedError: false, - expectedLogs: simpleOTLPLogs(), + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB }, { name: "lz4_valid_json", @@ -1444,8 +1610,9 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return createLz4CompressedJSON(simpleOTLPLogs()) }, - expectedError: false, - expectedLogs: simpleOTLPLogs(), + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB }, { name: "lz4_invalid_data", @@ -1454,8 +1621,82 @@ func TestContentEncoding(t *testing.T) { generateBody: func() ([]byte, error) { return []byte("invalid lz4 data"), nil }, - expectedError: true, - expectedLogs: plog.NewLogs(), + expectedError: true, + expectedLogs: plog.NewLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "lz4_too_large_protobuf", + contentType: "application/x-protobuf", + contentEncoding: "lz4", + generateBody: func() ([]byte, error) { + return createLz4CompressedProtobuf(largeOTLPLogs()) + }, + expectedError: true, + expectedErrorMessage: "message size too large than max (81921 vs 81920)", + expectedLogs: largeOTLPLogs(), + maxRecvMsgSize: 1 << 13, // 8 KB + maxDecompressedSize: 81920, // Explicitly set to trigger error + }, + { + name: "unsupported_encoding", + contentType: "application/x-protobuf", + contentEncoding: "br", + generateBody: func() ([]byte, error) { + return []byte("dummy brotly data"), nil + }, + expectedError: true, + expectedErrorMessage: "unsupported content encoding br: only gzip, lz4 and zstd are supported", + expectedLogs: plog.NewLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + }, + { + name: "gzip_with_zero_maxDecompressedSize", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedProtobuf(simpleOTLPLogs()) + }, + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + maxDecompressedSize: 0, // 0 means no limit (should still work for small payloads) + }, + { + name: "gzip_large_with_zero_maxDecompressedSize", + contentType: "application/x-protobuf", + contentEncoding: "gzip", + generateBody: func() ([]byte, error) { + return createGzipCompressedProtobuf(largeOTLPLogs()) + }, + expectedError: false, // No limit when maxDecompressedSize is 0 + expectedLogs: largeOTLPLogs(), + maxRecvMsgSize: 1 << 20, // 1 MB + maxDecompressedSize: 0, // 0 means no limit + }, + { + name: "zstd_with_zero_maxDecompressedSize", + contentType: "application/x-protobuf", + contentEncoding: "zstd", + generateBody: func() ([]byte, error) { + return createZstdCompressedProtobuf(simpleOTLPLogs()) + }, + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + maxDecompressedSize: 0, // 0 means no limit + }, + { + name: "lz4_with_zero_maxDecompressedSize", + contentType: "application/x-protobuf", + contentEncoding: "lz4", + generateBody: func() ([]byte, error) { + return createLz4CompressedProtobuf(simpleOTLPLogs()) + }, + expectedError: false, + expectedLogs: simpleOTLPLogs(), + maxRecvMsgSize: 100 << 20, // 100 MB + maxDecompressedSize: 0, // 0 means no limit }, } @@ -1469,10 +1710,31 @@ func TestContentEncoding(t *testing.T) { req.Header.Set("Content-Encoding", tc.contentEncoding) stats := NewPushStats() - extractedLogs, err := extractLogs(req, 100<<20, stats) + maxDecompressedSize := tc.maxDecompressedSize + // Only apply default if maxDecompressedSize is 0 and not explicitly testing zero behavior + // For test cases with maxDecompressedSize explicitly set to 0, we want to test the actual behavior + // For other cases, calculate as 10x maxRecvMsgSize (matching Validate() behavior) or use 100MB if maxRecvMsgSize is 0 + zeroMaxDecompressedSizeTests := map[string]bool{ + "gzip_with_zero_maxDecompressedSize": true, + "gzip_large_with_zero_maxDecompressedSize": true, + "zstd_with_zero_maxDecompressedSize": true, + "lz4_with_zero_maxDecompressedSize": true, + } + if maxDecompressedSize == 0 && !zeroMaxDecompressedSizeTests[tc.name] { + if tc.maxRecvMsgSize > 0 { + maxDecompressedSize = tc.maxRecvMsgSize * 50 // 50x default + } else { + maxDecompressedSize = 5000 << 20 // 5000 MB fallback default (50x 100MB) + } + } + extractedLogs, err := extractLogs(req, tc.maxRecvMsgSize, maxDecompressedSize, stats) if tc.expectedError { require.Error(t, err) + + if tc.expectedErrorMessage != "" { + require.Contains(t, err.Error(), tc.expectedErrorMessage) + } return } diff --git a/pkg/loghttp/push/push.go b/pkg/loghttp/push/push.go index 8d4d502254..4a27a83be9 100644 --- a/pkg/loghttp/push/push.go +++ b/pkg/loghttp/push/push.go @@ -118,7 +118,7 @@ type StreamResolver interface { } type ( - RequestParser func(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) + RequestParser func(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize, maxDecompressedSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) RequestParserWrapper func(inner RequestParser) RequestParser ErrorWriter func(w http.ResponseWriter, errorStr string, code int, logger log.Logger) ) @@ -171,8 +171,8 @@ type Stats struct { HasInternalStreams bool // True if any of the streams has aggregated metrics or is a pattern stream } -func ParseRequest(logger log.Logger, userID string, maxRecvMsgSize int, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, pushRequestParser RequestParser, tracker UsageTracker, streamResolver StreamResolver, presumedAgentIP, format string) (*logproto.PushRequest, *Stats, error) { - req, pushStats, err := pushRequestParser(userID, r, limits, tenantConfigs, maxRecvMsgSize, tracker, streamResolver, logger) +func ParseRequest(logger log.Logger, userID string, maxRecvMsgSize, maxDecompressedSize int, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, pushRequestParser RequestParser, tracker UsageTracker, streamResolver StreamResolver, presumedAgentIP, format string) (*logproto.PushRequest, *Stats, error) { + req, pushStats, err := pushRequestParser(userID, r, limits, tenantConfigs, maxRecvMsgSize, maxDecompressedSize, tracker, streamResolver, logger) if err != nil && !errors.Is(err, ErrAllLogsFiltered) { if errors.Is(err, util.ErrMessageSizeTooLarge) { return nil, nil, fmt.Errorf("%w: %s", ErrRequestBodyTooLarge, err.Error()) @@ -304,31 +304,46 @@ func ParseRequest(logger log.Logger, userID string, maxRecvMsgSize int, r *http. // parsePushRequestBody returns logproto.PushRequest from http.Request body, deserialized according to specified content type. // It also modifies pushStats. -func parsePushRequestBody(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (*logproto.PushRequest, error) { +func parsePushRequestBody(r *http.Request, maxRecvMsgSize, maxDecompressedSize int, pushStats *Stats) (*logproto.PushRequest, error) { // Body var body io.Reader // bodySize should always reflect the compressed size of the request body bodySize := util.NewSizeReader(r.Body) + + // Apply compressed size limit + body = bodySize + if maxRecvMsgSize > 0 { + body = io.LimitReader(body, int64(maxRecvMsgSize)+1) + } + contentEncoding := r.Header.Get(contentEnc) switch contentEncoding { case "": - body = bodySize case "snappy": - // Snappy-decoding is done by `util.ParseProtoReader(..., util.RawSnappy)` below. + // Snappy-decoding is done by `util.ParseProtoReaderWithLimits(..., util.RawSnappy)` below. // Pass on body bytes. Note: HTTP clients do not need to set this header, // but they sometimes do. See #3407. - body = bodySize case "gzip": - gzipReader, err := gzip.NewReader(bodySize) + gzipReader, err := gzip.NewReader(body) if err != nil { return nil, err } - defer gzipReader.Close() + defer func(gzipReader *gzip.Reader) { + _ = gzipReader.Close() + }(gzipReader) body = gzipReader + if maxDecompressedSize > 0 { + body = io.LimitReader(body, int64(maxDecompressedSize)+1) + } case "deflate": - flateReader := flate.NewReader(bodySize) - defer flateReader.Close() + flateReader := flate.NewReader(body) + defer func(flateReader io.ReadCloser) { + _ = flateReader.Close() + }(flateReader) body = flateReader + if maxDecompressedSize > 0 { + body = io.LimitReader(body, int64(maxDecompressedSize)+1) + } default: return nil, fmt.Errorf("Content-Encoding %q not supported", contentEncoding) } @@ -361,7 +376,7 @@ func parsePushRequestBody(r *http.Request, maxRecvMsgSize int, pushStats *Stats) default: // When no content-type header is set or when it is set to // `application/x-protobuf`: expect snappy compression. - if err := util.ParseProtoReader(r.Context(), body, int(r.ContentLength), maxRecvMsgSize, &req, util.RawSnappy); err != nil { + if err := util.ParseProtoReaderWithLimits(r.Context(), body, int(r.ContentLength), maxRecvMsgSize, maxDecompressedSize, &req, util.RawSnappy); err != nil { return nil, err } } @@ -370,13 +385,16 @@ func parsePushRequestBody(r *http.Request, maxRecvMsgSize int, pushStats *Stats) pushStats.ContentType = contentType pushStats.ContentEncoding = contentEncoding + if size := bodySize.Size(); size > int64(maxRecvMsgSize) && maxRecvMsgSize > 0 { + return nil, fmt.Errorf("compressed message size %d exceeds limit %d", size, maxRecvMsgSize) + } return &req, nil } -func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) { +func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize, maxDecompressedSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) { pushStats := NewPushStats() - req, err := parsePushRequestBody(r, maxRecvMsgSize, pushStats) + req, err := parsePushRequestBody(r, maxRecvMsgSize, maxDecompressedSize, pushStats) if err != nil { return nil, nil, err } diff --git a/pkg/loghttp/push/push_test.go b/pkg/loghttp/push/push_test.go index 70fa3d4d21..88879f18ea 100644 --- a/pkg/loghttp/push/push_test.go +++ b/pkg/loghttp/push/push_test.go @@ -363,6 +363,7 @@ func TestParseRequest(t *testing.T) { util_log.Logger, "fake", 100<<20, + 100<<20, request, test.fakeLimits, nil, @@ -510,7 +511,7 @@ func Test_ServiceDetection(t *testing.T) { limits := &fakeLimits{enabled: true, labels: []string{"foo"}} streamResolver := newMockStreamResolver("fake", limits) - data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseLokiRequest, tracker, streamResolver, "", "loki") + data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseLokiRequest, tracker, streamResolver, "", "loki") require.NoError(t, err) require.Equal(t, labels.FromStrings("foo", "bar", LabelServiceName, "bar").String(), data.Streams[0].Labels) @@ -522,7 +523,7 @@ func Test_ServiceDetection(t *testing.T) { limits := &fakeLimits{enabled: true} streamResolver := newMockStreamResolver("fake", limits) - data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki") + data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki") require.NoError(t, err) require.Equal(t, labels.FromStrings("k8s_job_name", "bar", LabelServiceName, "bar").String(), data.Streams[0].Labels) }) @@ -537,7 +538,7 @@ func Test_ServiceDetection(t *testing.T) { indexAttributes: []string{"special"}, } streamResolver := newMockStreamResolver("fake", limits) - data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki") + data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki") require.NoError(t, err) require.Equal(t, labels.FromStrings("special", "sauce", LabelServiceName, "sauce").String(), data.Streams[0].Labels) }) @@ -552,7 +553,7 @@ func Test_ServiceDetection(t *testing.T) { indexAttributes: []string{}, } streamResolver := newMockStreamResolver("fake", limits) - data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki") + data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki") require.NoError(t, err) require.Equal(t, labels.FromStrings(LabelServiceName, ServiceUnknown).String(), data.Streams[0].Labels) }) @@ -654,7 +655,7 @@ func TestNegativeSizeHandling(t *testing.T) { linesIngested.Reset() // Create a custom request parser that will generate negative sizes - var mockParser RequestParser = func(_ string, _ *http.Request, _ Limits, _ *runtime.TenantConfigs, _ int, _ UsageTracker, _ StreamResolver, _ kitlog.Logger) (*logproto.PushRequest, *Stats, error) { + var mockParser RequestParser = func(_ string, _ *http.Request, _ Limits, _ *runtime.TenantConfigs, _ int, _ int, _ UsageTracker, _ StreamResolver, _ kitlog.Logger) (*logproto.PushRequest, *Stats, error) { // Create a minimal valid request req := &logproto.PushRequest{ Streams: []logproto.Stream{ @@ -697,6 +698,7 @@ func TestNegativeSizeHandling(t *testing.T) { util_log.Logger, "fake", 100<<20, + 100<<20, request, &fakeLimits{}, nil, @@ -721,6 +723,108 @@ func TestNegativeSizeHandling(t *testing.T) { require.Equal(t, float64(0), testutil.ToFloat64(structuredMetadataBytesIngested.WithLabelValues(userID, "1", isAggregatedMetric, policy, "loki"))) } +func TestParseRequestWithZeroMaxDecompressedSize(t *testing.T) { + streamResolver := newMockStreamResolver("fake", &fakeLimits{}) + + testCases := []struct { + name string + body string + contentType string + contentEncoding string + maxRecvMsgSize int + maxDecompressedSize int + expectedError bool + expectedErrorMessage string + }{ + { + name: "gzip_with_zero_maxDecompressedSize", + body: gzipString(`{"streams": [{ "stream": { "foo": "bar" }, "values": [ [ "1570818238000000000", "test message" ] ] }]}`), + contentType: "application/json", + contentEncoding: "gzip", + maxRecvMsgSize: 100 << 20, // 100 MB + maxDecompressedSize: 0, // 0 means no limit + expectedError: false, + }, + { + name: "deflate_with_zero_maxDecompressedSize", + body: deflateString(`{"streams": [{ "stream": { "foo": "bar" }, "values": [ [ "1570818238000000000", "test message" ] ] }]}`), + contentType: "application/json", + contentEncoding: "deflate", + maxRecvMsgSize: 100 << 20, // 100 MB + maxDecompressedSize: 0, // 0 means no limit + expectedError: false, + }, + { + name: "snappy_with_zero_maxDecompressedSize", + body: "", // Will be set below + contentType: "application/x-protobuf", + contentEncoding: "", + maxRecvMsgSize: 100 << 20, // 100 MB + maxDecompressedSize: 0, // 0 means no limit + expectedError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var body []byte + var err error + + if tc.name == "snappy_with_zero_maxDecompressedSize" { + // Create a snappy-compressed protobuf request + req := &logproto.PushRequest{ + Streams: []logproto.Stream{ + { + Labels: `{foo="bar"}`, + Entries: []logproto.Entry{ + { + Timestamp: time.Now(), + Line: "test message", + }, + }, + }, + }, + } + protoBytes, err := proto.Marshal(req) + require.NoError(t, err) + body = snappy.Encode(nil, protoBytes) + } else { + body = []byte(tc.body) + } + + request := httptest.NewRequest("POST", "/loki/api/v1/push", bytes.NewReader(body)) + request.Header.Set("Content-Type", tc.contentType) + if tc.contentEncoding != "" { + request.Header.Set("Content-Encoding", tc.contentEncoding) + } + + _, _, err = ParseRequest( + util_log.Logger, + "fake", + tc.maxRecvMsgSize, + tc.maxDecompressedSize, + request, + &fakeLimits{}, + nil, + ParseLokiRequest, + NewMockTracker(), + streamResolver, + "", + "loki", + ) + + if tc.expectedError { + require.Error(t, err) + if tc.expectedErrorMessage != "" { + require.Contains(t, err.Error(), tc.expectedErrorMessage) + } + } else { + require.NoError(t, err) + } + }) + } +} + type fakeLimits struct { enabled bool labels []string diff --git a/pkg/util/http.go b/pkg/util/http.go index 6c062f47dc..c6687421c6 100644 --- a/pkg/util/http.go +++ b/pkg/util/http.go @@ -17,14 +17,17 @@ import ( "github.com/go-kit/log/level" "github.com/gogo/protobuf/proto" "github.com/golang/snappy" - attribute "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "gopkg.in/yaml.v2" ) const messageSizeLargerErrFmt = "%w than max (%d vs %d)" -var ErrMessageSizeTooLarge = errors.New("message size too large") +var ( + ErrMessageSizeTooLarge = errors.New("message size too large") + ErrMessageDecompressedSizeTooLarge = errors.New("decompressed message size too large") +) const ( HTTPRateLimited = "rate_limited" @@ -164,10 +167,17 @@ const ( ) // ParseProtoReader parses a compressed proto from an io.Reader. +// Deprecated: Use ParseProtoReaderWithLimits for separate compressed/decompressed limits. func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, req proto.Message, compression CompressionType) error { + return ParseProtoReaderWithLimits(ctx, reader, expectedSize, maxSize, maxSize, req, compression) +} + +// ParseProtoReaderWithLimits parses a compressed proto from an io.Reader with separate size limits. +// maxCompressedSize limits the compressed input size, maxDecompressedSize limits the decompressed output size. +func ParseProtoReaderWithLimits(ctx context.Context, reader io.Reader, expectedSize, maxCompressedSize, maxDecompressedSize int, req proto.Message, compression CompressionType) error { sp := trace.SpanFromContext(ctx) sp.AddEvent("util.ParseProtoRequest[start reading]") - body, err := decompressRequest(reader, expectedSize, maxSize, compression, sp) + body, err := decompressRequest(reader, expectedSize, maxCompressedSize, maxDecompressedSize, compression, sp) if err != nil { return err } @@ -189,25 +199,25 @@ func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSi return nil } -func decompressRequest(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp trace.Span) (body []byte, err error) { +func decompressRequest(reader io.Reader, expectedSize, maxCompressedSize, maxDecompressedSize int, compression CompressionType, sp trace.Span) (body []byte, err error) { defer func() { - if err != nil && len(body) > maxSize { - err = fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(body), maxSize) + if err != nil && maxDecompressedSize > 0 && len(body) > maxDecompressedSize { + err = fmt.Errorf(messageSizeLargerErrFmt, ErrMessageDecompressedSizeTooLarge, len(body), maxDecompressedSize) } }() - if expectedSize > maxSize { - return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, expectedSize, maxSize) + if expectedSize > maxCompressedSize { + return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, expectedSize, maxCompressedSize) } buffer, ok := tryBufferFromReader(reader) if ok { - body, err = decompressFromBuffer(buffer, maxSize, compression, sp) + body, err = decompressFromBuffer(buffer, maxCompressedSize, maxDecompressedSize, compression, sp) return } - body, err = decompressFromReader(reader, expectedSize, maxSize, compression, sp) + body, err = decompressFromReader(reader, expectedSize, maxCompressedSize, maxDecompressedSize, compression, sp) return } -func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp trace.Span) ([]byte, error) { +func decompressFromReader(reader io.Reader, expectedSize, maxCompressedSize, maxDecompressedSize int, compression CompressionType, sp trace.Span) ([]byte, error) { var ( buf bytes.Buffer body []byte @@ -218,7 +228,7 @@ func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compressi } // Read from LimitReader with limit max+1. So if the underlying // reader is over limit, the result will be bigger than max. - reader = io.LimitReader(reader, int64(maxSize)+1) + reader = io.LimitReader(reader, int64(maxCompressedSize)+1) switch compression { case NoCompression: _, err = buf.ReadFrom(reader) @@ -228,15 +238,16 @@ func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compressi if err != nil { return nil, err } - body, err = decompressFromBuffer(&buf, maxSize, RawSnappy, sp) + body, err = decompressFromBuffer(&buf, maxCompressedSize, maxDecompressedSize, RawSnappy, sp) } return body, err } -func decompressFromBuffer(buffer *bytes.Buffer, maxSize int, compression CompressionType, sp trace.Span) ([]byte, error) { +func decompressFromBuffer(buffer *bytes.Buffer, maxCompressedSize, maxDecompressedSize int, compression CompressionType, sp trace.Span) ([]byte, error) { bufBytes := buffer.Bytes() - if len(bufBytes) > maxSize { - return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(bufBytes), maxSize) + // Check compressed size + if len(bufBytes) > maxCompressedSize { + return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(bufBytes), maxCompressedSize) } switch compression { case NoCompression: @@ -249,8 +260,9 @@ func decompressFromBuffer(buffer *bytes.Buffer, maxSize int, compression Compres if err != nil { return nil, err } - if size > maxSize { - return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, size, maxSize) + // Check decompressed size (only if limit is set) + if maxDecompressedSize > 0 && size > maxDecompressedSize { + return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageDecompressedSizeTooLarge, size, maxDecompressedSize) } body, err := snappy.Decode(nil, bufBytes) if err != nil { diff --git a/pkg/util/http_loki_test.go b/pkg/util/http_loki_test.go index 29b6a5d428..a84ff6362c 100644 --- a/pkg/util/http_loki_test.go +++ b/pkg/util/http_loki_test.go @@ -13,6 +13,6 @@ func BenchmarkDecompressFromBuffer(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - decompressFromBuffer(&buf, 1000, RawSnappy, nil) //nolint:errcheck + decompressFromBuffer(&buf, 1000, 1000, RawSnappy, nil) //nolint:errcheck } }