diff --git a/server.go b/server.go index 7ccd5662..a7ef6cc2 100644 --- a/server.go +++ b/server.go @@ -1298,10 +1298,12 @@ type ServerTransportStream interface { SetTrailer(md metadata.MD) error } -// serverStreamFromContext returns the server stream saved in ctx. Returns -// nil if the given context has no stream associated with it (which implies -// it is not an RPC invocation context). -func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream { +// ServerTransportStreamFromContext returns the ServerTransportStream saved in +// ctx. Returns nil if the given context has no stream associated with it +// (which implies it is not an RPC invocation context). +// +// This API is EXPERIMENTAL. +func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream { s, _ := ctx.Value(streamKey{}).(ServerTransportStream) return s } @@ -1438,7 +1440,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream := serverTransportStreamFromContext(ctx) + stream := ServerTransportStreamFromContext(ctx) if stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } @@ -1448,7 +1450,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error { // SendHeader sends header metadata. It may be called at most once. // The provided md and headers set by SetHeader() will be sent. func SendHeader(ctx context.Context, md metadata.MD) error { - stream := serverTransportStreamFromContext(ctx) + stream := ServerTransportStreamFromContext(ctx) if stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } @@ -1464,7 +1466,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } - stream := serverTransportStreamFromContext(ctx) + stream := ServerTransportStreamFromContext(ctx) if stream == nil { return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) } @@ -1474,7 +1476,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { // Method returns the method string for the server context. The returned // string is in the format of "/service/method". func Method(ctx context.Context) (string, bool) { - s := serverTransportStreamFromContext(ctx) + s := ServerTransportStreamFromContext(ctx) if s == nil { return "", false } diff --git a/server_test.go b/server_test.go index a9fb05a3..2b379f37 100644 --- a/server_test.go +++ b/server_test.go @@ -128,7 +128,7 @@ func TestGetServiceInfo(t *testing.T) { func TestStreamContext(t *testing.T) { expectedStream := &transport.Stream{} ctx := NewContextWithServerTransportStream(context.Background(), expectedStream) - s := serverTransportStreamFromContext(ctx) + s := ServerTransportStreamFromContext(ctx) stream, ok := s.(*transport.Stream) if !ok || expectedStream != stream { t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)