diff --git a/server.go b/server.go index 997b24fa..0063906e 100644 --- a/server.go +++ b/server.go @@ -919,7 +919,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. } return nil } - reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt) + ctx := NewContextWithServerTransportStream(stream.Context(), stream) + reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) if appErr != nil { appStatus, ok := status.FromError(appErr) if !ok { @@ -995,7 +996,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp sh.HandleRPC(stream.Context(), end) }() } + ctx := NewContextWithServerTransportStream(stream.Context(), stream) ss := &serverStream{ + ctx: ctx, t: t, s: stream, p: &parser{r: stream}, @@ -1089,7 +1092,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp ss.mu.Unlock() } return t.WriteStatus(ss.s, status.New(codes.OK, "")) - } func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) { @@ -1171,6 +1173,40 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str } } +// The key to save ServerTransportStream in the context. +type streamKey struct{} + +// NewContextWithServerTransportStream creates a new context from ctx and +// attaches stream to it. +// +// This API is EXPERIMENTAL. +func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context { + return context.WithValue(ctx, streamKey{}, stream) +} + +// ServerTransportStream is a minimal interface that a transport stream must +// implement. This can be used to mock an actual transport stream for tests of +// handler code that use, for example, grpc.SetHeader (which requires some +// stream to be in context). +// +// See also NewContextWithServerTransportStream. +// +// This API is EXPERIMENTAL. +type ServerTransportStream interface { + Method() string + SetHeader(md metadata.MD) error + SendHeader(md metadata.MD) error + SetTrailer(md metadata.MD) error +} + +// serverStreamFromContext returns the server stream saved in ctx. Returns +// nil if the given context has no stream associated with it (which implies +// it is not an RPC invocation context). +func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream { + s, _ := ctx.Value(streamKey{}).(ServerTransportStream) + return s +} + // Stop stops the gRPC server. It immediately closes all open // connections and listeners. // It cancels all active RPCs on the server side and the corresponding @@ -1291,8 +1327,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream, ok := transport.StreamFromContext(ctx) - if !ok { + stream := serverTransportStreamFromContext(ctx) + if stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetHeader(md) @@ -1301,15 +1337,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error { // SendHeader sends header metadata. It may be called at most once. // The provided md and headers set by SetHeader() will be sent. func SendHeader(ctx context.Context, md metadata.MD) error { - stream, ok := transport.StreamFromContext(ctx) - if !ok { + stream := serverTransportStreamFromContext(ctx) + if stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } - t := stream.ServerTransport() - if t == nil { - grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream) - } - if err := t.WriteHeader(stream, md); err != nil { + if err := stream.SendHeader(md); err != nil { return toRPCErr(err) } return nil @@ -1321,8 +1353,8 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream, ok := transport.StreamFromContext(ctx) - if !ok { + stream := serverTransportStreamFromContext(ctx) + if stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } return stream.SetTrailer(md) diff --git a/server_test.go b/server_test.go index 4a7a5244..a9fb05a3 100644 --- a/server_test.go +++ b/server_test.go @@ -25,7 +25,9 @@ import ( "testing" "time" + "golang.org/x/net/context" "google.golang.org/grpc/test/leakcheck" + "google.golang.org/grpc/transport" ) type emptyServiceServer interface{} @@ -122,3 +124,13 @@ func TestGetServiceInfo(t *testing.T) { t.Errorf("GetServiceInfo() = %+v, want %+v", info, want) } } + +func TestStreamContext(t *testing.T) { + expectedStream := &transport.Stream{} + ctx := NewContextWithServerTransportStream(context.Background(), expectedStream) + s := serverTransportStreamFromContext(ctx) + stream, ok := s.(*transport.Stream) + if !ok || expectedStream != stream { + t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream) + } +} diff --git a/stream.go b/stream.go index 34d7414f..ff376cb0 100644 --- a/stream.go +++ b/stream.go @@ -608,6 +608,7 @@ type ServerStream interface { // serverStream implements a server side Stream. type serverStream struct { + ctx context.Context t transport.ServerTransport s *transport.Stream p *parser @@ -628,7 +629,7 @@ type serverStream struct { } func (ss *serverStream) Context() context.Context { - return ss.s.Context() + return ss.ctx } func (ss *serverStream) SetHeader(md metadata.MD) error { @@ -731,9 +732,9 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { // MethodFromServerStream returns the method string for the input stream. // The returned string is in the format of "/service/method". func MethodFromServerStream(stream ServerStream) (string, bool) { - s, ok := transport.StreamFromContext(stream.Context()) - if !ok { - return "", ok + s := serverTransportStreamFromContext(stream.Context()) + if s == nil { + return "", false } - return s.Method(), ok + return s.Method(), true } diff --git a/transport/handler_server.go b/transport/handler_server.go index 9937fff9..1a5e96c5 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -354,8 +354,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} } ctx = metadata.NewIncomingContext(ctx, ht.headerMD) - ctx = peer.NewContext(ctx, pr) - s.ctx = newContextWithStream(ctx, s) + s.ctx = peer.NewContext(ctx, pr) if ht.stats != nil { s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ diff --git a/transport/http2_server.go b/transport/http2_server.go index eb73b722..97b214c6 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -307,10 +307,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( pr.AuthInfo = t.authInfo } s.ctx = peer.NewContext(s.ctx, pr) - // Cache the current stream to the context so that the server application - // can find out. Required when the server wants to send some metadata - // back to the client (unary call only). - s.ctx = newContextWithStream(s.ctx, s) // Attach the received metadata to the context. if len(state.mdata) > 0 { s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) diff --git a/transport/transport.go b/transport/transport.go index e68f89ec..e0c1e343 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -366,6 +366,14 @@ func (s *Stream) SetHeader(md metadata.MD) error { return nil } +// SendHeader sends the given header metadata. The given metadata is +// combined with any metadata set by previous calls to SetHeader and +// then written to the transport stream. +func (s *Stream) SendHeader(md metadata.MD) error { + t := s.ServerTransport() + return t.WriteHeader(s, md) +} + // SetTrailer sets the trailer metadata which will be sent with the RPC status // by the server. This can be called multiple times. Server side only. func (s *Stream) SetTrailer(md metadata.MD) error { @@ -445,21 +453,6 @@ func (s *Stream) GoString() string { return fmt.Sprintf("", s, s.method) } -// The key to save transport.Stream in the context. -type streamKey struct{} - -// newContextWithStream creates a new context from ctx and attaches stream -// to it. -func newContextWithStream(ctx context.Context, stream *Stream) context.Context { - return context.WithValue(ctx, streamKey{}, stream) -} - -// StreamFromContext returns the stream saved in ctx. -func StreamFromContext(ctx context.Context) (s *Stream, ok bool) { - s, ok = ctx.Value(streamKey{}).(*Stream) - return -} - // state of transport type transportState int diff --git a/transport/transport_test.go b/transport/transport_test.go index ab6bc6ef..42261df9 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -1552,15 +1552,6 @@ func TestInvalidHeaderField(t *testing.T) { server.stop() } -func TestStreamContext(t *testing.T) { - expectedStream := &Stream{} - ctx := newContextWithStream(context.Background(), expectedStream) - s, ok := StreamFromContext(ctx) - if !ok || expectedStream != s { - t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream) - } -} - func TestIsReservedHeader(t *testing.T) { tests := []struct { h string