Allow storing alternate transport.ServerStream implementations in context (#1904)
This commit is contained in:

committed by
dfawley

parent
031ee13cfe
commit
57640c0e6f
58
server.go
58
server.go
@ -919,7 +919,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
|
||||
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
|
||||
reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
|
||||
if appErr != nil {
|
||||
appStatus, ok := status.FromError(appErr)
|
||||
if !ok {
|
||||
@ -995,7 +996,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
||||
sh.HandleRPC(stream.Context(), end)
|
||||
}()
|
||||
}
|
||||
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
|
||||
ss := &serverStream{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{r: stream},
|
||||
@ -1089,7 +1092,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
||||
ss.mu.Unlock()
|
||||
}
|
||||
return t.WriteStatus(ss.s, status.New(codes.OK, ""))
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
|
||||
@ -1171,6 +1173,40 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
||||
}
|
||||
}
|
||||
|
||||
// The key to save ServerTransportStream in the context.
|
||||
type streamKey struct{}
|
||||
|
||||
// NewContextWithServerTransportStream creates a new context from ctx and
|
||||
// attaches stream to it.
|
||||
//
|
||||
// This API is EXPERIMENTAL.
|
||||
func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
|
||||
return context.WithValue(ctx, streamKey{}, stream)
|
||||
}
|
||||
|
||||
// ServerTransportStream is a minimal interface that a transport stream must
|
||||
// implement. This can be used to mock an actual transport stream for tests of
|
||||
// handler code that use, for example, grpc.SetHeader (which requires some
|
||||
// stream to be in context).
|
||||
//
|
||||
// See also NewContextWithServerTransportStream.
|
||||
//
|
||||
// This API is EXPERIMENTAL.
|
||||
type ServerTransportStream interface {
|
||||
Method() string
|
||||
SetHeader(md metadata.MD) error
|
||||
SendHeader(md metadata.MD) error
|
||||
SetTrailer(md metadata.MD) error
|
||||
}
|
||||
|
||||
// serverStreamFromContext returns the server stream saved in ctx. Returns
|
||||
// nil if the given context has no stream associated with it (which implies
|
||||
// it is not an RPC invocation context).
|
||||
func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
||||
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
|
||||
return s
|
||||
}
|
||||
|
||||
// Stop stops the gRPC server. It immediately closes all open
|
||||
// connections and listeners.
|
||||
// It cancels all active RPCs on the server side and the corresponding
|
||||
@ -1291,8 +1327,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
||||
if md.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
stream, ok := transport.StreamFromContext(ctx)
|
||||
if !ok {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
return stream.SetHeader(md)
|
||||
@ -1301,15 +1337,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
||||
// SendHeader sends header metadata. It may be called at most once.
|
||||
// The provided md and headers set by SetHeader() will be sent.
|
||||
func SendHeader(ctx context.Context, md metadata.MD) error {
|
||||
stream, ok := transport.StreamFromContext(ctx)
|
||||
if !ok {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
t := stream.ServerTransport()
|
||||
if t == nil {
|
||||
grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
|
||||
}
|
||||
if err := t.WriteHeader(stream, md); err != nil {
|
||||
if err := stream.SendHeader(md); err != nil {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
return nil
|
||||
@ -1321,8 +1353,8 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
||||
if md.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
stream, ok := transport.StreamFromContext(ctx)
|
||||
if !ok {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
return stream.SetTrailer(md)
|
||||
|
Reference in New Issue
Block a user