mirror of
https://github.com/ipfs/kubo.git
synced 2025-07-01 19:24:14 +08:00
muxer now uses ctxCloser
This commit is contained in:
@ -7,6 +7,7 @@ import (
|
||||
msg "github.com/jbenet/go-ipfs/net/message"
|
||||
pb "github.com/jbenet/go-ipfs/net/mux/internal/pb"
|
||||
u "github.com/jbenet/go-ipfs/util"
|
||||
ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"
|
||||
|
||||
context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
|
||||
proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto"
|
||||
@ -14,6 +15,8 @@ import (
|
||||
|
||||
var log = u.Logger("muxer")
|
||||
|
||||
// ProtocolIDs used to identify each protocol.
|
||||
// These should probably be defined elsewhere.
|
||||
var (
|
||||
ProtocolID_Routing = pb.ProtocolID_Routing
|
||||
ProtocolID_Exchange = pb.ProtocolID_Exchange
|
||||
@ -38,11 +41,6 @@ type Muxer struct {
|
||||
// Protocols are the multiplexed services.
|
||||
Protocols ProtocolMap
|
||||
|
||||
// cancel is the function to stop the Muxer
|
||||
cancel context.CancelFunc
|
||||
ctx context.Context
|
||||
wg sync.WaitGroup
|
||||
|
||||
bwiLock sync.Mutex
|
||||
bwIn uint64
|
||||
|
||||
@ -50,14 +48,25 @@ type Muxer struct {
|
||||
bwOut uint64
|
||||
|
||||
*msg.Pipe
|
||||
ctxc.ContextCloser
|
||||
}
|
||||
|
||||
// NewMuxer constructs a muxer given a protocol map.
|
||||
func NewMuxer(mp ProtocolMap) *Muxer {
|
||||
return &Muxer{
|
||||
Protocols: mp,
|
||||
Pipe: msg.NewPipe(10),
|
||||
func NewMuxer(ctx context.Context, mp ProtocolMap) *Muxer {
|
||||
m := &Muxer{
|
||||
Protocols: mp,
|
||||
Pipe: msg.NewPipe(10),
|
||||
ContextCloser: ctxc.NewContextCloser(ctx, nil),
|
||||
}
|
||||
|
||||
m.Children().Add(1)
|
||||
go m.handleIncomingMessages()
|
||||
for pid, proto := range m.Protocols {
|
||||
m.Children().Add(1)
|
||||
go m.handleOutgoingMessages(pid, proto)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// GetPipe implements the Protocol interface
|
||||
@ -65,30 +74,7 @@ func (m *Muxer) GetPipe() *msg.Pipe {
|
||||
return m.Pipe
|
||||
}
|
||||
|
||||
// Start kicks off the Muxer goroutines.
|
||||
func (m *Muxer) Start(ctx context.Context) error {
|
||||
if m == nil {
|
||||
panic("nix muxer")
|
||||
}
|
||||
|
||||
if m.cancel != nil {
|
||||
return errors.New("Muxer already started.")
|
||||
}
|
||||
|
||||
// make a cancellable context.
|
||||
m.ctx, m.cancel = context.WithCancel(ctx)
|
||||
m.wg = sync.WaitGroup{}
|
||||
|
||||
m.wg.Add(1)
|
||||
go m.handleIncomingMessages()
|
||||
for pid, proto := range m.Protocols {
|
||||
m.wg.Add(1)
|
||||
go m.handleOutgoingMessages(pid, proto)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBandwidthTotals return the in/out bandwidth measured over this muxer.
|
||||
func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) {
|
||||
m.bwiLock.Lock()
|
||||
in = m.bwIn
|
||||
@ -100,19 +86,6 @@ func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) {
|
||||
return
|
||||
}
|
||||
|
||||
// Stop stops muxer activity.
|
||||
func (m *Muxer) Stop() {
|
||||
if m.cancel == nil {
|
||||
panic("muxer stopped twice.")
|
||||
}
|
||||
// issue cancel, and wipe func.
|
||||
m.cancel()
|
||||
m.cancel = context.CancelFunc(nil)
|
||||
|
||||
// wait for everything to wind down.
|
||||
m.wg.Wait()
|
||||
}
|
||||
|
||||
// AddProtocol adds a Protocol with given ProtocolID to the Muxer.
|
||||
func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error {
|
||||
if _, found := m.Protocols[pid]; found {
|
||||
@ -126,28 +99,26 @@ func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error {
|
||||
// handleIncoming consumes the messages on the m.Incoming channel and
|
||||
// routes them appropriately (to the protocols).
|
||||
func (m *Muxer) handleIncomingMessages() {
|
||||
defer m.wg.Done()
|
||||
defer m.Children().Done()
|
||||
|
||||
for {
|
||||
if m == nil {
|
||||
panic("nil muxer")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-m.Closing():
|
||||
return
|
||||
|
||||
case msg, more := <-m.Incoming:
|
||||
if !more {
|
||||
return
|
||||
}
|
||||
m.Children().Add(1)
|
||||
go m.handleIncomingMessage(msg)
|
||||
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleIncomingMessage routes message to the appropriate protocol.
|
||||
func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
|
||||
defer m.Children().Done()
|
||||
|
||||
m.bwiLock.Lock()
|
||||
// TODO: compensate for overhead
|
||||
@ -169,8 +140,7 @@ func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
|
||||
|
||||
select {
|
||||
case proto.GetPipe().Incoming <- m2:
|
||||
case <-m.ctx.Done():
|
||||
log.Error(m.ctx.Err())
|
||||
case <-m.Closing():
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -178,7 +148,7 @@ func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
|
||||
// handleOutgoingMessages consumes the messages on the proto.Outgoing channel,
|
||||
// wraps them and sends them out.
|
||||
func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) {
|
||||
defer m.wg.Done()
|
||||
defer m.Children().Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
@ -186,9 +156,10 @@ func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) {
|
||||
if !more {
|
||||
return
|
||||
}
|
||||
m.Children().Add(1)
|
||||
go m.handleOutgoingMessage(pid, msg)
|
||||
|
||||
case <-m.ctx.Done():
|
||||
case <-m.Closing():
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -196,6 +167,8 @@ func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) {
|
||||
|
||||
// handleOutgoingMessage wraps out a message and sends it out the
|
||||
func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) {
|
||||
defer m.Children().Done()
|
||||
|
||||
data, err := wrapData(m1.Data(), pid)
|
||||
if err != nil {
|
||||
log.Errorf("muxer serializing error: %v", err)
|
||||
@ -204,13 +177,14 @@ func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) {
|
||||
|
||||
m.bwoLock.Lock()
|
||||
// TODO: compensate for overhead
|
||||
// TODO(jbenet): switch this to a goroutine to prevent sync waiting.
|
||||
m.bwOut += uint64(len(data))
|
||||
m.bwoLock.Unlock()
|
||||
|
||||
m2 := msg.New(m1.Peer(), data)
|
||||
select {
|
||||
case m.GetPipe().Outgoing <- m2:
|
||||
case <-m.ctx.Done():
|
||||
case <-m.Closing():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -54,23 +54,20 @@ func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []by
|
||||
}
|
||||
|
||||
func TestSimpleMuxer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// setup
|
||||
p1 := &TestProtocol{Pipe: msg.NewPipe(10)}
|
||||
p2 := &TestProtocol{Pipe: msg.NewPipe(10)}
|
||||
pid1 := pb.ProtocolID_Test
|
||||
pid2 := pb.ProtocolID_Routing
|
||||
mux1 := NewMuxer(ProtocolMap{
|
||||
mux1 := NewMuxer(ctx, ProtocolMap{
|
||||
pid1: p1,
|
||||
pid2: p2,
|
||||
})
|
||||
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
|
||||
// peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
|
||||
|
||||
// run muxer
|
||||
ctx := context.Background()
|
||||
mux1.Start(ctx)
|
||||
|
||||
// test outgoing p1
|
||||
for _, s := range []string{"foo", "bar", "baz"} {
|
||||
p1.Outgoing <- msg.New(peer1, []byte(s))
|
||||
@ -105,23 +102,21 @@ func TestSimpleMuxer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSimultMuxer(t *testing.T) {
|
||||
// run muxer
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// setup
|
||||
p1 := &TestProtocol{Pipe: msg.NewPipe(10)}
|
||||
p2 := &TestProtocol{Pipe: msg.NewPipe(10)}
|
||||
pid1 := pb.ProtocolID_Test
|
||||
pid2 := pb.ProtocolID_Identify
|
||||
mux1 := NewMuxer(ProtocolMap{
|
||||
mux1 := NewMuxer(ctx, ProtocolMap{
|
||||
pid1: p1,
|
||||
pid2: p2,
|
||||
})
|
||||
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
|
||||
// peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
|
||||
|
||||
// run muxer
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
mux1.Start(ctx)
|
||||
|
||||
// counts
|
||||
total := 10000
|
||||
speed := time.Microsecond * 1
|
||||
@ -214,22 +209,20 @@ func TestSimultMuxer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStopping(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// setup
|
||||
p1 := &TestProtocol{Pipe: msg.NewPipe(10)}
|
||||
p2 := &TestProtocol{Pipe: msg.NewPipe(10)}
|
||||
pid1 := pb.ProtocolID_Test
|
||||
pid2 := pb.ProtocolID_Identify
|
||||
mux1 := NewMuxer(ProtocolMap{
|
||||
mux1 := NewMuxer(ctx, ProtocolMap{
|
||||
pid1: p1,
|
||||
pid2: p2,
|
||||
})
|
||||
peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa")
|
||||
// peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb")
|
||||
|
||||
// run muxer
|
||||
mux1.Start(context.Background())
|
||||
|
||||
// test outgoing p1
|
||||
for _, s := range []string{"foo1", "bar1", "baz1"} {
|
||||
p1.Outgoing <- msg.New(peer1, []byte(s))
|
||||
@ -246,10 +239,7 @@ func TestStopping(t *testing.T) {
|
||||
testMsg(t, <-p1.Incoming, []byte(s))
|
||||
}
|
||||
|
||||
mux1.Stop()
|
||||
if mux1.cancel != nil {
|
||||
t.Error("mux.cancel should be nil")
|
||||
}
|
||||
mux1.Close() // waits
|
||||
|
||||
// test outgoing p1
|
||||
for _, s := range []string{"foo3", "bar3", "baz3"} {
|
||||
@ -274,5 +264,4 @@ func TestStopping(t *testing.T) {
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
11
net/net.go
11
net/net.go
@ -36,17 +36,12 @@ func NewIpfsNetwork(ctx context.Context, local peer.Peer,
|
||||
|
||||
in := &IpfsNetwork{
|
||||
local: local,
|
||||
muxer: mux.NewMuxer(*pmap),
|
||||
muxer: mux.NewMuxer(ctx, *pmap),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
err := in.muxer.Start(ctx)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var err error
|
||||
in.swarm, err = swarm.NewSwarm(ctx, local, peers)
|
||||
if err != nil {
|
||||
cancel()
|
||||
@ -101,7 +96,7 @@ func (n *IpfsNetwork) Close() error {
|
||||
}
|
||||
|
||||
n.swarm.Close()
|
||||
n.muxer.Stop()
|
||||
n.muxer.Close()
|
||||
|
||||
n.cancel()
|
||||
n.cancel = nil
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
// CloseFunc is a function used to close a ContextCloser
|
||||
type CloseFunc func() error
|
||||
|
||||
var nilCloseFunc = func() error { return nil }
|
||||
|
||||
// ContextCloser is an interface for services able to be opened and closed.
|
||||
// It has a parent Context, and Children. But ContextCloser is not a proper
|
||||
// "tree" like the Context tree. It is more like a Context-WaitGroup hybrid.
|
||||
@ -92,6 +94,9 @@ type contextCloser struct {
|
||||
// NewContextCloser constructs and returns a ContextCloser. It will call
|
||||
// cf CloseFunc before its Done() Wait signals fire.
|
||||
func NewContextCloser(ctx context.Context, cf CloseFunc) ContextCloser {
|
||||
if cf == nil {
|
||||
cf = nilCloseFunc
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
c := &contextCloser{
|
||||
ctx: ctx,
|
||||
|
Reference in New Issue
Block a user