Minor fixes

This commit is contained in:
Menghan Li
2016-06-20 13:32:57 -07:00
parent 451a2e416b
commit 1302eb9c41
3 changed files with 36 additions and 51 deletions

View File

@ -4,15 +4,15 @@ Package reflection implements server reflection service.
The service implemented is defined in: https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
To install server reflection on a gRPC server:
To register server reflection on a gRPC server:
```go
import "google.golang.org/grpc/reflection"
s := grpc.NewServer()
pb.RegisterYourOwnServer(s, &server{})
// Install reflection service on gRPC server.
reflection.InstallOnServer(s)
// Register reflection service on gRPC server.
reflection.Register(s)
s.Serve(lis)
```

View File

@ -37,14 +37,14 @@ Package reflection implements server reflection service.
The service implemented is defined in:
https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
To install server reflection on a gRPC server:
To register server reflection on a gRPC server:
import "google.golang.org/grpc/reflection"
s := grpc.NewServer()
pb.RegisterYourOwnServer(s, &server{})
// Install reflection service on gRPC server.
reflection.InstallOnServer(s)
// Register reflection service on gRPC server.
reflection.Register(s)
s.Serve(lis)
@ -71,14 +71,17 @@ type serverReflectionServer struct {
// TODO add cache if necessary
}
// InstallOnServer installs server reflection service on the given gRPC server.
func InstallOnServer(s *grpc.Server) {
// Register registers the server reflection service on the given gRPC server.
func Register(s *grpc.Server) {
rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
s: s,
})
}
// protoMessage is the interface representing objects with function Descriptor().
// protoMessage is used for type assertion on proto messages.
// Generated proto message implements function Descriptor(), but Descriptor()
// is not part of interface proto.Message. This interface is needed to
// call Descriptor().
type protoMessage interface {
Descriptor() ([]byte, []int)
}
@ -92,19 +95,15 @@ func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDesc
}
enc, _ := m.Descriptor()
fd, err := s.decodeFileDesc(enc)
if err != nil {
return nil, err
}
return fd, nil
return s.decodeFileDesc(enc)
}
// decodeFileDesc does decompression and unmarshalling on the given
// file descriptor byte slice.
func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
raw := decompress(enc)
if raw == nil {
return nil, fmt.Errorf("failed to decompress enc")
raw, err := decompress(enc)
if err != nil {
return nil, fmt.Errorf("failed to decompress enc: %v", err)
}
fd := new(dpb.FileDescriptorProto)
@ -115,18 +114,16 @@ func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptor
}
// decompress does gzip decompression.
func decompress(b []byte) []byte {
func decompress(b []byte) ([]byte, error) {
r, err := gzip.NewReader(bytes.NewReader(b))
if err != nil {
fmt.Printf("bad gzipped descriptor: %v\n", err)
return nil
return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err)
}
out, err := ioutil.ReadAll(r)
if err != nil {
fmt.Printf("bad gzipped descriptor: %v\n", err)
return nil
return nil, fmt.Errorf("bad gzipped descriptor: %v\n", err)
}
return out
return out, nil
}
func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) {
@ -159,11 +156,7 @@ func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ex
extT := reflect.TypeOf(extDesc.ExtensionType).Elem()
fd, err := s.fileDescForType(extT)
if err != nil {
return nil, err
}
return fd, nil
return s.fileDescForType(extT)
}
func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
@ -173,20 +166,16 @@ func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]
}
exts := proto.RegisteredExtensions(m)
out := make([]int32, len(exts))
i := 0
out := make([]int32, 0, len(exts))
for id := range exts {
out[i] = id
i++
out = append(out, id)
}
return out, nil
}
// Following are helper functions for reflection service handler.
// fileDescWireFormatByFilename finds the file descriptor for given filename,
// fileDescEncodingByFilename finds the file descriptor for given filename,
// does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]byte, error) {
func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
enc := proto.FileDescriptor(name)
if enc == nil {
return nil, fmt.Errorf("unknown file: %v", name)
@ -202,10 +191,10 @@ func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]by
return b, nil
}
// fileDescWireFormatContainingSymbol finds the file descriptor containing the given symbol,
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
// does marshalling on it and returns the marshalled result.
// The given symbol can be a type, a service or a method.
func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string) ([]byte, error) {
func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
var (
fd *dpb.FileDescriptorProto
)
@ -239,9 +228,9 @@ func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string)
return nil, fmt.Errorf("unknown symbol: %v", name)
}
// fileDescWireFormatContainingExtension finds the file descriptor containing given extension,
// fileDescEncodingContainingExtension finds the file descriptor containing given extension,
// does marshalling on it and returns the marshalled result.
func (s *serverReflectionServer) fileDescWireFormatContainingExtension(typeName string, extNum int32) ([]byte, error) {
func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
st, err := s.typeForName(typeName)
if err != nil {
return nil, err
@ -288,7 +277,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
switch req := in.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
b, err := s.fileDescWireFormatByFilename(req.FileByFilename)
b, err := s.fileDescEncodingByFilename(req.FileByFilename)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
@ -302,7 +291,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
b, err := s.fileDescWireFormatContainingSymbol(req.FileContainingSymbol)
b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
@ -318,7 +307,7 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
case *rpb.ServerReflectionRequest_FileContainingExtension:
typeName := req.FileContainingExtension.ContainingType
extNum := req.FileContainingExtension.ExtensionNumber
b, err := s.fileDescWireFormatContainingExtension(typeName, extNum)
b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
if err != nil {
out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{

View File

@ -128,10 +128,6 @@ func TestAllExtensionNumbersForType(t *testing.T) {
// Do end2end tests.
var (
port = ":35764"
)
type server struct{}
func (s *server) Search(ctx context.Context, in *pb.SearchRequest) (*pb.SearchResponse, error) {
@ -144,18 +140,18 @@ func (s *server) StreamingSearch(stream pb.SearchService_StreamingSearchServer)
func TestEnd2end(t *testing.T) {
// Start server.
lis, err := net.Listen("tcp", port)
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer()
pb.RegisterSearchServiceServer(s, &server{})
// Install reflection service on s.
InstallOnServer(s)
// Register reflection service on s.
Register(s)
go s.Serve(lis)
// Create client.
conn, err := grpc.Dial("localhost"+port, grpc.WithInsecure())
conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure())
if err != nil {
t.Fatalf("cannot connect to server: %v", err)
}