From 9391d1a36d7df79595629d5fc29e82c679419000 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 7 Jun 2016 14:17:58 -0700 Subject: [PATCH] Ignore plus and semicolon and anything following in Content-Type --- transport/handler_server.go | 17 +++++++++++++++-- transport/handler_server_test.go | 22 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/transport/handler_server.go b/transport/handler_server.go index 7a4ae07b..efea686b 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -55,6 +55,19 @@ import ( "google.golang.org/grpc/peer" ) +func isGrpcContentType(t string) bool { + e := "application/grpc" + if !strings.HasPrefix(t, e) { + return false + } + // Support variations on the content-type + // (e.g. "application/grpc+blah", "application/grpc;blah"). + if len(t) > len(e) && t[len(e)] != '+' && t[len(e)] != ';' { + return false + } + return true +} + // NewServerHandlerTransport returns a ServerTransport handling gRPC // from inside an http.Handler. It requires that the http Server // supports HTTP/2. @@ -65,7 +78,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr if r.Method != "POST" { return nil, errors.New("invalid gRPC request method") } - if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + if !isGrpcContentType(r.Header.Get("Content-Type")) { return nil, errors.New("invalid gRPC request content-type") } if _, ok := w.(http.Flusher); !ok { @@ -97,7 +110,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr } for k, vv := range r.Header { k = strings.ToLower(k) - if isReservedHeader(k) && !isWhitelistedPseudoHeader(k){ + if isReservedHeader(k) && !isWhitelistedPseudoHeader(k) { continue } for _, v := range vv { diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 1fee72ff..4179cc5c 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -387,3 +387,25 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) } } + +func TestIsGrpcContentType(t *testing.T) { + tests := []struct { + h string + want bool + }{ + {"application/grpc", true}, + {"application/grpc+", true}, + {"application/grpc+blah", true}, + {"application/grpc;", true}, + {"application/grpc;blah", true}, + {"application/grpcd", false}, + {"application/grpd", false}, + {"application/grp", false}, + } + for _, tt := range tests { + got := isGrpcContentType(tt.h) + if got != tt.want { + t.Errorf("isGrpcContentType(%q) = %v; want %v", tt.h, got, tt.want) + } + } +}