From ccf6f79aa09add2d924e431d37a13188050ab3b1 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Wed, 24 Dec 2014 03:24:28 -0800 Subject: [PATCH] respect don contexteone --- net/conn/dial.go | 52 +++++++++++++++++++++++++---------------- net/net.go | 14 ++++++++++- routing/dht/dht_net.go | 2 +- routing/dht/dht_test.go | 3 ++- routing/dht/ext_test.go | 2 +- 5 files changed, 49 insertions(+), 24 deletions(-) diff --git a/net/conn/dial.go b/net/conn/dial.go index 5eed05d06..8ffb441d3 100644 --- a/net/conn/dial.go +++ b/net/conn/dial.go @@ -50,32 +50,44 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( return nil, err } + var connOut Conn + var errOut error + done := make(chan struct{}) + + // do it async to ensure we respect don contexteone + go func() { + defer func() { done <- struct{}{} }() + + c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) + if err != nil { + errOut = err + return + } + + if d.PrivateKey == nil { + log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) + connOut = c + return + } + c2, err := newSecureConn(ctx, d.PrivateKey, c) + if err != nil { + errOut = err + c.Close() + return + } + + connOut = c2 + }() + select { case <-ctx.Done(): maconn.Close() return nil, ctx.Err() - default: + case <-done: + // whew, finished. } - c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) - if err != nil { - return nil, err - } - - if d.PrivateKey == nil { - log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) - return c, nil - } - - select { - case <-ctx.Done(): - c.Close() - return nil, ctx.Err() - default: - } - - // return c, nil - return newSecureConn(ctx, d.PrivateKey, c) + return connOut, errOut } // MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks. diff --git a/net/net.go b/net/net.go index 0eae441c9..39afc6b10 100644 --- a/net/net.go +++ b/net/net.go @@ -148,7 +148,19 @@ func (n *network) DialPeer(ctx context.Context, p peer.ID) error { } // identify the connection before returning. - n.ids.IdentifyConn((*conn_)(sc)) + done := make(chan struct{}) + go func() { + n.ids.IdentifyConn((*conn_)(sc)) + close(done) + }() + + // respect don contexteone + select { + case <-done: + case <-ctx.Done(): + return ctx.Err() + } + log.Debugf("network for %s finished dialing %s", n.local, p) return nil } diff --git a/routing/dht/dht_net.go b/routing/dht/dht_net.go index caf0518c2..d247cf3af 100644 --- a/routing/dht/dht_net.go +++ b/routing/dht/dht_net.go @@ -31,7 +31,7 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { // receive msg pmes := new(pb.Message) if err := r.ReadMsg(pmes); err != nil { - log.Error("Error unmarshaling data") + log.Errorf("Error unmarshaling data: %s", err) return } diff --git a/routing/dht/dht_test.go b/routing/dht/dht_test.go index 02950e084..f2ff099df 100644 --- a/routing/dht/dht_test.go +++ b/routing/dht/dht_test.go @@ -265,7 +265,8 @@ func TestBootstrap(t *testing.T) { } t.Logf("bootstrapping them so they find each other", nDHTs) - bootstrap(t, ctx, dhts) + ctxT, _ := context.WithTimeout(ctx, 5*time.Second) + bootstrap(t, ctxT, dhts) // the routing tables should be full now. let's inspect them. t.Logf("checking routing table of %d", nDHTs) diff --git a/routing/dht/ext_test.go b/routing/dht/ext_test.go index 04f5111a9..b4b1158d7 100644 --- a/routing/dht/ext_test.go +++ b/routing/dht/ext_test.go @@ -73,7 +73,7 @@ func TestGetFailures(t *testing.T) { }) // This one should fail with NotFound - ctx2, _ := context.WithTimeout(context.Background(), time.Second) + ctx2, _ := context.WithTimeout(context.Background(), 3*time.Second) _, err = d.GetValue(ctx2, u.Key("test")) if err != nil { if err != routing.ErrNotFound {