diff --git a/routing/dht/routing.go b/routing/dht/routing.go index 8b0bb1670..98c4ce3d4 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -159,52 +159,48 @@ func (dht *IpfsDHT) getClosestPeers(ctx context.Context, key u.Key, count int) ( peerset := pset.NewLimited(count) for _, p := range tablepeers { - out <- p + select { + case out <- p: + case <-ctx.Done(): + return nil, ctx.Err() + } peerset.Add(p) } - wg := sync.WaitGroup{} - for _, p := range tablepeers { - wg.Add(1) - go func(p peer.ID) { - dht.getClosestPeersRecurse(ctx, key, p, peerset, out) - wg.Done() - }(p) - } + query := newQuery(key, dht.network, func(ctx context.Context, p peer.ID) (*dhtQueryResult, error) { + closer, err := dht.closerPeersSingle(ctx, key, p) + if err != nil { + log.Errorf("error getting closer peers: %s", err) + return nil, err + } + + var filtered []peer.PeerInfo + for _, p := range closer { + if kb.Closer(p, dht.self, key) && peerset.TryAdd(p) { + select { + case out <- p: + case <-ctx.Done(): + return nil, ctx.Err() + } + filtered = append(filtered, dht.peerstore.PeerInfo(p)) + } + } + + return &dhtQueryResult{closerPeers: filtered}, nil + }) go func() { - wg.Wait() - close(out) + defer close(out) + // run it! + _, err := query.Run(ctx, tablepeers) + if err != nil { + log.Errorf("closestPeers query run error: %s", err) + } }() return out, nil } -func (dht *IpfsDHT) getClosestPeersRecurse(ctx context.Context, key u.Key, p peer.ID, peers *pset.PeerSet, peerOut chan<- peer.ID) { - closer, err := dht.closerPeersSingle(ctx, key, p) - if err != nil { - log.Errorf("error getting closer peers: %s", err) - return - } - - wg := sync.WaitGroup{} - for _, p := range closer { - if kb.Closer(p, dht.self, key) && peers.TryAdd(p) { - select { - case peerOut <- p: - case <-ctx.Done(): - return - } - wg.Add(1) - go func(p peer.ID) { - dht.getClosestPeersRecurse(ctx, key, p, peers, peerOut) - wg.Done() - }(p) - } - } - wg.Wait() -} - func (dht *IpfsDHT) closerPeersSingle(ctx context.Context, key u.Key, p peer.ID) ([]peer.ID, error) { pmes, err := dht.findPeerSingle(ctx, p, peer.ID(key)) if err != nil {