diff --git a/swarm/swarm.go b/swarm/swarm.go index cf968984c..7c952962e 100644 --- a/swarm/swarm.go +++ b/swarm/swarm.go @@ -262,8 +262,14 @@ func (s *Swarm) fanOut() { continue } + wrapped, err := Wrap(msg.Data, PBWrapper_DHT_MESSAGE) + if err != nil { + s.Error(err) + continue + } + // queue it in the connection's buffer - conn.Outgoing.MsgChan <- msg.Data + conn.Outgoing.MsgChan <- wrapped } } } @@ -288,8 +294,14 @@ func (s *Swarm) fanIn(conn *Conn) { goto out } + wrapper, err := Unwrap(data) + if err != nil { + s.Error(err) + continue + } + // wrap it for consumers. - msg := &Message{Peer: conn.Peer, Data: data} + msg := &Message{Peer: conn.Peer, Data: wrapper.GetMessage()} s.Chan.Incoming <- msg } } @@ -399,4 +411,26 @@ func (s *Swarm) GetChan() *Chan { return s.Chan } +func Wrap(data []byte, typ PBWrapper_MessageType) ([]byte, error) { + wrapper := new(PBWrapper) + wrapper.Message = data + wrapper.Type = &typ + b, err := proto.Marshal(wrapper) + if err != nil { + return nil, err + } + return b, nil +} + +func Unwrap(data []byte) (*PBWrapper, error) { + mes := new(PBWrapper) + err := proto.Unmarshal(data, mes) + if err != nil { + return nil, err + } + + return mes, nil +} + +// Temporary to ensure that the Swarm always matches the Network interface as we are changing it var _ Network = &Swarm{} diff --git a/swarm/swarm_test.go b/swarm/swarm_test.go index 609288c38..2760e9a80 100644 --- a/swarm/swarm_test.go +++ b/swarm/swarm_test.go @@ -14,7 +14,7 @@ func pingListen(listener *net.TCPListener, peer *peer.Peer) { for { c, err := listener.Accept() if err == nil { - fmt.Println("accepeted") + fmt.Println("accepted") go pong(c, peer) } } @@ -29,11 +29,21 @@ func pong(c net.Conn, peer *peer.Peer) { fmt.Printf("error %v\n", err) return } - if string(data[:n]) != "ping" { - fmt.Printf("error: didn't receive ping: '%v'\n", data[:n]) + b, err := Unwrap(data[:n]) + if err != nil { + fmt.Printf("error %v\n", err) return } - err = mrw.WriteMsg([]byte("pong")) + if string(b.GetMessage()) != "ping" { + fmt.Printf("error: didn't receive ping: '%v'\n", b.GetMessage()) + return + } + data, err = Wrap([]byte("pong"), PBWrapper_DHT_MESSAGE) + if err != nil { + fmt.Printf("error %v\n", err) + return + } + err = mrw.WriteMsg(data) if err != nil { fmt.Printf("error %v\n", err) return