From cf7f5da42659ea1dab7e3aec6b3950cf2b91773c Mon Sep 17 00:00:00 2001 From: Jeromy Date: Wed, 10 Feb 2016 21:42:17 -0800 Subject: [PATCH] don't fail promises that already succeeded License: MIT Signed-off-by: Jeromy --- merkledag/merkledag.go | 57 ++++++++++++++++++++++++++++--------- merkledag/merkledag_test.go | 44 ++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 14 deletions(-) diff --git a/merkledag/merkledag.go b/merkledag/merkledag.go index 3761096cc..6a6ad0ecd 100644 --- a/merkledag/merkledag.go +++ b/merkledag/merkledag.go @@ -176,9 +176,8 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { } promises := make([]NodeGetter, len(keys)) - sendChans := make([]chan<- *Node, len(keys)) for i := range keys { - promises[i], sendChans[i] = newNodePromise(ctx) + promises[i] = newNodePromise(ctx) } dedupedKeys := dedupeKeys(keys) @@ -199,7 +198,9 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { } if opt.Err != nil { - log.Error("error fetching: ", opt.Err) + for _, p := range promises { + p.Fail(opt.Err) + } return } @@ -214,7 +215,7 @@ func GetNodes(ctx context.Context, ds DAGService, keys []key.Key) []NodeGetter { is := FindLinks(keys, k, 0) for _, i := range is { count++ - sendChans[i] <- nd + promises[i].Send(nd) } case <-ctx.Done(): return @@ -237,18 +238,18 @@ func dedupeKeys(ks []key.Key) []key.Key { return out } -func newNodePromise(ctx context.Context) (NodeGetter, chan<- *Node) { - ch := make(chan *Node, 1) +func newNodePromise(ctx context.Context) NodeGetter { return &nodePromise{ - recv: ch, + recv: make(chan *Node, 1), ctx: ctx, err: make(chan error, 1), - }, ch + } } type nodePromise struct { cache *Node - recv <-chan *Node + clk sync.Mutex + recv chan *Node ctx context.Context err chan error } @@ -260,20 +261,49 @@ type nodePromise struct { type NodeGetter interface { Get(context.Context) (*Node, error) Fail(err error) + Send(*Node) } func (np *nodePromise) Fail(err error) { + np.clk.Lock() + v := np.cache + np.clk.Unlock() + + // if promise has a value, don't fail it + if v != nil { + return + } + np.err <- err } -func (np *nodePromise) Get(ctx context.Context) (*Node, error) { +func (np *nodePromise) Send(nd *Node) { + var already bool + np.clk.Lock() if np.cache != nil { - return np.cache, nil + already = true + } + np.cache = nd + np.clk.Unlock() + + if already { + panic("sending twice to the same promise is an error!") + } + + np.recv <- nd +} + +func (np *nodePromise) Get(ctx context.Context) (*Node, error) { + np.clk.Lock() + c := np.cache + np.clk.Unlock() + if c != nil { + return c, nil } select { - case blk := <-np.recv: - np.cache = blk + case nd := <-np.recv: + return nd, nil case <-np.ctx.Done(): return nil, np.ctx.Err() case <-ctx.Done(): @@ -281,7 +311,6 @@ func (np *nodePromise) Get(ctx context.Context) (*Node, error) { case err := <-np.err: return nil, err } - return np.cache, nil } type Batch struct { diff --git a/merkledag/merkledag_test.go b/merkledag/merkledag_test.go index 8137496d8..e475fa680 100644 --- a/merkledag/merkledag_test.go +++ b/merkledag/merkledag_test.go @@ -20,6 +20,7 @@ import ( imp "github.com/ipfs/go-ipfs/importer" chunk "github.com/ipfs/go-ipfs/importer/chunk" . "github.com/ipfs/go-ipfs/merkledag" + dstest "github.com/ipfs/go-ipfs/merkledag/test" "github.com/ipfs/go-ipfs/pin" uio "github.com/ipfs/go-ipfs/unixfs/io" u "gx/ipfs/QmZNVWh8LLjAavuQ2JXuFmuYH3C11xo988vSgp7UQrTRj1/go-ipfs-util" @@ -323,3 +324,46 @@ func TestEnumerateChildren(t *testing.T) { traverse(root) } + +func TestFetchFailure(t *testing.T) { + ds := dstest.Mock() + ds_bad := dstest.Mock() + + top := new(Node) + for i := 0; i < 10; i++ { + nd := &Node{Data: []byte{byte('a' + i)}} + _, err := ds.Add(nd) + if err != nil { + t.Fatal(err) + } + + err = top.AddNodeLinkClean(fmt.Sprintf("AA%d", i), nd) + if err != nil { + t.Fatal(err) + } + } + + for i := 0; i < 10; i++ { + nd := &Node{Data: []byte{'f', 'a' + byte(i)}} + _, err := ds_bad.Add(nd) + if err != nil { + t.Fatal(err) + } + + err = top.AddNodeLinkClean(fmt.Sprintf("BB%d", i), nd) + if err != nil { + t.Fatal(err) + } + } + + getters := GetDAG(context.Background(), ds, top) + for i, getter := range getters { + _, err := getter.Get(context.Background()) + if err != nil && i < 10 { + t.Fatal(err) + } + if err == nil && i >= 10 { + t.Fatal("should have failed request") + } + } +}