From e79007995682d7ba2a30e88898e958178aec93f0 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Mon, 29 Aug 2016 15:16:08 -0700 Subject: [PATCH] Add grpc.SetHeader and ServerStream.SetHeader --- server.go | 21 ++- stream.go | 20 ++- test/end2end_test.go | 339 +++++++++++++++++++++++++++++++++++++- transport/http2_server.go | 40 +++-- transport/transport.go | 18 ++ 5 files changed, 406 insertions(+), 32 deletions(-) diff --git a/server.go b/server.go index debbd79a..6a69bde9 100644 --- a/server.go +++ b/server.go @@ -865,12 +865,26 @@ func (s *Server) testingCloseConns() { s.mu.Unlock() } -// SendHeader sends header metadata. It may be called at most once from a unary -// RPC handler. The ctx is the RPC handler's Context or one derived from it. -func SendHeader(ctx context.Context, md metadata.MD) error { +// SetHeader sets the header metadata. +// When called multiple times, all the provided metadata will be merged. +// All the metadata will be sent out when one of the following happens: +// - grpc.SendHeader() is called; +// - The first response is sent out; +// - An RPC status is sent out (error or success). +func SetHeader(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil } + stream, ok := transport.StreamFromContext(ctx) + if !ok { + return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) + } + return stream.SetHeader(md) +} + +// SendHeader sends header metadata. It may be called at most once. +// The provided md and headers set by SetHeader() will be sent. +func SendHeader(ctx context.Context, md metadata.MD) error { stream, ok := transport.StreamFromContext(ctx) if !ok { return Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) @@ -887,7 +901,6 @@ func SendHeader(ctx context.Context, md metadata.MD) error { // SetTrailer sets the trailer metadata that will be sent when an RPC returns. // When called more than once, all the provided metadata will be merged. -// The ctx is the RPC handler's Context or one derived from it. func SetTrailer(ctx context.Context, md metadata.MD) error { if md.Len() == 0 { return nil diff --git a/stream.go b/stream.go index 68d777b5..46810544 100644 --- a/stream.go +++ b/stream.go @@ -410,9 +410,16 @@ func (cs *clientStream) finish(err error) { // ServerStream defines the interface a server stream has to satisfy. type ServerStream interface { - // SendHeader sends the header metadata. It should not be called - // after SendProto. It fails if called multiple times or if - // called after SendProto. + // SetHeader sets the header metadata. It may be called multiple times. + // When call multiple times, all the provided metadata will be merged. + // All the metadata will be sent out when one of the following happens: + // - ServerStream.SendHeader() is called; + // - The first response is sent out; + // - An RPC status is sent out (error or success). + SetHeader(metadata.MD) error + // SendHeader sends the header metadata. + // The provided md and headers set by SetHeader() will be sent. + // It fails if called multiple times. SendHeader(metadata.MD) error // SetTrailer sets the trailer metadata which will be sent with the RPC status. // When called more than once, all the provided metadata will be merged. @@ -441,6 +448,13 @@ func (ss *serverStream) Context() context.Context { return ss.s.Context() } +func (ss *serverStream) SetHeader(md metadata.MD) error { + if md.Len() == 0 { + return nil + } + return ss.s.SetHeader(md) +} + func (ss *serverStream) SendHeader(md metadata.MD) error { return ss.t.WriteHeader(ss.s, md) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 131db299..c4178ef8 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -74,6 +74,10 @@ var ( "key1": []string{"value1"}, "key2": []string{"value2"}, } + testMetadata2 = metadata.MD{ + "key1": []string{"value12"}, + "key2": []string{"value22"}, + } // For trailers: testTrailerMetadata = metadata.MD{ "tkey1": []string{"trailerValue1"}, @@ -95,6 +99,8 @@ var raceMode bool // set by race_test.go in race mode type testServer struct { security string // indicate the authentication protocol used by this server. earlyFail bool // whether to error out the execution of a service handler prematurely. + setAndSendHeader bool // whether to call setHeader and sendHeader. + setHeaderOnly bool // whether to only call setHeader, not sendHeader. multipleSetTrailer bool // whether to call setTrailer multiple times. } @@ -138,8 +144,24 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if _, exists := md[":authority"]; !exists { return nil, grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md) } - if err := grpc.SendHeader(ctx, md); err != nil { - return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil) + if s.setAndSendHeader { + if err := grpc.SetHeader(ctx, md); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SetHeader(_, %v) = %v, want ", md, err) + } + if err := grpc.SendHeader(ctx, testMetadata2); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want ", testMetadata2, err) + } + } else if s.setHeaderOnly { + if err := grpc.SetHeader(ctx, md); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SetHeader(_, %v) = %v, want ", md, err) + } + if err := grpc.SetHeader(ctx, testMetadata2); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SetHeader(_, %v) = %v, want ", testMetadata2, err) + } + } else { + if err := grpc.SendHeader(ctx, md); err != nil { + return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want ", md, err) + } } if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want ", testTrailerMetadata, err) @@ -240,8 +262,24 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { md, ok := metadata.FromContext(stream.Context()) if ok { - if err := stream.SendHeader(md); err != nil { - return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + if s.setAndSendHeader { + if err := stream.SetHeader(md); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SetHeader(_, %v) = %v, want ", stream, md, err) + } + if err := stream.SendHeader(testMetadata2); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SendHeader(_, %v) = %v, want ", stream, testMetadata2, err) + } + } else if s.setHeaderOnly { + if err := stream.SetHeader(md); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SetHeader(_, %v) = %v, want ", stream, md, err) + } + if err := stream.SetHeader(testMetadata2); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SetHeader(_, %v) = %v, want ", stream, testMetadata2, err) + } + } else { + if err := stream.SendHeader(md); err != nil { + return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + } } stream.SetTrailer(testTrailerMetadata) if s.multipleSetTrailer { @@ -1278,6 +1316,299 @@ func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) { } } +func TestSetAndSendHeaderUnaryRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testSetAndSendHeaderUnaryRPC(t, e) + } +} + +// To test header metadata is sent on SendHeader(). +func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, setAndSendHeader: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = 1 + ) + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + var header metadata.MD + ctx := metadata.NewContext(context.Background(), testMetadata) + if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil { + t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) + } + expectedHeader := metadata.Join(testMetadata, testMetadata2) + if !reflect.DeepEqual(header, expectedHeader) { + t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) + } +} + +func TestMultipleSetHeaderUnaryRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testMultipleSetHeaderUnaryRPC(t, e) + } +} + +// To test header metadata is sent when sending response. +func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, setHeaderOnly: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = 1 + ) + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + + var header metadata.MD + ctx := metadata.NewContext(context.Background(), testMetadata) + if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil { + t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) + } + expectedHeader := metadata.Join(testMetadata, testMetadata2) + if !reflect.DeepEqual(header, expectedHeader) { + t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) + } +} + +func TestMultipleSetHeaderUnaryRPCError(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testMultipleSetHeaderUnaryRPCError(t, e) + } +} + +// To test header metadata is sent when sending status. +func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, setHeaderOnly: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = -1 // Invalid respSize to make RPC fail. + ) + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(respSize), + Payload: payload, + } + var header metadata.MD + ctx := metadata.NewContext(context.Background(), testMetadata) + if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err == nil { + t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) + } + expectedHeader := metadata.Join(testMetadata, testMetadata2) + if !reflect.DeepEqual(header, expectedHeader) { + t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) + } +} + +func TestSetAndSendHeaderStreamingRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testSetAndSendHeaderStreamingRPC(t, e) + } +} + +// To test header metadata is sent on SendHeader(). +func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, setAndSendHeader: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = 1 + ) + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + if err := stream.CloseSend(); err != nil { + t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err) + } + + header, err := stream.Header() + if err != nil { + t.Fatalf("%v.Header() = _, %v, want _, ", stream, err) + } + expectedHeader := metadata.Join(testMetadata, testMetadata2) + if !reflect.DeepEqual(header, expectedHeader) { + t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) + } +} + +func TestMultipleSetHeaderStreamingRPC(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testMultipleSetHeaderStreamingRPC(t, e) + } +} + +// To test header metadata is sent when sending response. +func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, setHeaderOnly: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = 1 + ) + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: []*testpb.ResponseParameters{ + {Size: proto.Int32(respSize)}, + }, + Payload: payload, + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + if _, err := stream.Recv(); err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } + if err := stream.CloseSend(); err != nil { + t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err) + } + + header, err := stream.Header() + if err != nil { + t.Fatalf("%v.Header() = _, %v, want _, ", stream, err) + } + expectedHeader := metadata.Join(testMetadata, testMetadata2) + if !reflect.DeepEqual(header, expectedHeader) { + t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) + } + +} + +func TestMultipleSetHeaderStreamingRPCError(t *testing.T) { + defer leakCheck(t)() + for _, e := range listTestEnv() { + if e.name == "handler-tls" { + continue + } + testMultipleSetHeaderStreamingRPCError(t, e) + } +} + +// To test header metadata is sent when sending status. +func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) { + te := newTest(t, e) + te.startServer(&testServer{security: e.security, setHeaderOnly: true}) + defer te.tearDown() + tc := testpb.NewTestServiceClient(te.clientConn()) + + const ( + argSize = 1 + respSize = -1 + ) + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize) + if err != nil { + t.Fatal(err) + } + + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: []*testpb.ResponseParameters{ + {Size: proto.Int32(respSize)}, + }, + Payload: payload, + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + if _, err := stream.Recv(); err == nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } + + header, err := stream.Header() + if err != nil { + t.Fatalf("%v.Header() = _, %v, want _, ", stream, err) + } + expectedHeader := metadata.Join(testMetadata, testMetadata2) + if !reflect.DeepEqual(header, expectedHeader) { + t.Fatalf("Received header metadata %v, want %v", header, expectedHeader) + } + + if err := stream.CloseSend(); err != nil { + t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } +} + // TestMalformedHTTP2Metedata verfies the returned error when the client // sends an illegal metadata. func TestMalformedHTTP2Metadata(t *testing.T) { diff --git a/transport/http2_server.go b/transport/http2_server.go index f753c4f1..dfdd6386 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -462,6 +462,14 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { return ErrIllegalHeaderWrite } s.headerOk = true + if md.Len() > 0 { + if s.header.Len() > 0 { + s.header = metadata.Join(s.header, md) + } else { + s.header = md + } + } + md = s.header s.mu.Unlock() if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err @@ -493,7 +501,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { // TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early // OK is adopted. func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { - var headersSent bool + var headersSent, hasHeader bool s.mu.Lock() if s.state == streamDone { s.mu.Unlock() @@ -502,7 +510,16 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s if s.headerOk { headersSent = true } + if s.header.Len() > 0 { + hasHeader = true + } s.mu.Unlock() + + if !headersSent && hasHeader { + t.WriteHeader(s, nil) + headersSent = true + } + if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { return err } @@ -548,29 +565,10 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { } if !s.headerOk { writeHeaderFrame = true - s.headerOk = true } s.mu.Unlock() if writeHeaderFrame { - if _, err := wait(s.ctx, nil, nil, t.shutdownChan, t.writableChan); err != nil { - return err - } - t.hBuf.Reset() - t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) - t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) - if s.sendCompress != "" { - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: s.sendCompress}) - } - p := http2.HeadersFrameParam{ - StreamID: s.id, - BlockFragment: t.hBuf.Bytes(), - EndHeaders: true, - } - if err := t.framer.writeHeaders(false, p); err != nil { - t.Close() - return connectionErrorf(true, err, "transport: %v", err) - } - t.writableChan <- 0 + t.WriteHeader(s, nil) } r := bytes.NewBuffer(data) for { diff --git a/transport/transport.go b/transport/transport.go index 3d6b6a6d..8f42dbc8 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -286,9 +286,27 @@ func (s *Stream) StatusDesc() string { return s.statusDesc } +// SetHeader sets the header metadata. This can be called multiple times. +// Server side only. +func (s *Stream) SetHeader(md metadata.MD) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.headerOk || s.state == streamDone { + return ErrIllegalHeaderWrite + } + if md.Len() == 0 { + return nil + } + s.header = metadata.Join(s.header, md) + return nil +} + // SetTrailer sets the trailer metadata which will be sent with the RPC status // by the server. This can be called multiple times. Server side only. func (s *Stream) SetTrailer(md metadata.MD) error { + if md.Len() == 0 { + return nil + } s.mu.Lock() defer s.mu.Unlock() s.trailer = metadata.Join(s.trailer, md)