From fcece3e32e54daea7ae940ad7eece0df23b60564 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Mon, 12 Jan 2015 22:14:23 -0800 Subject: [PATCH] p2p/net/swarm: dial once at a time --- p2p/net/swarm/simul_test.go | 43 +++++++++++++++++++++++++++++++++++++ p2p/net/swarm/swarm.go | 15 +++++++++---- p2p/net/swarm/swarm_dial.go | 39 ++++++++++++++++++++++++++++----- 3 files changed, 88 insertions(+), 9 deletions(-) diff --git a/p2p/net/swarm/simul_test.go b/p2p/net/swarm/simul_test.go index c87df91c3..9382cb645 100644 --- a/p2p/net/swarm/simul_test.go +++ b/p2p/net/swarm/simul_test.go @@ -11,6 +11,49 @@ import ( ma "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" ) +func TestSimultDials(t *testing.T) { + // t.Skip("skipping for another test") + + ctx := context.Background() + swarms := makeSwarms(ctx, t, 2) + + // connect everyone + { + var wg sync.WaitGroup + connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) { + // copy for other peer + log.Debugf("TestSimultOpen: connecting: %s --> %s (%s)", s.local, dst, addr) + s.peers.AddAddress(dst, addr) + if _, err := s.Dial(ctx, dst); err != nil { + t.Fatal("error swarm dialing to peer", err) + } + wg.Done() + } + + log.Info("Connecting swarms simultaneously.") + for i := 0; i < 10; i++ { // connect 10x for each. + wg.Add(2) + go connect(swarms[0], swarms[1].local, swarms[1].ListenAddresses()[0]) + go connect(swarms[1], swarms[0].local, swarms[0].ListenAddresses()[0]) + } + wg.Wait() + } + + // should still just have 1, at most 2 connections :) + c01l := len(swarms[0].ConnectionsToPeer(swarms[1].local)) + if c01l > 2 { + t.Error("0->1 has", c01l) + } + c10l := len(swarms[1].ConnectionsToPeer(swarms[0].local)) + if c10l > 2 { + t.Error("1->0 has", c10l) + } + + for _, s := range swarms { + s.Close() + } +} + func TestSimultOpen(t *testing.T) { // t.Skip("skipping for another test") diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 8080fcde9..42511752a 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -4,6 +4,7 @@ package swarm import ( "fmt" + "sync" inet "github.com/jbenet/go-ipfs/p2p/net" addrutil "github.com/jbenet/go-ipfs/p2p/net/swarm/addr" @@ -33,6 +34,11 @@ type Swarm struct { peers peer.Peerstore connh ConnHandler + // dialing is a channel for the current peers being dialed. + // this way, we dont kick off N dials simultaneously. + dialing map[peer.ID]chan struct{} + dialingmu sync.Mutex + cg ctxgroup.ContextGroup } @@ -49,10 +55,11 @@ func NewSwarm(ctx context.Context, listenAddrs []ma.Multiaddr, } s := &Swarm{ - swarm: ps.NewSwarm(PSTransport), - local: local, - peers: peers, - cg: ctxgroup.WithContext(ctx), + swarm: ps.NewSwarm(PSTransport), + local: local, + peers: peers, + cg: ctxgroup.WithContext(ctx), + dialing: map[peer.ID]chan struct{}{}, } // configure Swarm diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 96581009e..dc2e94e01 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -25,12 +25,41 @@ func (s *Swarm) Dial(ctx context.Context, p peer.ID) (*Conn, error) { return nil, errors.New("Attempted connection to self!") } - // check if we already have an open connection first - cs := s.ConnectionsToPeer(p) - for _, c := range cs { - if c != nil { // dump out the first one we find - return c, nil + for { + // check if we already have an open connection first + cs := s.ConnectionsToPeer(p) + for _, c := range cs { + if c != nil { // dump out the first one we find + return c, nil + } } + + // check if there's an ongoing dial to this peer + s.dialingmu.Lock() + dialDone, found := s.dialing[p] + if !found { // if not, set one up. + dialDone = make(chan struct{}) + s.dialing[p] = dialDone + } + s.dialingmu.Unlock() + + if found { + select { + case <-dialDone: // wait for that dial to finish. + continue // and see if it worked (loop). it may not have. + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + // else, we're the ones dialing for others. + defer func() { + s.dialingmu.Lock() + delete(s.dialing, p) + close(dialDone) + s.dialingmu.Unlock() + }() + break } sk := s.peers.PrivKey(s.local)