diff --git a/routing/dht/routing.go b/routing/dht/routing.go index 9a2ca14aa..49c82f19c 100644 --- a/routing/dht/routing.go +++ b/routing/dht/routing.go @@ -122,6 +122,13 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key key.Key) ([]byte, error) { // if someone sent us a different 'less-valid' record, lets correct them if !bytes.Equal(v.Val, best) { go func(v routing.RecvdVal) { + if v.From == dht.self { + err := dht.putLocal(key, fixupRec) + if err != nil { + log.Error("Error correcting local dht entry:", err) + } + return + } ctx, cancel := context.WithTimeout(dht.Context(), time.Second*30) defer cancel() err := dht.putValueToPeer(ctx, v.From, key, fixupRec) diff --git a/routing/record/validation.go b/routing/record/validation.go index 4cce81b2a..16bf60090 100644 --- a/routing/record/validation.go +++ b/routing/record/validation.go @@ -73,13 +73,19 @@ func (v Validator) IsSigned(k key.Key) (bool, error) { // verifies that the passed in record value is the PublicKey // that matches the passed in key. func ValidatePublicKeyRecord(k key.Key, val []byte) error { - keyparts := bytes.Split([]byte(k), []byte("/")) - if len(keyparts) < 3 { - return errors.New("invalid key") + if len(k) != 38 { + return errors.New("invalid public key record key") } + prefix := string(k[:4]) + if prefix != "/pk/" { + return errors.New("key was not prefixed with /pk/") + } + + keyhash := []byte(k[4:]) + pkh := u.Hash(val) - if !bytes.Equal(keyparts[2], pkh) { + if !bytes.Equal(keyhash, pkh) { return errors.New("public key does not match storage key") } return nil diff --git a/routing/record/validation_test.go b/routing/record/validation_test.go new file mode 100644 index 000000000..ae389244e --- /dev/null +++ b/routing/record/validation_test.go @@ -0,0 +1,35 @@ +package record + +import ( + "encoding/base64" + "testing" + + key "github.com/ipfs/go-ipfs/blocks/key" + ci "gx/ipfs/QmUWER4r4qMvaCnX5zREcfyiWN7cXN9g3a7fkRqNz8qWPP/go-libp2p-crypto" +) + +var OffensiveKey = "CAASXjBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQDjXAQQMal4SB2tSnX6NJIPmC69/BT8A8jc7/gDUZNkEhdhYHvc7k7S4vntV/c92nJGxNdop9fKJyevuNMuXhhHAgMBAAE=" + +func TestValidatePublicKey(t *testing.T) { + pkb, err := base64.StdEncoding.DecodeString(OffensiveKey) + if err != nil { + t.Fatal(err) + } + + pubk, err := ci.UnmarshalPublicKey(pkb) + if err != nil { + t.Fatal(err) + } + + pkh, err := pubk.Hash() + if err != nil { + t.Fatal(err) + } + + k := key.Key("/pk/" + string(pkh)) + + err = ValidatePublicKeyRecord(k, pkb) + if err != nil { + t.Fatal(err) + } +}