Merge pull request #20 from iamqizhao/master

Tune metadata package a bit.
This commit is contained in:
Qi Zhao
2015-01-28 16:08:37 -08:00
8 changed files with 32 additions and 49 deletions

View File

@ -146,11 +146,10 @@ func Invoke(ctx context.Context, method string, args, reply proto.Message, cc *C
return toRPCErr(err) return toRPCErr(err)
} }
// Try to acquire header metadata from the server if there is any. // Try to acquire header metadata from the server if there is any.
m, err := stream.Header() c.headerMD, err = stream.Header()
if err != nil { if err != nil {
return toRPCErr(err) return toRPCErr(err)
} }
c.headerMD = metadata.New(m)
// Receive the response // Receive the response
lastErr = recv(stream, reply) lastErr = recv(stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok { if _, ok := lastErr.(transport.ConnectionError); ok {

View File

@ -86,16 +86,6 @@ func DecodeKeyValue(k, v string) (string, string, error) {
// MD is a mapping from metadata keys to values. // MD is a mapping from metadata keys to values.
type MD map[string]string type MD map[string]string
// New creates a MD from given key-value map.
func New(m map[string]string) MD {
md := MD{}
for k, v := range m {
key, val := encodeKeyValue(k, v)
md[key] = val
}
return md
}
// Pairs returns an MD formed by the mapping of key, value ... // Pairs returns an MD formed by the mapping of key, value ...
// Pairs panics if len(kv) is odd. // Pairs panics if len(kv) is odd.
func Pairs(kv ...string) MD { func Pairs(kv ...string) MD {
@ -115,18 +105,18 @@ func Pairs(kv ...string) MD {
return md return md
} }
// Len returns the length of md. // Len returns the number of items in md.
func (md MD) Len() int { func (md MD) Len() int {
return len(md) return len(md)
} }
// Copy returns a copy of md's mapping. // Copy returns a copy of md.
func (md MD) Copy() map[string]string { func (md MD) Copy() MD {
m := make(map[string]string) out := MD{}
for k, v := range md { for k, v := range md {
m[k] = v out[k] = v
} }
return m return out
} }
type mdKey struct{} type mdKey struct{}

View File

@ -69,7 +69,7 @@ func TestPairsMD(t *testing.T) {
md MD md MD
}{ }{
{[]string{}, MD{}}, {[]string{}, MD{}},
{[]string{"k1", "v1", "k2", binaryValue}, New(map[string]string{ {[]string{"k1", "v1", "k2", binaryValue}, MD(map[string]string{
"k1": "v1", "k1": "v1",
"k2-bin": "woA=", "k2-bin": "woA=",
})}, })},

View File

@ -110,11 +110,7 @@ func (cs *clientStream) Context() context.Context {
// is any. Empty metadata.MD is returned if there is no header metadata. // is any. Empty metadata.MD is returned if there is no header metadata.
// It blocks if the metadata is not ready to read. // It blocks if the metadata is not ready to read.
func (cs *clientStream) Header() (md metadata.MD, err error) { func (cs *clientStream) Header() (md metadata.MD, err error) {
m, err := cs.s.Header() return cs.s.Header()
if err != nil {
return md, err
}
return metadata.New(m), nil
} }
// Trailer returns the trailer metadata from the server. It must be called // Trailer returns the trailer metadata from the server. It must be called

View File

@ -56,7 +56,7 @@ import (
) )
var ( var (
testMetadata = map[string]string{ testMetadata = metadata.MD{
"key1": "value1", "key1": "value1",
"key2": "value2", "key2": "value2",
} }
@ -208,18 +208,17 @@ func TestMetadataUnaryRPC(t *testing.T) {
Dividend: proto.Int64(8), Dividend: proto.Int64(8),
Divisor: proto.Int64(2), Divisor: proto.Int64(2),
} }
md := metadata.New(testMetadata) ctx := metadata.NewContext(context.Background(), testMetadata)
ctx := metadata.NewContext(context.Background(), md)
var header, trailer metadata.MD var header, trailer metadata.MD
_, err := mc.Div(ctx, args, rpc.Header(&header), rpc.Trailer(&trailer)) _, err := mc.Div(ctx, args, rpc.Header(&header), rpc.Trailer(&trailer))
if err != nil { if err != nil {
t.Fatalf("mathClient.Div(%v, _, _, _) = _, %v; want _, <nil>", ctx, err) t.Fatalf("mathClient.Div(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
} }
if !reflect.DeepEqual(testMetadata, header.Copy()) { if !reflect.DeepEqual(testMetadata, header) {
t.Fatalf("Received header metadata %v, want %v", header.Copy(), testMetadata) t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
} }
if !reflect.DeepEqual(testMetadata, trailer.Copy()) { if !reflect.DeepEqual(testMetadata, trailer) {
t.Fatalf("Received trailer metadata %v, want %v", trailer.Copy(), testMetadata) t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata)
} }
} }
@ -234,7 +233,7 @@ func performOneRPC(t *testing.T, mc testpb.MathClient, wg *sync.WaitGroup) {
Remainder: proto.Int64(2), Remainder: proto.Int64(2),
} }
if err != nil || !proto.Equal(reply, want) { if err != nil || !proto.Equal(reply, want) {
t.Fatalf(`mathClient.Div(_, _) = %v, %v; want %v, <nil>`, reply, err, want) t.Errorf(`mathClient.Div(_, _) = %v, %v; want %v, <nil>`, reply, err, want)
} }
wg.Done() wg.Done()
} }
@ -322,7 +321,7 @@ func TestBidiStreaming(t *testing.T) {
go func() { go func() {
for _, args := range parseArgs(test.divs) { for _, args := range parseArgs(test.divs) {
if err := stream.Send(args); err != nil { if err := stream.Send(args); err != nil {
t.Fatal("Send failed: ", err) t.Errorf("Send failed: ", err)
return return
} }
} }
@ -367,25 +366,24 @@ func parseArgs(divs []string) (args []*testpb.DivArgs) {
func TestMetadataStreamingRPC(t *testing.T) { func TestMetadataStreamingRPC(t *testing.T) {
s, mc := setUp(true, math.MaxUint32) s, mc := setUp(true, math.MaxUint32)
defer s.Stop() defer s.Stop()
md := metadata.New(testMetadata) ctx := metadata.NewContext(context.Background(), testMetadata)
ctx := metadata.NewContext(context.Background(), md)
stream, err := mc.DivMany(ctx) stream, err := mc.DivMany(ctx)
if err != nil { if err != nil {
t.Fatalf("Failed to create stream %v", err) t.Fatalf("Failed to create stream %v", err)
} }
go func() { go func() {
headerMD, err := stream.Header() headerMD, err := stream.Header()
if err != nil || !reflect.DeepEqual(testMetadata, headerMD.Copy()) { if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Fatalf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata) t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
} }
// test the cached value. // test the cached value.
headerMD, err = stream.Header() headerMD, err = stream.Header()
if err != nil || !reflect.DeepEqual(testMetadata, headerMD.Copy()) { if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Fatalf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata) t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
} }
for _, args := range parseArgs([]string{"1/1", "3/2", "2/3"}) { for _, args := range parseArgs([]string{"1/1", "3/2", "2/3"}) {
if err := stream.Send(args); err != nil { if err := stream.Send(args); err != nil {
t.Fatalf("%v.Send(_) failed: %v", stream, err) t.Errorf("%v.Send(_) failed: %v", stream, err)
return return
} }
} }
@ -399,7 +397,7 @@ func TestMetadataStreamingRPC(t *testing.T) {
} }
} }
trailerMD := stream.Trailer() trailerMD := stream.Trailer()
if !reflect.DeepEqual(testMetadata, trailerMD.Copy()) { if !reflect.DeepEqual(testMetadata, trailerMD) {
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata) t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
} }
} }

View File

@ -230,7 +230,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
} }
if md, ok := metadata.FromContext(ctx); ok { if md, ok := metadata.FromContext(ctx); ok {
for k, v := range md.Copy() { for k, v := range md {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
} }
} }

View File

@ -167,7 +167,7 @@ func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame header
s.ctx = newContextWithStream(s.ctx, s) s.ctx = newContextWithStream(s.ctx, s)
// Attach the received metadata to the context. // Attach the received metadata to the context.
if len(hDec.state.mdata) > 0 { if len(hDec.state.mdata) > 0 {
s.ctx = metadata.NewContext(s.ctx, metadata.New(hDec.state.mdata)) s.ctx = metadata.NewContext(s.ctx, hDec.state.mdata)
} }
s.dec = &recvBufferReader{ s.dec = &recvBufferReader{
@ -404,7 +404,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
t.hBuf.Reset() t.hBuf.Reset()
t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
for k, v := range md.Copy() { for k, v := range md {
t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
} }
if err := t.writeHeaders(s, t.hBuf, false); err != nil { if err := t.writeHeaders(s, t.hBuf, false); err != nil {

View File

@ -181,9 +181,9 @@ type Stream struct {
// Close headerChan to indicate the end of reception of header metadata. // Close headerChan to indicate the end of reception of header metadata.
headerChan chan struct{} headerChan chan struct{}
// header caches the received header metadata. // header caches the received header metadata.
header map[string]string header metadata.MD
// The key-value map of trailer metadata. // The key-value map of trailer metadata.
trailer map[string]string trailer metadata.MD
mu sync.RWMutex mu sync.RWMutex
// headerOK becomes true from the first header is about to send. // headerOK becomes true from the first header is about to send.
@ -200,12 +200,12 @@ type Stream struct {
// Header acquires the key-value pairs of header metadata once it // Header acquires the key-value pairs of header metadata once it
// is available. It blocks until i) the metadata is ready or ii) there is no // is available. It blocks until i) the metadata is ready or ii) there is no
// header metadata or iii) the stream is cancelled/expired. // header metadata or iii) the stream is cancelled/expired.
func (s *Stream) Header() (map[string]string, error) { func (s *Stream) Header() (metadata.MD, error) {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil, ContextErr(s.ctx.Err()) return nil, ContextErr(s.ctx.Err())
case <-s.headerChan: case <-s.headerChan:
return s.header, nil return s.header.Copy(), nil
} }
} }
@ -215,7 +215,7 @@ func (s *Stream) Header() (map[string]string, error) {
func (s *Stream) Trailer() metadata.MD { func (s *Stream) Trailer() metadata.MD {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
return metadata.New(s.trailer) return s.trailer.Copy()
} }
// ServerTransport returns the underlying ServerTransport for the stream. // ServerTransport returns the underlying ServerTransport for the stream.