diff --git a/diagnostics/diag.go b/diagnostics/diag.go index c54b7dd51..dbdf2be86 100644 --- a/diagnostics/diag.go +++ b/diagnostics/diag.go @@ -4,11 +4,9 @@ package diagnostics import ( - "bytes" "encoding/json" "errors" "fmt" - "io" "sync" "time" @@ -33,6 +31,8 @@ var log = util.Logger("diagnostics") // ProtocolDiag is the diagnostics protocol.ID var ProtocolDiag protocol.ID = "/ipfs/diagnostics" +var ErrAlreadyRunning = errors.New("diagnostic with that ID already running") + const ResponseTimeout = time.Second * 10 const HopTimeoutDecrement = time.Second * 2 @@ -159,86 +159,56 @@ func (d *Diagnostics) GetDiagnostic(timeout time.Duration) ([]*DiagInfo, error) return nil, fmt.Errorf("diagnostic from peers err: %s", err) } - var out []*DiagInfo di := d.getDiagInfo() - out = append(out, di) - for _, dpi := range dpeers { - out = appendDiagnostics(out, dpi) + out := []*DiagInfo{di} + for dpi := range dpeers { + out = append(out, dpi) } return out, nil } -func appendDiagnostics(cur []*DiagInfo, data []byte) []*DiagInfo { - buf := bytes.NewBuffer(data) - dec := json.NewDecoder(buf) - for { - di := new(DiagInfo) - err := dec.Decode(di) - if err != nil { - if err != io.EOF { - log.Errorf("error decoding DiagInfo: %v", err) - } - break - } - cur = append(cur, di) - } - return cur -} - -func (d *Diagnostics) getDiagnosticFromPeers(ctx context.Context, peers map[peer.ID]int, pmes *pb.Message) ([][]byte, error) { - timeout := pmes.GetTimeoutDuration() - if timeout < 1 { - return nil, fmt.Errorf("timeout too short: %s", timeout) - } - ctx, _ = context.WithTimeout(ctx, timeout) - - respdata := make(chan []byte) - sendcount := 0 - for p, _ := range peers { - log.Debugf("Sending diagnostic request to peer: %s", p) - sendcount++ - go func(p peer.ID) { - out, err := d.getDiagnosticFromPeer(ctx, p, pmes) - if err != nil { - log.Errorf("getDiagnostic error: %v", err) - respdata <- nil - return - } - respdata <- out - }(p) - } - - outall := make([][]byte, 0, len(peers)) - for i := 0; i < sendcount; i++ { - out := <-respdata - outall = append(outall, out) - } - - return outall, nil -} - -// TODO: this method no longer needed. -func (d *Diagnostics) getDiagnosticFromPeer(ctx context.Context, p peer.ID, mes *pb.Message) ([]byte, error) { - rpmes, err := d.sendRequest(ctx, p, mes) +func decodeDiagJson(data []byte) (*DiagInfo, error) { + di := new(DiagInfo) + err := json.Unmarshal(data, di) if err != nil { return nil, err } - return rpmes.GetData(), nil + + return di, nil } -func newMessage(diagID string) *pb.Message { - pmes := new(pb.Message) - pmes.DiagID = proto.String(diagID) - return pmes +func (d *Diagnostics) getDiagnosticFromPeers(ctx context.Context, peers map[peer.ID]int, pmes *pb.Message) (<-chan *DiagInfo, error) { + respdata := make(chan *DiagInfo) + wg := sync.WaitGroup{} + for p, _ := range peers { + wg.Add(1) + log.Debugf("Sending diagnostic request to peer: %s", p) + go func(p peer.ID) { + defer wg.Done() + out, err := d.getDiagnosticFromPeer(ctx, p, pmes) + if err != nil { + log.Errorf("Error getting diagnostic from %s: %s", p, err) + return + } + for d := range out { + respdata <- d + } + }(p) + } + + go func() { + wg.Wait() + close(respdata) + }() + + return respdata, nil } -func (d *Diagnostics) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { - +func (d *Diagnostics) getDiagnosticFromPeer(ctx context.Context, p peer.ID, pmes *pb.Message) (<-chan *DiagInfo, error) { s, err := d.host.NewStream(ProtocolDiag, p) if err != nil { return nil, err } - defer s.Close() cr := ctxutil.NewReader(ctx, s) // ok to use. we defer close stream in this func cw := ctxutil.NewWriter(ctx, s) // ok to use. we defer close stream in this func @@ -251,51 +221,57 @@ func (d *Diagnostics) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Messa return nil, err } - rpmes := new(pb.Message) - if err := r.ReadMsg(rpmes); err != nil { - return nil, err - } - if rpmes == nil { - return nil, errors.New("no response to request") - } + out := make(chan *DiagInfo) + go func() { - rtt := time.Since(start) - log.Infof("diagnostic request took: %s", rtt.String()) - return rpmes, nil + defer func() { + close(out) + s.Close() + rtt := time.Since(start) + log.Infof("diagnostic request took: %s", rtt.String()) + }() + + for { + rpmes := new(pb.Message) + if err := r.ReadMsg(rpmes); err != nil { + log.Errorf("Error reading diagnostic from stream: %s", err) + return + } + if rpmes == nil { + log.Error("Got no response back from diag request.") + return + } + + di, err := decodeDiagJson(rpmes.GetData()) + if err != nil { + log.Error(err) + return + } + + select { + case out <- di: + case <-ctx.Done(): + return + } + } + + }() + + return out, nil } -func (d *Diagnostics) handleDiagnostic(p peer.ID, pmes *pb.Message) (*pb.Message, error) { - log.Debugf("HandleDiagnostic from %s for id = %s", p, util.Key(pmes.GetDiagID()).B58String()) - resp := newMessage(pmes.GetDiagID()) - - // Make sure we havent already handled this request to prevent loops - d.diagLock.Lock() - _, found := d.diagMap[pmes.GetDiagID()] - if found { - d.diagLock.Unlock() - return resp, nil - } - d.diagMap[pmes.GetDiagID()] = time.Now() - d.diagLock.Unlock() - - di := d.getDiagInfo() - resp.Data = di.Marshal() - dpeers, err := d.getDiagnosticFromPeers(context.TODO(), d.getPeers(), pmes) - if err != nil { - log.Errorf("diagnostic from peers err: %s", err) - } else { - for _, b := range dpeers { - resp.Data = append(resp.Data, b...) // concatenate them all. - } - } - - return resp, nil +func newMessage(diagID string) *pb.Message { + pmes := new(pb.Message) + pmes.DiagID = proto.String(diagID) + return pmes } func (d *Diagnostics) HandleMessage(ctx context.Context, s inet.Stream) error { - r := ggio.NewDelimitedReader(s, 32768) // maxsize - w := ggio.NewDelimitedWriter(s) + cr := ctxutil.NewReader(ctx, s) + cw := ctxutil.NewWriter(ctx, s) + r := ggio.NewDelimitedReader(cr, inet.MessageSizeMax) // maxsize + w := ggio.NewDelimitedWriter(cw) // deserialize msg pmes := new(pb.Message) @@ -308,25 +284,51 @@ func (d *Diagnostics) HandleMessage(ctx context.Context, s inet.Stream) error { log.Infof("[peer: %s] Got message from [%s]\n", d.self.Pretty(), s.Conn().RemotePeer()) - // dispatch handler. - p := s.Conn().RemotePeer() - rpmes, err := d.handleDiagnostic(p, pmes) + // Make sure we havent already handled this request to prevent loops + if err := d.startDiag(pmes.GetDiagID()); err != nil { + return nil + } + + resp := newMessage(pmes.GetDiagID()) + resp.Data = d.getDiagInfo().Marshal() + if err := w.WriteMsg(resp); err != nil { + log.Errorf("Failed to write protobuf message over stream: %s", err) + return err + } + + timeout := pmes.GetTimeoutDuration() + if timeout < HopTimeoutDecrement { + return fmt.Errorf("timeout too short: %s", timeout) + } + ctx, _ = context.WithTimeout(ctx, timeout) + pmes.SetTimeoutDuration(timeout - HopTimeoutDecrement) + + dpeers, err := d.getDiagnosticFromPeers(ctx, d.getPeers(), pmes) if err != nil { - log.Errorf("handleDiagnostic error: %s", err) - return nil + log.Errorf("diagnostic from peers err: %s", err) + return err + } + for b := range dpeers { + resp := newMessage(pmes.GetDiagID()) + resp.Data = b.Marshal() + if err := w.WriteMsg(resp); err != nil { + log.Errorf("Failed to write protobuf message over stream: %s", err) + return err + } } - // if nil response, return it before serializing - if rpmes == nil { - return nil - } + return nil +} - // serialize + send response msg - if err := w.WriteMsg(rpmes); err != nil { - log.Errorf("Failed to encode protobuf message: %v", err) - return nil +func (d *Diagnostics) startDiag(id string) error { + d.diagLock.Lock() + _, found := d.diagMap[id] + if found { + d.diagLock.Unlock() + return ErrAlreadyRunning } - + d.diagMap[id] = time.Now() + d.diagLock.Unlock() return nil }