From 93497c2d00fe26e11986a45936a54630104d04e1 Mon Sep 17 00:00:00 2001 From: Juan Batiz-Benet Date: Sat, 25 Oct 2014 00:48:56 -0700 Subject: [PATCH] muxer now uses ctxCloser --- net/mux/mux.go | 92 ++++++++++++++-------------------------- net/mux/mux_test.go | 27 ++++-------- net/net.go | 11 ++--- util/ctxcloser/closer.go | 5 +++ 4 files changed, 49 insertions(+), 86 deletions(-) diff --git a/net/mux/mux.go b/net/mux/mux.go index e717e67fb..a8865bb73 100644 --- a/net/mux/mux.go +++ b/net/mux/mux.go @@ -7,6 +7,7 @@ import ( msg "github.com/jbenet/go-ipfs/net/message" pb "github.com/jbenet/go-ipfs/net/mux/internal/pb" u "github.com/jbenet/go-ipfs/util" + ctxc "github.com/jbenet/go-ipfs/util/ctxcloser" context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context" proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto" @@ -14,6 +15,8 @@ import ( var log = u.Logger("muxer") +// ProtocolIDs used to identify each protocol. +// These should probably be defined elsewhere. var ( ProtocolID_Routing = pb.ProtocolID_Routing ProtocolID_Exchange = pb.ProtocolID_Exchange @@ -38,11 +41,6 @@ type Muxer struct { // Protocols are the multiplexed services. Protocols ProtocolMap - // cancel is the function to stop the Muxer - cancel context.CancelFunc - ctx context.Context - wg sync.WaitGroup - bwiLock sync.Mutex bwIn uint64 @@ -50,14 +48,25 @@ type Muxer struct { bwOut uint64 *msg.Pipe + ctxc.ContextCloser } // NewMuxer constructs a muxer given a protocol map. -func NewMuxer(mp ProtocolMap) *Muxer { - return &Muxer{ - Protocols: mp, - Pipe: msg.NewPipe(10), +func NewMuxer(ctx context.Context, mp ProtocolMap) *Muxer { + m := &Muxer{ + Protocols: mp, + Pipe: msg.NewPipe(10), + ContextCloser: ctxc.NewContextCloser(ctx, nil), } + + m.Children().Add(1) + go m.handleIncomingMessages() + for pid, proto := range m.Protocols { + m.Children().Add(1) + go m.handleOutgoingMessages(pid, proto) + } + + return m } // GetPipe implements the Protocol interface @@ -65,30 +74,7 @@ func (m *Muxer) GetPipe() *msg.Pipe { return m.Pipe } -// Start kicks off the Muxer goroutines. -func (m *Muxer) Start(ctx context.Context) error { - if m == nil { - panic("nix muxer") - } - - if m.cancel != nil { - return errors.New("Muxer already started.") - } - - // make a cancellable context. - m.ctx, m.cancel = context.WithCancel(ctx) - m.wg = sync.WaitGroup{} - - m.wg.Add(1) - go m.handleIncomingMessages() - for pid, proto := range m.Protocols { - m.wg.Add(1) - go m.handleOutgoingMessages(pid, proto) - } - - return nil -} - +// GetBandwidthTotals return the in/out bandwidth measured over this muxer. func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) { m.bwiLock.Lock() in = m.bwIn @@ -100,19 +86,6 @@ func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) { return } -// Stop stops muxer activity. -func (m *Muxer) Stop() { - if m.cancel == nil { - panic("muxer stopped twice.") - } - // issue cancel, and wipe func. - m.cancel() - m.cancel = context.CancelFunc(nil) - - // wait for everything to wind down. - m.wg.Wait() -} - // AddProtocol adds a Protocol with given ProtocolID to the Muxer. func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error { if _, found := m.Protocols[pid]; found { @@ -126,28 +99,26 @@ func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error { // handleIncoming consumes the messages on the m.Incoming channel and // routes them appropriately (to the protocols). func (m *Muxer) handleIncomingMessages() { - defer m.wg.Done() + defer m.Children().Done() for { - if m == nil { - panic("nil muxer") - } - select { + case <-m.Closing(): + return + case msg, more := <-m.Incoming: if !more { return } + m.Children().Add(1) go m.handleIncomingMessage(msg) - - case <-m.ctx.Done(): - return } } } // handleIncomingMessage routes message to the appropriate protocol. func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { + defer m.Children().Done() m.bwiLock.Lock() // TODO: compensate for overhead @@ -169,8 +140,7 @@ func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { select { case proto.GetPipe().Incoming <- m2: - case <-m.ctx.Done(): - log.Error(m.ctx.Err()) + case <-m.Closing(): return } } @@ -178,7 +148,7 @@ func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) { // handleOutgoingMessages consumes the messages on the proto.Outgoing channel, // wraps them and sends them out. func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { - defer m.wg.Done() + defer m.Children().Done() for { select { @@ -186,9 +156,10 @@ func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { if !more { return } + m.Children().Add(1) go m.handleOutgoingMessage(pid, msg) - case <-m.ctx.Done(): + case <-m.Closing(): return } } @@ -196,6 +167,8 @@ func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) { // handleOutgoingMessage wraps out a message and sends it out the func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) { + defer m.Children().Done() + data, err := wrapData(m1.Data(), pid) if err != nil { log.Errorf("muxer serializing error: %v", err) @@ -204,13 +177,14 @@ func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) { m.bwoLock.Lock() // TODO: compensate for overhead + // TODO(jbenet): switch this to a goroutine to prevent sync waiting. m.bwOut += uint64(len(data)) m.bwoLock.Unlock() m2 := msg.New(m1.Peer(), data) select { case m.GetPipe().Outgoing <- m2: - case <-m.ctx.Done(): + case <-m.Closing(): return } } diff --git a/net/mux/mux_test.go b/net/mux/mux_test.go index 72187893b..3b0235820 100644 --- a/net/mux/mux_test.go +++ b/net/mux/mux_test.go @@ -54,23 +54,20 @@ func testWrappedMsg(t *testing.T, m msg.NetMessage, pid pb.ProtocolID, data []by } func TestSimpleMuxer(t *testing.T) { + ctx := context.Background() // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Routing - mux1 := NewMuxer(ProtocolMap{ + mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - // run muxer - ctx := context.Background() - mux1.Start(ctx) - // test outgoing p1 for _, s := range []string{"foo", "bar", "baz"} { p1.Outgoing <- msg.New(peer1, []byte(s)) @@ -105,23 +102,21 @@ func TestSimpleMuxer(t *testing.T) { } func TestSimultMuxer(t *testing.T) { + // run muxer + ctx, cancel := context.WithCancel(context.Background()) // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Identify - mux1 := NewMuxer(ProtocolMap{ + mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - // run muxer - ctx, cancel := context.WithCancel(context.Background()) - mux1.Start(ctx) - // counts total := 10000 speed := time.Microsecond * 1 @@ -214,22 +209,20 @@ func TestSimultMuxer(t *testing.T) { } func TestStopping(t *testing.T) { + ctx := context.Background() // setup p1 := &TestProtocol{Pipe: msg.NewPipe(10)} p2 := &TestProtocol{Pipe: msg.NewPipe(10)} pid1 := pb.ProtocolID_Test pid2 := pb.ProtocolID_Identify - mux1 := NewMuxer(ProtocolMap{ + mux1 := NewMuxer(ctx, ProtocolMap{ pid1: p1, pid2: p2, }) peer1 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275aaaaaa") // peer2 := newPeer(t, "11140beec7b5ea3f0fdbc95d0dd47f3c5bc275bbbbbb") - // run muxer - mux1.Start(context.Background()) - // test outgoing p1 for _, s := range []string{"foo1", "bar1", "baz1"} { p1.Outgoing <- msg.New(peer1, []byte(s)) @@ -246,10 +239,7 @@ func TestStopping(t *testing.T) { testMsg(t, <-p1.Incoming, []byte(s)) } - mux1.Stop() - if mux1.cancel != nil { - t.Error("mux.cancel should be nil") - } + mux1.Close() // waits // test outgoing p1 for _, s := range []string{"foo3", "bar3", "baz3"} { @@ -274,5 +264,4 @@ func TestStopping(t *testing.T) { case <-time.After(time.Millisecond): } } - } diff --git a/net/net.go b/net/net.go index de433546a..8d6d34d30 100644 --- a/net/net.go +++ b/net/net.go @@ -36,17 +36,12 @@ func NewIpfsNetwork(ctx context.Context, local peer.Peer, in := &IpfsNetwork{ local: local, - muxer: mux.NewMuxer(*pmap), + muxer: mux.NewMuxer(ctx, *pmap), ctx: ctx, cancel: cancel, } - err := in.muxer.Start(ctx) - if err != nil { - cancel() - return nil, err - } - + var err error in.swarm, err = swarm.NewSwarm(ctx, local, peers) if err != nil { cancel() @@ -101,7 +96,7 @@ func (n *IpfsNetwork) Close() error { } n.swarm.Close() - n.muxer.Stop() + n.muxer.Close() n.cancel() n.cancel = nil diff --git a/util/ctxcloser/closer.go b/util/ctxcloser/closer.go index e04178c24..348951ad6 100644 --- a/util/ctxcloser/closer.go +++ b/util/ctxcloser/closer.go @@ -9,6 +9,8 @@ import ( // CloseFunc is a function used to close a ContextCloser type CloseFunc func() error +var nilCloseFunc = func() error { return nil } + // ContextCloser is an interface for services able to be opened and closed. // It has a parent Context, and Children. But ContextCloser is not a proper // "tree" like the Context tree. It is more like a Context-WaitGroup hybrid. @@ -92,6 +94,9 @@ type contextCloser struct { // NewContextCloser constructs and returns a ContextCloser. It will call // cf CloseFunc before its Done() Wait signals fire. func NewContextCloser(ctx context.Context, cf CloseFunc) ContextCloser { + if cf == nil { + cf = nilCloseFunc + } ctx, cancel := context.WithCancel(ctx) c := &contextCloser{ ctx: ctx,