diff --git a/net/conn/dial.go b/net/conn/dial.go index 5eed05d06..8ffb441d3 100644 --- a/net/conn/dial.go +++ b/net/conn/dial.go @@ -50,32 +50,44 @@ func (d *Dialer) Dial(ctx context.Context, raddr ma.Multiaddr, remote peer.ID) ( return nil, err } + var connOut Conn + var errOut error + done := make(chan struct{}) + + // do it async to ensure we respect don contexteone + go func() { + defer func() { done <- struct{}{} }() + + c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) + if err != nil { + errOut = err + return + } + + if d.PrivateKey == nil { + log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) + connOut = c + return + } + c2, err := newSecureConn(ctx, d.PrivateKey, c) + if err != nil { + errOut = err + c.Close() + return + } + + connOut = c2 + }() + select { case <-ctx.Done(): maconn.Close() return nil, ctx.Err() - default: + case <-done: + // whew, finished. } - c, err := newSingleConn(ctx, d.LocalPeer, remote, maconn) - if err != nil { - return nil, err - } - - if d.PrivateKey == nil { - log.Warning("dialer %s dialing INSECURELY %s at %s!", d, remote, raddr) - return c, nil - } - - select { - case <-ctx.Done(): - c.Close() - return nil, ctx.Err() - default: - } - - // return c, nil - return newSecureConn(ctx, d.PrivateKey, c) + return connOut, errOut } // MultiaddrProtocolsMatch returns whether two multiaddrs match in protocol stacks. diff --git a/net/conn/listen.go b/net/conn/listen.go index 17eb03dbe..e8eddb997 100644 --- a/net/conn/listen.go +++ b/net/conn/listen.go @@ -109,7 +109,7 @@ func Listen(ctx context.Context, addr ma.Multiaddr, local peer.ID, sk ic.PrivKey } l.cg.SetTeardown(l.teardown) - log.Infof("swarm listening on %s\n", l.Multiaddr()) + log.Infof("swarm listening on %s", l.Multiaddr()) log.Event(ctx, "swarmListen", l) return l, nil } diff --git a/net/id.go b/net/id.go index 17dc76610..802d54794 100644 --- a/net/id.go +++ b/net/id.go @@ -38,10 +38,11 @@ func NewIDService(n Network) *IDService { func (ids *IDService) IdentifyConn(c Conn) { ids.currmu.Lock() - if _, found := ids.currid[c]; found { + if wait, found := ids.currid[c]; found { ids.currmu.Unlock() log.Debugf("IdentifyConn called twice on: %s", c) - return // already identifying it. + <-wait // already identifying it. wait for it. + return } ids.currid[c] = make(chan struct{}) ids.currmu.Unlock() @@ -50,10 +51,11 @@ func (ids *IDService) IdentifyConn(c Conn) { if err != nil { log.Error("network: unable to open initial stream for %s", ProtocolIdentify) log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer()) - } + } else { - // ok give the response to our handler. - ids.ResponseHandler(s) + // ok give the response to our handler. + ids.ResponseHandler(s) + } ids.currmu.Lock() ch, found := ids.currid[c] diff --git a/net/interface.go b/net/interface.go index 74354e5cd..7f9f3e617 100644 --- a/net/interface.go +++ b/net/interface.go @@ -82,15 +82,6 @@ type Network interface { // If ProtocolID is "", writes no header. NewStream(ProtocolID, peer.ID) (Stream, error) - // Peers returns the peers connected - Peers() []peer.ID - - // Conns returns the connections in this Netowrk - Conns() []Conn - - // ConnsToPeer returns the connections in this Netowrk for given peer. - ConnsToPeer(p peer.ID) []Conn - // BandwidthTotals returns the total number of bytes passed through // the network since it was instantiated BandwidthTotals() (uint64, uint64) @@ -133,6 +124,15 @@ type Dialer interface { // Connectedness returns a state signaling connection capabilities Connectedness(peer.ID) Connectedness + + // Peers returns the peers connected + Peers() []peer.ID + + // Conns returns the connections in this Netowrk + Conns() []Conn + + // ConnsToPeer returns the connections in this Netowrk for given peer. + ConnsToPeer(p peer.ID) []Conn } // Connectedness signals the capacity for a connection with a given node. diff --git a/net/net.go b/net/net.go index 0eae441c9..39afc6b10 100644 --- a/net/net.go +++ b/net/net.go @@ -148,7 +148,19 @@ func (n *network) DialPeer(ctx context.Context, p peer.ID) error { } // identify the connection before returning. - n.ids.IdentifyConn((*conn_)(sc)) + done := make(chan struct{}) + go func() { + n.ids.IdentifyConn((*conn_)(sc)) + close(done) + }() + + // respect don contexteone + select { + case <-done: + case <-ctx.Done(): + return ctx.Err() + } + log.Debugf("network for %s finished dialing %s", n.local, p) return nil } diff --git a/net/swarm/swarm_test.go b/net/swarm/swarm_test.go index c0a1ab9fa..3d692a064 100644 --- a/net/swarm/swarm_test.go +++ b/net/swarm/swarm_test.go @@ -248,15 +248,21 @@ func TestConnHandler(t *testing.T) { <-time.After(time.Millisecond) // should've gotten 5 by now. - close(gotconn) + + swarms[0].SetConnHandler(nil) expect := 4 - actual := 0 - for _ = range gotconn { - actual++ + for i := 0; i < expect; i++ { + select { + case <-time.After(time.Second): + t.Fatal("failed to get connections") + case <-gotconn: + } } - if actual != expect { - t.Fatal("should have connected to %d swarms. got: %d", actual, expect) + select { + case <-gotconn: + t.Fatalf("should have connected to %d swarms", expect) + default: } } diff --git a/routing/dht/dht.go b/routing/dht/dht.go index a893e62b8..4cbf68e43 100644 --- a/routing/dht/dht.go +++ b/routing/dht/dht.go @@ -28,6 +28,10 @@ var log = eventlog.Logger("dht") const doPinging = false +// NumBootstrapQueries defines the number of random dht queries to do to +// collect members of the routing table. +const NumBootstrapQueries = 5 + // TODO. SEE https://github.com/jbenet/node-ipfs/blob/master/submodules/ipfs-dht/index.js // IpfsDHT is an implementation of Kademlia with Coral and S/Kademlia modifications. @@ -361,25 +365,20 @@ func (dht *IpfsDHT) PingRoutine(t time.Duration) { } // Bootstrap builds up list of peers by requesting random peer IDs -func (dht *IpfsDHT) Bootstrap(ctx context.Context) { - - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - id := make([]byte, 16) - rand.Read(id) - pi, err := dht.FindPeer(ctx, peer.ID(id)) - if err != nil { - // NOTE: this is not an error. this is expected! - log.Errorf("Bootstrap peer error: %s", err) - } +func (dht *IpfsDHT) Bootstrap(ctx context.Context, queries int) { + // bootstrap sequentially, as results will compound + for i := 0; i < NumBootstrapQueries; i++ { + id := make([]byte, 16) + rand.Read(id) + pi, err := dht.FindPeer(ctx, peer.ID(id)) + if err == routing.ErrNotFound { + // this isn't an error. this is precisely what we expect. + } else if err != nil { + log.Errorf("Bootstrap peer error: %s", err) + } else { // woah, we got a peer under a random id? it _cannot_ be valid. log.Errorf("dht seemingly found a peer at a random bootstrap id (%s)...", pi) - }() + } } - wg.Wait() } diff --git a/routing/dht/dht_net.go b/routing/dht/dht_net.go index a91e0f53c..d247cf3af 100644 --- a/routing/dht/dht_net.go +++ b/routing/dht/dht_net.go @@ -7,6 +7,7 @@ import ( inet "github.com/jbenet/go-ipfs/net" peer "github.com/jbenet/go-ipfs/peer" pb "github.com/jbenet/go-ipfs/routing/dht/pb" + ctxutil "github.com/jbenet/go-ipfs/util/ctx" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" ggio "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/gogoprotobuf/io" @@ -21,18 +22,21 @@ func (dht *IpfsDHT) handleNewMessage(s inet.Stream) { defer s.Close() ctx := dht.Context() - r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) - w := ggio.NewDelimitedWriter(s) + cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func + cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func + r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) + w := ggio.NewDelimitedWriter(cw) mPeer := s.Conn().RemotePeer() // receive msg pmes := new(pb.Message) if err := r.ReadMsg(pmes); err != nil { - log.Error("Error unmarshaling data") + log.Errorf("Error unmarshaling data: %s", err) return } + // update the peer (on valid msgs only) - dht.Update(ctx, mPeer) + dht.updateFromMessage(ctx, mPeer, pmes) log.Event(ctx, "foo", dht.self, mPeer, pmes) @@ -76,8 +80,10 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message } defer s.Close() - r := ggio.NewDelimitedReader(s, inet.MessageSizeMax) - w := ggio.NewDelimitedWriter(s) + cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func + cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func + r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) + w := ggio.NewDelimitedWriter(cw) start := time.Now() @@ -98,6 +104,9 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message return nil, errors.New("no response to request") } + // update the peer (on valid msgs only) + dht.updateFromMessage(ctx, p, rpmes) + dht.peerstore.RecordLatency(p, time.Since(start)) log.Event(ctx, "dhtReceivedMessage", dht.self, p, rpmes) return rpmes, nil @@ -113,7 +122,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message } defer s.Close() - w := ggio.NewDelimitedWriter(s) + cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func + w := ggio.NewDelimitedWriter(cw) log.Debugf("%s writing", dht.self) if err := w.WriteMsg(pmes); err != nil { @@ -123,3 +133,8 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message log.Debugf("%s done", dht.self) return nil } + +func (dht *IpfsDHT) updateFromMessage(ctx context.Context, p peer.ID, mes *pb.Message) error { + dht.Update(ctx, p) + return nil +} diff --git a/routing/dht/dht_test.go b/routing/dht/dht_test.go index b378675c6..5603c4d5c 100644 --- a/routing/dht/dht_test.go +++ b/routing/dht/dht_test.go @@ -2,7 +2,9 @@ package dht import ( "bytes" + "fmt" "sort" + "sync" "testing" "time" @@ -15,10 +17,22 @@ import ( // ci "github.com/jbenet/go-ipfs/crypto" inet "github.com/jbenet/go-ipfs/net" peer "github.com/jbenet/go-ipfs/peer" + routing "github.com/jbenet/go-ipfs/routing" u "github.com/jbenet/go-ipfs/util" testutil "github.com/jbenet/go-ipfs/util/testutil" ) +var testCaseValues = map[u.Key][]byte{} + +func init() { + testCaseValues["hello"] = []byte("world") + for i := 0; i < 100; i++ { + k := fmt.Sprintf("%d -- key", i) + v := fmt.Sprintf("%d -- value", i) + testCaseValues[u.Key(k)] = []byte(v) + } +} + func setupDHT(ctx context.Context, t *testing.T, addr ma.Multiaddr) *IpfsDHT { sk, pk, err := testutil.RandKeyPair(512) @@ -78,6 +92,27 @@ func connect(t *testing.T, ctx context.Context, a, b *IpfsDHT) { } } +func bootstrap(t *testing.T, ctx context.Context, dhts []*IpfsDHT) { + + ctx, cancel := context.WithCancel(ctx) + + rounds := 1 + for i := 0; i < rounds; i++ { + log.Debugf("bootstrapping round %d/%d\n", i, rounds) + + // tried async. sequential fares much better. compare: + // 100 async https://gist.github.com/jbenet/56d12f0578d5f34810b2 + // 100 sync https://gist.github.com/jbenet/6c59e7c15426e48aaedd + // probably because results compound + for _, dht := range dhts { + log.Debugf("bootstrapping round %d/%d -- %s\n", i, rounds, dht.self) + dht.Bootstrap(ctx, 3) + } + } + + cancel() +} + func TestPing(t *testing.T) { // t.Skip("skipping test to debug another") ctx := context.Background() @@ -174,37 +209,208 @@ func TestProvides(t *testing.T) { connect(t, ctx, dhts[1], dhts[2]) connect(t, ctx, dhts[1], dhts[3]) - err := dhts[3].putLocal(u.Key("hello"), []byte("world")) - if err != nil { - t.Fatal(err) + for k, v := range testCaseValues { + log.Debugf("adding local values for %s = %s", k, v) + err := dhts[3].putLocal(k, v) + if err != nil { + t.Fatal(err) + } + + bits, err := dhts[3].getLocal(k) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(bits, v) { + t.Fatal("didn't store the right bits (%s, %s)", k, v) + } } - bits, err := dhts[3].getLocal(u.Key("hello")) - if err != nil && bytes.Equal(bits, []byte("world")) { - t.Fatal(err) - } - - err = dhts[3].Provide(ctx, u.Key("hello")) - if err != nil { - t.Fatal(err) + for k, _ := range testCaseValues { + log.Debugf("announcing provider for %s", k) + if err := dhts[3].Provide(ctx, k); err != nil { + t.Fatal(err) + } } // what is this timeout for? was 60ms before. time.Sleep(time.Millisecond * 6) - ctxT, _ := context.WithTimeout(ctx, time.Second) - provchan := dhts[0].FindProvidersAsync(ctxT, u.Key("hello"), 1) + n := 0 + for k, _ := range testCaseValues { + n = (n + 1) % 3 - select { - case prov := <-provchan: - if prov.ID == "" { - t.Fatal("Got back nil provider") + log.Debugf("getting providers for %s from %d", k, n) + ctxT, _ := context.WithTimeout(ctx, time.Second) + provchan := dhts[n].FindProvidersAsync(ctxT, k, 1) + + select { + case prov := <-provchan: + if prov.ID == "" { + t.Fatal("Got back nil provider") + } + if prov.ID != dhts[3].self { + t.Fatal("Got back wrong provider") + } + case <-ctxT.Done(): + t.Fatal("Did not get a provider back.") } - if prov.ID != dhts[3].self { - t.Fatal("Got back nil provider") + } +} + +func TestBootstrap(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + + ctx := context.Background() + + nDHTs := 15 + _, _, dhts := setupDHTS(ctx, nDHTs, t) + defer func() { + for i := 0; i < nDHTs; i++ { + dhts[i].Close() + defer dhts[i].network.Close() } - case <-ctxT.Done(): - t.Fatal("Did not get a provider back.") + }() + + t.Logf("connecting %d dhts in a ring", nDHTs) + for i := 0; i < nDHTs; i++ { + connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)]) + } + + <-time.After(100 * time.Millisecond) + t.Logf("bootstrapping them so they find each other", nDHTs) + ctxT, _ := context.WithTimeout(ctx, 5*time.Second) + bootstrap(t, ctxT, dhts) + + if u.Debug { + // the routing tables should be full now. let's inspect them. + <-time.After(5 * time.Second) + t.Logf("checking routing table of %d", nDHTs) + for _, dht := range dhts { + fmt.Printf("checking routing table of %s\n", dht.self) + dht.routingTable.Print() + fmt.Println("") + } + } + + // test "well-formed-ness" (>= 3 peers in every routing table) + for _, dht := range dhts { + rtlen := dht.routingTable.Size() + if rtlen < 4 { + t.Errorf("routing table for %s only has %d peers", dht.self, rtlen) + } + } +} + +func TestProvidesMany(t *testing.T) { + t.Skip("this test doesn't work") + // t.Skip("skipping test to debug another") + ctx := context.Background() + + nDHTs := 40 + _, _, dhts := setupDHTS(ctx, nDHTs, t) + defer func() { + for i := 0; i < nDHTs; i++ { + dhts[i].Close() + defer dhts[i].network.Close() + } + }() + + t.Logf("connecting %d dhts in a ring", nDHTs) + for i := 0; i < nDHTs; i++ { + connect(t, ctx, dhts[i], dhts[(i+1)%len(dhts)]) + } + + <-time.After(100 * time.Millisecond) + t.Logf("bootstrapping them so they find each other", nDHTs) + ctxT, _ := context.WithTimeout(ctx, 5*time.Second) + bootstrap(t, ctxT, dhts) + + if u.Debug { + // the routing tables should be full now. let's inspect them. + <-time.After(5 * time.Second) + t.Logf("checking routing table of %d", nDHTs) + for _, dht := range dhts { + fmt.Printf("checking routing table of %s\n", dht.self) + dht.routingTable.Print() + fmt.Println("") + } + } + + var providers = map[u.Key]peer.ID{} + + d := 0 + for k, v := range testCaseValues { + d = (d + 1) % len(dhts) + dht := dhts[d] + providers[k] = dht.self + + t.Logf("adding local values for %s = %s (on %s)", k, v, dht.self) + err := dht.putLocal(k, v) + if err != nil { + t.Fatal(err) + } + + bits, err := dht.getLocal(k) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(bits, v) { + t.Fatal("didn't store the right bits (%s, %s)", k, v) + } + + t.Logf("announcing provider for %s", k) + if err := dht.Provide(ctx, k); err != nil { + t.Fatal(err) + } + } + + // what is this timeout for? was 60ms before. + time.Sleep(time.Millisecond * 6) + + errchan := make(chan error) + + ctxT, _ = context.WithTimeout(ctx, 5*time.Second) + + var wg sync.WaitGroup + getProvider := func(dht *IpfsDHT, k u.Key) { + defer wg.Done() + + expected := providers[k] + + provchan := dht.FindProvidersAsync(ctxT, k, 1) + select { + case prov := <-provchan: + actual := prov.ID + if actual == "" { + errchan <- fmt.Errorf("Got back nil provider (%s at %s)", k, dht.self) + } else if actual != expected { + errchan <- fmt.Errorf("Got back wrong provider (%s != %s) (%s at %s)", + expected, actual, k, dht.self) + } + case <-ctxT.Done(): + errchan <- fmt.Errorf("Did not get a provider back (%s at %s)", k, dht.self) + } + } + + for k, _ := range testCaseValues { + // everyone should be able to find it... + for _, dht := range dhts { + log.Debugf("getting providers for %s at %s", k, dht.self) + wg.Add(1) + go getProvider(dht, k) + } + } + + // we need this because of printing errors + go func() { + wg.Wait() + close(errchan) + }() + + for err := range errchan { + t.Error(err) } } @@ -291,18 +497,20 @@ func TestLayeredGet(t *testing.T) { t.Fatal(err) } - time.Sleep(time.Millisecond * 60) + time.Sleep(time.Millisecond * 6) + t.Log("interface was changed. GetValue should not use providers.") ctxT, _ := context.WithTimeout(ctx, time.Second) val, err := dhts[0].GetValue(ctxT, u.Key("/v/hello")) - if err != nil { - t.Fatal(err) + if err != routing.ErrNotFound { + t.Error(err) } - - if string(val) != "world" { - t.Fatal("Got incorrect value.") + if string(val) == "world" { + t.Error("should not get value.") + } + if len(val) > 0 && string(val) != "world" { + t.Error("worse, there's a value and its not even the right one.") } - } func TestFindPeer(t *testing.T) { diff --git a/routing/dht/ext_test.go b/routing/dht/ext_test.go index 04f5111a9..b4b1158d7 100644 --- a/routing/dht/ext_test.go +++ b/routing/dht/ext_test.go @@ -73,7 +73,7 @@ func TestGetFailures(t *testing.T) { }) // This one should fail with NotFound - ctx2, _ := context.WithTimeout(context.Background(), time.Second) + ctx2, _ := context.WithTimeout(context.Background(), 3*time.Second) _, err = d.GetValue(ctx2, u.Key("test")) if err != nil { if err != routing.ErrNotFound { diff --git a/routing/dht/handlers.go b/routing/dht/handlers.go index 070f320a9..5aec6c2ff 100644 --- a/routing/dht/handlers.go +++ b/routing/dht/handlers.go @@ -148,7 +148,7 @@ func (dht *IpfsDHT) handleFindPeer(ctx context.Context, p peer.ID, pmes *pb.Mess } if closest == nil { - log.Errorf("handleFindPeer: could not find anything.") + log.Debugf("handleFindPeer: could not find anything.") return resp, nil } diff --git a/routing/dht/query.go b/routing/dht/query.go index c45fa239f..6a7bb687d 100644 --- a/routing/dht/query.go +++ b/routing/dht/query.go @@ -12,6 +12,7 @@ import ( todoctr "github.com/jbenet/go-ipfs/util/todocounter" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" + ctxgroup "github.com/jbenet/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-ctxgroup" ) var maxQueryConcurrency = AlphaValue @@ -78,9 +79,8 @@ type dhtQueryRunner struct { // peersRemaining is a counter of peers remaining (toQuery + processing) peersRemaining todoctr.Counter - // context - ctx context.Context - cancel context.CancelFunc + // context group + cg ctxgroup.ContextGroup // result result *dhtQueryResult @@ -93,16 +93,13 @@ type dhtQueryRunner struct { } func newQueryRunner(ctx context.Context, q *dhtQuery) *dhtQueryRunner { - ctx, cancel := context.WithCancel(ctx) - return &dhtQueryRunner{ - ctx: ctx, - cancel: cancel, query: q, peersToQuery: queue.NewChanQueue(ctx, queue.NewXORDistancePQ(q.key)), peersRemaining: todoctr.NewSyncCounter(), peersSeen: peer.Set{}, rateLimit: make(chan struct{}, q.concurrency), + cg: ctxgroup.WithContext(ctx), } } @@ -120,11 +117,13 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { // add all the peers we got first. for _, p := range peers { - r.addPeerToQuery(p, "") // don't have access to self here... + r.addPeerToQuery(r.cg.Context(), p, "") // don't have access to self here... } // go do this thing. - go r.spawnWorkers() + // do it as a child func to make sure Run exits + // ONLY AFTER spawn workers has exited. + r.cg.AddChildFunc(r.spawnWorkers) // so workers are working. @@ -133,7 +132,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { select { case <-r.peersRemaining.Done(): - r.cancel() // ran all and nothing. cancel all outstanding workers. + r.cg.Close() r.RLock() defer r.RUnlock() @@ -141,10 +140,10 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { err = r.errs[0] } - case <-r.ctx.Done(): + case <-r.cg.Closed(): r.RLock() defer r.RUnlock() - err = r.ctx.Err() + err = r.cg.Context().Err() // collect the error. } if r.result != nil && r.result.success { @@ -154,7 +153,7 @@ func (r *dhtQueryRunner) Run(peers []peer.ID) (*dhtQueryResult, error) { return nil, err } -func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) { +func (r *dhtQueryRunner) addPeerToQuery(ctx context.Context, next peer.ID, benchmark peer.ID) { // if new peer is ourselves... if next == r.query.dialer.LocalPeer() { return @@ -180,43 +179,48 @@ func (r *dhtQueryRunner) addPeerToQuery(next peer.ID, benchmark peer.ID) { r.peersSeen[next] = struct{}{} r.Unlock() - log.Debugf("adding peer to query: %v\n", next) + log.Debugf("adding peer to query: %v", next) // do this after unlocking to prevent possible deadlocks. r.peersRemaining.Increment(1) select { case r.peersToQuery.EnqChan <- next: - case <-r.ctx.Done(): + case <-ctx.Done(): } } -func (r *dhtQueryRunner) spawnWorkers() { +func (r *dhtQueryRunner) spawnWorkers(parent ctxgroup.ContextGroup) { for { select { case <-r.peersRemaining.Done(): return - case <-r.ctx.Done(): + case <-r.cg.Closing(): return case p, more := <-r.peersToQuery.DeqChan: if !more { return // channel closed. } - log.Debugf("spawning worker for: %v\n", p) - go r.queryPeer(p) + log.Debugf("spawning worker for: %v", p) + + // do it as a child func to make sure Run exits + // ONLY AFTER spawn workers has exited. + parent.AddChildFunc(func(cg ctxgroup.ContextGroup) { + r.queryPeer(cg, p) + }) } } } -func (r *dhtQueryRunner) queryPeer(p peer.ID) { +func (r *dhtQueryRunner) queryPeer(cg ctxgroup.ContextGroup, p peer.ID) { log.Debugf("spawned worker for: %v", p) // make sure we rate limit concurrency. select { case <-r.rateLimit: - case <-r.ctx.Done(): + case <-cg.Closing(): r.peersRemaining.Decrement(1) return } @@ -233,17 +237,22 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { }() // make sure we're connected to the peer. - err := r.query.dialer.DialPeer(r.ctx, p) - if err != nil { - log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err) - r.Lock() - r.errs = append(r.errs, err) - r.Unlock() - return + if conns := r.query.dialer.ConnsToPeer(p); len(conns) == 0 { + log.Infof("worker for: %v -- not connected. dial start", p) + + if err := r.query.dialer.DialPeer(cg.Context(), p); err != nil { + log.Debugf("ERROR worker for: %v -- err connecting: %v", p, err) + r.Lock() + r.errs = append(r.errs, err) + r.Unlock() + return + } + + log.Infof("worker for: %v -- not connected. dial success!", p) } // finally, run the query against this peer - res, err := r.query.qfunc(r.ctx, p) + res, err := r.query.qfunc(cg.Context(), p) if err != nil { log.Debugf("ERROR worker for: %v %v", p, err) @@ -256,14 +265,20 @@ func (r *dhtQueryRunner) queryPeer(p peer.ID) { r.Lock() r.result = res r.Unlock() - r.cancel() // signal to everyone that we're done. + go r.cg.Close() // signal to everyone that we're done. + // must be async, as we're one of the children, and Close blocks. } else if len(res.closerPeers) > 0 { log.Debugf("PEERS CLOSER -- worker for: %v (%d closer peers)", p, len(res.closerPeers)) for _, next := range res.closerPeers { // add their addresses to the dialer's peerstore + conns := r.query.dialer.ConnsToPeer(next.ID) + if len(conns) == 0 { + log.Infof("PEERS CLOSER -- worker for %v FOUND NEW PEER: %s %s", p, next.ID, next.Addrs) + } + r.query.dialer.Peerstore().AddAddresses(next.ID, next.Addrs) - r.addPeerToQuery(next.ID, p) + r.addPeerToQuery(cg.Context(), next.ID, p) log.Debugf("PEERS CLOSER -- worker for: %v added %v (%v)", p, next.ID, next.Addrs) } } else { diff --git a/routing/kbucket/table.go b/routing/kbucket/table.go index da4c6e720..bed7447a5 100644 --- a/routing/kbucket/table.go +++ b/routing/kbucket/table.go @@ -223,8 +223,16 @@ func (rt *RoutingTable) ListPeers() []peer.ID { func (rt *RoutingTable) Print() { fmt.Printf("Routing Table, bs = %d, Max latency = %d\n", rt.bucketsize, rt.maxLatency) rt.tabLock.RLock() - peers := rt.ListPeers() - for i, p := range peers { - fmt.Printf("%d) %s %s\n", i, p.Pretty(), rt.metrics.LatencyEWMA(p).String()) + + for i, b := range rt.Buckets { + fmt.Printf("\tbucket: %d\n", i) + + b.lk.RLock() + for e := b.list.Front(); e != nil; e = e.Next() { + p := e.Value.(peer.ID) + fmt.Printf("\t\t- %s %s\n", p.Pretty(), rt.metrics.LatencyEWMA(p).String()) + } + b.lk.RUnlock() } + rt.tabLock.RUnlock() } diff --git a/util/ctx/ctxio.go b/util/ctx/ctxio.go new file mode 100644 index 000000000..0b41086df --- /dev/null +++ b/util/ctx/ctxio.go @@ -0,0 +1,110 @@ +package ctxutil + +import ( + "io" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +type ioret struct { + n int + err error +} + +type Writer interface { + io.Writer +} + +type ctxWriter struct { + w io.Writer + ctx context.Context +} + +// NewWriter wraps a writer to make it respect given Context. +// If there is a blocking write, the returned Writer will return +// whenever the context is cancelled (the return values are n=0 +// and err=ctx.Err().) +// +// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying +// write-- there is no way to do that with the standard go io +// interface. So the read and write _will_ happen or hang. So, use +// this sparingly, make sure to cancel the read or write as necesary +// (e.g. closing a connection whose context is up, etc.) +// +// Furthermore, in order to protect your memory from being read +// _after_ you've cancelled the context, this io.Writer will +// first make a **copy** of the buffer. +func NewWriter(ctx context.Context, w io.Writer) *ctxWriter { + if ctx == nil { + ctx = context.Background() + } + return &ctxWriter{ctx: ctx, w: w} +} + +func (w *ctxWriter) Write(buf []byte) (int, error) { + buf2 := make([]byte, len(buf)) + copy(buf2, buf) + + c := make(chan ioret, 1) + + go func() { + n, err := w.w.Write(buf2) + c <- ioret{n, err} + close(c) + }() + + select { + case r := <-c: + return r.n, r.err + case <-w.ctx.Done(): + return 0, w.ctx.Err() + } +} + +type Reader interface { + io.Reader +} + +type ctxReader struct { + r io.Reader + ctx context.Context +} + +// NewReader wraps a reader to make it respect given Context. +// If there is a blocking read, the returned Reader will return +// whenever the context is cancelled (the return values are n=0 +// and err=ctx.Err().) +// +// Note well: this wrapper DOES NOT ACTUALLY cancel the underlying +// write-- there is no way to do that with the standard go io +// interface. So the read and write _will_ happen or hang. So, use +// this sparingly, make sure to cancel the read or write as necesary +// (e.g. closing a connection whose context is up, etc.) +// +// Furthermore, in order to protect your memory from being read +// _before_ you've cancelled the context, this io.Reader will +// allocate a buffer of the same size, and **copy** into the client's +// if the read succeeds in time. +func NewReader(ctx context.Context, r io.Reader) *ctxReader { + return &ctxReader{ctx: ctx, r: r} +} + +func (r *ctxReader) Read(buf []byte) (int, error) { + buf2 := make([]byte, len(buf)) + + c := make(chan ioret, 1) + + go func() { + n, err := r.r.Read(buf2) + c <- ioret{n, err} + close(c) + }() + + select { + case ret := <-c: + copy(buf, buf2) + return ret.n, ret.err + case <-r.ctx.Done(): + return 0, r.ctx.Err() + } +} diff --git a/util/ctx/ctxio_test.go b/util/ctx/ctxio_test.go new file mode 100644 index 000000000..4104fb4a0 --- /dev/null +++ b/util/ctx/ctxio_test.go @@ -0,0 +1,273 @@ +package ctxutil + +import ( + "bytes" + "io" + "testing" + "time" + + context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func TestReader(t *testing.T) { + buf := []byte("abcdef") + buf2 := make([]byte, 3) + r := NewReader(context.Background(), bytes.NewReader(buf)) + + // read first half + n, err := r.Read(buf2) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf2) != string(buf[:3]) { + t.Error("incorrect contents") + } + + // read second half + n, err = r.Read(buf2) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf2) != string(buf[3:6]) { + t.Error("incorrect contents") + } + + // read more. + n, err = r.Read(buf2) + if n != 0 { + t.Error("n should be 0", n) + } + if err != io.EOF { + t.Error("should be EOF", err) + } +} + +func TestWriter(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(context.Background(), &buf) + + // write three + n, err := w.Write([]byte("abc")) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf.Bytes()) != string("abc") { + t.Error("incorrect contents") + } + + // write three more + n, err = w.Write([]byte("def")) + if n != 3 { + t.Error("n should be 3") + } + if err != nil { + t.Error("should have no error") + } + if string(buf.Bytes()) != string("abcdef") { + t.Error("incorrect contents") + } +} + +func TestReaderCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + r := NewReader(ctx, piper) + + buf := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := r.Read(buf) + done <- ioret{n, err} + }() + + pipew.Write([]byte("abcdefghij")) + + select { + case ret := <-done: + if ret.n != 10 { + t.Error("ret.n should be 10", ret.n) + } + if ret.err != nil { + t.Error("ret.err should be nil", ret.err) + } + if string(buf) != "abcdefghij" { + t.Error("read contents differ") + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to read") + } + + go func() { + n, err := r.Read(buf) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop reading after cancel") + } +} + +func TestWriterCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + w := NewWriter(ctx, pipew) + + buf := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := w.Write([]byte("abcdefghij")) + done <- ioret{n, err} + }() + + piper.Read(buf) + + select { + case ret := <-done: + if ret.n != 10 { + t.Error("ret.n should be 10", ret.n) + } + if ret.err != nil { + t.Error("ret.err should be nil", ret.err) + } + if string(buf) != "abcdefghij" { + t.Error("write contents differ") + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to write") + } + + go func() { + n, err := w.Write([]byte("abcdefghij")) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop writing after cancel") + } +} + +func TestReadPostCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + r := NewReader(ctx, piper) + + buf := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := r.Read(buf) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop reading after cancel") + } + + pipew.Write([]byte("abcdefghij")) + + if !bytes.Equal(buf, make([]byte, len(buf))) { + t.Fatal("buffer should have not been written to") + } +} + +func TestWritePostCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + piper, pipew := io.Pipe() + w := NewWriter(ctx, pipew) + + buf := []byte("abcdefghij") + buf2 := make([]byte, 10) + done := make(chan ioret) + + go func() { + n, err := w.Write(buf) + done <- ioret{n, err} + }() + + piper.Read(buf2) + + select { + case ret := <-done: + if ret.n != 10 { + t.Error("ret.n should be 10", ret.n) + } + if ret.err != nil { + t.Error("ret.err should be nil", ret.err) + } + if string(buf2) != "abcdefghij" { + t.Error("write contents differ") + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to write") + } + + go func() { + n, err := w.Write(buf) + done <- ioret{n, err} + }() + + cancel() + + select { + case ret := <-done: + if ret.n != 0 { + t.Error("ret.n should be 0", ret.n) + } + if ret.err == nil { + t.Error("ret.err should be ctx error", ret.err) + } + case <-time.After(20 * time.Millisecond): + t.Fatal("failed to stop writing after cancel") + } + + copy(buf, []byte("aaaaaaaaaa")) + + piper.Read(buf2) + + if string(buf2) == "aaaaaaaaaa" { + t.Error("buffer was read from after ctx cancel") + } else if string(buf2) != "abcdefghij" { + t.Error("write contents differ from expected") + } +} diff --git a/util/testutil/gen.go b/util/testutil/gen.go index 16f39ef45..a8b832084 100644 --- a/util/testutil/gen.go +++ b/util/testutil/gen.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "sync" "testing" ci "github.com/jbenet/go-ipfs/crypto" @@ -49,17 +50,24 @@ func RandLocalTCPAddress() ma.Multiaddr { // most ports above 10000 aren't in use by long running processes, so yay. // (maybe there should be a range of "loopback" ports that are guaranteed // to be open for the process, but naturally can only talk to self.) - if lastPort == 0 { - lastPort = 10000 + SeededRand.Intn(50000) - } - lastPort++ - addr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", lastPort) + lastPort.Lock() + if lastPort.port == 0 { + lastPort.port = 10000 + SeededRand.Intn(50000) + } + port := lastPort.port + lastPort.port++ + lastPort.Unlock() + + addr := fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", port) maddr, _ := ma.NewMultiaddr(addr) return maddr } -var lastPort = 0 +var lastPort = struct { + port int + sync.Mutex +}{} // PeerNetParams is a struct to bundle together the four things // you need to run a connection with a peer: id, 2keys, and addr.