From e218c924aa070fc266f235cca96daf6bfe25e83f Mon Sep 17 00:00:00 2001
From: Menghan Li <menghanl@google.com>
Date: Thu, 14 Jun 2018 13:53:31 -0700
Subject: [PATCH] status: handle invalid utf-8 characters (#2109) (#2134)

fixes #2078

A status with invalid utf-8 characters could still be created, but invalid characters will be replaced with [Unicode replacement character](https://en.wikipedia.org/wiki/Specials_(Unicode_block)#Replacement_character) before being sent out. Those bytes will still be percent encoded.

All details added to this invalid status will be dropped.
---
 test/end2end_test.go        | 67 +++++++++++++++++++++++++++++++++++++
 transport/http2_server.go   |  7 ++--
 transport/http_util.go      | 38 ++++++++++++++-------
 transport/http_util_test.go | 36 +++++++++++++++++---
 4 files changed, 129 insertions(+), 19 deletions(-)

diff --git a/test/end2end_test.go b/test/end2end_test.go
index ba1d672e..6574a64f 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -6143,6 +6143,73 @@ func TestServeExitsWhenListenerClosed(t *testing.T) {
 	}
 }
 
+// Service handler returns status with invalid utf8 message.
+func TestStatusInvalidUTF8Message(t *testing.T) {
+	defer leakcheck.Check(t)
+
+	var (
+		origMsg = string([]byte{0xff, 0xfe, 0xfd})
+		wantMsg = "���"
+	)
+
+	ss := &stubServer{
+		emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
+			return nil, status.Errorf(codes.Internal, origMsg)
+		},
+	}
+	if err := ss.Start(nil); err != nil {
+		t.Fatalf("Error starting endpoint server: %v", err)
+	}
+	defer ss.Stop()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != wantMsg {
+		t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, status.Convert(err).Message(), wantMsg)
+	}
+}
+
+// Service handler returns status with details and invalid utf8 message. Proto
+// will fail to marshal the status because of the invalid utf8 message. Details
+// will be dropped when sending.
+func TestStatusInvalidUTF8Details(t *testing.T) {
+	defer leakcheck.Check(t)
+
+	var (
+		origMsg = string([]byte{0xff, 0xfe, 0xfd})
+		wantMsg = "���"
+	)
+
+	ss := &stubServer{
+		emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
+			st := status.New(codes.Internal, origMsg)
+			st, err := st.WithDetails(&testpb.Empty{})
+			if err != nil {
+				return nil, err
+			}
+			return nil, st.Err()
+		},
+	}
+	if err := ss.Start(nil); err != nil {
+		t.Fatalf("Error starting endpoint server: %v", err)
+	}
+	defer ss.Stop()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	_, err := ss.client.EmptyCall(ctx, &testpb.Empty{})
+	st := status.Convert(err)
+	if st.Message() != wantMsg {
+		t.Fatalf("ss.client.EmptyCall(_, _) = _, %v (msg %q); want _, err with msg %q", err, st.Message(), wantMsg)
+	}
+	if len(st.Details()) != 0 {
+		// Details should be dropped on the server side.
+		t.Fatalf("RPC status contain details: %v, want no details", st.Details())
+	}
+}
+
 func TestClientDoesntDeadlockWhileWritingErrornousLargeMessages(t *testing.T) {
 	defer leakcheck.Check(t)
 	for _, e := range listTestEnv() {
diff --git a/transport/http2_server.go b/transport/http2_server.go
index 3643e823..3303a9b1 100644
--- a/transport/http2_server.go
+++ b/transport/http2_server.go
@@ -38,6 +38,7 @@ import (
 	"google.golang.org/grpc/channelz"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/credentials"
+	"google.golang.org/grpc/grpclog"
 	"google.golang.org/grpc/internal/grpcrand"
 	"google.golang.org/grpc/keepalive"
 	"google.golang.org/grpc/metadata"
@@ -769,10 +770,10 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
 		stBytes, err := proto.Marshal(p)
 		if err != nil {
 			// TODO: return error instead, when callers are able to handle it.
-			panic(err)
+			grpclog.Errorf("transport: failed to marshal rpc status: %v, error: %v", p, err)
+		} else {
+			headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
 		}
-
-		headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)})
 	}
 
 	// Attach the trailer metadata.
diff --git a/transport/http_util.go b/transport/http_util.go
index 835c8126..fe555473 100644
--- a/transport/http_util.go
+++ b/transport/http_util.go
@@ -28,6 +28,7 @@ import (
 	"strconv"
 	"strings"
 	"time"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/proto"
 	"golang.org/x/net/http2"
@@ -442,11 +443,12 @@ const (
 )
 
 // encodeGrpcMessage is used to encode status code in header field
-// "grpc-message".
-// It checks to see if each individual byte in msg is an
-// allowable byte, and then either percent encoding or passing it through.
-// When percent encoding, the byte is converted into hexadecimal notation
-// with a '%' prepended.
+// "grpc-message". It does percent encoding and also replaces invalid utf-8
+// characters with Unicode replacement character.
+//
+// It checks to see if each individual byte in msg is an allowable byte, and
+// then either percent encoding or passing it through. When percent encoding,
+// the byte is converted into hexadecimal notation with a '%' prepended.
 func encodeGrpcMessage(msg string) string {
 	if msg == "" {
 		return ""
@@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string {
 
 func encodeGrpcMessageUnchecked(msg string) string {
 	var buf bytes.Buffer
-	lenMsg := len(msg)
-	for i := 0; i < lenMsg; i++ {
-		c := msg[i]
-		if c >= spaceByte && c < tildaByte && c != percentByte {
-			buf.WriteByte(c)
-		} else {
-			buf.WriteString(fmt.Sprintf("%%%02X", c))
+	for len(msg) > 0 {
+		r, size := utf8.DecodeRuneInString(msg)
+		for _, b := range []byte(string(r)) {
+			if size > 1 {
+				// If size > 1, r is not ascii. Always do percent encoding.
+				buf.WriteString(fmt.Sprintf("%%%02X", b))
+				continue
+			}
+
+			// The for loop is necessary even if size == 1. r could be
+			// utf8.RuneError.
+			//
+			// fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD".
+			if b >= spaceByte && b < tildaByte && b != percentByte {
+				buf.WriteByte(b)
+			} else {
+				buf.WriteString(fmt.Sprintf("%%%02X", b))
+			}
 		}
+		msg = msg[size:]
 	}
 	return buf.String()
 }
diff --git a/transport/http_util_test.go b/transport/http_util_test.go
index c3754781..1295a2f6 100644
--- a/transport/http_util_test.go
+++ b/transport/http_util_test.go
@@ -102,12 +102,14 @@ func TestEncodeGrpcMessage(t *testing.T) {
 	}{
 		{"", ""},
 		{"Hello", "Hello"},
-		{"my favorite character is \u0000", "my favorite character is %00"},
-		{"my favorite character is %", "my favorite character is %25"},
+		{"\u0000", "%00"},
+		{"%", "%25"},
+		{"系统", "%E7%B3%BB%E7%BB%9F"},
+		{string([]byte{0xff, 0xfe, 0xfd}), "%EF%BF%BD%EF%BF%BD%EF%BF%BD"},
 	} {
 		actual := encodeGrpcMessage(tt.input)
 		if tt.expected != actual {
-			t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
+			t.Errorf("encodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
 		}
 	}
 }
@@ -123,10 +125,36 @@ func TestDecodeGrpcMessage(t *testing.T) {
 		{"H%6", "H%6"},
 		{"%G0", "%G0"},
 		{"%E7%B3%BB%E7%BB%9F", "系统"},
+		{"%EF%BF%BD", "�"},
 	} {
 		actual := decodeGrpcMessage(tt.input)
 		if tt.expected != actual {
-			t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
+			t.Errorf("dncodeGrpcMessage(%q) = %q, want %q", tt.input, actual, tt.expected)
+		}
+	}
+}
+
+// Decode an encoded string should get the same thing back, except for invalid
+// utf8 chars.
+func TestDecodeEncodeGrpcMessage(t *testing.T) {
+	testCases := []struct {
+		orig string
+		want string
+	}{
+		{"", ""},
+		{"hello", "hello"},
+		{"h%6", "h%6"},
+		{"%G0", "%G0"},
+		{"系统", "系统"},
+		{"Hello, 世界", "Hello, 世界"},
+
+		{string([]byte{0xff, 0xfe, 0xfd}), "���"},
+		{string([]byte{0xff}) + "Hello" + string([]byte{0xfe}) + "世界" + string([]byte{0xfd}), "�Hello�世界�"},
+	}
+	for _, tC := range testCases {
+		got := decodeGrpcMessage(encodeGrpcMessage(tC.orig))
+		if got != tC.want {
+			t.Errorf("decodeGrpcMessage(encodeGrpcMessage(%q)) = %q, want %q", tC.orig, got, tC.want)
 		}
 	}
 }