dns: rate limit DNS resolution requests (#2760)

This commit is contained in:
Easwar Swaminathan
2019-05-02 10:23:31 -07:00
committed by Doug Fawley
parent d5973a9170
commit 5ed5cbab96
2 changed files with 139 additions and 4 deletions

View File

@ -66,6 +66,9 @@ var (
var (
defaultResolver netResolver = net.DefaultResolver
// To prevent excessive re-resolution, we enforce a rate limit on DNS
// resolution requests.
minDNSResRate = 30 * time.Second
)
var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() {
return
case <-d.t.C:
case <-d.rn:
if !d.t.Stop() {
// Before resetting a timer, it should be stopped to prevent racing with
// reads on it's channel.
<-d.t.C
}
}
result, sc := d.lookup()
// Next lookup should happen within an interval defined by d.freq. It may be
// more often due to exponential retry on empty address list.
@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() {
}
d.cc.NewServiceConfig(sc)
d.cc.NewAddress(result)
// Sleep to prevent excessive re-resolutions. Incoming resolution requests
// will be queued in d.rn.
t := time.NewTimer(minDNSResRate)
select {
case <-t.C:
case <-d.ctx.Done():
t.Stop()
return
}
}
}

View File

@ -34,7 +34,12 @@ import (
)
func TestMain(m *testing.M) {
cleanup := replaceNetFunc()
// Set a valid duration for the re-resolution rate only for tests which are
// actually testing that feature.
dc := replaceDNSResRate(time.Duration(0))
defer dc()
cleanup := replaceNetFunc(nil)
code := m.Run()
cleanup()
os.Exit(code)
@ -85,9 +90,16 @@ func (t *testClientConn) getSc() (string, int) {
}
type testResolver struct {
// A write to this channel is made when this resolver receives a resolution
// request. Tests can rely on reading from this channel to be notified about
// resolution requests instead of sleeping for a predefined period of time.
ch chan struct{}
}
func (*testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
if tr.ch != nil {
tr.ch <- struct{}{}
}
return hostLookup(host)
}
@ -99,15 +111,24 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro
return txtLookup(host)
}
func replaceNetFunc() func() {
func replaceNetFunc(ch chan struct{}) func() {
oldResolver := defaultResolver
defaultResolver = &testResolver{}
defaultResolver = &testResolver{ch: ch}
return func() {
defaultResolver = oldResolver
}
}
func replaceDNSResRate(d time.Duration) func() {
oldMinDNSResRate := minDNSResRate
minDNSResRate = d
return func() {
minDNSResRate = oldMinDNSResRate
}
}
var hostLookupTbl = struct {
sync.Mutex
tbl map[string][]string
@ -1126,3 +1147,98 @@ func TestCustomAuthority(t *testing.T) {
}
}
}
// TestRateLimitedResolve exercises the rate limit enforced on re-resolution
// requests. It sets the re-resolution rate to a small value and repeatedly
// calls ResolveNow() and ensures only the expected number of resolution
// requests are made.
func TestRateLimitedResolve(t *testing.T) {
defer leakcheck.Check(t)
const dnsResRate = 100 * time.Millisecond
dc := replaceDNSResRate(dnsResRate)
defer dc()
// Create a new testResolver{} for this test because we want the exact count
// of the number of times the resolver was invoked.
nc := replaceNetFunc(make(chan struct{}, 1))
defer nc()
target := "foo.bar.com"
b := NewBuilder()
cc := &testClientConn{target: target}
r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOption{})
if err != nil {
t.Fatalf("resolver.Build() returned error: %v\n", err)
}
defer r.Close()
dnsR, ok := r.(*dnsResolver)
if !ok {
t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
}
tr, ok := dnsR.resolver.(*testResolver)
if !ok {
t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
}
// Wait for the first resolution request to be done. This happens as part of
// the first iteration of the for loop in watcher() because we start with a
// timer of zero duration.
<-tr.ch
// Here we start a couple of goroutines. One repeatedly calls ResolveNow()
// until asked to stop, and the other waits for two resolution requests to be
// made to our testResolver and stops the former. We measure the start and
// end times, and expect the duration elapsed to be in the interval
// {2*dnsResRate, 3*dnsResRate}
start := time.Now()
done := make(chan struct{})
go func() {
for {
select {
case <-done:
return
default:
r.ResolveNow(resolver.ResolveNowOption{})
time.Sleep(1 * time.Millisecond)
}
}
}()
gotCalls := 0
const wantCalls = 2
min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate
tMax := time.NewTimer(max)
for gotCalls != wantCalls {
select {
case <-tr.ch:
gotCalls++
case <-tMax.C:
t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls)
}
}
close(done)
elapsed := time.Since(start)
if gotCalls != wantCalls {
t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls)
}
if elapsed < min {
t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max)
}
wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
var gotAddrs []resolver.Address
for {
var cnt int
gotAddrs, cnt = cc.getAddress()
if cnt > 0 {
break
}
time.Sleep(time.Millisecond)
}
if !reflect.DeepEqual(gotAddrs, wantAddrs) {
t.Errorf("Resolved addresses of target: %q = %+v, want %+v\n", target, gotAddrs, wantAddrs)
}
}