dns: rate limit DNS resolution requests (#2760)
This commit is contained in:

committed by
Doug Fawley

parent
d5973a9170
commit
5ed5cbab96
@ -66,6 +66,9 @@ var (
|
||||
|
||||
var (
|
||||
defaultResolver netResolver = net.DefaultResolver
|
||||
// To prevent excessive re-resolution, we enforce a rate limit on DNS
|
||||
// resolution requests.
|
||||
minDNSResRate = 30 * time.Second
|
||||
)
|
||||
|
||||
var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() {
|
||||
return
|
||||
case <-d.t.C:
|
||||
case <-d.rn:
|
||||
if !d.t.Stop() {
|
||||
// Before resetting a timer, it should be stopped to prevent racing with
|
||||
// reads on it's channel.
|
||||
<-d.t.C
|
||||
}
|
||||
}
|
||||
|
||||
result, sc := d.lookup()
|
||||
// Next lookup should happen within an interval defined by d.freq. It may be
|
||||
// more often due to exponential retry on empty address list.
|
||||
@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() {
|
||||
}
|
||||
d.cc.NewServiceConfig(sc)
|
||||
d.cc.NewAddress(result)
|
||||
|
||||
// Sleep to prevent excessive re-resolutions. Incoming resolution requests
|
||||
// will be queued in d.rn.
|
||||
t := time.NewTimer(minDNSResRate)
|
||||
select {
|
||||
case <-t.C:
|
||||
case <-d.ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,12 @@ import (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
cleanup := replaceNetFunc()
|
||||
// Set a valid duration for the re-resolution rate only for tests which are
|
||||
// actually testing that feature.
|
||||
dc := replaceDNSResRate(time.Duration(0))
|
||||
defer dc()
|
||||
|
||||
cleanup := replaceNetFunc(nil)
|
||||
code := m.Run()
|
||||
cleanup()
|
||||
os.Exit(code)
|
||||
@ -85,9 +90,16 @@ func (t *testClientConn) getSc() (string, int) {
|
||||
}
|
||||
|
||||
type testResolver struct {
|
||||
// A write to this channel is made when this resolver receives a resolution
|
||||
// request. Tests can rely on reading from this channel to be notified about
|
||||
// resolution requests instead of sleeping for a predefined period of time.
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
func (*testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||
func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||
if tr.ch != nil {
|
||||
tr.ch <- struct{}{}
|
||||
}
|
||||
return hostLookup(host)
|
||||
}
|
||||
|
||||
@ -99,15 +111,24 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro
|
||||
return txtLookup(host)
|
||||
}
|
||||
|
||||
func replaceNetFunc() func() {
|
||||
func replaceNetFunc(ch chan struct{}) func() {
|
||||
oldResolver := defaultResolver
|
||||
defaultResolver = &testResolver{}
|
||||
defaultResolver = &testResolver{ch: ch}
|
||||
|
||||
return func() {
|
||||
defaultResolver = oldResolver
|
||||
}
|
||||
}
|
||||
|
||||
func replaceDNSResRate(d time.Duration) func() {
|
||||
oldMinDNSResRate := minDNSResRate
|
||||
minDNSResRate = d
|
||||
|
||||
return func() {
|
||||
minDNSResRate = oldMinDNSResRate
|
||||
}
|
||||
}
|
||||
|
||||
var hostLookupTbl = struct {
|
||||
sync.Mutex
|
||||
tbl map[string][]string
|
||||
@ -1126,3 +1147,98 @@ func TestCustomAuthority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimitedResolve exercises the rate limit enforced on re-resolution
|
||||
// requests. It sets the re-resolution rate to a small value and repeatedly
|
||||
// calls ResolveNow() and ensures only the expected number of resolution
|
||||
// requests are made.
|
||||
func TestRateLimitedResolve(t *testing.T) {
|
||||
defer leakcheck.Check(t)
|
||||
|
||||
const dnsResRate = 100 * time.Millisecond
|
||||
dc := replaceDNSResRate(dnsResRate)
|
||||
defer dc()
|
||||
|
||||
// Create a new testResolver{} for this test because we want the exact count
|
||||
// of the number of times the resolver was invoked.
|
||||
nc := replaceNetFunc(make(chan struct{}, 1))
|
||||
defer nc()
|
||||
|
||||
target := "foo.bar.com"
|
||||
b := NewBuilder()
|
||||
cc := &testClientConn{target: target}
|
||||
r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOption{})
|
||||
if err != nil {
|
||||
t.Fatalf("resolver.Build() returned error: %v\n", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
dnsR, ok := r.(*dnsResolver)
|
||||
if !ok {
|
||||
t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
|
||||
}
|
||||
tr, ok := dnsR.resolver.(*testResolver)
|
||||
if !ok {
|
||||
t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
|
||||
}
|
||||
|
||||
// Wait for the first resolution request to be done. This happens as part of
|
||||
// the first iteration of the for loop in watcher() because we start with a
|
||||
// timer of zero duration.
|
||||
<-tr.ch
|
||||
|
||||
// Here we start a couple of goroutines. One repeatedly calls ResolveNow()
|
||||
// until asked to stop, and the other waits for two resolution requests to be
|
||||
// made to our testResolver and stops the former. We measure the start and
|
||||
// end times, and expect the duration elapsed to be in the interval
|
||||
// {2*dnsResRate, 3*dnsResRate}
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
r.ResolveNow(resolver.ResolveNowOption{})
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
gotCalls := 0
|
||||
const wantCalls = 2
|
||||
min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate
|
||||
tMax := time.NewTimer(max)
|
||||
for gotCalls != wantCalls {
|
||||
select {
|
||||
case <-tr.ch:
|
||||
gotCalls++
|
||||
case <-tMax.C:
|
||||
t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls)
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if gotCalls != wantCalls {
|
||||
t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls)
|
||||
}
|
||||
if elapsed < min {
|
||||
t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max)
|
||||
}
|
||||
|
||||
wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
|
||||
var gotAddrs []resolver.Address
|
||||
for {
|
||||
var cnt int
|
||||
gotAddrs, cnt = cc.getAddress()
|
||||
if cnt > 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
if !reflect.DeepEqual(gotAddrs, wantAddrs) {
|
||||
t.Errorf("Resolved addresses of target: %q = %+v, want %+v\n", target, gotAddrs, wantAddrs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user