change CallBack to handler, and move ctx out of Stats.

This commit is contained in:
Menghan Li
2016-10-31 16:42:23 -07:00
parent 46e80bf1f6
commit 7984a9c679
5 changed files with 72 additions and 97 deletions

View File

@ -559,9 +559,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
cbuf = new(bytes.Buffer) cbuf = new(bytes.Buffer)
} }
if stats.On() { if stats.On() {
outgoingPayloadStats = &stats.OutgoingPayloadStats{ outgoingPayloadStats = &stats.OutgoingPayloadStats{}
Ctx: stream.Context(),
}
} }
p, err := encode(s.opts.codec, msg, cp, cbuf, outgoingPayloadStats) p, err := encode(s.opts.codec, msg, cp, cbuf, outgoingPayloadStats)
if err != nil { if err != nil {
@ -577,7 +575,8 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
err = t.Write(stream, p, opts) err = t.Write(stream, p, opts)
if outgoingPayloadStats != nil { if outgoingPayloadStats != nil {
outgoingPayloadStats.SentTime = time.Now() outgoingPayloadStats.SentTime = time.Now()
stats.CallBack()(outgoingPayloadStats)
stats.Handle(stream.Context(), outgoingPayloadStats)
} }
return err return err
} }
@ -604,7 +603,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var incomingPayloadStats *stats.IncomingPayloadStats var incomingPayloadStats *stats.IncomingPayloadStats
if stats.On() { if stats.On() {
incomingPayloadStats = &stats.IncomingPayloadStats{ incomingPayloadStats = &stats.IncomingPayloadStats{
Ctx: stream.Context(),
ReceivedTime: time.Now(), ReceivedTime: time.Now(),
} }
} }
@ -675,7 +674,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if incomingPayloadStats != nil { if incomingPayloadStats != nil {
incomingPayloadStats.Data = req incomingPayloadStats.Data = req
incomingPayloadStats.Length = len(req) incomingPayloadStats.Length = len(req)
stats.CallBack()(incomingPayloadStats) stats.Handle(stream.Context(), incomingPayloadStats)
} }
if trInfo != nil { if trInfo != nil {
trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true)

View File

@ -52,8 +52,6 @@ type Stats interface {
// InitStats indicates an RPC just started. // InitStats indicates an RPC just started.
type InitStats struct { type InitStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// Method is the full RPC method string, i.e., /package.service/method. // Method is the full RPC method string, i.e., /package.service/method.
@ -70,8 +68,6 @@ func (s *InitStats) isStats() {}
// IncomingPayloadStats contains the information for a incoming payload. // IncomingPayloadStats contains the information for a incoming payload.
type IncomingPayloadStats struct { type IncomingPayloadStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// Data is the unencrypted message payload. // Data is the unencrypted message payload.
@ -88,8 +84,6 @@ func (s *IncomingPayloadStats) isStats() {}
// IncomingHeaderStats indicates a header is received. // IncomingHeaderStats indicates a header is received.
type IncomingHeaderStats struct { type IncomingHeaderStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// WireLength is the wire length of header. // WireLength is the wire length of header.
@ -100,8 +94,6 @@ func (s *IncomingHeaderStats) isStats() {}
// IncomingTrailerStats indicates a trailer is received. // IncomingTrailerStats indicates a trailer is received.
type IncomingTrailerStats struct { type IncomingTrailerStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// WireLength is the wire length of header. // WireLength is the wire length of header.
@ -112,8 +104,6 @@ func (s *IncomingTrailerStats) isStats() {}
// OutgoingPayloadStats contains the information for a outgoing payload. // OutgoingPayloadStats contains the information for a outgoing payload.
type OutgoingPayloadStats struct { type OutgoingPayloadStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// Data is the unencrypted message payload. // Data is the unencrypted message payload.
@ -130,8 +120,6 @@ func (s *OutgoingPayloadStats) isStats() {}
// OutgoingHeaderStats indicates a header is sent. // OutgoingHeaderStats indicates a header is sent.
type OutgoingHeaderStats struct { type OutgoingHeaderStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// WireLength is the wire length of header. // WireLength is the wire length of header.
@ -142,8 +130,6 @@ func (s *OutgoingHeaderStats) isStats() {}
// OutgoingTrailerStats indicates a trailer is sent. // OutgoingTrailerStats indicates a trailer is sent.
type OutgoingTrailerStats struct { type OutgoingTrailerStats struct {
// Ctx is the context associated with the RPC.
Ctx context.Context
// IsClient indicates if this stats is a client stats. // IsClient indicates if this stats is a client stats.
IsClient bool IsClient bool
// WireLength is the wire length of header. // WireLength is the wire length of header.
@ -153,8 +139,8 @@ type OutgoingTrailerStats struct {
func (s *OutgoingTrailerStats) isStats() {} func (s *OutgoingTrailerStats) isStats() {}
var ( var (
on = new(int32) on = new(int32)
f func(Stats) handler func(context.Context, Stats)
) )
// On indicates whether stats is started. // On indicates whether stats is started.
@ -162,15 +148,15 @@ func On() bool {
return atomic.LoadInt32(on) == 1 return atomic.LoadInt32(on) == 1
} }
// CallBack returns the call back function registered by user to process the stats. // Handle returns the call back function registered by user to process the stats.
func CallBack() func(Stats) { func Handle(ctx context.Context, s Stats) {
return f handler(ctx, s)
} }
// RegisterCallBack registers the user callback function and starts the stats collection. // RegisterHandler registers the user handler function and starts the stats collection.
// This callback function will be called to process the stats. // This handler function will be called to process the stats.
func RegisterCallBack(cb func(Stats)) { func RegisterHandler(f func(context.Context, Stats)) {
f = cb handler = f
start() start()
} }

View File

@ -48,7 +48,7 @@ import (
) )
func TestStartStop(t *testing.T) { func TestStartStop(t *testing.T) {
stats.RegisterCallBack(nil) stats.RegisterHandler(nil)
defer stats.Stop() // Stop stats in the case of the first Fatalf. defer stats.Stop() // Stop stats in the case of the first Fatalf.
if stats.On() != true { if stats.On() != true {
t.Fatalf("after start.RegisterCallBack(_), stats.On() = false, want true") t.Fatalf("after start.RegisterCallBack(_), stats.On() = false, want true")
@ -268,17 +268,21 @@ type expectedData struct {
outgoing []*testpb.SimpleResponse outgoing []*testpb.SimpleResponse
} }
func checkInitStats(t *testing.T, s stats.Stats, e *expectedData) { type gotData struct {
t.Logf(" - %T", s) ctx context.Context
s stats.Stats
}
func checkInitStats(t *testing.T, d gotData, e *expectedData) {
var ( var (
ok bool ok bool
st *stats.InitStats st *stats.InitStats
) )
if st, ok = s.(*stats.InitStats); !ok { if st, ok = d.s.(*stats.InitStats); !ok {
t.Fatalf("got %T, want InitStats", s) t.Fatalf("got %T, want InitStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st IsClient = true, want false") t.Fatalf("st IsClient = true, want false")
@ -294,17 +298,16 @@ func checkInitStats(t *testing.T, s stats.Stats, e *expectedData) {
} }
} }
func checkIncomingHeaderStats(t *testing.T, s stats.Stats, e *expectedData) { func checkIncomingHeaderStats(t *testing.T, d gotData, e *expectedData) {
t.Logf(" - %T", s)
var ( var (
ok bool ok bool
st *stats.IncomingHeaderStats st *stats.IncomingHeaderStats
) )
if st, ok = s.(*stats.IncomingHeaderStats); !ok { if st, ok = d.s.(*stats.IncomingHeaderStats); !ok {
t.Fatalf("got %T, want IncomingHeaderStats", s) t.Fatalf("got %T, want IncomingHeaderStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st.IsClient = true, want false") t.Fatalf("st.IsClient = true, want false")
@ -315,17 +318,16 @@ func checkIncomingHeaderStats(t *testing.T, s stats.Stats, e *expectedData) {
} }
} }
func checkIncomingPayloadStats(t *testing.T, s stats.Stats, e *expectedData) { func checkIncomingPayloadStats(t *testing.T, d gotData, e *expectedData) {
t.Logf(" - %T", s)
var ( var (
ok bool ok bool
st *stats.IncomingPayloadStats st *stats.IncomingPayloadStats
) )
if st, ok = s.(*stats.IncomingPayloadStats); !ok { if st, ok = d.s.(*stats.IncomingPayloadStats); !ok {
t.Fatalf("got %T, want IncomingPayloadStats", s) t.Fatalf("got %T, want IncomingPayloadStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st IsClient = true, want false") t.Fatalf("st IsClient = true, want false")
@ -347,17 +349,16 @@ func checkIncomingPayloadStats(t *testing.T, s stats.Stats, e *expectedData) {
} }
} }
func checkIncomingTrailerStats(t *testing.T, s stats.Stats, e *expectedData) { func checkIncomingTrailerStats(t *testing.T, d gotData, e *expectedData) {
t.Logf(" - %T", s)
var ( var (
ok bool ok bool
st *stats.IncomingTrailerStats st *stats.IncomingTrailerStats
) )
if st, ok = s.(*stats.IncomingTrailerStats); !ok { if st, ok = d.s.(*stats.IncomingTrailerStats); !ok {
t.Fatalf("got %T, want IncomingTrailerStats", s) t.Fatalf("got %T, want IncomingTrailerStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st.IsClient = true, want false") t.Fatalf("st.IsClient = true, want false")
@ -368,17 +369,16 @@ func checkIncomingTrailerStats(t *testing.T, s stats.Stats, e *expectedData) {
} }
} }
func checkOutgoingHeaderStats(t *testing.T, s stats.Stats, e *expectedData) { func checkOutgoingHeaderStats(t *testing.T, d gotData, e *expectedData) {
t.Logf(" - %T", s)
var ( var (
ok bool ok bool
st *stats.OutgoingHeaderStats st *stats.OutgoingHeaderStats
) )
if st, ok = s.(*stats.OutgoingHeaderStats); !ok { if st, ok = d.s.(*stats.OutgoingHeaderStats); !ok {
t.Fatalf("got %T, want OutgoingHeaderStats", s) t.Fatalf("got %T, want OutgoingHeaderStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st IsClient = true, want false") t.Fatalf("st IsClient = true, want false")
@ -389,17 +389,16 @@ func checkOutgoingHeaderStats(t *testing.T, s stats.Stats, e *expectedData) {
} }
} }
func checkOutgoingPayloadStats(t *testing.T, s stats.Stats, e *expectedData) { func checkOutgoingPayloadStats(t *testing.T, d gotData, e *expectedData) {
t.Logf(" - %T", s)
var ( var (
ok bool ok bool
st *stats.OutgoingPayloadStats st *stats.OutgoingPayloadStats
) )
if st, ok = s.(*stats.OutgoingPayloadStats); !ok { if st, ok = d.s.(*stats.OutgoingPayloadStats); !ok {
t.Fatalf("got %T, want OutgoingPayloadStats", s) t.Fatalf("got %T, want OutgoingPayloadStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st IsClient = true, want false") t.Fatalf("st IsClient = true, want false")
@ -421,17 +420,16 @@ func checkOutgoingPayloadStats(t *testing.T, s stats.Stats, e *expectedData) {
} }
} }
func checkOutgoingTrailerStats(t *testing.T, s stats.Stats, e *expectedData) { func checkOutgoingTrailerStats(t *testing.T, d gotData, e *expectedData) {
t.Logf(" - %T", s)
var ( var (
ok bool ok bool
st *stats.OutgoingTrailerStats st *stats.OutgoingTrailerStats
) )
if st, ok = s.(*stats.OutgoingTrailerStats); !ok { if st, ok = d.s.(*stats.OutgoingTrailerStats); !ok {
t.Fatalf("got %T, want OutgoingTrailerStats", s) t.Fatalf("got %T, want OutgoingTrailerStats", d.s)
} }
if st.Ctx == nil { if d.ctx == nil {
t.Fatalf("st.Ctx = nil, want <non-nil>") t.Fatalf("d.ctx = nil, want <non-nil>")
} }
if st.IsClient { if st.IsClient {
t.Fatalf("st IsClient = true, want false") t.Fatalf("st IsClient = true, want false")
@ -445,12 +443,12 @@ func checkOutgoingTrailerStats(t *testing.T, s stats.Stats, e *expectedData) {
func TestServerStatsUnaryRPC(t *testing.T) { func TestServerStatsUnaryRPC(t *testing.T) {
var ( var (
mu sync.Mutex mu sync.Mutex
got []stats.Stats got []gotData
) )
stats.RegisterCallBack(func(s stats.Stats) { stats.RegisterHandler(func(ctx context.Context, s stats.Stats) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
got = append(got, s) got = append(got, gotData{ctx, s})
}) })
te := newTest(t, "") te := newTest(t, "")
@ -467,7 +465,7 @@ func TestServerStatsUnaryRPC(t *testing.T) {
outgoing: []*testpb.SimpleResponse{resp}, outgoing: []*testpb.SimpleResponse{resp},
} }
for i, f := range []func(t *testing.T, s stats.Stats, e *expectedData){ for i, f := range []func(t *testing.T, d gotData, e *expectedData){
checkInitStats, checkInitStats,
checkIncomingHeaderStats, checkIncomingHeaderStats,
checkIncomingPayloadStats, checkIncomingPayloadStats,
@ -486,12 +484,12 @@ func TestServerStatsUnaryRPC(t *testing.T) {
func TestServerStatsStreamingRPC(t *testing.T) { func TestServerStatsStreamingRPC(t *testing.T) {
var ( var (
mu sync.Mutex mu sync.Mutex
got []stats.Stats got []gotData
) )
stats.RegisterCallBack(func(s stats.Stats) { stats.RegisterHandler(func(ctx context.Context, s stats.Stats) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
got = append(got, s) got = append(got, gotData{ctx, s})
}) })
te := newTest(t, "gzip") te := newTest(t, "gzip")
@ -510,12 +508,12 @@ func TestServerStatsStreamingRPC(t *testing.T) {
outgoing: resps, outgoing: resps,
} }
checkFuncs := []func(t *testing.T, s stats.Stats, e *expectedData){ checkFuncs := []func(t *testing.T, d gotData, e *expectedData){
checkInitStats, checkInitStats,
checkIncomingHeaderStats, checkIncomingHeaderStats,
checkOutgoingHeaderStats, checkOutgoingHeaderStats,
} }
ioPayFuncs := []func(t *testing.T, s stats.Stats, e *expectedData){ ioPayFuncs := []func(t *testing.T, d gotData, e *expectedData){
checkIncomingPayloadStats, checkIncomingPayloadStats,
checkOutgoingPayloadStats, checkOutgoingPayloadStats,
} }

View File

@ -485,9 +485,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}() }()
var outgoingPayloadStats *stats.OutgoingPayloadStats var outgoingPayloadStats *stats.OutgoingPayloadStats
if stats.On() { if stats.On() {
outgoingPayloadStats = &stats.OutgoingPayloadStats{ outgoingPayloadStats = &stats.OutgoingPayloadStats{}
Ctx: ss.s.Context(),
}
} }
out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outgoingPayloadStats) out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outgoingPayloadStats)
defer func() { defer func() {
@ -504,7 +502,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
} }
if outgoingPayloadStats != nil { if outgoingPayloadStats != nil {
outgoingPayloadStats.SentTime = time.Now() outgoingPayloadStats.SentTime = time.Now()
stats.CallBack()(outgoingPayloadStats) stats.Handle(ss.s.Context(), outgoingPayloadStats)
} }
return nil return nil
} }
@ -526,9 +524,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
}() }()
var incomingPayloadStats *stats.IncomingPayloadStats var incomingPayloadStats *stats.IncomingPayloadStats
if stats.On() { if stats.On() {
incomingPayloadStats = &stats.IncomingPayloadStats{ incomingPayloadStats = &stats.IncomingPayloadStats{}
Ctx: ss.s.Context(),
}
} }
if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, incomingPayloadStats); err != nil { if err := recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize, incomingPayloadStats); err != nil {
if err == io.EOF { if err == io.EOF {
@ -540,7 +536,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
return toRPCErr(err) return toRPCErr(err)
} }
if incomingPayloadStats != nil { if incomingPayloadStats != nil {
stats.CallBack()(incomingPayloadStats) stats.Handle(ss.s.Context(), incomingPayloadStats)
} }
return nil return nil
} }

View File

@ -237,18 +237,16 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
} }
if stats.On() { if stats.On() {
initStats := &stats.InitStats{ initStats := &stats.InitStats{
Ctx: s.ctx,
Method: s.method, Method: s.method,
RemoteAddr: t.conn.RemoteAddr(), RemoteAddr: t.conn.RemoteAddr(),
LocalAddr: t.conn.LocalAddr(), LocalAddr: t.conn.LocalAddr(),
Encryption: s.recvCompress, Encryption: s.recvCompress,
} }
stats.CallBack()(initStats) stats.Handle(s.ctx, initStats)
incomingHeaderStats := &stats.IncomingHeaderStats{ incomingHeaderStats := &stats.IncomingHeaderStats{
Ctx: s.ctx,
WireLength: int(frame.Header().Length), WireLength: int(frame.Header().Length),
} }
stats.CallBack()(incomingHeaderStats) stats.Handle(s.ctx, incomingHeaderStats)
} }
handle(s) handle(s)
return return
@ -530,10 +528,9 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
} }
if stats.On() { if stats.On() {
outgoingHeaderStats := &stats.OutgoingHeaderStats{ outgoingHeaderStats := &stats.OutgoingHeaderStats{
Ctx: s.Context(),
WireLength: bufLen, WireLength: bufLen,
} }
stats.CallBack()(outgoingHeaderStats) stats.Handle(s.Context(), outgoingHeaderStats)
} }
t.writableChan <- 0 t.writableChan <- 0
return nil return nil
@ -594,10 +591,9 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
} }
if stats.On() { if stats.On() {
outgoingTrailerStats := &stats.OutgoingTrailerStats{ outgoingTrailerStats := &stats.OutgoingTrailerStats{
Ctx: s.Context(),
WireLength: bufLen, WireLength: bufLen,
} }
stats.CallBack()(outgoingTrailerStats) stats.Handle(s.Context(), outgoingTrailerStats)
} }
t.closeStream(s) t.closeStream(s)
t.writableChan <- 0 t.writableChan <- 0