mirror of
https://github.com/ipfs/kubo.git
synced 2025-07-01 10:49:24 +08:00
net: have an explicit IdentifyConn on dial
- Make sure we call IdentifyConn on dialed out conns - we wait until the identify is **done** before return - on listening case, we can also wait. - tests now make sure dial does wait. - tests now make sure we can wait on listening case.
This commit is contained in:
65
net/id.go
65
net/id.go
@ -1,6 +1,8 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
handshake "github.com/jbenet/go-ipfs/net/handshake"
|
handshake "github.com/jbenet/go-ipfs/net/handshake"
|
||||||
pb "github.com/jbenet/go-ipfs/net/handshake/pb"
|
pb "github.com/jbenet/go-ipfs/net/handshake/pb"
|
||||||
|
|
||||||
@ -18,14 +20,54 @@ import (
|
|||||||
// * Our public Listen Addresses
|
// * Our public Listen Addresses
|
||||||
type IDService struct {
|
type IDService struct {
|
||||||
Network Network
|
Network Network
|
||||||
|
|
||||||
|
// connections undergoing identification
|
||||||
|
// for wait purposes
|
||||||
|
currid map[Conn]chan struct{}
|
||||||
|
currmu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIDService(n Network) *IDService {
|
func NewIDService(n Network) *IDService {
|
||||||
s := &IDService{Network: n}
|
s := &IDService{
|
||||||
|
Network: n,
|
||||||
|
currid: make(map[Conn]chan struct{}),
|
||||||
|
}
|
||||||
n.SetHandler(ProtocolIdentify, s.RequestHandler)
|
n.SetHandler(ProtocolIdentify, s.RequestHandler)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ids *IDService) IdentifyConn(c Conn) {
|
||||||
|
ids.currmu.Lock()
|
||||||
|
if _, found := ids.currid[c]; found {
|
||||||
|
ids.currmu.Unlock()
|
||||||
|
log.Debugf("IdentifyConn called twice on: %s", c)
|
||||||
|
return // already identifying it.
|
||||||
|
}
|
||||||
|
ids.currid[c] = make(chan struct{})
|
||||||
|
ids.currmu.Unlock()
|
||||||
|
|
||||||
|
s, err := c.NewStreamWithProtocol(ProtocolIdentify)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("network: unable to open initial stream for %s", ProtocolIdentify)
|
||||||
|
log.Event(ids.Network.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ok give the response to our handler.
|
||||||
|
ids.ResponseHandler(s)
|
||||||
|
|
||||||
|
ids.currmu.Lock()
|
||||||
|
ch, found := ids.currid[c]
|
||||||
|
delete(ids.currid, c)
|
||||||
|
ids.currmu.Unlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
log.Errorf("IdentifyConn failed to find channel (programmer error) for %s", c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(ch) // release everyone waiting.
|
||||||
|
}
|
||||||
|
|
||||||
func (ids *IDService) RequestHandler(s Stream) {
|
func (ids *IDService) RequestHandler(s Stream) {
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
c := s.Conn()
|
c := s.Conn()
|
||||||
@ -101,6 +143,7 @@ func (ids *IDService) consumeMessage(mes *pb.Handshake3, c Conn) {
|
|||||||
|
|
||||||
// update our peerstore with the addresses.
|
// update our peerstore with the addresses.
|
||||||
ids.Network.Peerstore().AddAddresses(p, lmaddrs)
|
ids.Network.Peerstore().AddAddresses(p, lmaddrs)
|
||||||
|
log.Debugf("%s received listen addrs for %s: %s", c.LocalPeer(), c.RemotePeer(), lmaddrs)
|
||||||
|
|
||||||
// get protocol versions
|
// get protocol versions
|
||||||
pv := *mes.H1.ProtocolVersion
|
pv := *mes.H1.ProtocolVersion
|
||||||
@ -108,3 +151,23 @@ func (ids *IDService) consumeMessage(mes *pb.Handshake3, c Conn) {
|
|||||||
ids.Network.Peerstore().Put(p, "ProtocolVersion", pv)
|
ids.Network.Peerstore().Put(p, "ProtocolVersion", pv)
|
||||||
ids.Network.Peerstore().Put(p, "AgentVersion", av)
|
ids.Network.Peerstore().Put(p, "AgentVersion", av)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IdentifyWait returns a channel which will be closed once
|
||||||
|
// "ProtocolIdentify" (handshake3) finishes on given conn.
|
||||||
|
// This happens async so the connection can start to be used
|
||||||
|
// even if handshake3 knowledge is not necesary.
|
||||||
|
// Users **MUST** call IdentifyWait _after_ IdentifyConn
|
||||||
|
func (ids *IDService) IdentifyWait(c Conn) <-chan struct{} {
|
||||||
|
ids.currmu.Lock()
|
||||||
|
ch, found := ids.currid[c]
|
||||||
|
ids.currmu.Unlock()
|
||||||
|
if found {
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// if not found, it means we are already done identifying it, or
|
||||||
|
// haven't even started. either way, return a new channel closed.
|
||||||
|
ch = make(chan struct{})
|
||||||
|
close(ch)
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
@ -32,7 +32,7 @@ func DivulgeAddresses(a, b inet.Network) {
|
|||||||
b.Peerstore().AddAddresses(id, addrs)
|
b.Peerstore().AddAddresses(id, addrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIDService(t *testing.T) {
|
func subtestIDService(t *testing.T, postDialWait time.Duration) {
|
||||||
|
|
||||||
// the generated networks should have the id service wired in.
|
// the generated networks should have the id service wired in.
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@ -55,16 +55,26 @@ func TestIDService(t *testing.T) {
|
|||||||
t.Fatalf("Failed to dial:", err)
|
t.Fatalf("Failed to dial:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is shitty. dial should wait for connecting to end
|
// we need to wait here if Dial returns before ID service is finished.
|
||||||
<-time.After(100 * time.Millisecond)
|
if postDialWait > 0 {
|
||||||
|
<-time.After(postDialWait)
|
||||||
|
}
|
||||||
|
|
||||||
// the IDService should be opened automatically, by the network.
|
// the IDService should be opened automatically, by the network.
|
||||||
// what we should see now is that both peers know about each others listen addresses.
|
// what we should see now is that both peers know about each others listen addresses.
|
||||||
testKnowsAddrs(t, n1, n2p, n2.Peerstore().Addresses(n2p)) // has them
|
testKnowsAddrs(t, n1, n2p, n2.Peerstore().Addresses(n2p)) // has them
|
||||||
testKnowsAddrs(t, n2, n1p, n1.Peerstore().Addresses(n1p)) // has them
|
testHasProtocolVersions(t, n1, n2p)
|
||||||
|
|
||||||
|
// now, this wait we do have to do. it's the wait for the Listening side
|
||||||
|
// to be done identifying the connection.
|
||||||
|
c := n2.ConnsToPeer(n1.LocalPeer())
|
||||||
|
if len(c) < 1 {
|
||||||
|
t.Fatal("should have connection by now at least.")
|
||||||
|
}
|
||||||
|
<-n2.IdentifyProtocol().IdentifyWait(c[0])
|
||||||
|
|
||||||
// and the protocol versions.
|
// and the protocol versions.
|
||||||
testHasProtocolVersions(t, n1, n2p)
|
testKnowsAddrs(t, n2, n1p, n1.Peerstore().Addresses(n1p)) // has them
|
||||||
testHasProtocolVersions(t, n2, n1p)
|
testHasProtocolVersions(t, n2, n1p)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,18 +92,39 @@ func testKnowsAddrs(t *testing.T, n inet.Network, p peer.ID, expected []ma.Multi
|
|||||||
for _, addr := range expected {
|
for _, addr := range expected {
|
||||||
if _, found := have[addr.String()]; !found {
|
if _, found := have[addr.String()]; !found {
|
||||||
t.Errorf("%s did not have addr for %s: %s", n.LocalPeer(), p, addr)
|
t.Errorf("%s did not have addr for %s: %s", n.LocalPeer(), p, addr)
|
||||||
panic("ahhhhhhh")
|
// panic("ahhhhhhh")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testHasProtocolVersions(t *testing.T, n inet.Network, p peer.ID) {
|
func testHasProtocolVersions(t *testing.T, n inet.Network, p peer.ID) {
|
||||||
v, err := n.Peerstore().Get(p, "ProtocolVersion")
|
v, err := n.Peerstore().Get(p, "ProtocolVersion")
|
||||||
|
if v == nil {
|
||||||
|
t.Error("no protocol version")
|
||||||
|
return
|
||||||
|
}
|
||||||
if v.(string) != handshake.IpfsVersion.String() {
|
if v.(string) != handshake.IpfsVersion.String() {
|
||||||
t.Fatal("protocol mismatch", err)
|
t.Error("protocol mismatch", err)
|
||||||
}
|
}
|
||||||
v, err = n.Peerstore().Get(p, "AgentVersion")
|
v, err = n.Peerstore().Get(p, "AgentVersion")
|
||||||
if v.(string) != handshake.ClientVersion {
|
if v.(string) != handshake.ClientVersion {
|
||||||
t.Fatal("agent version mismatch", err)
|
t.Error("agent version mismatch", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIDServiceWait gives the ID service 100ms to finish after dialing
|
||||||
|
// this is becasue it used to be concurrent. Now, Dial wait till the
|
||||||
|
// id service is done.
|
||||||
|
func TestIDServiceWait(t *testing.T) {
|
||||||
|
N := 3
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
subtestIDService(t, 100*time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIDServiceNoWait(t *testing.T) {
|
||||||
|
N := 3
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
subtestIDService(t, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,6 +88,9 @@ type Network interface {
|
|||||||
// Conns returns the connections in this Netowrk
|
// Conns returns the connections in this Netowrk
|
||||||
Conns() []Conn
|
Conns() []Conn
|
||||||
|
|
||||||
|
// ConnsToPeer returns the connections in this Netowrk for given peer.
|
||||||
|
ConnsToPeer(p peer.ID) []Conn
|
||||||
|
|
||||||
// BandwidthTotals returns the total number of bytes passed through
|
// BandwidthTotals returns the total number of bytes passed through
|
||||||
// the network since it was instantiated
|
// the network since it was instantiated
|
||||||
BandwidthTotals() (uint64, uint64)
|
BandwidthTotals() (uint64, uint64)
|
||||||
@ -102,6 +105,11 @@ type Network interface {
|
|||||||
|
|
||||||
// CtxGroup returns the network's contextGroup
|
// CtxGroup returns the network's contextGroup
|
||||||
CtxGroup() ctxgroup.ContextGroup
|
CtxGroup() ctxgroup.ContextGroup
|
||||||
|
|
||||||
|
// IdentifyProtocol returns the instance of the object running the Identify
|
||||||
|
// Protocol. This is what runs the ifps handshake-- this should be removed
|
||||||
|
// if this abstracted out to its own package.
|
||||||
|
IdentifyProtocol() *IDService
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialer represents a service that can dial out to peers
|
// Dialer represents a service that can dial out to peers
|
||||||
|
@ -29,6 +29,7 @@ type peernet struct {
|
|||||||
|
|
||||||
// needed to implement inet.Network
|
// needed to implement inet.Network
|
||||||
mux inet.Mux
|
mux inet.Mux
|
||||||
|
ids *inet.IDService
|
||||||
|
|
||||||
cg ctxgroup.ContextGroup
|
cg ctxgroup.ContextGroup
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
@ -61,6 +62,11 @@ func newPeernet(ctx context.Context, m *mocknet, k ic.PrivKey,
|
|||||||
}
|
}
|
||||||
|
|
||||||
n.cg.SetTeardown(n.teardown)
|
n.cg.SetTeardown(n.teardown)
|
||||||
|
|
||||||
|
// setup a conn handler that immediately "asks the other side about them"
|
||||||
|
// this is ProtocolIdentify.
|
||||||
|
n.ids = inet.NewIDService(n)
|
||||||
|
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,6 +164,10 @@ func (pn *peernet) remoteOpenedConn(c *conn) {
|
|||||||
// addConn constructs and adds a connection
|
// addConn constructs and adds a connection
|
||||||
// to given remote peer over given link
|
// to given remote peer over given link
|
||||||
func (pn *peernet) addConn(c *conn) {
|
func (pn *peernet) addConn(c *conn) {
|
||||||
|
|
||||||
|
// run the Identify protocol/handshake.
|
||||||
|
pn.ids.IdentifyConn(c)
|
||||||
|
|
||||||
pn.Lock()
|
pn.Lock()
|
||||||
cs, found := pn.connsByPeer[c.RemotePeer()]
|
cs, found := pn.connsByPeer[c.RemotePeer()]
|
||||||
if !found {
|
if !found {
|
||||||
@ -327,3 +337,7 @@ func (pn *peernet) NewStream(pr inet.ProtocolID, p peer.ID) (inet.Stream, error)
|
|||||||
func (pn *peernet) SetHandler(p inet.ProtocolID, h inet.StreamHandler) {
|
func (pn *peernet) SetHandler(p inet.ProtocolID, h inet.StreamHandler) {
|
||||||
pn.mux.SetHandler(p, h)
|
pn.mux.SetHandler(p, h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pn *peernet) IdentifyProtocol() *inet.IDService {
|
||||||
|
return pn.ids
|
||||||
|
}
|
||||||
|
34
net/net.go
34
net/net.go
@ -129,21 +129,21 @@ func NewNetwork(ctx context.Context, listen []ma.Multiaddr, local peer.ID,
|
|||||||
|
|
||||||
func (n *network) newConnHandler(c *swarm.Conn) {
|
func (n *network) newConnHandler(c *swarm.Conn) {
|
||||||
cc := (*conn_)(c)
|
cc := (*conn_)(c)
|
||||||
s, err := cc.NewStreamWithProtocol(ProtocolIdentify)
|
n.ids.IdentifyConn(cc)
|
||||||
if err != nil {
|
|
||||||
log.Error("network: unable to open initial stream for %s", ProtocolIdentify)
|
|
||||||
log.Event(n.CtxGroup().Context(), "IdentifyOpenFailed", c.RemotePeer())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ok give the response to our handler.
|
|
||||||
n.ids.ResponseHandler(s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialPeer attempts to establish a connection to a given peer.
|
// DialPeer attempts to establish a connection to a given peer.
|
||||||
// Respects the context.
|
// Respects the context.
|
||||||
func (n *network) DialPeer(ctx context.Context, p peer.ID) error {
|
func (n *network) DialPeer(ctx context.Context, p peer.ID) error {
|
||||||
_, err := n.swarm.Dial(ctx, p)
|
sc, err := n.swarm.Dial(ctx, p)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// identify the connection before returning.
|
||||||
|
n.ids.IdentifyConn((*conn_)(sc))
|
||||||
|
log.Debugf("network for %s finished dialing %s", n.local, p)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *network) Protocols() []ProtocolID {
|
func (n *network) Protocols() []ProtocolID {
|
||||||
@ -185,6 +185,16 @@ func (n *network) Conns() []Conn {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnsToPeer returns the connections in this Netowrk for given peer.
|
||||||
|
func (n *network) ConnsToPeer(p peer.ID) []Conn {
|
||||||
|
conns1 := n.swarm.ConnectionsToPeer(p)
|
||||||
|
out := make([]Conn, len(conns1))
|
||||||
|
for i, c := range conns1 {
|
||||||
|
out[i] = (*conn_)(c)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// ClosePeer connection to peer
|
// ClosePeer connection to peer
|
||||||
func (n *network) ClosePeer(p peer.ID) error {
|
func (n *network) ClosePeer(p peer.ID) error {
|
||||||
return n.swarm.CloseConnection(p)
|
return n.swarm.CloseConnection(p)
|
||||||
@ -254,6 +264,10 @@ func (n *network) SetHandler(p ProtocolID, h StreamHandler) {
|
|||||||
n.mux.SetHandler(p, h)
|
n.mux.SetHandler(p, h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *network) IdentifyProtocol() *IDService {
|
||||||
|
return n.ids
|
||||||
|
}
|
||||||
|
|
||||||
func WriteProtocolHeader(pr ProtocolID, s Stream) error {
|
func WriteProtocolHeader(pr ProtocolID, s Stream) error {
|
||||||
if pr != "" { // only write proper protocol headers
|
if pr != "" { // only write proper protocol headers
|
||||||
if err := WriteLengthPrefix(s, string(pr)); err != nil {
|
if err := WriteLengthPrefix(s, string(pr)); err != nil {
|
||||||
|
Reference in New Issue
Block a user