Minor fixes

This commit is contained in:
Menghan Li
2016-06-20 13:32:57 -07:00
parent 451a2e416b
commit 1302eb9c41
3 changed files with 36 additions and 51 deletions

View File

@ -4,15 +4,15 @@ Package reflection implements server reflection service.
The service implemented is defined in: https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. The service implemented is defined in: https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
To install server reflection on a gRPC server: To register server reflection on a gRPC server:
```go ```go
import "google.golang.org/grpc/reflection" import "google.golang.org/grpc/reflection"
s := grpc.NewServer() s := grpc.NewServer()
pb.RegisterYourOwnServer(s, &server{}) pb.RegisterYourOwnServer(s, &server{})
// Install reflection service on gRPC server. // Register reflection service on gRPC server.
reflection.InstallOnServer(s) reflection.Register(s)
s.Serve(lis) s.Serve(lis)
``` ```

View File

@ -37,14 +37,14 @@ Package reflection implements server reflection service.
The service implemented is defined in: The service implemented is defined in:
https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
To install server reflection on a gRPC server: To register server reflection on a gRPC server:
import "google.golang.org/grpc/reflection" import "google.golang.org/grpc/reflection"
s := grpc.NewServer() s := grpc.NewServer()
pb.RegisterYourOwnServer(s, &server{}) pb.RegisterYourOwnServer(s, &server{})
// Install reflection service on gRPC server. // Register reflection service on gRPC server.
reflection.InstallOnServer(s) reflection.Register(s)
s.Serve(lis) s.Serve(lis)
@ -71,14 +71,17 @@ type serverReflectionServer struct {
// TODO add cache if necessary // TODO add cache if necessary
} }
// InstallOnServer installs server reflection service on the given gRPC server. // Register registers the server reflection service on the given gRPC server.
func InstallOnServer(s *grpc.Server) { func Register(s *grpc.Server) {
rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
s: s, s: s,
}) })
} }
// protoMessage is the interface representing objects with function Descriptor(). // protoMessage is used for type assertion on proto messages.
// Generated proto message implements function Descriptor(), but Descriptor()
// is not part of interface proto.Message. This interface is needed to
// call Descriptor().
type protoMessage interface { type protoMessage interface {
Descriptor() ([]byte, []int) Descriptor() ([]byte, []int)
} }
@ -92,19 +95,15 @@ func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDesc
} }
enc, _ := m.Descriptor() enc, _ := m.Descriptor()
fd, err := s.decodeFileDesc(enc) return s.decodeFileDesc(enc)
if err != nil {
return nil, err
}
return fd, nil
} }
// decodeFileDesc does decompression and unmarshalling on the given // decodeFileDesc does decompression and unmarshalling on the given
// file descriptor byte slice. // file descriptor byte slice.
func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
raw := decompress(enc) raw, err := decompress(enc)
if raw == nil { if err != nil {
return nil, fmt.Errorf("failed to decompress enc") return nil, fmt.Errorf("failed to decompress enc: %v", err)
} }
fd := new(dpb.FileDescriptorProto) fd := new(dpb.FileDescriptorProto)
@ -115,18 +114,16 @@ func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptor
} }
// decompress does gzip decompression. // decompress does gzip decompression.
func decompress(b []byte) []byte { func decompress(b []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(b)) r, err := gzip.NewReader(bytes.NewReader(b))
if err != nil { if err != nil {
fmt.Printf("bad gzipped descriptor: %v\n", err) return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err)
return nil
} }
out, err := ioutil.ReadAll(r) out, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
fmt.Printf("bad gzipped descriptor: %v\n", err) return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err)
return nil
} }
return out return out, nil
} }
func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) { func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) {
@ -159,11 +156,7 @@ func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ex
extT := reflect.TypeOf(extDesc.ExtensionType).Elem() extT := reflect.TypeOf(extDesc.ExtensionType).Elem()
fd, err := s.fileDescForType(extT) return s.fileDescForType(extT)
if err != nil {
return nil, err
}
return fd, nil
} }
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
@ -173,20 +166,16 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]
} }
exts := proto.RegisteredExtensions(m) exts := proto.RegisteredExtensions(m)
out := make([]int32, len(exts)) out := make([]int32, 0, len(exts))
i := 0
for id := range exts { for id := range exts {
out[i] = id out = append(out, id)
i++
} }
return out, nil return out, nil
} }
// Following are helper functions for reflection service handler. // fileDescEncodingByFilename finds the file descriptor for given filename,
// fileDescWireFormatByFilename finds the file descriptor for given filename,
// does marshalling on it and returns the marshalled result. // does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]byte, error) { func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
enc := proto.FileDescriptor(name) enc := proto.FileDescriptor(name)
if enc == nil { if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name) return nil, fmt.Errorf("unknown file: %v", name)
@ -202,10 +191,10 @@ func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]by
return b, nil return b, nil
} }
// fileDescWireFormatContainingSymbol finds the file descriptor containing the given symbol, // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
// does marshalling on it and returns the marshalled result. // does marshalling on it and returns the marshalled result.
// The given symbol can be a type, a service or a method. // The given symbol can be a type, a service or a method.
func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string) ([]byte, error) { func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
var ( var (
fd *dpb.FileDescriptorProto fd *dpb.FileDescriptorProto
) )
@ -239,9 +228,9 @@ func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string)
return nil, fmt.Errorf("unknown symbol: %v", name) return nil, fmt.Errorf("unknown symbol: %v", name)
} }
// fileDescWireFormatContainingExtension finds the file descriptor containing given extension, // fileDescEncodingContainingExtension finds the file descriptor containing given extension,
// does marshalling on it and returns the marshalled result. // does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescWireFormatContainingExtension(typeName string, extNum int32) ([]byte, error) { func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
st, err := s.typeForName(typeName) st, err := s.typeForName(typeName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -288,7 +277,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
} }
switch req := in.MessageRequest.(type) { switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename: case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescWireFormatByFilename(req.FileByFilename) b, err := s.fileDescEncodingByFilename(req.FileByFilename)
if err != nil { if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{
@ -302,7 +291,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
} }
} }
case *rpb.ServerReflectionRequest_FileContainingSymbol: case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescWireFormatContainingSymbol(req.FileContainingSymbol) b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
if err != nil { if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{
@ -318,7 +307,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
case *rpb.ServerReflectionRequest_FileContainingExtension: case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescWireFormatContainingExtension(typeName, extNum) b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
if err != nil { if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{ ErrorResponse: &rpb.ErrorResponse{

View File

@ -128,10 +128,6 @@ func TestAllExtensionNumbersForType(t *testing.T) {
// Do end2end tests. // Do end2end tests.
var (
port = ":35764"
)
type server struct{} type server struct{}
func (s *server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.SearchResponse, error) { func (s *server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.SearchResponse, error) {
@ -144,18 +140,18 @@ func (s *server) StreamingSearch(stream pb.SearchService_StreamingSearchServer)
func TestEnd2end(t *testing.T) { func TestEnd2end(t *testing.T) {
// Start server. // Start server.
lis, err := net.Listen("tcp", port) lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
t.Fatalf("failed to listen: %v", err) t.Fatalf("failed to listen: %v", err)
} }
s := grpc.NewServer() s := grpc.NewServer()
pb.RegisterSearchServiceServer(s, &server{}) pb.RegisterSearchServiceServer(s, &server{})
// Install reflection service on s. // Register reflection service on s.
InstallOnServer(s) Register(s)
go s.Serve(lis) go s.Serve(lis)
// Create client. // Create client.
conn, err := grpc.Dial("localhost"+port, grpc.WithInsecure()) conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil { if err != nil {
t.Fatalf("cannot connect to server: %v", err) t.Fatalf("cannot connect to server: %v", err)
} }