diff --git a/p2p/crypto/secio/interface.go b/p2p/crypto/secio/interface.go index 0167e612a..3fc54875f 100644 --- a/p2p/crypto/secio/interface.go +++ b/p2p/crypto/secio/interface.go @@ -17,28 +17,13 @@ type SessionGenerator struct { PrivateKey ci.PrivKey } -// NewSession takes an insecure io.ReadWriter, performs a TLS-like +// NewSession takes an insecure io.ReadWriter, sets up a TLS-like // handshake with the other side, and returns a secure session. +// The handshake isn't run until the connection is read or written to. // See the source for the protocol details and security implementation. // The provided Context is only needed for the duration of this function. -func (sg *SessionGenerator) NewSession(ctx context.Context, - insecure io.ReadWriter) (Session, error) { - - ss, err := newSecureSession(sg.LocalID, sg.PrivateKey) - if err != nil { - return nil, err - } - - if ctx == nil { - ctx = context.Background() - } - ctx, cancel := context.WithCancel(ctx) - if err := ss.handshake(ctx, insecure); err != nil { - cancel() - return nil, err - } - - return ss, nil +func (sg *SessionGenerator) NewSession(ctx context.Context, insecure io.ReadWriteCloser) (Session, error) { + return newSecureSession(ctx, sg.LocalID, sg.PrivateKey, insecure) } type Session interface { @@ -64,6 +49,9 @@ type Session interface { // SecureReadWriter returns the encrypted communication channel func (s *secureSession) ReadWriter() msgio.ReadWriteCloser { + if err := s.Handshake(); err != nil { + return &closedRW{err} + } return s.secure } @@ -79,15 +67,60 @@ func (s *secureSession) LocalPrivateKey() ci.PrivKey { // RemotePeer retrieves the remote peer. func (s *secureSession) RemotePeer() peer.ID { + if err := s.Handshake(); err != nil { + return "" + } return s.remotePeer } // RemotePeer retrieves the remote peer. func (s *secureSession) RemotePublicKey() ci.PubKey { + if err := s.Handshake(); err != nil { + return nil + } return s.remote.permanentPubKey } // Close closes the secure session func (s *secureSession) Close() error { + s.cancel() + s.handshakeMu.Lock() + defer s.handshakeMu.Unlock() + if s.secure == nil { + return s.insecure.Close() // hadn't secured yet. + } return s.secure.Close() } + +// closedRW implements a stub msgio interface that's already +// closed and errored. +type closedRW struct { + err error +} + +func (c *closedRW) Read(buf []byte) (int, error) { + return 0, c.err +} + +func (c *closedRW) Write(buf []byte) (int, error) { + return 0, c.err +} + +func (c *closedRW) NextMsgLen() (int, error) { + return 0, c.err +} + +func (c *closedRW) ReadMsg() ([]byte, error) { + return nil, c.err +} + +func (c *closedRW) WriteMsg(buf []byte) error { + return c.err +} + +func (c *closedRW) Close() error { + return c.err +} + +func (c *closedRW) ReleaseMsg(m []byte) { +} diff --git a/p2p/crypto/secio/protocol.go b/p2p/crypto/secio/protocol.go index 07656c582..73cb5fdf0 100644 --- a/p2p/crypto/secio/protocol.go +++ b/p2p/crypto/secio/protocol.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "io" + "sync" + "time" msgio "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio" context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context" @@ -27,15 +29,23 @@ var ErrClosed = errors.New("connection closed") // ErrEcho is returned when we're attempting to handshake with the same keys and nonces. var ErrEcho = errors.New("same keys and nonces. one side talking to self.") +// HandshakeTimeout governs how long the handshake will be allowed to take place for. +// Making this number large means there could be many bogus connections waiting to +// timeout in flight. Typical handshakes take ~3RTTs, so it should be completed within +// seconds across a typical planet in the solar system. +var HandshakeTimeout = time.Second * 30 + // nonceSize is the size of our nonces (in bytes) const nonceSize = 16 // secureSession encapsulates all the parameters needed for encrypting // and decrypting traffic from an insecure channel. type secureSession struct { - secure msgio.ReadWriteCloser + ctx context.Context + cancel context.CancelFunc - insecure io.ReadWriter + secure msgio.ReadWriteCloser + insecure io.ReadWriteCloser insecureM msgio.ReadWriter localKey ci.PrivKey @@ -46,6 +56,10 @@ type secureSession struct { remote encParams sharedSecret []byte + + handshakeMu sync.Mutex // guards handshakeDone + handshakeErr + handshakeDone bool + handshakeErr error } func (s *secureSession) Loggable() map[string]interface{} { @@ -56,8 +70,9 @@ func (s *secureSession) Loggable() map[string]interface{} { return m } -func newSecureSession(local peer.ID, key ci.PrivKey) (*secureSession, error) { +func newSecureSession(ctx context.Context, local peer.ID, key ci.PrivKey, insecure io.ReadWriteCloser) (*secureSession, error) { s := &secureSession{localPeer: local, localKey: key} + s.ctx, s.cancel = context.WithCancel(ctx) switch { case s.localPeer == "": @@ -66,18 +81,37 @@ func newSecureSession(local peer.ID, key ci.PrivKey) (*secureSession, error) { return nil, errors.New("no local private key provided") case !s.localPeer.MatchesPrivateKey(s.localKey): return nil, fmt.Errorf("peer.ID does not match PrivateKey") + case insecure == nil: + return nil, fmt.Errorf("insecure ReadWriter is nil") } + s.ctx = ctx + s.insecure = insecure + s.insecureM = msgio.NewReadWriter(insecure) return s, nil } -// handsahke performs initial communication over insecure channel to share +func (s *secureSession) Handshake() error { + s.handshakeMu.Lock() + defer s.handshakeMu.Unlock() + + if s.handshakeErr != nil { + return s.handshakeErr + } + + if !s.handshakeDone { + s.handshakeErr = s.runHandshake() + s.handshakeDone = true + } + return s.handshakeErr +} + +// runHandshake performs initial communication over insecure channel to share // keys, IDs, and initiate communication, assigning all necessary params. // requires the duplex channel to be a msgio.ReadWriter (for framed messaging) -func (s *secureSession) handshake(ctx context.Context, insecure io.ReadWriter) error { - - s.insecure = insecure - s.insecureM = msgio.NewReadWriter(insecure) +func (s *secureSession) runHandshake() error { + ctx, cancel := context.WithTimeout(s.ctx, HandshakeTimeout) // remove + defer cancel() // ============================================================================= // step 1. Propose -- propose cipher suite + send pubkeys + nonce diff --git a/p2p/net/conn/dial_test.go b/p2p/net/conn/dial_test.go index 585e51780..4c5b584bc 100644 --- a/p2p/net/conn/dial_test.go +++ b/p2p/net/conn/dial_test.go @@ -75,18 +75,35 @@ func setupConn(t *testing.T, ctx context.Context, secure bool) (a, b Conn, p1, p done := make(chan error) go func() { + defer close(done) + var err error c2, err = d2.Dial(ctx, p1.Addr, p1.ID) if err != nil { done <- err + return + } + + // if secure, need to read + write, as that's what triggers the handshake. + if secure { + if err := sayHello(c2); err != nil { + done <- err + } } - close(done) }() c1, err := l1.Accept() if err != nil { t.Fatal("failed to accept", err) } + + // if secure, need to read + write, as that's what triggers the handshake. + if secure { + if err := sayHello(c1); err != nil { + done <- err + } + } + if err := <-done; err != nil { t.Fatal(err) } @@ -94,6 +111,20 @@ func setupConn(t *testing.T, ctx context.Context, secure bool) (a, b Conn, p1, p return c1.(Conn), c2, p1, p2 } +func sayHello(c net.Conn) error { + h := []byte("hello") + if _, err := c.Write(h); err != nil { + return err + } + if _, err := c.Read(h); err != nil { + return err + } + if string(h) != "hello" { + return fmt.Errorf("did not get hello") + } + return nil +} + func testDialer(t *testing.T, secure bool) { // t.Skip("Skipping in favor of another test") @@ -203,7 +234,7 @@ func testDialerCloseEarly(t *testing.T, secure bool) { go func() { defer func() { done <- struct{}{} }() - _, err := l1.Accept() + c, err := l1.Accept() if err != nil { if strings.Contains(err.Error(), "closed") { gotclosed <- struct{}{} @@ -211,7 +242,13 @@ func testDialerCloseEarly(t *testing.T, secure bool) { } errs <- err } - errs <- fmt.Errorf("got conn") + + if _, err := c.Write([]byte("hello")); err != nil { + gotclosed <- struct{}{} + return + } + + errs <- fmt.Errorf("wrote to conn") }() c, err := d2.Dial(ctx, p1.Addr, p1.ID) diff --git a/p2p/net/conn/secure_conn.go b/p2p/net/conn/secure_conn.go index 7a8ce6f62..f5ac698e6 100644 --- a/p2p/net/conn/secure_conn.go +++ b/p2p/net/conn/secure_conn.go @@ -5,7 +5,6 @@ import ( "net" "time" - msgio "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio" ma "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr" context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context" @@ -16,15 +15,8 @@ import ( // secureConn wraps another Conn object with an encrypted channel. type secureConn struct { - - // the wrapped conn - insecure Conn - - // secure io (wrapping insecure) - secure msgio.ReadWriteCloser - - // secure Session - session secio.Session + insecure Conn // the wrapped conn + secure secio.Session // secure Session } // newConn constructs a new connection @@ -37,23 +29,20 @@ func newSecureConn(ctx context.Context, sk ic.PrivKey, insecure Conn) (Conn, err return nil, errors.New("insecure.LocalPeer() is nil") } if sk == nil { - panic("way") return nil, errors.New("private key is nil") } // NewSession performs the secure handshake, which takes multiple RTT sessgen := secio.SessionGenerator{LocalID: insecure.LocalPeer(), PrivateKey: sk} - session, err := sessgen.NewSession(ctx, insecure) + secure, err := sessgen.NewSession(ctx, insecure) if err != nil { return nil, err } conn := &secureConn{ insecure: insecure, - session: session, - secure: session.ReadWriter(), + secure: secure, } - log.Debugf("newSecureConn: %v to %v handshake success!", conn.LocalPeer(), conn.RemotePeer()) return conn, nil } @@ -102,49 +91,49 @@ func (c *secureConn) RemoteMultiaddr() ma.Multiaddr { // LocalPeer is the Peer on this side func (c *secureConn) LocalPeer() peer.ID { - return c.session.LocalPeer() + return c.secure.LocalPeer() } // RemotePeer is the Peer on the remote side func (c *secureConn) RemotePeer() peer.ID { - return c.session.RemotePeer() + return c.secure.RemotePeer() } // LocalPrivateKey is the public key of the peer on this side func (c *secureConn) LocalPrivateKey() ic.PrivKey { - return c.session.LocalPrivateKey() + return c.secure.LocalPrivateKey() } // RemotePubKey is the public key of the peer on the remote side func (c *secureConn) RemotePublicKey() ic.PubKey { - return c.session.RemotePublicKey() + return c.secure.RemotePublicKey() } // Read reads data, net.Conn style func (c *secureConn) Read(buf []byte) (int, error) { - return c.secure.Read(buf) + return c.secure.ReadWriter().Read(buf) } // Write writes data, net.Conn style func (c *secureConn) Write(buf []byte) (int, error) { - return c.secure.Write(buf) + return c.secure.ReadWriter().Write(buf) } func (c *secureConn) NextMsgLen() (int, error) { - return c.secure.NextMsgLen() + return c.secure.ReadWriter().NextMsgLen() } // ReadMsg reads data, net.Conn style func (c *secureConn) ReadMsg() ([]byte, error) { - return c.secure.ReadMsg() + return c.secure.ReadWriter().ReadMsg() } // WriteMsg writes data, net.Conn style func (c *secureConn) WriteMsg(buf []byte) error { - return c.secure.WriteMsg(buf) + return c.secure.ReadWriter().WriteMsg(buf) } // ReleaseMsg releases a buffer func (c *secureConn) ReleaseMsg(m []byte) { - c.secure.ReleaseMsg(m) + c.secure.ReadWriter().ReleaseMsg(m) } diff --git a/p2p/net/conn/secure_conn_test.go b/p2p/net/conn/secure_conn_test.go index 79b7da572..f027b6a4c 100644 --- a/p2p/net/conn/secure_conn_test.go +++ b/p2p/net/conn/secure_conn_test.go @@ -23,6 +23,15 @@ func upgradeToSecureConn(t *testing.T, ctx context.Context, sk ic.PrivKey, c Con if err != nil { return nil, err } + + // need to read + write, as that's what triggers the handshake. + h := []byte("hello") + if _, err := s.Write(h); err != nil { + return nil, err + } + if _, err := s.Read(h); err != nil { + return nil, err + } return s, nil }