diff --git a/routing/dht/dht.go b/routing/dht/dht.go index b00ae0a4d..c28ca0a0f 100644 --- a/routing/dht/dht.go +++ b/routing/dht/dht.go @@ -39,10 +39,6 @@ type IpfsDHT struct { providers map[u.Key][]*providerInfo providerLock sync.RWMutex - // map of channels waiting for reply messages - listeners map[uint64]*listenInfo - listenLock sync.RWMutex - // Signal to shutdown dht shutdown chan struct{} @@ -51,18 +47,9 @@ type IpfsDHT struct { //lock to make diagnostics work better diaglock sync.Mutex -} -// The listen info struct holds information about a message that is being waited for -type listenInfo struct { - // Responses matching the listen ID will be sent through resp - resp chan *swarm.Message - - // count is the number of responses to listen for - count int - - // eol is the time at which this listener will expire - eol time.Time + // listener is a server to register to listen for responses to messages + listener *MesListener } // NewDHT creates a new DHT object with the given peer as the 'local' host @@ -71,7 +58,6 @@ func NewDHT(p *peer.Peer, net swarm.Network) *IpfsDHT { dht.network = net dht.datastore = ds.NewMapDatastore() dht.self = p - dht.listeners = make(map[uint64]*listenInfo) dht.providers = make(map[u.Key][]*providerInfo) dht.shutdown = make(chan struct{}) @@ -80,6 +66,7 @@ func NewDHT(p *peer.Peer, net swarm.Network) *IpfsDHT { dht.routes[1] = kb.NewRoutingTable(20, kb.ConvertPeerID(p.ID), time.Millisecond*100) dht.routes[2] = kb.NewRoutingTable(20, kb.ConvertPeerID(p.ID), time.Hour) + dht.listener = NewMesListener() dht.birth = time.Now() return dht } @@ -135,25 +122,7 @@ func (dht *IpfsDHT) handleMessages() { // Note: not sure if this is the correct place for this if pmes.GetResponse() { - dht.listenLock.RLock() - list, ok := dht.listeners[pmes.GetId()] - dht.listenLock.RUnlock() - if time.Now().After(list.eol) { - dht.Unlisten(pmes.GetId()) - ok = false - } - if list.count > 1 { - list.count-- - } - if ok { - list.resp <- mes - if list.count == 1 { - dht.Unlisten(pmes.GetId()) - } - } else { - u.DOut("Received response with nobody listening...") - } - + dht.listener.Respond(pmes.GetId(), mes) continue } // @@ -187,7 +156,6 @@ func (dht *IpfsDHT) handleMessages() { case <-checkTimeouts.C: // Time to collect some garbage! dht.cleanExpiredProviders() - dht.cleanExpiredListeners() } } } @@ -206,21 +174,6 @@ func (dht *IpfsDHT) cleanExpiredProviders() { dht.providerLock.Unlock() } -func (dht *IpfsDHT) cleanExpiredListeners() { - dht.listenLock.Lock() - var remove []uint64 - now := time.Now() - for k, v := range dht.listeners { - if now.After(v.eol) { - remove = append(remove, k) - } - } - for _, k := range remove { - delete(dht.listeners, k) - } - dht.listenLock.Unlock() -} - func (dht *IpfsDHT) putValueToNetwork(p *peer.Peer, key string, value []byte) error { pmes := DHTMessage{ Type: PBDHTMessage_PUT_VALUE, @@ -393,41 +346,6 @@ func (dht *IpfsDHT) handleAddProvider(p *peer.Peer, pmes *PBDHTMessage) { dht.addProviderEntry(key, p) } -// Register a handler for a specific message ID, used for getting replies -// to certain messages (i.e. response to a GET_VALUE message) -func (dht *IpfsDHT) ListenFor(mesid uint64, count int, timeout time.Duration) <-chan *swarm.Message { - lchan := make(chan *swarm.Message) - dht.listenLock.Lock() - dht.listeners[mesid] = &listenInfo{lchan, count, time.Now().Add(timeout)} - dht.listenLock.Unlock() - return lchan -} - -// Unregister the given message id from the listener map -func (dht *IpfsDHT) Unlisten(mesid uint64) { - dht.listenLock.Lock() - list, ok := dht.listeners[mesid] - if ok { - delete(dht.listeners, mesid) - } - dht.listenLock.Unlock() - close(list.resp) -} - -// Check whether or not the dht is currently listening for mesid -func (dht *IpfsDHT) IsListening(mesid uint64) bool { - dht.listenLock.RLock() - li, ok := dht.listeners[mesid] - dht.listenLock.RUnlock() - if time.Now().After(li.eol) { - dht.listenLock.Lock() - delete(dht.listeners, mesid) - dht.listenLock.Unlock() - return false - } - return ok -} - // Stop all communications from this peer and shut down func (dht *IpfsDHT) Halt() { dht.shutdown <- struct{}{} @@ -444,16 +362,8 @@ func (dht *IpfsDHT) addProviderEntry(key u.Key, p *peer.Peer) { // NOTE: not yet finished, low priority func (dht *IpfsDHT) handleDiagnostic(p *peer.Peer, pmes *PBDHTMessage) { - dht.diaglock.Lock() - if dht.IsListening(pmes.GetId()) { - //TODO: ehhh.......... - dht.diaglock.Unlock() - return - } - dht.diaglock.Unlock() - seq := dht.routes[0].NearestPeers(kb.ConvertPeerID(dht.self.ID), 10) - listenChan := dht.ListenFor(pmes.GetId(), len(seq), time.Second*30) + listenChan := dht.listener.Listen(pmes.GetId(), len(seq), time.Second*30) for _, ps := range seq { mes := swarm.NewMessage(ps, pmes) @@ -499,7 +409,7 @@ out: func (dht *IpfsDHT) getValueOrPeers(p *peer.Peer, key u.Key, timeout time.Duration, level int) ([]byte, []*peer.Peer, error) { pmes, err := dht.getValueSingle(p, key, timeout, level) if err != nil { - return nil, nil, u.WrapError(err, "getValue Error") + return nil, nil, err } if pmes.GetSuccess() { @@ -517,6 +427,9 @@ func (dht *IpfsDHT) getValueOrPeers(p *peer.Peer, key u.Key, timeout time.Durati // We were given a closer node var peers []*peer.Peer for _, pb := range pmes.GetPeers() { + if peer.ID(pb.GetId()).Equal(dht.self.ID) { + continue + } addr, err := ma.NewMultiaddr(pb.GetAddr()) if err != nil { u.PErr(err.Error()) @@ -543,7 +456,7 @@ func (dht *IpfsDHT) getValueSingle(p *peer.Peer, key u.Key, timeout time.Duratio Value: []byte{byte(level)}, Id: GenerateMessageID(), } - response_chan := dht.ListenFor(pmes.Id, 1, time.Minute) + response_chan := dht.listener.Listen(pmes.Id, 1, time.Minute) mes := swarm.NewMessage(p, pmes.ToProtobuf()) t := time.Now() @@ -553,7 +466,7 @@ func (dht *IpfsDHT) getValueSingle(p *peer.Peer, key u.Key, timeout time.Duratio timeup := time.After(timeout) select { case <-timeup: - dht.Unlisten(pmes.Id) + dht.listener.Unlisten(pmes.Id) return nil, u.ErrTimeout case resp, ok := <-response_chan: if !ok { @@ -658,13 +571,13 @@ func (dht *IpfsDHT) findPeerSingle(p *peer.Peer, id peer.ID, timeout time.Durati } mes := swarm.NewMessage(p, pmes.ToProtobuf()) - listenChan := dht.ListenFor(pmes.Id, 1, time.Minute) + listenChan := dht.listener.Listen(pmes.Id, 1, time.Minute) t := time.Now() dht.network.Send(mes) after := time.After(timeout) select { case <-after: - dht.Unlisten(pmes.Id) + dht.listener.Unlisten(pmes.Id) return nil, u.ErrTimeout case resp := <-listenChan: roundtrip := time.Since(t) @@ -695,12 +608,12 @@ func (dht *IpfsDHT) findProvidersSingle(p *peer.Peer, key u.Key, level int, time mes := swarm.NewMessage(p, pmes.ToProtobuf()) - listenChan := dht.ListenFor(pmes.Id, 1, time.Minute) + listenChan := dht.listener.Listen(pmes.Id, 1, time.Minute) dht.network.Send(mes) after := time.After(timeout) select { case <-after: - dht.Unlisten(pmes.Id) + dht.listener.Unlisten(pmes.Id) return nil, u.ErrTimeout case resp := <-listenChan: u.DOut("FindProviders: got response.") diff --git a/routing/dht/routing.go b/routing/dht/routing.go index 2ecd8ba45..e56a54e0c 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -81,6 +81,36 @@ func (c *counter) Size() int { return c.n } +type peerSet struct { + ps map[string]bool + lk sync.RWMutex +} + +func newPeerSet() *peerSet { + ps := new(peerSet) + ps.ps = make(map[string]bool) + return ps +} + +func (ps *peerSet) Add(p *peer.Peer) { + ps.lk.Lock() + ps.ps[string(p.ID)] = true + ps.lk.Unlock() +} + +func (ps *peerSet) Contains(p *peer.Peer) bool { + ps.lk.RLock() + _, ok := ps.ps[string(p.ID)] + ps.lk.RUnlock() + return ok +} + +func (ps *peerSet) Size() int { + ps.lk.RLock() + defer ps.lk.RUnlock() + return len(ps.ps) +} + // GetValue searches for the value corresponding to given Key. // If the search does not succeed, a multiaddr string of a closer peer is // returned along with util.ErrSearchIncomplete @@ -111,8 +141,10 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { proc_peer := make(chan *peer.Peer, 30) err_chan := make(chan error) after := time.After(timeout) + pset := newPeerSet() for _, p := range closest { + pset.Add(p) npeer_chan <- p } @@ -130,6 +162,7 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { break } c.Increment() + proc_peer <- p default: if c.Size() == 0 { @@ -161,7 +194,10 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { for _, np := range peers { // TODO: filter out peers that arent closer - npeer_chan <- np + if !pset.Contains(np) && pset.Size() < limit { + pset.Add(np) //This is racey... make a single function to do operation + npeer_chan <- np + } } c.Decrement() } @@ -175,13 +211,10 @@ func (s *IpfsDHT) GetValue(key u.Key, timeout time.Duration) ([]byte, error) { select { case val := <-val_chan: - close(npeer_chan) return val, nil case err := <-err_chan: - close(npeer_chan) return nil, err case <-after: - close(npeer_chan) return nil, u.ErrTimeout } } @@ -288,12 +321,12 @@ func (s *IpfsDHT) FindPeer(id peer.ID, timeout time.Duration) (*peer.Peer, error addr, err := ma.NewMultiaddr(found.GetAddr()) if err != nil { - return nil, u.WrapError(err, "FindPeer received bad info") + return nil, err } nxtPeer, err := s.network.GetConnection(peer.ID(found.GetId()), addr) if err != nil { - return nil, u.WrapError(err, "FindPeer failed to connect to new peer.") + return nil, err } if pmes.GetSuccess() { if !id.Equal(nxtPeer.ID) { @@ -316,7 +349,7 @@ func (dht *IpfsDHT) Ping(p *peer.Peer, timeout time.Duration) error { mes := swarm.NewMessage(p, pmes.ToProtobuf()) before := time.Now() - response_chan := dht.ListenFor(pmes.Id, 1, time.Minute) + response_chan := dht.listener.Listen(pmes.Id, 1, time.Minute) dht.network.Send(mes) tout := time.After(timeout) @@ -329,7 +362,7 @@ func (dht *IpfsDHT) Ping(p *peer.Peer, timeout time.Duration) error { case <-tout: // Timed out, think about removing peer from network u.DOut("Ping peer timed out.") - dht.Unlisten(pmes.Id) + dht.listener.Unlisten(pmes.Id) return u.ErrTimeout } } @@ -345,7 +378,7 @@ func (dht *IpfsDHT) GetDiagnostic(timeout time.Duration) ([]*diagInfo, error) { Id: GenerateMessageID(), } - listenChan := dht.ListenFor(pmes.Id, len(targets), time.Minute*2) + listenChan := dht.listener.Listen(pmes.Id, len(targets), time.Minute*2) pbmes := pmes.ToProtobuf() for _, p := range targets { diff --git a/util/util.go b/util/util.go index 8ffd43f11..ac9ca8100 100644 --- a/util/util.go +++ b/util/util.go @@ -1,12 +1,10 @@ package util import ( - "bytes" "errors" "fmt" "os" "os/user" - "runtime" "strings" b58 "github.com/jbenet/go-base58" @@ -36,30 +34,6 @@ func (k Key) Pretty() string { return b58.Encode([]byte(k)) } -type IpfsError struct { - Inner error - Note string - Stack string -} - -func (ie *IpfsError) Error() string { - buf := new(bytes.Buffer) - fmt.Fprintln(buf, ie.Inner) - fmt.Fprintln(buf, ie.Note) - fmt.Fprintln(buf, ie.Stack) - return buf.String() -} - -func WrapError(err error, note string) error { - ie := new(IpfsError) - ie.Inner = err - ie.Note = note - stack := make([]byte, 2048) - n := runtime.Stack(stack, false) - ie.Stack = string(stack[:n]) - return ie -} - // Hash is the global IPFS hash function. uses multihash SHA2_256, 256 bits func Hash(data []byte) (mh.Multihash, error) { return mh.Sum(data, mh.SHA2_256, -1)