Allow storing alternate transport.ServerStream implementations in context (#1904)
This commit is contained in:

committed by
dfawley

parent
031ee13cfe
commit
57640c0e6f
58
server.go
58
server.go
@ -919,7 +919,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
|
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
|
||||||
|
reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
|
||||||
if appErr != nil {
|
if appErr != nil {
|
||||||
appStatus, ok := status.FromError(appErr)
|
appStatus, ok := status.FromError(appErr)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -995,7 +996,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
sh.HandleRPC(stream.Context(), end)
|
sh.HandleRPC(stream.Context(), end)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
|
||||||
ss := &serverStream{
|
ss := &serverStream{
|
||||||
|
ctx: ctx,
|
||||||
t: t,
|
t: t,
|
||||||
s: stream,
|
s: stream,
|
||||||
p: &parser{r: stream},
|
p: &parser{r: stream},
|
||||||
@ -1089,7 +1092,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||||||
ss.mu.Unlock()
|
ss.mu.Unlock()
|
||||||
}
|
}
|
||||||
return t.WriteStatus(ss.s, status.New(codes.OK, ""))
|
return t.WriteStatus(ss.s, status.New(codes.OK, ""))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
|
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
|
||||||
@ -1171,6 +1173,40 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The key to save ServerTransportStream in the context.
|
||||||
|
type streamKey struct{}
|
||||||
|
|
||||||
|
// NewContextWithServerTransportStream creates a new context from ctx and
|
||||||
|
// attaches stream to it.
|
||||||
|
//
|
||||||
|
// This API is EXPERIMENTAL.
|
||||||
|
func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
|
||||||
|
return context.WithValue(ctx, streamKey{}, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerTransportStream is a minimal interface that a transport stream must
|
||||||
|
// implement. This can be used to mock an actual transport stream for tests of
|
||||||
|
// handler code that use, for example, grpc.SetHeader (which requires some
|
||||||
|
// stream to be in context).
|
||||||
|
//
|
||||||
|
// See also NewContextWithServerTransportStream.
|
||||||
|
//
|
||||||
|
// This API is EXPERIMENTAL.
|
||||||
|
type ServerTransportStream interface {
|
||||||
|
Method() string
|
||||||
|
SetHeader(md metadata.MD) error
|
||||||
|
SendHeader(md metadata.MD) error
|
||||||
|
SetTrailer(md metadata.MD) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// serverStreamFromContext returns the server stream saved in ctx. Returns
|
||||||
|
// nil if the given context has no stream associated with it (which implies
|
||||||
|
// it is not an RPC invocation context).
|
||||||
|
func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
||||||
|
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the gRPC server. It immediately closes all open
|
// Stop stops the gRPC server. It immediately closes all open
|
||||||
// connections and listeners.
|
// connections and listeners.
|
||||||
// It cancels all active RPCs on the server side and the corresponding
|
// It cancels all active RPCs on the server side and the corresponding
|
||||||
@ -1291,8 +1327,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
|||||||
if md.Len() == 0 {
|
if md.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stream, ok := transport.StreamFromContext(ctx)
|
stream := serverTransportStreamFromContext(ctx)
|
||||||
if !ok {
|
if stream == nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||||
}
|
}
|
||||||
return stream.SetHeader(md)
|
return stream.SetHeader(md)
|
||||||
@ -1301,15 +1337,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
|||||||
// SendHeader sends header metadata. It may be called at most once.
|
// SendHeader sends header metadata. It may be called at most once.
|
||||||
// The provided md and headers set by SetHeader() will be sent.
|
// The provided md and headers set by SetHeader() will be sent.
|
||||||
func SendHeader(ctx context.Context, md metadata.MD) error {
|
func SendHeader(ctx context.Context, md metadata.MD) error {
|
||||||
stream, ok := transport.StreamFromContext(ctx)
|
stream := serverTransportStreamFromContext(ctx)
|
||||||
if !ok {
|
if stream == nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||||
}
|
}
|
||||||
t := stream.ServerTransport()
|
if err := stream.SendHeader(md); err != nil {
|
||||||
if t == nil {
|
|
||||||
grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
|
|
||||||
}
|
|
||||||
if err := t.WriteHeader(stream, md); err != nil {
|
|
||||||
return toRPCErr(err)
|
return toRPCErr(err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -1321,8 +1353,8 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
|||||||
if md.Len() == 0 {
|
if md.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stream, ok := transport.StreamFromContext(ctx)
|
stream := serverTransportStreamFromContext(ctx)
|
||||||
if !ok {
|
if stream == nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||||
}
|
}
|
||||||
return stream.SetTrailer(md)
|
return stream.SetTrailer(md)
|
||||||
|
@ -25,7 +25,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc/test/leakcheck"
|
"google.golang.org/grpc/test/leakcheck"
|
||||||
|
"google.golang.org/grpc/transport"
|
||||||
)
|
)
|
||||||
|
|
||||||
type emptyServiceServer interface{}
|
type emptyServiceServer interface{}
|
||||||
@ -122,3 +124,13 @@ func TestGetServiceInfo(t *testing.T) {
|
|||||||
t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
|
t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStreamContext(t *testing.T) {
|
||||||
|
expectedStream := &transport.Stream{}
|
||||||
|
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
|
||||||
|
s := serverTransportStreamFromContext(ctx)
|
||||||
|
stream, ok := s.(*transport.Stream)
|
||||||
|
if !ok || expectedStream != stream {
|
||||||
|
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
11
stream.go
11
stream.go
@ -608,6 +608,7 @@ type ServerStream interface {
|
|||||||
|
|
||||||
// serverStream implements a server side Stream.
|
// serverStream implements a server side Stream.
|
||||||
type serverStream struct {
|
type serverStream struct {
|
||||||
|
ctx context.Context
|
||||||
t transport.ServerTransport
|
t transport.ServerTransport
|
||||||
s *transport.Stream
|
s *transport.Stream
|
||||||
p *parser
|
p *parser
|
||||||
@ -628,7 +629,7 @@ type serverStream struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ss *serverStream) Context() context.Context {
|
func (ss *serverStream) Context() context.Context {
|
||||||
return ss.s.Context()
|
return ss.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ss *serverStream) SetHeader(md metadata.MD) error {
|
func (ss *serverStream) SetHeader(md metadata.MD) error {
|
||||||
@ -731,9 +732,9 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||||||
// MethodFromServerStream returns the method string for the input stream.
|
// MethodFromServerStream returns the method string for the input stream.
|
||||||
// The returned string is in the format of "/service/method".
|
// The returned string is in the format of "/service/method".
|
||||||
func MethodFromServerStream(stream ServerStream) (string, bool) {
|
func MethodFromServerStream(stream ServerStream) (string, bool) {
|
||||||
s, ok := transport.StreamFromContext(stream.Context())
|
s := serverTransportStreamFromContext(stream.Context())
|
||||||
if !ok {
|
if s == nil {
|
||||||
return "", ok
|
return "", false
|
||||||
}
|
}
|
||||||
return s.Method(), ok
|
return s.Method(), true
|
||||||
}
|
}
|
||||||
|
@ -354,8 +354,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||||||
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
|
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
|
||||||
}
|
}
|
||||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||||
ctx = peer.NewContext(ctx, pr)
|
s.ctx = peer.NewContext(ctx, pr)
|
||||||
s.ctx = newContextWithStream(ctx, s)
|
|
||||||
if ht.stats != nil {
|
if ht.stats != nil {
|
||||||
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||||
inHeader := &stats.InHeader{
|
inHeader := &stats.InHeader{
|
||||||
|
@ -307,10 +307,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||||||
pr.AuthInfo = t.authInfo
|
pr.AuthInfo = t.authInfo
|
||||||
}
|
}
|
||||||
s.ctx = peer.NewContext(s.ctx, pr)
|
s.ctx = peer.NewContext(s.ctx, pr)
|
||||||
// Cache the current stream to the context so that the server application
|
|
||||||
// can find out. Required when the server wants to send some metadata
|
|
||||||
// back to the client (unary call only).
|
|
||||||
s.ctx = newContextWithStream(s.ctx, s)
|
|
||||||
// Attach the received metadata to the context.
|
// Attach the received metadata to the context.
|
||||||
if len(state.mdata) > 0 {
|
if len(state.mdata) > 0 {
|
||||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
||||||
|
@ -366,6 +366,14 @@ func (s *Stream) SetHeader(md metadata.MD) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendHeader sends the given header metadata. The given metadata is
|
||||||
|
// combined with any metadata set by previous calls to SetHeader and
|
||||||
|
// then written to the transport stream.
|
||||||
|
func (s *Stream) SendHeader(md metadata.MD) error {
|
||||||
|
t := s.ServerTransport()
|
||||||
|
return t.WriteHeader(s, md)
|
||||||
|
}
|
||||||
|
|
||||||
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
||||||
// by the server. This can be called multiple times. Server side only.
|
// by the server. This can be called multiple times. Server side only.
|
||||||
func (s *Stream) SetTrailer(md metadata.MD) error {
|
func (s *Stream) SetTrailer(md metadata.MD) error {
|
||||||
@ -445,21 +453,6 @@ func (s *Stream) GoString() string {
|
|||||||
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
|
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The key to save transport.Stream in the context.
|
|
||||||
type streamKey struct{}
|
|
||||||
|
|
||||||
// newContextWithStream creates a new context from ctx and attaches stream
|
|
||||||
// to it.
|
|
||||||
func newContextWithStream(ctx context.Context, stream *Stream) context.Context {
|
|
||||||
return context.WithValue(ctx, streamKey{}, stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StreamFromContext returns the stream saved in ctx.
|
|
||||||
func StreamFromContext(ctx context.Context) (s *Stream, ok bool) {
|
|
||||||
s, ok = ctx.Value(streamKey{}).(*Stream)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// state of transport
|
// state of transport
|
||||||
type transportState int
|
type transportState int
|
||||||
|
|
||||||
|
@ -1552,15 +1552,6 @@ func TestInvalidHeaderField(t *testing.T) {
|
|||||||
server.stop()
|
server.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamContext(t *testing.T) {
|
|
||||||
expectedStream := &Stream{}
|
|
||||||
ctx := newContextWithStream(context.Background(), expectedStream)
|
|
||||||
s, ok := StreamFromContext(ctx)
|
|
||||||
if !ok || expectedStream != s {
|
|
||||||
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsReservedHeader(t *testing.T) {
|
func TestIsReservedHeader(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
h string
|
h string
|
||||||
|
Reference in New Issue
Block a user