Add grpc.SetHeader and ServerStream.SetHeader

This commit is contained in:
Menghan Li
2016-08-29 15:16:08 -07:00
parent 2131fedea9
commit e790079956
5 changed files with 406 additions and 32 deletions

View File

@ -865,12 +865,26 @@ func (s *Server) testingCloseConns() {
s.mu.Unlock()
}
// SendHeader sends header metadata. It may be called at most once from a unary
// RPC handler. The ctx is the RPC handler's Context or one derived from it.
func SendHeader(ctx context.Context, md metadata.MD) error {
// SetHeader sets the header metadata.
// When called multiple times, all the provided metadata will be merged.
// All the metadata will be sent out when one of the following happens:
// - grpc.SendHeader() is called;
// - The first response is sent out;
// - An RPC status is sent out (error or success).
func SetHeader(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 {
return nil
}
stream, ok := transport.StreamFromContext(ctx)
if !ok {
return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
}
return stream.SetHeader(md)
}
// SendHeader sends header metadata. It may be called at most once.
// The provided md and headers set by SetHeader() will be sent.
func SendHeader(ctx context.Context, md metadata.MD) error {
stream, ok := transport.StreamFromContext(ctx)
if !ok {
return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
@ -887,7 +901,6 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
// The ctx is the RPC handler's Context or one derived from it.
func SetTrailer(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 {
return nil

View File

@ -410,9 +410,16 @@ func (cs *clientStream) finish(err error) {
// ServerStream defines the interface a server stream has to satisfy.
type ServerStream interface {
// SendHeader sends the header metadata. It should not be called
// after SendProto. It fails if called multiple times or if
// called after SendProto.
// SetHeader sets the header metadata. It may be called multiple times.
// When call multiple times, all the provided metadata will be merged.
// All the metadata will be sent out when one of the following happens:
// - ServerStream.SendHeader() is called;
// - The first response is sent out;
// - An RPC status is sent out (error or success).
SetHeader(metadata.MD) error
// SendHeader sends the header metadata.
// The provided md and headers set by SetHeader() will be sent.
// It fails if called multiple times.
SendHeader(metadata.MD) error
// SetTrailer sets the trailer metadata which will be sent with the RPC status.
// When called more than once, all the provided metadata will be merged.
@ -441,6 +448,13 @@ func (ss *serverStream) Context() context.Context {
return ss.s.Context()
}
func (ss *serverStream) SetHeader(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
return ss.s.SetHeader(md)
}
func (ss *serverStream) SendHeader(md metadata.MD) error {
return ss.t.WriteHeader(ss.s, md)
}

View File

@ -74,6 +74,10 @@ var (
"key1": []string{"value1"},
"key2": []string{"value2"},
}
testMetadata2 = metadata.MD{
"key1": []string{"value12"},
"key2": []string{"value22"},
}
// For trailers:
testTrailerMetadata = metadata.MD{
"tkey1": []string{"trailerValue1"},
@ -95,6 +99,8 @@ var raceMode bool // set by race_test.go in race mode
type testServer struct {
security string // indicate the authentication protocol used by this server.
earlyFail bool // whether to error out the execution of a service handler prematurely.
setAndSendHeader bool // whether to call setHeader and sendHeader.
setHeaderOnly bool // whether to only call setHeader, not sendHeader.
multipleSetTrailer bool // whether to call setTrailer multiple times.
}
@ -138,8 +144,24 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
if _, exists := md[":authority"]; !exists {
return nil, grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md)
}
if err := grpc.SendHeader(ctx, md); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil)
if s.setAndSendHeader {
if err := grpc.SetHeader(ctx, md); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetHeader(_, %v) = %v, want <nil>", md, err)
}
if err := grpc.SendHeader(ctx, testMetadata2); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", testMetadata2, err)
}
} else if s.setHeaderOnly {
if err := grpc.SetHeader(ctx, md); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetHeader(_, %v) = %v, want <nil>", md, err)
}
if err := grpc.SetHeader(ctx, testMetadata2); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetHeader(_, %v) = %v, want <nil>", testMetadata2, err)
}
} else {
if err := grpc.SendHeader(ctx, md); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
}
}
if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
@ -240,8 +262,24 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
md, ok := metadata.FromContext(stream.Context())
if ok {
if err := stream.SendHeader(md); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
if s.setAndSendHeader {
if err := stream.SetHeader(md); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SetHeader(_, %v) = %v, want <nil>", stream, md, err)
}
if err := stream.SendHeader(testMetadata2); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(_, %v) = %v, want <nil>", stream, testMetadata2, err)
}
} else if s.setHeaderOnly {
if err := stream.SetHeader(md); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SetHeader(_, %v) = %v, want <nil>", stream, md, err)
}
if err := stream.SetHeader(testMetadata2); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SetHeader(_, %v) = %v, want <nil>", stream, testMetadata2, err)
}
} else {
if err := stream.SendHeader(md); err != nil {
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
}
}
stream.SetTrailer(testTrailerMetadata)
if s.multipleSetTrailer {
@ -1278,6 +1316,299 @@ func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) {
}
}
func TestSetAndSendHeaderUnaryRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testSetAndSendHeaderUnaryRPC(t, e)
}
}
// To test header metadata is sent on SendHeader().
func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, setAndSendHeader: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
const (
argSize = 1
respSize = 1
)
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
var header metadata.MD
ctx := metadata.NewContext(context.Background(), testMetadata)
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
}
}
func TestMultipleSetHeaderUnaryRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testMultipleSetHeaderUnaryRPC(t, e)
}
}
// To test header metadata is sent when sending response.
func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, setHeaderOnly: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
const (
argSize = 1
respSize = 1
)
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
var header metadata.MD
ctx := metadata.NewContext(context.Background(), testMetadata)
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
}
}
func TestMultipleSetHeaderUnaryRPCError(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testMultipleSetHeaderUnaryRPCError(t, e)
}
}
// To test header metadata is sent when sending status.
func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, setHeaderOnly: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
const (
argSize = 1
respSize = -1 // Invalid respSize to make RPC fail.
)
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
var header metadata.MD
ctx := metadata.NewContext(context.Background(), testMetadata)
if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err == nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <non-nil>", ctx, err)
}
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
}
}
func TestSetAndSendHeaderStreamingRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testSetAndSendHeaderStreamingRPC(t, e)
}
}
// To test header metadata is sent on SendHeader().
func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, setAndSendHeader: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
const (
argSize = 1
respSize = 1
)
ctx := metadata.NewContext(context.Background(), testMetadata)
stream, err := tc.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
if err := stream.CloseSend(); err != nil {
t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
}
if _, err := stream.Recv(); err != io.EOF {
t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err)
}
header, err := stream.Header()
if err != nil {
t.Fatalf("%v.Header() = _, %v, want _, <nil>", stream, err)
}
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
}
}
func TestMultipleSetHeaderStreamingRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testMultipleSetHeaderStreamingRPC(t, e)
}
}
// To test header metadata is sent when sending response.
func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, setHeaderOnly: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
const (
argSize = 1
respSize = 1
)
ctx := metadata.NewContext(context.Background(), testMetadata)
stream, err := tc.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: []*testpb.ResponseParameters{
{Size: proto.Int32(respSize)},
},
Payload: payload,
}
if err := stream.Send(req); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, req, err)
}
if _, err := stream.Recv(); err != nil {
t.Fatalf("%v.Recv() = %v, want <nil>", stream, err)
}
if err := stream.CloseSend(); err != nil {
t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
}
if _, err := stream.Recv(); err != io.EOF {
t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err)
}
header, err := stream.Header()
if err != nil {
t.Fatalf("%v.Header() = _, %v, want _, <nil>", stream, err)
}
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
}
}
func TestMultipleSetHeaderStreamingRPCError(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testMultipleSetHeaderStreamingRPCError(t, e)
}
}
// To test header metadata is sent when sending status.
func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security, setHeaderOnly: true})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())
const (
argSize = 1
respSize = -1
)
ctx := metadata.NewContext(context.Background(), testMetadata)
stream, err := tc.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
req := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: []*testpb.ResponseParameters{
{Size: proto.Int32(respSize)},
},
Payload: payload,
}
if err := stream.Send(req); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, req, err)
}
if _, err := stream.Recv(); err == nil {
t.Fatalf("%v.Recv() = %v, want <non-nil>", stream, err)
}
header, err := stream.Header()
if err != nil {
t.Fatalf("%v.Header() = _, %v, want _, <nil>", stream, err)
}
expectedHeader := metadata.Join(testMetadata, testMetadata2)
if !reflect.DeepEqual(header, expectedHeader) {
t.Fatalf("Received header metadata %v, want %v", header, expectedHeader)
}
if err := stream.CloseSend(); err != nil {
t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
}
}
// TestMalformedHTTP2Metedata verfies the returned error when the client
// sends an illegal metadata.
func TestMalformedHTTP2Metadata(t *testing.T) {

View File

@ -462,6 +462,14 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
return ErrIllegalHeaderWrite
}
s.headerOk = true
if md.Len() > 0 {
if s.header.Len() > 0 {
s.header = metadata.Join(s.header, md)
} else {
s.header = md
}
}
md = s.header
s.mu.Unlock()
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
@ -493,7 +501,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early
// OK is adopted.
func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error {
var headersSent bool
var headersSent, hasHeader bool
s.mu.Lock()
if s.state == streamDone {
s.mu.Unlock()
@ -502,7 +510,16 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
if s.headerOk {
headersSent = true
}
if s.header.Len() > 0 {
hasHeader = true
}
s.mu.Unlock()
if !headersSent && hasHeader {
t.WriteHeader(s, nil)
headersSent = true
}
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
@ -548,29 +565,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error {
}
if !s.headerOk {
writeHeaderFrame = true
s.headerOk = true
}
s.mu.Unlock()
if writeHeaderFrame {
if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil {
return err
}
t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
if s.sendCompress != "" {
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress})
}
p := http2.HeadersFrameParam{
StreamID: s.id,
BlockFragment: t.hBuf.Bytes(),
EndHeaders: true,
}
if err := t.framer.writeHeaders(false, p); err != nil {
t.Close()
return connectionErrorf(true, err, "transport: %v", err)
}
t.writableChan <- 0
t.WriteHeader(s, nil)
}
r := bytes.NewBuffer(data)
for {

View File

@ -286,9 +286,27 @@ func (s *Stream) StatusDesc() string {
return s.statusDesc
}
// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
func (s *Stream) SetHeader(md metadata.MD) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.headerOk || s.state == streamDone {
return ErrIllegalHeaderWrite
}
if md.Len() == 0 {
return nil
}
s.header = metadata.Join(s.header, md)
return nil
}
// SetTrailer sets the trailer metadata which will be sent with the RPC status
// by the server. This can be called multiple times. Server side only.
func (s *Stream) SetTrailer(md metadata.MD) error {
if md.Len() == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
s.trailer = metadata.Join(s.trailer, md)