diff --git a/merkledag/merkledag.go b/merkledag/merkledag.go index fbb07c9ee..ef66a9f2e 100644 --- a/merkledag/merkledag.go +++ b/merkledag/merkledag.go @@ -332,6 +332,13 @@ func (ds *dagService) GetDAG(ctx context.Context, root *Node) <-chan *Node { break } nodes[i] = nd + for { //Check for duplicate links + ni, err := FindLink(root, blk.Key(), nodes) + if err != nil { + break + } + nodes[ni] = nd + } if next == i { sig <- nd diff --git a/merkledag/merkledag_test.go b/merkledag/merkledag_test.go index b5f170c24..621378964 100644 --- a/merkledag/merkledag_test.go +++ b/merkledag/merkledag_test.go @@ -75,6 +75,25 @@ func makeTestDag(t *testing.T) *Node { return root } +type devZero struct{} + +func (_ devZero) Read(b []byte) (int, error) { + for i, _ := range b { + b[i] = 0 + } + return len(b), nil +} + +func makeZeroDag(t *testing.T) *Node { + read := io.LimitReader(devZero{}, 1024*32) + spl := &chunk.SizeSplitter{512} + root, err := imp.NewDagFromReaderWithSplitter(read, spl) + if err != nil { + t.Fatal(err) + } + return root +} + func TestBatchFetch(t *testing.T) { var dagservs []DAGService for _, bsi := range blockservice.Mocks(t, 5) { @@ -133,3 +152,62 @@ func TestBatchFetch(t *testing.T) { <-done } } + +func TestBatchFetchDupBlock(t *testing.T) { + var dagservs []DAGService + for _, bsi := range blockservice.Mocks(t, 5) { + dagservs = append(dagservs, NewDAGService(bsi)) + } + t.Log("finished setup.") + + root := makeZeroDag(t) + read, err := uio.NewDagReader(root, nil) + if err != nil { + t.Fatal(err) + } + expected, err := ioutil.ReadAll(read) + if err != nil { + t.Fatal(err) + } + + err = dagservs[0].AddRecursive(root) + if err != nil { + t.Fatal(err) + } + + t.Log("Added file to first node.") + + k, err := root.Key() + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + for i := 1; i < len(dagservs); i++ { + go func(i int) { + first, err := dagservs[i].Get(k) + if err != nil { + t.Fatal(err) + } + fmt.Println("Got first node back.") + + read, err := uio.NewDagReader(first, dagservs[i]) + if err != nil { + t.Fatal(err) + } + datagot, err := ioutil.ReadAll(read) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(datagot, expected) { + t.Fatal("Got bad data back!") + } + done <- struct{}{} + }(i) + } + + for i := 1; i < len(dagservs); i++ { + <-done + } +}