diff --git a/merkledag/merkledag.go b/merkledag/merkledag.go index 552fc068d..aebc370ad 100644 --- a/merkledag/merkledag.go +++ b/merkledag/merkledag.go @@ -3,6 +3,7 @@ package merkledag import ( "fmt" + "sync" blocks "github.com/ipfs/go-ipfs/blocks" key "github.com/ipfs/go-ipfs/blocks/key" @@ -24,7 +25,7 @@ type DAGService interface { // GetDAG returns, in order, all the single leve child // nodes of the passed in node. - GetMany(context.Context, []key.Key) (<-chan *Node, <-chan error) + GetMany(context.Context, []key.Key) <-chan *NodeOption Batch() *Batch } @@ -145,9 +146,13 @@ func FindLinks(links []key.Key, k key.Key, start int) []int { return out } -func (ds *dagService) GetMany(ctx context.Context, keys []key.Key) (<-chan *Node, <-chan error) { - out := make(chan *Node, len(keys)) - errs := make(chan error, 1) +type NodeOption struct { + Node *Node + Err error +} + +func (ds *dagService) GetMany(ctx context.Context, keys []key.Key) <-chan *NodeOption { + out := make(chan *NodeOption, len(keys)) blocks := ds.Blocks.GetBlocks(ctx, keys) var count int @@ -158,27 +163,27 @@ func (ds *dagService) GetMany(ctx context.Context, keys []key.Key) (<-chan *Node case b, ok := <-blocks: if !ok { if count != len(keys) { - errs <- fmt.Errorf("failed to fetch all nodes") + out <- &NodeOption{Err: fmt.Errorf("failed to fetch all nodes")} } return } nd, err := Decoded(b.Data) if err != nil { - errs <- err + out <- &NodeOption{Err: err} return } // buffered, no need to select - out <- nd + out <- &NodeOption{Node: nd} count++ case <-ctx.Done(): - errs <- ctx.Err() + out <- &NodeOption{Err: ctx.Err()} return } } }() - return out, errs + return out } // GetDAG will fill out all of the links of the given Node. @@ -213,15 +218,22 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { ctx, cancel := context.WithCancel(ctx) defer cancel() - nodechan, errchan := ds.GetMany(ctx, dedupedKeys) + nodechan := ds.GetMany(ctx, dedupedKeys) for count := 0; count < len(keys); { select { - case nd, ok := <-nodechan: + case opt, ok := <-nodechan: if !ok { return } + if opt.Err != nil { + log.Error("error fetching: ", opt.Err) + return + } + + nd := opt.Node + k, err := nd.Key() if err != nil { log.Error("Failed to get node key: ", err) @@ -233,9 +245,6 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { count++ sendChans[i] <- nd } - case err := <-errchan: - log.Error("error fetching: ", err) - return case <-ctx.Done(): return } @@ -356,24 +365,30 @@ func EnumerateChildren(ctx context.Context, ds DAGService, root *Node, set key.K func EnumerateChildrenAsync(ctx context.Context, ds DAGService, root *Node, set key.KeySet) error { toprocess := make(chan []key.Key, 8) - nodes := make(chan *Node, 8) - errs := make(chan error, 1) + nodes := make(chan *NodeOption, 8) ctx, cancel := context.WithCancel(ctx) defer cancel() defer close(toprocess) - go fetchNodes(ctx, ds, toprocess, nodes, errs) + go fetchNodes(ctx, ds, toprocess, nodes) - nodes <- root + nodes <- &NodeOption{Node: root} live := 1 for { select { - case nd, ok := <-nodes: + case opt, ok := <-nodes: if !ok { return nil } + + if opt.Err != nil { + return opt.Err + } + + nd := opt.Node + // a node has been fetched live-- @@ -398,38 +413,35 @@ func EnumerateChildrenAsync(ctx context.Context, ds DAGService, root *Node, set return ctx.Err() } } - case err := <-errs: - return err case <-ctx.Done(): return ctx.Err() } } } -func fetchNodes(ctx context.Context, ds DAGService, in <-chan []key.Key, out chan<- *Node, errs chan<- error) { - defer close(out) +func fetchNodes(ctx context.Context, ds DAGService, in <-chan []key.Key, out chan<- *NodeOption) { + var wg sync.WaitGroup + defer func() { + // wait for all 'get' calls to complete so we don't accidentally send + // on a closed channel + wg.Wait() + close(out) + }() get := func(ks []key.Key) { - nodes, errch := ds.GetMany(ctx, ks) - for { + defer wg.Done() + nodes := ds.GetMany(ctx, ks) + for opt := range nodes { select { - case nd, ok := <-nodes: - if !ok { - return - } - select { - case out <- nd: - case <-ctx.Done(): - return - } - case err := <-errch: - errs <- err + case out <- opt: + case <-ctx.Done(): return } } } for ks := range in { + wg.Add(1) go get(ks) } }