Files
grafana/pkg/util/scheduler/scheduler_test.go
2025-07-03 12:02:05 +02:00

246 lines
6.3 KiB
Go

package scheduler
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/grafana/dskit/services"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/stretchr/testify/require"
)
func TestScheduler(t *testing.T) {
t.Parallel()
t.Run("ConfigValidation", func(t *testing.T) {
t.Parallel()
t.Run("ValidConfig", func(t *testing.T) {
cfg := Config{
NumWorkers: 5,
MaxBackoff: 1 * time.Second,
Logger: log.New("qos.test"),
}
err := cfg.validate()
require.NoError(t, err)
require.NotNil(t, cfg.Logger)
})
t.Run("ZeroWorkersGetDefault", func(t *testing.T) {
cfg := Config{
NumWorkers: 0,
MaxBackoff: 1 * time.Second,
}
err := cfg.validate()
require.NoError(t, err)
require.Equal(t, cfg.NumWorkers, DefaultNumWorkers)
require.NotNil(t, cfg.Logger, "Logger should not be nil")
})
t.Run("NilLoggerGetsDefault", func(t *testing.T) {
cfg := Config{
NumWorkers: 1,
MaxBackoff: 1 * time.Second,
Logger: nil,
}
err := cfg.validate()
require.NoError(t, err)
require.NotNil(t, cfg.Logger)
})
t.Run("ZeroTimeoutGetsDefault", func(t *testing.T) {
cfg := Config{
NumWorkers: 1,
MaxBackoff: 0,
Logger: log.New("qos.test"),
}
err := cfg.validate()
require.NoError(t, err)
require.Equal(t, cfg.MaxBackoff, DefaultMaxBackoff)
})
t.Run("ZeroRetriesGetsDefault", func(t *testing.T) {
cfg := Config{
NumWorkers: 1,
MaxBackoff: 1 * time.Second,
MaxRetries: 0,
Logger: log.New("qos.test"),
}
err := cfg.validate()
require.NoError(t, err)
require.Equal(t, cfg.MaxRetries, DefaultMaxRetries)
})
})
t.Run("NewScheduler", func(t *testing.T) {
t.Parallel()
t.Run("ValidParameters", func(t *testing.T) {
q := NewQueue(QueueOptionsWithDefaults(nil))
require.NoError(t, services.StartAndAwaitRunning(context.Background(), q))
cfg := Config{
NumWorkers: 2,
MaxBackoff: 1 * time.Second,
Logger: log.New("qos.test"),
}
scheduler, err := NewScheduler(q, &cfg)
require.NoError(t, err)
require.NotNil(t, scheduler)
require.NoError(t, services.StartAndAwaitRunning(context.Background(), scheduler))
require.Equal(t, q, scheduler.queue)
require.Equal(t, cfg.NumWorkers, scheduler.numWorkers)
require.True(t, scheduler.State() == services.Running)
require.NoError(t, services.StopAndAwaitTerminated(context.Background(), scheduler))
})
t.Run("NilQueue", func(t *testing.T) {
cfg := Config{
NumWorkers: 2,
MaxBackoff: 1 * time.Second,
}
scheduler, err := NewScheduler(nil, &cfg)
require.Error(t, err)
require.Nil(t, scheduler)
})
})
t.Run("Lifecycle", func(t *testing.T) {
t.Parallel()
q := NewQueue(QueueOptionsWithDefaults(nil))
require.NoError(t, services.StartAndAwaitRunning(context.Background(), q))
scheduler, err := NewScheduler(q, &Config{
NumWorkers: 3,
MaxBackoff: 100 * time.Millisecond,
Logger: log.New("qos.test"),
})
require.NoError(t, err)
require.True(t, scheduler.State() == services.New)
require.NoError(t, services.StartAndAwaitRunning(context.Background(), scheduler))
require.True(t, scheduler.State() == services.Running)
require.NoError(t, services.StopAndAwaitTerminated(context.Background(), scheduler))
require.True(t, scheduler.State() == services.Terminated)
})
t.Run("ProcessItems", func(t *testing.T) {
t.Parallel()
q := NewQueue(QueueOptionsWithDefaults(&QueueOptions{MaxSizePerTenant: 1000}))
require.NoError(t, services.StartAndAwaitRunning(context.Background(), q))
const itemCount = 1000
var processed sync.Map
var wg sync.WaitGroup
wg.Add(itemCount)
scheduler, err := NewScheduler(q, &Config{
NumWorkers: 10,
MaxBackoff: 100 * time.Millisecond,
Logger: log.New("qos.test"),
})
require.NoError(t, err)
require.NoError(t, services.StartAndAwaitRunning(context.Background(), scheduler))
for i := 0; i < itemCount; i++ {
itemID := i
tenantIndex := itemID % 10
tenantID := fmt.Sprintf("tenant-%d", tenantIndex)
require.NoError(t, q.Enqueue(context.Background(), tenantID, func() {
processed.Store(itemID, true)
time.Sleep(10 * time.Millisecond)
wg.Done()
}))
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("Timed out waiting for all items to be processed")
}
count := 0
processed.Range(func(_, _ any) bool {
count++
return true
})
require.Equal(t, itemCount, count, "Not all items were processed")
require.NoError(t, services.StopAndAwaitTerminated(context.Background(), scheduler))
})
t.Run("GracefulShutdown", func(t *testing.T) {
t.Parallel()
q := NewQueue(QueueOptionsWithDefaults(nil))
require.NoError(t, services.StartAndAwaitRunning(context.Background(), q))
var processed atomic.Int32
taskStarted := make(chan struct{})
taskFinished := make(chan struct{})
scheduler, err := NewScheduler(q, &Config{
NumWorkers: 1,
MaxBackoff: 100 * time.Millisecond,
Logger: log.New("qos.test"),
})
require.NoError(t, err)
require.NoError(t, services.StartAndAwaitRunning(context.Background(), scheduler))
for i := 0; i < 5; i++ {
require.NoError(t, q.Enqueue(context.Background(), "tenant-1", func() {
processed.Add(1)
}))
}
require.NoError(t, q.Enqueue(context.Background(), "tenant-1", func() {
close(taskStarted)
time.Sleep(1 * time.Second)
processed.Add(1)
close(taskFinished)
}))
select {
case <-taskStarted:
case <-time.After(2 * time.Second):
t.Fatal("Timed out waiting for long-running task to start")
}
scheduler.StopAsync()
select {
case <-taskFinished:
case <-time.After(2 * time.Second):
t.Fatal("Timed out waiting for long-running task to finish")
}
require.Equal(t, int32(6), processed.Load(), "Not all items were processed")
require.NoError(t, scheduler.AwaitTerminated(context.Background()))
})
t.Run("WithQueueClosed", func(t *testing.T) {
t.Parallel()
q := NewQueue(QueueOptionsWithDefaults(nil))
scheduler, err := NewScheduler(q, &Config{
NumWorkers: 2,
MaxBackoff: 100 * time.Millisecond,
Logger: log.New("qos.test"),
})
require.NoError(t, err)
require.ErrorContains(t, services.StartAndAwaitRunning(context.Background(), scheduler), "queue is not running")
})
}