dns: rate limit DNS resolution requests (#2760)
This commit is contained in:

committed by
Doug Fawley

parent
d5973a9170
commit
5ed5cbab96
@ -66,6 +66,9 @@ var (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
defaultResolver netResolver = net.DefaultResolver
|
defaultResolver netResolver = net.DefaultResolver
|
||||||
|
// To prevent excessive re-resolution, we enforce a rate limit on DNS
|
||||||
|
// resolution requests.
|
||||||
|
minDNSResRate = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
|
var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() {
|
|||||||
return
|
return
|
||||||
case <-d.t.C:
|
case <-d.t.C:
|
||||||
case <-d.rn:
|
case <-d.rn:
|
||||||
|
if !d.t.Stop() {
|
||||||
|
// Before resetting a timer, it should be stopped to prevent racing with
|
||||||
|
// reads on it's channel.
|
||||||
|
<-d.t.C
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result, sc := d.lookup()
|
result, sc := d.lookup()
|
||||||
// Next lookup should happen within an interval defined by d.freq. It may be
|
// Next lookup should happen within an interval defined by d.freq. It may be
|
||||||
// more often due to exponential retry on empty address list.
|
// more often due to exponential retry on empty address list.
|
||||||
@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() {
|
|||||||
}
|
}
|
||||||
d.cc.NewServiceConfig(sc)
|
d.cc.NewServiceConfig(sc)
|
||||||
d.cc.NewAddress(result)
|
d.cc.NewAddress(result)
|
||||||
|
|
||||||
|
// Sleep to prevent excessive re-resolutions. Incoming resolution requests
|
||||||
|
// will be queued in d.rn.
|
||||||
|
t := time.NewTimer(minDNSResRate)
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
t.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
cleanup := replaceNetFunc()
|
// Set a valid duration for the re-resolution rate only for tests which are
|
||||||
|
// actually testing that feature.
|
||||||
|
dc := replaceDNSResRate(time.Duration(0))
|
||||||
|
defer dc()
|
||||||
|
|
||||||
|
cleanup := replaceNetFunc(nil)
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
cleanup()
|
cleanup()
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
@ -85,9 +90,16 @@ func (t *testClientConn) getSc() (string, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type testResolver struct {
|
type testResolver struct {
|
||||||
|
// A write to this channel is made when this resolver receives a resolution
|
||||||
|
// request. Tests can rely on reading from this channel to be notified about
|
||||||
|
// resolution requests instead of sleeping for a predefined period of time.
|
||||||
|
ch chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||||
|
if tr.ch != nil {
|
||||||
|
tr.ch <- struct{}{}
|
||||||
|
}
|
||||||
return hostLookup(host)
|
return hostLookup(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,15 +111,24 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro
|
|||||||
return txtLookup(host)
|
return txtLookup(host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func replaceNetFunc() func() {
|
func replaceNetFunc(ch chan struct{}) func() {
|
||||||
oldResolver := defaultResolver
|
oldResolver := defaultResolver
|
||||||
defaultResolver = &testResolver{}
|
defaultResolver = &testResolver{ch: ch}
|
||||||
|
|
||||||
return func() {
|
return func() {
|
||||||
defaultResolver = oldResolver
|
defaultResolver = oldResolver
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func replaceDNSResRate(d time.Duration) func() {
|
||||||
|
oldMinDNSResRate := minDNSResRate
|
||||||
|
minDNSResRate = d
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
minDNSResRate = oldMinDNSResRate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var hostLookupTbl = struct {
|
var hostLookupTbl = struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
tbl map[string][]string
|
tbl map[string][]string
|
||||||
@ -1126,3 +1147,98 @@ func TestCustomAuthority(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedResolve exercises the rate limit enforced on re-resolution
|
||||||
|
// requests. It sets the re-resolution rate to a small value and repeatedly
|
||||||
|
// calls ResolveNow() and ensures only the expected number of resolution
|
||||||
|
// requests are made.
|
||||||
|
func TestRateLimitedResolve(t *testing.T) {
|
||||||
|
defer leakcheck.Check(t)
|
||||||
|
|
||||||
|
const dnsResRate = 100 * time.Millisecond
|
||||||
|
dc := replaceDNSResRate(dnsResRate)
|
||||||
|
defer dc()
|
||||||
|
|
||||||
|
// Create a new testResolver{} for this test because we want the exact count
|
||||||
|
// of the number of times the resolver was invoked.
|
||||||
|
nc := replaceNetFunc(make(chan struct{}, 1))
|
||||||
|
defer nc()
|
||||||
|
|
||||||
|
target := "foo.bar.com"
|
||||||
|
b := NewBuilder()
|
||||||
|
cc := &testClientConn{target: target}
|
||||||
|
r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOption{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolver.Build() returned error: %v\n", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
dnsR, ok := r.(*dnsResolver)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("resolver.Build() returned unexpected type: %T\n", dnsR)
|
||||||
|
}
|
||||||
|
tr, ok := dnsR.resolver.(*testResolver)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the first resolution request to be done. This happens as part of
|
||||||
|
// the first iteration of the for loop in watcher() because we start with a
|
||||||
|
// timer of zero duration.
|
||||||
|
<-tr.ch
|
||||||
|
|
||||||
|
// Here we start a couple of goroutines. One repeatedly calls ResolveNow()
|
||||||
|
// until asked to stop, and the other waits for two resolution requests to be
|
||||||
|
// made to our testResolver and stops the former. We measure the start and
|
||||||
|
// end times, and expect the duration elapsed to be in the interval
|
||||||
|
// {2*dnsResRate, 3*dnsResRate}
|
||||||
|
start := time.Now()
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
r.ResolveNow(resolver.ResolveNowOption{})
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
gotCalls := 0
|
||||||
|
const wantCalls = 2
|
||||||
|
min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate
|
||||||
|
tMax := time.NewTimer(max)
|
||||||
|
for gotCalls != wantCalls {
|
||||||
|
select {
|
||||||
|
case <-tr.ch:
|
||||||
|
gotCalls++
|
||||||
|
case <-tMax.C:
|
||||||
|
t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if gotCalls != wantCalls {
|
||||||
|
t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls)
|
||||||
|
}
|
||||||
|
if elapsed < min {
|
||||||
|
t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
|
||||||
|
var gotAddrs []resolver.Address
|
||||||
|
for {
|
||||||
|
var cnt int
|
||||||
|
gotAddrs, cnt = cc.getAddress()
|
||||||
|
if cnt > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(gotAddrs, wantAddrs) {
|
||||||
|
t.Errorf("Resolved addresses of target: %q = %+v, want %+v\n", target, gotAddrs, wantAddrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user