diff --git a/transport/http_util.go b/transport/http_util.go index be4925e8..ea363d93 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -35,6 +35,7 @@ package transport import ( "bufio" + "bytes" "fmt" "io" "net" @@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) { return d * time.Duration(t), nil } +const ( + spaceByte = ' ' + tildaByte = '~' + percentByte = '%' +) + +// grpcMessageEncode is used to encode status code in header field +// "grpc-message". +// It checks to see if each individual byte in msg is an +// allowable byte, and then either percent encoding or passing it through. +// When percent encoding, the byte is converted into hexadecimal notation +// with a '%' prepended. +func grpcMessageEncode(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if !(c >= spaceByte && c < tildaByte && c != percentByte) { + return grpcMessageEncodeUnchecked(msg) + } + } + return msg +} + +func grpcMessageEncodeUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c >= spaceByte && c < tildaByte && c != percentByte { + buf.WriteByte(c) + } else { + buf.WriteString(fmt.Sprintf("%%%02X", c)) + } + } + return buf.String() +} + +// grpcMessageDecode decodes the msg encoded by grpcMessageEncode. +func grpcMessageDecode(msg string) string { + if msg == "" { + return "" + } + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + if msg[i] == percentByte && i+2 < lenMsg { + return grpcMessageDecodeUnchecked(msg) + } + } + return msg +} + +func grpcMessageDecodeUnchecked(msg string) string { + var buf bytes.Buffer + lenMsg := len(msg) + for i := 0; i < lenMsg; i++ { + c := msg[i] + if c == percentByte && i+2 < lenMsg { + parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8) + if err != nil { + buf.WriteByte(c) + } else { + buf.WriteByte(byte(parsed)) + i += 2 + } + } else { + buf.WriteByte(c) + } + } + return buf.String() +} + type framer struct { numWriters int32 reader io.Reader diff --git a/transport/http_util_test.go b/transport/http_util_test.go index 279acbc5..62519490 100644 --- a/transport/http_util_test.go +++ b/transport/http_util_test.go @@ -107,3 +107,29 @@ func TestValidContentType(t *testing.T) { } } } + +func TestGrpcMessageEncode(t *testing.T) { + testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00") + testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25") +} + +func TestGrpcMessageDecode(t *testing.T) { + testGrpcMessageDecode(t, "Hello", "Hello") + testGrpcMessageDecode(t, "H%61o", "Hao") + testGrpcMessageDecode(t, "H%6", "H%6") + testGrpcMessageDecode(t, "%G0", "%G0") +} + +func testGrpcMessageEncode(t *testing.T, input string, expected string) { + actual := grpcMessageEncode(input) + if expected != actual { + t.Errorf("grpcMessageEncode(%v) = %v, want %v", input, actual, expected) + } +} + +func testGrpcMessageDecode(t *testing.T, input string, expected string) { + actual := grpcMessageDecode(input) + if expected != actual { + t.Errorf("grpcMessageDncode(%v) = %v, want %v", input, actual, expected) + } +} diff --git a/transport/transport.go b/transport/transport.go index 84a3269a..9dade654 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -43,7 +43,6 @@ import ( "fmt" "io" "net" - "strconv" "sync" "time" @@ -559,77 +558,3 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <- return i, nil } } - -const ( - spaceByte = ' ' - tildaByte = '~' - percentByte = '%' -) - -// grpcMessageEncode is used to encode status code in header field -// "grpc-message". -// It checks to see if each individual byte in msg is an -// allowable byte, and then either percent encoding or passing it through. -// When percent encoding, the byte is converted into hexadecimal notation -// with a '%' prepended. -func grpcMessageEncode(msg string) string { - if msg == "" { - return "" - } - lenMsg := len(msg) - for i := 0; i < lenMsg; i++ { - c := msg[i] - if !(c >= spaceByte && c < tildaByte && c != percentByte) { - return grpcMessageEncodeUnchecked(msg) - } - } - return msg -} - -func grpcMessageEncodeUnchecked(msg string) string { - var buf bytes.Buffer - lenMsg := len(msg) - for i := 0; i < lenMsg; i++ { - c := msg[i] - if c >= spaceByte && c < tildaByte && c != percentByte { - _ = buf.WriteByte(c) - } else { - _, _ = buf.WriteString(fmt.Sprintf("%%%02X", c)) - } - } - return buf.String() -} - -// grpcMessageDecode decodes the msg encoded by grpcMessageEncode. -func grpcMessageDecode(msg string) string { - if msg == "" { - return "" - } - lenMsg := len(msg) - for i := 0; i < lenMsg; i++ { - if msg[i] == percentByte && i+2 < lenMsg { - return grpcMessageDecodeUnchecked(msg) - } - } - return msg -} - -func grpcMessageDecodeUnchecked(msg string) string { - var buf bytes.Buffer - lenMsg := len(msg) - for i := 0; i < lenMsg; i++ { - c := msg[i] - if c == percentByte && i+2 < lenMsg { - parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8) - if err != nil { - _ = buf.WriteByte(c) - } else { - _ = buf.WriteByte(byte(parsed)) - i += 2 - } - } else { - _ = buf.WriteByte(c) - } - } - return buf.String() -} diff --git a/transport/transport_test.go b/transport/transport_test.go index 6e38fe17..6ad0c782 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -769,29 +769,3 @@ func TestIsReservedHeader(t *testing.T) { } } } - -func TestGrpcMessageEncode(t *testing.T) { - testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00") - testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25") -} - -func TestGrpcMessageDecode(t *testing.T) { - testGrpcMessageDecode(t, "Hello", "Hello") - testGrpcMessageDecode(t, "H%61o", "Hao") - testGrpcMessageDecode(t, "H%6", "H%6") - testGrpcMessageDecode(t, "%G0", "%G0") -} - -func testGrpcMessageEncode(t *testing.T, input string, expected string) { - actual := grpcMessageEncode(input) - if expected != actual { - t.Errorf("grpcMessageEncode(%v) = %v, want %v", input, actual, expected) - } -} - -func testGrpcMessageDecode(t *testing.T, input string, expected string) { - actual := grpcMessageDecode(input) - if expected != actual { - t.Errorf("grpcMessageDncode(%v) = %v, want %v", input, actual, expected) - } -}