diff --git a/exchange/bitswap/bitswap.go b/exchange/bitswap/bitswap.go index cae1baa33..ee80df950 100644 --- a/exchange/bitswap/bitswap.go +++ b/exchange/bitswap/bitswap.go @@ -19,6 +19,7 @@ import ( peer "github.com/jbenet/go-ipfs/peer" u "github.com/jbenet/go-ipfs/util" eventlog "github.com/jbenet/go-ipfs/util/eventlog" + pset "github.com/jbenet/go-ipfs/util/peerset" ) var log = eventlog.Logger("bitswap") @@ -204,57 +205,34 @@ func (bs *bitswap) sendWantListTo(ctx context.Context, peers <-chan peer.Peer) e } func (bs *bitswap) sendWantlistToProviders(ctx context.Context, wantlist *wl.Wantlist) { - provset := make(map[u.Key]peer.Peer) - provcollect := make(chan peer.Peer) - ctx, cancel := context.WithCancel(ctx) defer cancel() - wg := sync.WaitGroup{} - // Get providers for all entries in wantlist (could take a while) - for _, e := range wantlist.Entries() { - wg.Add(1) - go func(k u.Key) { - child, _ := context.WithTimeout(ctx, providerRequestTimeout) - providers := bs.routing.FindProvidersAsync(child, k, maxProvidersPerRequest) - - for prov := range providers { - provcollect <- prov - } - wg.Done() - }(e.Value) - } - - // When all workers finish, close the providers channel - go func() { - wg.Wait() - close(provcollect) - }() - - // Filter out duplicates, - // no need to send our wantlists out twice in a given time period - for { - select { - case p, ok := <-provcollect: - if !ok { - break - } - provset[p.Key()] = p - case <-ctx.Done(): - log.Error("Context cancelled before we got all the providers!") - return - } - } - message := bsmsg.New() message.SetFull(true) for _, e := range bs.wantlist.Entries() { message.AddEntry(e.Value, e.Priority, false) } - for _, prov := range provset { - bs.send(ctx, prov, message) + ps := pset.NewPeerSet() + + // Get providers for all entries in wantlist (could take a while) + wg := sync.WaitGroup{} + for _, e := range wantlist.Entries() { + wg.Add(1) + go func(k u.Key) { + defer wg.Done() + child, _ := context.WithTimeout(ctx, providerRequestTimeout) + providers := bs.routing.FindProvidersAsync(child, k, maxProvidersPerRequest) + + for prov := range providers { + if ps.AddIfSmallerThan(prov, -1) { //Do once per peer + bs.send(ctx, prov, message) + } + } + }(e.Value) } + wg.Wait() } func (bs *bitswap) roundWorker(ctx context.Context) { diff --git a/routing/dht/routing.go b/routing/dht/routing.go index 76260e710..f8036ca38 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -11,6 +11,7 @@ import ( pb "github.com/jbenet/go-ipfs/routing/dht/pb" kb "github.com/jbenet/go-ipfs/routing/kbucket" u "github.com/jbenet/go-ipfs/util" + pset "github.com/jbenet/go-ipfs/util/peerset" ) // asyncQueryBuffer is the size of buffered channels in async queries. This @@ -140,7 +141,7 @@ func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key u.Key, count int func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key u.Key, count int, peerOut chan peer.Peer) { defer close(peerOut) - ps := newPeerSet() + ps := pset.NewPeerSet() provs := dht.providers.GetProviders(ctx, key) for _, p := range provs { // NOTE: assuming that this list of peers is unique @@ -207,7 +208,7 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key u.Key, co } } -func (dht *IpfsDHT) addPeerListAsync(ctx context.Context, k u.Key, peers []*pb.Message_Peer, ps *peerSet, count int, out chan peer.Peer) { +func (dht *IpfsDHT) addPeerListAsync(ctx context.Context, k u.Key, peers []*pb.Message_Peer, ps *pset.PeerSet, count int, out chan peer.Peer) { var wg sync.WaitGroup for _, pbp := range peers { wg.Add(1) diff --git a/routing/dht/util.go b/routing/dht/util.go index 00ac38dbc..2b0c1e2a2 100644 --- a/routing/dht/util.go +++ b/routing/dht/util.go @@ -2,8 +2,6 @@ package dht import ( "sync" - - peer "github.com/jbenet/go-ipfs/peer" ) // Pool size is the number of nodes used for group find/set RPC calls @@ -39,45 +37,3 @@ func (c *counter) Size() (s int) { c.mut.Unlock() return } - -// peerSet is a threadsafe set of peers -type peerSet struct { - ps map[string]bool - lk sync.RWMutex -} - -func newPeerSet() *peerSet { - ps := new(peerSet) - ps.ps = make(map[string]bool) - return ps -} - -func (ps *peerSet) Add(p peer.Peer) { - ps.lk.Lock() - ps.ps[string(p.ID())] = true - ps.lk.Unlock() -} - -func (ps *peerSet) Contains(p peer.Peer) bool { - ps.lk.RLock() - _, ok := ps.ps[string(p.ID())] - ps.lk.RUnlock() - return ok -} - -func (ps *peerSet) Size() int { - ps.lk.RLock() - defer ps.lk.RUnlock() - return len(ps.ps) -} - -func (ps *peerSet) AddIfSmallerThan(p peer.Peer, maxsize int) bool { - var success bool - ps.lk.Lock() - if _, ok := ps.ps[string(p.ID())]; !ok && len(ps.ps) < maxsize { - success = true - ps.ps[string(p.ID())] = true - } - ps.lk.Unlock() - return success -} diff --git a/util/peerset/peerset.go b/util/peerset/peerset.go new file mode 100644 index 000000000..c2a488e43 --- /dev/null +++ b/util/peerset/peerset.go @@ -0,0 +1,48 @@ +package peerset + +import ( + peer "github.com/jbenet/go-ipfs/peer" + "sync" +) + +// PeerSet is a threadsafe set of peers +type PeerSet struct { + ps map[string]bool + lk sync.RWMutex +} + +func NewPeerSet() *PeerSet { + ps := new(PeerSet) + ps.ps = make(map[string]bool) + return ps +} + +func (ps *PeerSet) Add(p peer.Peer) { + ps.lk.Lock() + ps.ps[string(p.ID())] = true + ps.lk.Unlock() +} + +func (ps *PeerSet) Contains(p peer.Peer) bool { + ps.lk.RLock() + _, ok := ps.ps[string(p.ID())] + ps.lk.RUnlock() + return ok +} + +func (ps *PeerSet) Size() int { + ps.lk.RLock() + defer ps.lk.RUnlock() + return len(ps.ps) +} + +func (ps *PeerSet) AddIfSmallerThan(p peer.Peer, maxsize int) bool { + var success bool + ps.lk.Lock() + if _, ok := ps.ps[string(p.ID())]; !ok && len(ps.ps) < maxsize { + success = true + ps.ps[string(p.ID())] = true + } + ps.lk.Unlock() + return success +}