diff --git a/transport/handler_server.go b/transport/handler_server.go index d7e18a0b..490de932 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -192,7 +192,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, h := ht.rw.Header() h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) if statusDesc != "" { - h.Set("Grpc-Message", statusDesc) + h.Set("Grpc-Message", grpcMessageEncode(statusDesc)) } if md := s.Trailer(); len(md) > 0 { for k, vv := range md { diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 1fee72ff..0711d012 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) "Content-Type": {"application/grpc"}, "Trailer": {"Grpc-Status", "Grpc-Message"}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, - "Grpc-Message": {msg}, + "Grpc-Message": {grpcMessageEncode(msg)}, } if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) @@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { "Content-Type": {"application/grpc"}, "Trailer": {"Grpc-Status", "Grpc-Message"}, "Grpc-Status": {"4"}, - "Grpc-Message": {"too slow"}, + "Grpc-Message": {grpcMessageEncode("too slow")}, } if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) diff --git a/transport/http2_server.go b/transport/http2_server.go index 03164236..7ecd209c 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -485,7 +485,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s Name: "grpc-status", Value: strconv.Itoa(int(statusCode)), }) - t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: grpcMessageEncode(statusDesc)}) // Attach the trailer metadata. for k, v := range s.trailer { for _, entry := range v { diff --git a/transport/http_util.go b/transport/http_util.go index 73c12d5f..7e92f63d 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -149,7 +149,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) { } d.statusCode = codes.Code(code) case "grpc-message": - d.statusDesc = f.Value + d.statusDesc = grpcMessageDecode(f.Value) case "grpc-timeout": d.timeoutSet = true var err error diff --git a/transport/transport.go b/transport/transport.go index 6eca1b3b..1d44974f 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -43,6 +43,7 @@ import ( "fmt" "io" "net" + "strconv" "sync" "time" @@ -506,3 +507,50 @@ func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int return i, nil } } + +const ( + spaceByte = byte(int(' ')) + tildaByte = byte(int('~')) + percentByte = byte(int('%')) +) + +// matching https://github.com/grpc/grpc-java/pull/1517/files +func grpcMessageEncode(grpcMessage string) string { + if grpcMessage == "" { + return "" + } + var buffer bytes.Buffer + for _, c := range []byte(grpcMessage) { + if c >= spaceByte && c < tildaByte && c != percentByte { + _ = buffer.WriteByte(c) + } else { + _, _ = buffer.WriteString(fmt.Sprintf("%%%02X", c)) + } + } + return buffer.String() +} + +// matching https://github.com/grpc/grpc-java/pull/1517/files +func grpcMessageDecode(encodedGrpcMessage string) string { + if encodedGrpcMessage == "" { + return "" + } + var buffer bytes.Buffer + data := []byte(encodedGrpcMessage) + lenData := len(data) + for i := 0; i < lenData; i++ { + c := data[i] + if c == percentByte && i+2 < lenData { + parsed, err := strconv.ParseInt(string(data[i+1:i+3]), 16, 8) + if err != nil { + _ = buffer.WriteByte(c) + } else { + _ = buffer.WriteByte(byte(parsed)) + i += 2 + } + } else { + _ = buffer.WriteByte(c) + } + } + return buffer.String() +} diff --git a/transport/transport_test.go b/transport/transport_test.go index c9a95328..6a15bf30 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -720,3 +720,29 @@ func TestIsReservedHeader(t *testing.T) { } } } + +func TestGrpcMessageEncode(t *testing.T) { + testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00") + testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25") +} + +func TestGrpcMessageDecode(t *testing.T) { + testGrpcMessageDecode(t, "Hello", "Hello") + testGrpcMessageDecode(t, "H%61o", "Hao") + testGrpcMessageDecode(t, "H%6", "H%6") + testGrpcMessageDecode(t, "%G0", "%G0") +} + +func testGrpcMessageEncode(t *testing.T, input string, expected string) { + actual := grpcMessageEncode(input) + if expected != actual { + t.Errorf("Expected %s from grpcMessageEncode, got %s", expected, actual) + } +} + +func testGrpcMessageDecode(t *testing.T, input string, expected string) { + actual := grpcMessageDecode(input) + if expected != actual { + t.Errorf("Expected %s from grpcMessageDecode, got %s", expected, actual) + } +}