Tune metadata package to simplify the code.
This commit is contained in:
@ -146,11 +146,10 @@ func Invoke(ctx context.Context, method string, args, reply proto.Message, cc *C
|
||||
return toRPCErr(err)
|
||||
}
|
||||
// Try to acquire header metadata from the server if there is any.
|
||||
m, err := stream.Header()
|
||||
c.headerMD, err = stream.Header()
|
||||
if err != nil {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
c.headerMD = metadata.New(m)
|
||||
// Receive the response
|
||||
lastErr = recv(stream, reply)
|
||||
if _, ok := lastErr.(transport.ConnectionError); ok {
|
||||
|
@ -86,16 +86,6 @@ func DecodeKeyValue(k, v string) (string, string, error) {
|
||||
// MD is a mapping from metadata keys to values.
|
||||
type MD map[string]string
|
||||
|
||||
// New creates a MD from given key-value map.
|
||||
func New(m map[string]string) MD {
|
||||
md := MD{}
|
||||
for k, v := range m {
|
||||
key, val := encodeKeyValue(k, v)
|
||||
md[key] = val
|
||||
}
|
||||
return md
|
||||
}
|
||||
|
||||
// Pairs returns an MD formed by the mapping of key, value ...
|
||||
// Pairs panics if len(kv) is odd.
|
||||
func Pairs(kv ...string) MD {
|
||||
@ -115,18 +105,18 @@ func Pairs(kv ...string) MD {
|
||||
return md
|
||||
}
|
||||
|
||||
// Len returns the length of md.
|
||||
// Len returns the number of items in md.
|
||||
func (md MD) Len() int {
|
||||
return len(md)
|
||||
}
|
||||
|
||||
// Copy returns a copy of md's mapping.
|
||||
func (md MD) Copy() map[string]string {
|
||||
m := make(map[string]string)
|
||||
// Copy returns a copy of md.
|
||||
func (md MD) Copy() MD {
|
||||
out := MD{}
|
||||
for k, v := range md {
|
||||
m[k] = v
|
||||
out[k] = v
|
||||
}
|
||||
return m
|
||||
return out
|
||||
}
|
||||
|
||||
type mdKey struct{}
|
||||
|
@ -110,11 +110,7 @@ func (cs *clientStream) Context() context.Context {
|
||||
// is any. Empty metadata.MD is returned if there is no header metadata.
|
||||
// It blocks if the metadata is not ready to read.
|
||||
func (cs *clientStream) Header() (md metadata.MD, err error) {
|
||||
m, err := cs.s.Header()
|
||||
if err != nil {
|
||||
return md, err
|
||||
}
|
||||
return metadata.New(m), nil
|
||||
return cs.s.Header()
|
||||
}
|
||||
|
||||
// Trailer returns the trailer metadata from the server. It must be called
|
||||
|
@ -56,7 +56,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
testMetadata = map[string]string{
|
||||
testMetadata = metadata.MD{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
@ -208,18 +208,17 @@ func TestMetadataUnaryRPC(t *testing.T) {
|
||||
Dividend: proto.Int64(8),
|
||||
Divisor: proto.Int64(2),
|
||||
}
|
||||
md := metadata.New(testMetadata)
|
||||
ctx := metadata.NewContext(context.Background(), md)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
var header, trailer metadata.MD
|
||||
_, err := mc.Div(ctx, args, rpc.Header(&header), rpc.Trailer(&trailer))
|
||||
if err != nil {
|
||||
t.Fatalf("mathClient.Div(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||
}
|
||||
if !reflect.DeepEqual(testMetadata, header.Copy()) {
|
||||
t.Fatalf("Received header metadata %v, want %v", header.Copy(), testMetadata)
|
||||
if !reflect.DeepEqual(testMetadata, header) {
|
||||
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
|
||||
}
|
||||
if !reflect.DeepEqual(testMetadata, trailer.Copy()) {
|
||||
t.Fatalf("Received trailer metadata %v, want %v", trailer.Copy(), testMetadata)
|
||||
if !reflect.DeepEqual(testMetadata, trailer) {
|
||||
t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata)
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,7 +233,7 @@ func performOneRPC(t *testing.T, mc testpb.MathClient, wg *sync.WaitGroup) {
|
||||
Remainder: proto.Int64(2),
|
||||
}
|
||||
if err != nil || !proto.Equal(reply, want) {
|
||||
t.Fatalf(`mathClient.Div(_, _) = %v, %v; want %v, <nil>`, reply, err, want)
|
||||
t.Errorf(`mathClient.Div(_, _) = %v, %v; want %v, <nil>`, reply, err, want)
|
||||
}
|
||||
wg.Done()
|
||||
}
|
||||
@ -322,7 +321,7 @@ func TestBidiStreaming(t *testing.T) {
|
||||
go func() {
|
||||
for _, args := range parseArgs(test.divs) {
|
||||
if err := stream.Send(args); err != nil {
|
||||
t.Fatal("Send failed: ", err)
|
||||
t.Errorf("Send failed: ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -367,25 +366,24 @@ func parseArgs(divs []string) (args []*testpb.DivArgs) {
|
||||
func TestMetadataStreamingRPC(t *testing.T) {
|
||||
s, mc := setUp(true, math.MaxUint32)
|
||||
defer s.Stop()
|
||||
md := metadata.New(testMetadata)
|
||||
ctx := metadata.NewContext(context.Background(), md)
|
||||
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||
stream, err := mc.DivMany(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create stream %v", err)
|
||||
}
|
||||
go func() {
|
||||
headerMD, err := stream.Header()
|
||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD.Copy()) {
|
||||
t.Fatalf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
||||
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||
}
|
||||
// test the cached value.
|
||||
headerMD, err = stream.Header()
|
||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD.Copy()) {
|
||||
t.Fatalf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
|
||||
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
|
||||
}
|
||||
for _, args := range parseArgs([]string{"1/1", "3/2", "2/3"}) {
|
||||
if err := stream.Send(args); err != nil {
|
||||
t.Fatalf("%v.Send(_) failed: %v", stream, err)
|
||||
t.Errorf("%v.Send(_) failed: %v", stream, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -399,7 +397,7 @@ func TestMetadataStreamingRPC(t *testing.T) {
|
||||
}
|
||||
}
|
||||
trailerMD := stream.Trailer()
|
||||
if !reflect.DeepEqual(testMetadata, trailerMD.Copy()) {
|
||||
if !reflect.DeepEqual(testMetadata, trailerMD) {
|
||||
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
|
||||
}
|
||||
}
|
||||
|
@ -230,7 +230,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
|
||||
}
|
||||
if md, ok := metadata.FromContext(ctx); ok {
|
||||
for k, v := range md.Copy() {
|
||||
for k, v := range md {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
|
||||
s.ctx = newContextWithStream(s.ctx, s)
|
||||
// Attach the received metadata to the context.
|
||||
if len(hDec.state.mdata) > 0 {
|
||||
s.ctx = metadata.NewContext(s.ctx, metadata.New(hDec.state.mdata))
|
||||
s.ctx = metadata.NewContext(s.ctx, hDec.state.mdata)
|
||||
}
|
||||
|
||||
s.dec = &recvBufferReader{
|
||||
@ -404,7 +404,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
|
||||
t.hBuf.Reset()
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
|
||||
for k, v := range md.Copy() {
|
||||
for k, v := range md {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
if err := t.writeHeaders(s, t.hBuf, false); err != nil {
|
||||
|
@ -181,9 +181,9 @@ type Stream struct {
|
||||
// Close headerChan to indicate the end of reception of header metadata.
|
||||
headerChan chan struct{}
|
||||
// header caches the received header metadata.
|
||||
header map[string]string
|
||||
header metadata.MD
|
||||
// The key-value map of trailer metadata.
|
||||
trailer map[string]string
|
||||
trailer metadata.MD
|
||||
|
||||
mu sync.RWMutex
|
||||
// headerOK becomes true from the first header is about to send.
|
||||
@ -200,12 +200,12 @@ type Stream struct {
|
||||
// Header acquires the key-value pairs of header metadata once it
|
||||
// is available. It blocks until i) the metadata is ready or ii) there is no
|
||||
// header metadata or iii) the stream is cancelled/expired.
|
||||
func (s *Stream) Header() (map[string]string, error) {
|
||||
func (s *Stream) Header() (metadata.MD, error) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return nil, ContextErr(s.ctx.Err())
|
||||
case <-s.headerChan:
|
||||
return s.header, nil
|
||||
return s.header.Copy(), nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ func (s *Stream) Header() (map[string]string, error) {
|
||||
func (s *Stream) Trailer() metadata.MD {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return metadata.New(s.trailer)
|
||||
return s.trailer.Copy()
|
||||
}
|
||||
|
||||
// ServerTransport returns the underlying ServerTransport for the stream.
|
||||
|
Reference in New Issue
Block a user