mirror of
https://github.com/grafana/loki.git
synced 2026-03-13 09:33:58 +08:00
fix: enforce maxRecvMsgSize and maxCompressedMsgSize for OTLP (#20312)
Co-authored-by: shantanualshi <shantanu.alshi@grafana.com>
This commit is contained in:
@@ -115,7 +115,7 @@ func (t *PushTarget) run() error {
|
||||
func (t *PushTarget) handleLoki(w http.ResponseWriter, r *http.Request) {
|
||||
logger := util_log.WithContext(r.Context(), util_log.Logger)
|
||||
userID, _ := tenant.TenantID(r.Context())
|
||||
req, _, err := push.ParseRequest(logger, userID, t.config.MaxSendMsgSize, r, push.EmptyLimits{}, nil, push.ParseLokiRequest, nil, nil, "", "loki")
|
||||
req, _, err := push.ParseRequest(logger, userID, t.config.MaxSendMsgSize, 0, r, push.EmptyLimits{}, nil, push.ParseLokiRequest, nil, nil, "", "loki")
|
||||
if err != nil {
|
||||
level.Warn(t.logger).Log("msg", "failed to parse incoming push request", "err", err.Error())
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
|
||||
@@ -3212,6 +3212,10 @@ ring:
|
||||
# CLI flag: -distributor.max-recv-msg-size
|
||||
[max_recv_msg_size: <int> | default = 104857600]
|
||||
|
||||
# The maximum size of a decompressed message. Defaults to 50x max-recv-msg-size.
|
||||
# CLI flag: -distributor.max-decompressed-size
|
||||
[max_decompressed_size: <int> | default = 5242880000]
|
||||
|
||||
rate_store:
|
||||
# The max number of concurrent requests to make to ingester stream apis
|
||||
# CLI flag: -distributor.rate-store.max-request-parallelism
|
||||
|
||||
@@ -87,7 +87,8 @@ type Config struct {
|
||||
PushWorkerCount int `yaml:"push_worker_count"`
|
||||
|
||||
// Request parser
|
||||
MaxRecvMsgSize int `yaml:"max_recv_msg_size"`
|
||||
MaxRecvMsgSize int `yaml:"max_recv_msg_size"`
|
||||
MaxDecompressedSize int `yaml:"max_decompressed_size"`
|
||||
|
||||
// For testing.
|
||||
factory ring_client.PoolFactory `yaml:"-"`
|
||||
@@ -121,6 +122,7 @@ func (cfg *Config) RegisterFlags(fs *flag.FlagSet) {
|
||||
cfg.RateStore.RegisterFlagsWithPrefix("distributor.rate-store", fs)
|
||||
cfg.WriteFailuresLogging.RegisterFlagsWithPrefix("distributor.write-failures-logging", fs)
|
||||
fs.IntVar(&cfg.MaxRecvMsgSize, "distributor.max-recv-msg-size", 100<<20, "The maximum size of a received message.")
|
||||
fs.IntVar(&cfg.MaxDecompressedSize, "distributor.max-decompressed-size", 5000<<20, "The maximum size of a decompressed message. Defaults to 50x max-recv-msg-size.")
|
||||
fs.IntVar(&cfg.PushWorkerCount, "distributor.push-worker-count", 256, "Number of workers to push batches to ingesters.")
|
||||
fs.BoolVar(&cfg.KafkaEnabled, "distributor.kafka-writes-enabled", false, "Enable writes to Kafka during Push requests.")
|
||||
fs.BoolVar(&cfg.IngesterEnabled, "distributor.ingester-writes-enabled", true, "Enable writes to Ingesters during Push requests. Defaults to true.")
|
||||
@@ -135,6 +137,10 @@ func (cfg *Config) Validate() error {
|
||||
if err := cfg.DataObjTeeConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
// Set default maxDecompressedSize if not configured (50x maxRecvMsgSize)
|
||||
if cfg.MaxDecompressedSize == 0 && cfg.MaxRecvMsgSize > 0 {
|
||||
cfg.MaxDecompressedSize = cfg.MaxRecvMsgSize * 50
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2651,3 +2651,64 @@ func TestDistributor_PushIngestLimits(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
expectedMaxDecompressedSize int
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "sets default maxDecompressedSize when zero and maxRecvMsgSize is set",
|
||||
cfg: Config{
|
||||
MaxRecvMsgSize: 100 << 20, // 100 MB
|
||||
MaxDecompressedSize: 0,
|
||||
KafkaEnabled: false,
|
||||
IngesterEnabled: true,
|
||||
},
|
||||
expectedMaxDecompressedSize: 5000 << 20, // 5000 MB (50x)
|
||||
},
|
||||
{
|
||||
name: "does not override explicit maxDecompressedSize",
|
||||
cfg: Config{
|
||||
MaxRecvMsgSize: 100 << 20, // 100 MB
|
||||
MaxDecompressedSize: 500 << 20, // 500 MB
|
||||
KafkaEnabled: false,
|
||||
IngesterEnabled: true,
|
||||
},
|
||||
expectedMaxDecompressedSize: 500 << 20, // 500 MB (unchanged)
|
||||
},
|
||||
{
|
||||
name: "does not set default when maxRecvMsgSize is zero",
|
||||
cfg: Config{
|
||||
MaxRecvMsgSize: 0,
|
||||
MaxDecompressedSize: 0,
|
||||
KafkaEnabled: false,
|
||||
IngesterEnabled: true,
|
||||
},
|
||||
expectedMaxDecompressedSize: 0, // Should remain 0
|
||||
},
|
||||
{
|
||||
name: "validates kafka and ingester enabled",
|
||||
cfg: Config{
|
||||
KafkaEnabled: false,
|
||||
IngesterEnabled: false,
|
||||
},
|
||||
expectedError: "at least one of kafka and ingestor writes must be enabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.cfg.Validate()
|
||||
if tt.expectedError != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.expectedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedMaxDecompressedSize, tt.cfg.MaxDecompressedSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (d *Distributor) pushHandler(w http.ResponseWriter, r *http.Request, pushRe
|
||||
logPushRequestStreams := d.tenantConfigs.LogPushRequestStreams(tenantID)
|
||||
filterPushRequestStreamsIPs := d.tenantConfigs.FilterPushRequestStreamsIPs(tenantID)
|
||||
presumedAgentIP := extractPresumedAgentIP(r)
|
||||
req, pushStats, err := push.ParseRequest(logger, tenantID, d.cfg.MaxRecvMsgSize, r, d.validator.Limits, d.tenantConfigs,
|
||||
req, pushStats, err := push.ParseRequest(logger, tenantID, d.cfg.MaxRecvMsgSize, d.cfg.MaxDecompressedSize, r, d.validator.Limits, d.tenantConfigs,
|
||||
pushRequestParser, d.usageTracker, streamResolver, presumedAgentIP, format)
|
||||
if err != nil {
|
||||
switch {
|
||||
|
||||
@@ -176,6 +176,7 @@ func (p *fakeParser) parseRequest(
|
||||
_ push.Limits,
|
||||
_ *runtime.TenantConfigs,
|
||||
_ int,
|
||||
_ int,
|
||||
_ push.UsageTracker,
|
||||
_ push.StreamResolver,
|
||||
_ log.Logger,
|
||||
|
||||
@@ -44,9 +44,9 @@ const (
|
||||
messageSizeLargerErrFmt = "%w than max (%d vs %d)"
|
||||
)
|
||||
|
||||
func ParseOTLPRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) {
|
||||
func ParseOTLPRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize, maxDecompressedSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) {
|
||||
stats := NewPushStats()
|
||||
otlpLogs, err := extractLogs(r, maxRecvMsgSize, stats)
|
||||
otlpLogs, err := extractLogs(r, maxRecvMsgSize, maxDecompressedSize, stats)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func ParseOTLPRequest(userID string, r *http.Request, limits Limits, tenantConfi
|
||||
return req, stats, err
|
||||
}
|
||||
|
||||
func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Logs, error) {
|
||||
func extractLogs(r *http.Request, maxRecvMsgSize, maxDecompressedSize int, pushStats *Stats) (plog.Logs, error) {
|
||||
pushStats.ContentEncoding = r.Header.Get(contentEnc)
|
||||
// bodySize should always reflect the compressed size of the request body
|
||||
bodySize := loki_util.NewSizeReader(r.Body)
|
||||
@@ -67,7 +67,7 @@ func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Lo
|
||||
}
|
||||
switch pushStats.ContentEncoding {
|
||||
case gzipContentEncoding:
|
||||
r, err := gzip.NewReader(bodySize)
|
||||
r, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return plog.NewLogs(), err
|
||||
}
|
||||
@@ -75,14 +75,24 @@ func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Lo
|
||||
defer func(reader *gzip.Reader) {
|
||||
_ = reader.Close()
|
||||
}(r)
|
||||
if maxDecompressedSize > 0 {
|
||||
body = io.LimitReader(body, int64(maxDecompressedSize)+1)
|
||||
}
|
||||
|
||||
case zstdContentEncoding:
|
||||
var err error
|
||||
body, err = zstd.NewReader(body)
|
||||
if err != nil {
|
||||
return plog.NewLogs(), err
|
||||
}
|
||||
if maxDecompressedSize > 0 {
|
||||
body = io.LimitReader(body, int64(maxDecompressedSize)+1)
|
||||
}
|
||||
case lz4ContentEncoding:
|
||||
body = io.NopCloser(lz4.NewReader(body))
|
||||
if maxDecompressedSize > 0 {
|
||||
body = io.LimitReader(body, int64(maxDecompressedSize)+1)
|
||||
}
|
||||
case "":
|
||||
// no content encoding, use the body as is
|
||||
default:
|
||||
@@ -90,12 +100,18 @@ func extractLogs(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (plog.Lo
|
||||
}
|
||||
buf, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
if size := bodySize.Size(); size > int64(maxRecvMsgSize) && maxRecvMsgSize > 0 {
|
||||
return plog.NewLogs(), fmt.Errorf(messageSizeLargerErrFmt, loki_util.ErrMessageSizeTooLarge, size, maxRecvMsgSize)
|
||||
}
|
||||
return plog.NewLogs(), err
|
||||
}
|
||||
|
||||
// Check the size of the compressed body
|
||||
if size := bodySize.Size(); size > int64(maxRecvMsgSize) && maxRecvMsgSize > 0 {
|
||||
return plog.NewLogs(), fmt.Errorf(messageSizeLargerErrFmt, loki_util.ErrMessageSizeTooLarge, size, maxRecvMsgSize)
|
||||
}
|
||||
// Check the size of the decompressed body
|
||||
if len(buf) > maxDecompressedSize && maxDecompressedSize > 0 {
|
||||
return plog.NewLogs(), fmt.Errorf(messageSizeLargerErrFmt, loki_util.ErrMessageDecompressedSizeTooLarge, len(buf), maxDecompressedSize)
|
||||
}
|
||||
|
||||
pushStats.BodySize = bodySize.Size()
|
||||
|
||||
req := plogotlp.NewExportRequest()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package push
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@@ -1297,6 +1298,59 @@ func simpleOTLPLogs() plog.Logs {
|
||||
return ld
|
||||
}
|
||||
|
||||
// largeOTLPLogs creates an OTLP log record which is larger than 1MB
|
||||
// and will compress to less than 1MB (~3kb depending on the compression algorithm).
|
||||
func largeOTLPLogs() plog.Logs {
|
||||
ld := plog.NewLogs()
|
||||
rl := ld.ResourceLogs().AppendEmpty()
|
||||
rl.Resource().Attributes().PutStr("service.name", "test-service")
|
||||
sl := rl.ScopeLogs().AppendEmpty()
|
||||
for i := 0; i < 1024; i++ {
|
||||
logRecord := sl.LogRecords().AppendEmpty()
|
||||
logRecord.Body().SetStr(strings.Repeat(" ", 1024))
|
||||
}
|
||||
return ld
|
||||
}
|
||||
|
||||
func createJSON(logs plog.Logs) ([]byte, error) {
|
||||
req := plogotlp.NewExportRequestFromLogs(logs)
|
||||
jsonBytes, err := req.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jsonBytes, nil
|
||||
}
|
||||
|
||||
func createGzipCompressedProtobuf(logs plog.Logs) ([]byte, error) {
|
||||
req := plogotlp.NewExportRequestFromLogs(logs)
|
||||
protoBytes, err := req.MarshalProto()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return compressWithGzip(protoBytes)
|
||||
}
|
||||
|
||||
func createGzipCompressedJSON(logs plog.Logs) ([]byte, error) {
|
||||
req := plogotlp.NewExportRequestFromLogs(logs)
|
||||
jsonBytes, err := req.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return compressWithGzip(jsonBytes)
|
||||
}
|
||||
|
||||
func compressWithGzip(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
writer := gzip.NewWriter(&buf)
|
||||
if _, err := writer.Write(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func createZstdCompressedProtobuf(logs plog.Logs) ([]byte, error) {
|
||||
req := plogotlp.NewExportRequestFromLogs(logs)
|
||||
protoBytes, err := req.MarshalProto()
|
||||
@@ -1378,15 +1432,109 @@ func createOTLPLogWithNestedAttributes() plog.Logs {
|
||||
return ld
|
||||
}
|
||||
|
||||
func TestContentEncoding(t *testing.T) {
|
||||
func TestContentEncodingAndLength(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
contentType string
|
||||
contentEncoding string
|
||||
generateBody func() ([]byte, error)
|
||||
expectedError bool
|
||||
expectedLogs plog.Logs
|
||||
name string
|
||||
contentType string
|
||||
contentEncoding string
|
||||
generateBody func() ([]byte, error)
|
||||
expectedError bool
|
||||
expectedErrorMessage string
|
||||
expectedLogs plog.Logs
|
||||
maxRecvMsgSize int
|
||||
maxDecompressedSize int
|
||||
}{
|
||||
{
|
||||
name: "identity_valid_json",
|
||||
contentType: "application/json",
|
||||
contentEncoding: "",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createJSON(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "identity_large_json",
|
||||
contentType: "application/json",
|
||||
contentEncoding: "",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createJSON(largeOTLPLogs())
|
||||
},
|
||||
expectedError: true,
|
||||
expectedErrorMessage: "message size too large than max",
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 1 << 20, // 1 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_valid_protobuf",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedProtobuf(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_valid_json",
|
||||
contentType: "application/json",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedJSON(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_invalid_data",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return []byte("invalid gzip data"), nil
|
||||
},
|
||||
expectedError: true,
|
||||
expectedLogs: plog.NewLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_nested_attributes",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedProtobuf(createOTLPLogWithNestedAttributes())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: createOTLPLogWithNestedAttributes(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_large_protobuf",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedProtobuf(largeOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: largeOTLPLogs(),
|
||||
maxRecvMsgSize: 1 << 20, // 1 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_too_large_protobuf",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedProtobuf(largeOTLPLogs())
|
||||
},
|
||||
expectedError: true,
|
||||
expectedErrorMessage: "message size too large than max (40961 vs 40960)",
|
||||
expectedLogs: largeOTLPLogs(),
|
||||
maxRecvMsgSize: 1 << 12, // 4 KB
|
||||
maxDecompressedSize: 40960, // Explicitly set to trigger error
|
||||
},
|
||||
{
|
||||
name: "zstd_valid_protobuf",
|
||||
contentType: "application/x-protobuf",
|
||||
@@ -1394,8 +1542,9 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createZstdCompressedProtobuf(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "zstd_valid_json",
|
||||
@@ -1404,8 +1553,9 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createZstdCompressedJSON(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "zstd_invalid_data",
|
||||
@@ -1414,8 +1564,9 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return []byte("invalid zstd data"), nil
|
||||
},
|
||||
expectedError: true,
|
||||
expectedLogs: plog.NewLogs(),
|
||||
expectedError: true,
|
||||
expectedLogs: plog.NewLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "zstd_nested_attributes",
|
||||
@@ -1424,8 +1575,22 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createZstdCompressedProtobuf(createOTLPLogWithNestedAttributes())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: createOTLPLogWithNestedAttributes(),
|
||||
expectedError: false,
|
||||
expectedLogs: createOTLPLogWithNestedAttributes(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "zstd_too_large_protobuf",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "zstd",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createZstdCompressedProtobuf(largeOTLPLogs())
|
||||
},
|
||||
expectedError: true,
|
||||
expectedErrorMessage: "message size too large than max (40961 vs 40960)",
|
||||
expectedLogs: largeOTLPLogs(),
|
||||
maxRecvMsgSize: 1 << 12, // 4 KB
|
||||
maxDecompressedSize: 40960, // Explicitly set to trigger error
|
||||
},
|
||||
{
|
||||
name: "lz4_valid_protobuf",
|
||||
@@ -1434,8 +1599,9 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createLz4CompressedProtobuf(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "lz4_valid_json",
|
||||
@@ -1444,8 +1610,9 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createLz4CompressedJSON(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "lz4_invalid_data",
|
||||
@@ -1454,8 +1621,82 @@ func TestContentEncoding(t *testing.T) {
|
||||
generateBody: func() ([]byte, error) {
|
||||
return []byte("invalid lz4 data"), nil
|
||||
},
|
||||
expectedError: true,
|
||||
expectedLogs: plog.NewLogs(),
|
||||
expectedError: true,
|
||||
expectedLogs: plog.NewLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "lz4_too_large_protobuf",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "lz4",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createLz4CompressedProtobuf(largeOTLPLogs())
|
||||
},
|
||||
expectedError: true,
|
||||
expectedErrorMessage: "message size too large than max (81921 vs 81920)",
|
||||
expectedLogs: largeOTLPLogs(),
|
||||
maxRecvMsgSize: 1 << 13, // 8 KB
|
||||
maxDecompressedSize: 81920, // Explicitly set to trigger error
|
||||
},
|
||||
{
|
||||
name: "unsupported_encoding",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "br",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return []byte("dummy brotly data"), nil
|
||||
},
|
||||
expectedError: true,
|
||||
expectedErrorMessage: "unsupported content encoding br: only gzip, lz4 and zstd are supported",
|
||||
expectedLogs: plog.NewLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
},
|
||||
{
|
||||
name: "gzip_with_zero_maxDecompressedSize",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedProtobuf(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit (should still work for small payloads)
|
||||
},
|
||||
{
|
||||
name: "gzip_large_with_zero_maxDecompressedSize",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "gzip",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createGzipCompressedProtobuf(largeOTLPLogs())
|
||||
},
|
||||
expectedError: false, // No limit when maxDecompressedSize is 0
|
||||
expectedLogs: largeOTLPLogs(),
|
||||
maxRecvMsgSize: 1 << 20, // 1 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit
|
||||
},
|
||||
{
|
||||
name: "zstd_with_zero_maxDecompressedSize",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "zstd",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createZstdCompressedProtobuf(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit
|
||||
},
|
||||
{
|
||||
name: "lz4_with_zero_maxDecompressedSize",
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "lz4",
|
||||
generateBody: func() ([]byte, error) {
|
||||
return createLz4CompressedProtobuf(simpleOTLPLogs())
|
||||
},
|
||||
expectedError: false,
|
||||
expectedLogs: simpleOTLPLogs(),
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1469,10 +1710,31 @@ func TestContentEncoding(t *testing.T) {
|
||||
req.Header.Set("Content-Encoding", tc.contentEncoding)
|
||||
|
||||
stats := NewPushStats()
|
||||
extractedLogs, err := extractLogs(req, 100<<20, stats)
|
||||
maxDecompressedSize := tc.maxDecompressedSize
|
||||
// Only apply default if maxDecompressedSize is 0 and not explicitly testing zero behavior
|
||||
// For test cases with maxDecompressedSize explicitly set to 0, we want to test the actual behavior
|
||||
// For other cases, calculate as 10x maxRecvMsgSize (matching Validate() behavior) or use 100MB if maxRecvMsgSize is 0
|
||||
zeroMaxDecompressedSizeTests := map[string]bool{
|
||||
"gzip_with_zero_maxDecompressedSize": true,
|
||||
"gzip_large_with_zero_maxDecompressedSize": true,
|
||||
"zstd_with_zero_maxDecompressedSize": true,
|
||||
"lz4_with_zero_maxDecompressedSize": true,
|
||||
}
|
||||
if maxDecompressedSize == 0 && !zeroMaxDecompressedSizeTests[tc.name] {
|
||||
if tc.maxRecvMsgSize > 0 {
|
||||
maxDecompressedSize = tc.maxRecvMsgSize * 50 // 50x default
|
||||
} else {
|
||||
maxDecompressedSize = 5000 << 20 // 5000 MB fallback default (50x 100MB)
|
||||
}
|
||||
}
|
||||
extractedLogs, err := extractLogs(req, tc.maxRecvMsgSize, maxDecompressedSize, stats)
|
||||
|
||||
if tc.expectedError {
|
||||
require.Error(t, err)
|
||||
|
||||
if tc.expectedErrorMessage != "" {
|
||||
require.Contains(t, err.Error(), tc.expectedErrorMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ type StreamResolver interface {
|
||||
}
|
||||
|
||||
type (
|
||||
RequestParser func(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error)
|
||||
RequestParser func(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize, maxDecompressedSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error)
|
||||
RequestParserWrapper func(inner RequestParser) RequestParser
|
||||
ErrorWriter func(w http.ResponseWriter, errorStr string, code int, logger log.Logger)
|
||||
)
|
||||
@@ -171,8 +171,8 @@ type Stats struct {
|
||||
HasInternalStreams bool // True if any of the streams has aggregated metrics or is a pattern stream
|
||||
}
|
||||
|
||||
func ParseRequest(logger log.Logger, userID string, maxRecvMsgSize int, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, pushRequestParser RequestParser, tracker UsageTracker, streamResolver StreamResolver, presumedAgentIP, format string) (*logproto.PushRequest, *Stats, error) {
|
||||
req, pushStats, err := pushRequestParser(userID, r, limits, tenantConfigs, maxRecvMsgSize, tracker, streamResolver, logger)
|
||||
func ParseRequest(logger log.Logger, userID string, maxRecvMsgSize, maxDecompressedSize int, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, pushRequestParser RequestParser, tracker UsageTracker, streamResolver StreamResolver, presumedAgentIP, format string) (*logproto.PushRequest, *Stats, error) {
|
||||
req, pushStats, err := pushRequestParser(userID, r, limits, tenantConfigs, maxRecvMsgSize, maxDecompressedSize, tracker, streamResolver, logger)
|
||||
if err != nil && !errors.Is(err, ErrAllLogsFiltered) {
|
||||
if errors.Is(err, util.ErrMessageSizeTooLarge) {
|
||||
return nil, nil, fmt.Errorf("%w: %s", ErrRequestBodyTooLarge, err.Error())
|
||||
@@ -304,31 +304,46 @@ func ParseRequest(logger log.Logger, userID string, maxRecvMsgSize int, r *http.
|
||||
|
||||
// parsePushRequestBody returns logproto.PushRequest from http.Request body, deserialized according to specified content type.
|
||||
// It also modifies pushStats.
|
||||
func parsePushRequestBody(r *http.Request, maxRecvMsgSize int, pushStats *Stats) (*logproto.PushRequest, error) {
|
||||
func parsePushRequestBody(r *http.Request, maxRecvMsgSize, maxDecompressedSize int, pushStats *Stats) (*logproto.PushRequest, error) {
|
||||
// Body
|
||||
var body io.Reader
|
||||
// bodySize should always reflect the compressed size of the request body
|
||||
bodySize := util.NewSizeReader(r.Body)
|
||||
|
||||
// Apply compressed size limit
|
||||
body = bodySize
|
||||
if maxRecvMsgSize > 0 {
|
||||
body = io.LimitReader(body, int64(maxRecvMsgSize)+1)
|
||||
}
|
||||
|
||||
contentEncoding := r.Header.Get(contentEnc)
|
||||
switch contentEncoding {
|
||||
case "":
|
||||
body = bodySize
|
||||
case "snappy":
|
||||
// Snappy-decoding is done by `util.ParseProtoReader(..., util.RawSnappy)` below.
|
||||
// Snappy-decoding is done by `util.ParseProtoReaderWithLimits(..., util.RawSnappy)` below.
|
||||
// Pass on body bytes. Note: HTTP clients do not need to set this header,
|
||||
// but they sometimes do. See #3407.
|
||||
body = bodySize
|
||||
case "gzip":
|
||||
gzipReader, err := gzip.NewReader(bodySize)
|
||||
gzipReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
defer func(gzipReader *gzip.Reader) {
|
||||
_ = gzipReader.Close()
|
||||
}(gzipReader)
|
||||
body = gzipReader
|
||||
if maxDecompressedSize > 0 {
|
||||
body = io.LimitReader(body, int64(maxDecompressedSize)+1)
|
||||
}
|
||||
case "deflate":
|
||||
flateReader := flate.NewReader(bodySize)
|
||||
defer flateReader.Close()
|
||||
flateReader := flate.NewReader(body)
|
||||
defer func(flateReader io.ReadCloser) {
|
||||
_ = flateReader.Close()
|
||||
}(flateReader)
|
||||
body = flateReader
|
||||
if maxDecompressedSize > 0 {
|
||||
body = io.LimitReader(body, int64(maxDecompressedSize)+1)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("Content-Encoding %q not supported", contentEncoding)
|
||||
}
|
||||
@@ -361,7 +376,7 @@ func parsePushRequestBody(r *http.Request, maxRecvMsgSize int, pushStats *Stats)
|
||||
default:
|
||||
// When no content-type header is set or when it is set to
|
||||
// `application/x-protobuf`: expect snappy compression.
|
||||
if err := util.ParseProtoReader(r.Context(), body, int(r.ContentLength), maxRecvMsgSize, &req, util.RawSnappy); err != nil {
|
||||
if err := util.ParseProtoReaderWithLimits(r.Context(), body, int(r.ContentLength), maxRecvMsgSize, maxDecompressedSize, &req, util.RawSnappy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -370,13 +385,16 @@ func parsePushRequestBody(r *http.Request, maxRecvMsgSize int, pushStats *Stats)
|
||||
pushStats.ContentType = contentType
|
||||
pushStats.ContentEncoding = contentEncoding
|
||||
|
||||
if size := bodySize.Size(); size > int64(maxRecvMsgSize) && maxRecvMsgSize > 0 {
|
||||
return nil, fmt.Errorf("compressed message size %d exceeds limit %d", size, maxRecvMsgSize)
|
||||
}
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) {
|
||||
func ParseLokiRequest(userID string, r *http.Request, limits Limits, tenantConfigs *runtime.TenantConfigs, maxRecvMsgSize, maxDecompressedSize int, tracker UsageTracker, streamResolver StreamResolver, logger log.Logger) (*logproto.PushRequest, *Stats, error) {
|
||||
pushStats := NewPushStats()
|
||||
|
||||
req, err := parsePushRequestBody(r, maxRecvMsgSize, pushStats)
|
||||
req, err := parsePushRequestBody(r, maxRecvMsgSize, maxDecompressedSize, pushStats)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -363,6 +363,7 @@ func TestParseRequest(t *testing.T) {
|
||||
util_log.Logger,
|
||||
"fake",
|
||||
100<<20,
|
||||
100<<20,
|
||||
request,
|
||||
test.fakeLimits,
|
||||
nil,
|
||||
@@ -510,7 +511,7 @@ func Test_ServiceDetection(t *testing.T) {
|
||||
|
||||
limits := &fakeLimits{enabled: true, labels: []string{"foo"}}
|
||||
streamResolver := newMockStreamResolver("fake", limits)
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseLokiRequest, tracker, streamResolver, "", "loki")
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseLokiRequest, tracker, streamResolver, "", "loki")
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, labels.FromStrings("foo", "bar", LabelServiceName, "bar").String(), data.Streams[0].Labels)
|
||||
@@ -522,7 +523,7 @@ func Test_ServiceDetection(t *testing.T) {
|
||||
|
||||
limits := &fakeLimits{enabled: true}
|
||||
streamResolver := newMockStreamResolver("fake", limits)
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki")
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, labels.FromStrings("k8s_job_name", "bar", LabelServiceName, "bar").String(), data.Streams[0].Labels)
|
||||
})
|
||||
@@ -537,7 +538,7 @@ func Test_ServiceDetection(t *testing.T) {
|
||||
indexAttributes: []string{"special"},
|
||||
}
|
||||
streamResolver := newMockStreamResolver("fake", limits)
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki")
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, labels.FromStrings("special", "sauce", LabelServiceName, "sauce").String(), data.Streams[0].Labels)
|
||||
})
|
||||
@@ -552,7 +553,7 @@ func Test_ServiceDetection(t *testing.T) {
|
||||
indexAttributes: []string{},
|
||||
}
|
||||
streamResolver := newMockStreamResolver("fake", limits)
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki")
|
||||
data, _, err := ParseRequest(util_log.Logger, "fake", 100<<20, 100<<20, request, limits, nil, ParseOTLPRequest, tracker, streamResolver, "", "loki")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, labels.FromStrings(LabelServiceName, ServiceUnknown).String(), data.Streams[0].Labels)
|
||||
})
|
||||
@@ -654,7 +655,7 @@ func TestNegativeSizeHandling(t *testing.T) {
|
||||
linesIngested.Reset()
|
||||
|
||||
// Create a custom request parser that will generate negative sizes
|
||||
var mockParser RequestParser = func(_ string, _ *http.Request, _ Limits, _ *runtime.TenantConfigs, _ int, _ UsageTracker, _ StreamResolver, _ kitlog.Logger) (*logproto.PushRequest, *Stats, error) {
|
||||
var mockParser RequestParser = func(_ string, _ *http.Request, _ Limits, _ *runtime.TenantConfigs, _ int, _ int, _ UsageTracker, _ StreamResolver, _ kitlog.Logger) (*logproto.PushRequest, *Stats, error) {
|
||||
// Create a minimal valid request
|
||||
req := &logproto.PushRequest{
|
||||
Streams: []logproto.Stream{
|
||||
@@ -697,6 +698,7 @@ func TestNegativeSizeHandling(t *testing.T) {
|
||||
util_log.Logger,
|
||||
"fake",
|
||||
100<<20,
|
||||
100<<20,
|
||||
request,
|
||||
&fakeLimits{},
|
||||
nil,
|
||||
@@ -721,6 +723,108 @@ func TestNegativeSizeHandling(t *testing.T) {
|
||||
require.Equal(t, float64(0), testutil.ToFloat64(structuredMetadataBytesIngested.WithLabelValues(userID, "1", isAggregatedMetric, policy, "loki")))
|
||||
}
|
||||
|
||||
func TestParseRequestWithZeroMaxDecompressedSize(t *testing.T) {
|
||||
streamResolver := newMockStreamResolver("fake", &fakeLimits{})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
body string
|
||||
contentType string
|
||||
contentEncoding string
|
||||
maxRecvMsgSize int
|
||||
maxDecompressedSize int
|
||||
expectedError bool
|
||||
expectedErrorMessage string
|
||||
}{
|
||||
{
|
||||
name: "gzip_with_zero_maxDecompressedSize",
|
||||
body: gzipString(`{"streams": [{ "stream": { "foo": "bar" }, "values": [ [ "1570818238000000000", "test message" ] ] }]}`),
|
||||
contentType: "application/json",
|
||||
contentEncoding: "gzip",
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "deflate_with_zero_maxDecompressedSize",
|
||||
body: deflateString(`{"streams": [{ "stream": { "foo": "bar" }, "values": [ [ "1570818238000000000", "test message" ] ] }]}`),
|
||||
contentType: "application/json",
|
||||
contentEncoding: "deflate",
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "snappy_with_zero_maxDecompressedSize",
|
||||
body: "", // Will be set below
|
||||
contentType: "application/x-protobuf",
|
||||
contentEncoding: "",
|
||||
maxRecvMsgSize: 100 << 20, // 100 MB
|
||||
maxDecompressedSize: 0, // 0 means no limit
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var body []byte
|
||||
var err error
|
||||
|
||||
if tc.name == "snappy_with_zero_maxDecompressedSize" {
|
||||
// Create a snappy-compressed protobuf request
|
||||
req := &logproto.PushRequest{
|
||||
Streams: []logproto.Stream{
|
||||
{
|
||||
Labels: `{foo="bar"}`,
|
||||
Entries: []logproto.Entry{
|
||||
{
|
||||
Timestamp: time.Now(),
|
||||
Line: "test message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
protoBytes, err := proto.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
body = snappy.Encode(nil, protoBytes)
|
||||
} else {
|
||||
body = []byte(tc.body)
|
||||
}
|
||||
|
||||
request := httptest.NewRequest("POST", "/loki/api/v1/push", bytes.NewReader(body))
|
||||
request.Header.Set("Content-Type", tc.contentType)
|
||||
if tc.contentEncoding != "" {
|
||||
request.Header.Set("Content-Encoding", tc.contentEncoding)
|
||||
}
|
||||
|
||||
_, _, err = ParseRequest(
|
||||
util_log.Logger,
|
||||
"fake",
|
||||
tc.maxRecvMsgSize,
|
||||
tc.maxDecompressedSize,
|
||||
request,
|
||||
&fakeLimits{},
|
||||
nil,
|
||||
ParseLokiRequest,
|
||||
NewMockTracker(),
|
||||
streamResolver,
|
||||
"",
|
||||
"loki",
|
||||
)
|
||||
|
||||
if tc.expectedError {
|
||||
require.Error(t, err)
|
||||
if tc.expectedErrorMessage != "" {
|
||||
require.Contains(t, err.Error(), tc.expectedErrorMessage)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeLimits struct {
|
||||
enabled bool
|
||||
labels []string
|
||||
|
||||
@@ -17,14 +17,17 @@ import (
|
||||
"github.com/go-kit/log/level"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/golang/snappy"
|
||||
attribute "go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const messageSizeLargerErrFmt = "%w than max (%d vs %d)"
|
||||
|
||||
var ErrMessageSizeTooLarge = errors.New("message size too large")
|
||||
var (
|
||||
ErrMessageSizeTooLarge = errors.New("message size too large")
|
||||
ErrMessageDecompressedSizeTooLarge = errors.New("decompressed message size too large")
|
||||
)
|
||||
|
||||
const (
|
||||
HTTPRateLimited = "rate_limited"
|
||||
@@ -164,10 +167,17 @@ const (
|
||||
)
|
||||
|
||||
// ParseProtoReader parses a compressed proto from an io.Reader.
|
||||
// Deprecated: Use ParseProtoReaderWithLimits for separate compressed/decompressed limits.
|
||||
func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, req proto.Message, compression CompressionType) error {
|
||||
return ParseProtoReaderWithLimits(ctx, reader, expectedSize, maxSize, maxSize, req, compression)
|
||||
}
|
||||
|
||||
// ParseProtoReaderWithLimits parses a compressed proto from an io.Reader with separate size limits.
|
||||
// maxCompressedSize limits the compressed input size, maxDecompressedSize limits the decompressed output size.
|
||||
func ParseProtoReaderWithLimits(ctx context.Context, reader io.Reader, expectedSize, maxCompressedSize, maxDecompressedSize int, req proto.Message, compression CompressionType) error {
|
||||
sp := trace.SpanFromContext(ctx)
|
||||
sp.AddEvent("util.ParseProtoRequest[start reading]")
|
||||
body, err := decompressRequest(reader, expectedSize, maxSize, compression, sp)
|
||||
body, err := decompressRequest(reader, expectedSize, maxCompressedSize, maxDecompressedSize, compression, sp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -189,25 +199,25 @@ func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSi
|
||||
return nil
|
||||
}
|
||||
|
||||
func decompressRequest(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp trace.Span) (body []byte, err error) {
|
||||
func decompressRequest(reader io.Reader, expectedSize, maxCompressedSize, maxDecompressedSize int, compression CompressionType, sp trace.Span) (body []byte, err error) {
|
||||
defer func() {
|
||||
if err != nil && len(body) > maxSize {
|
||||
err = fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(body), maxSize)
|
||||
if err != nil && maxDecompressedSize > 0 && len(body) > maxDecompressedSize {
|
||||
err = fmt.Errorf(messageSizeLargerErrFmt, ErrMessageDecompressedSizeTooLarge, len(body), maxDecompressedSize)
|
||||
}
|
||||
}()
|
||||
if expectedSize > maxSize {
|
||||
return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, expectedSize, maxSize)
|
||||
if expectedSize > maxCompressedSize {
|
||||
return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, expectedSize, maxCompressedSize)
|
||||
}
|
||||
buffer, ok := tryBufferFromReader(reader)
|
||||
if ok {
|
||||
body, err = decompressFromBuffer(buffer, maxSize, compression, sp)
|
||||
body, err = decompressFromBuffer(buffer, maxCompressedSize, maxDecompressedSize, compression, sp)
|
||||
return
|
||||
}
|
||||
body, err = decompressFromReader(reader, expectedSize, maxSize, compression, sp)
|
||||
body, err = decompressFromReader(reader, expectedSize, maxCompressedSize, maxDecompressedSize, compression, sp)
|
||||
return
|
||||
}
|
||||
|
||||
func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp trace.Span) ([]byte, error) {
|
||||
func decompressFromReader(reader io.Reader, expectedSize, maxCompressedSize, maxDecompressedSize int, compression CompressionType, sp trace.Span) ([]byte, error) {
|
||||
var (
|
||||
buf bytes.Buffer
|
||||
body []byte
|
||||
@@ -218,7 +228,7 @@ func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compressi
|
||||
}
|
||||
// Read from LimitReader with limit max+1. So if the underlying
|
||||
// reader is over limit, the result will be bigger than max.
|
||||
reader = io.LimitReader(reader, int64(maxSize)+1)
|
||||
reader = io.LimitReader(reader, int64(maxCompressedSize)+1)
|
||||
switch compression {
|
||||
case NoCompression:
|
||||
_, err = buf.ReadFrom(reader)
|
||||
@@ -228,15 +238,16 @@ func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compressi
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, err = decompressFromBuffer(&buf, maxSize, RawSnappy, sp)
|
||||
body, err = decompressFromBuffer(&buf, maxCompressedSize, maxDecompressedSize, RawSnappy, sp)
|
||||
}
|
||||
return body, err
|
||||
}
|
||||
|
||||
func decompressFromBuffer(buffer *bytes.Buffer, maxSize int, compression CompressionType, sp trace.Span) ([]byte, error) {
|
||||
func decompressFromBuffer(buffer *bytes.Buffer, maxCompressedSize, maxDecompressedSize int, compression CompressionType, sp trace.Span) ([]byte, error) {
|
||||
bufBytes := buffer.Bytes()
|
||||
if len(bufBytes) > maxSize {
|
||||
return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(bufBytes), maxSize)
|
||||
// Check compressed size
|
||||
if len(bufBytes) > maxCompressedSize {
|
||||
return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(bufBytes), maxCompressedSize)
|
||||
}
|
||||
switch compression {
|
||||
case NoCompression:
|
||||
@@ -249,8 +260,9 @@ func decompressFromBuffer(buffer *bytes.Buffer, maxSize int, compression Compres
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if size > maxSize {
|
||||
return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, size, maxSize)
|
||||
// Check decompressed size (only if limit is set)
|
||||
if maxDecompressedSize > 0 && size > maxDecompressedSize {
|
||||
return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageDecompressedSizeTooLarge, size, maxDecompressedSize)
|
||||
}
|
||||
body, err := snappy.Decode(nil, bufBytes)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,6 +13,6 @@ func BenchmarkDecompressFromBuffer(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
decompressFromBuffer(&buf, 1000, RawSnappy, nil) //nolint:errcheck
|
||||
decompressFromBuffer(&buf, 1000, 1000, RawSnappy, nil) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user