400 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			400 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /*
 | |
|  *
 | |
|  * Copyright 2016 gRPC authors.
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  *
 | |
|  */
 | |
| 
 | |
| //go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto
 | |
| 
 | |
| /*
 | |
| Package reflection implements server reflection service.
 | |
| 
 | |
| The service implemented is defined in:
 | |
| https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
 | |
| 
 | |
| To register server reflection on a gRPC server:
 | |
| 	import "google.golang.org/grpc/reflection"
 | |
| 
 | |
| 	s := grpc.NewServer()
 | |
| 	pb.RegisterYourOwnServer(s, &server{})
 | |
| 
 | |
| 	// Register reflection service on gRPC server.
 | |
| 	reflection.Register(s)
 | |
| 
 | |
| 	s.Serve(lis)
 | |
| 
 | |
| */
 | |
| package reflection // import "google.golang.org/grpc/reflection"
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"compress/gzip"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"io/ioutil"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/golang/protobuf/proto"
 | |
| 	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
 | |
| 	"google.golang.org/grpc"
 | |
| 	"google.golang.org/grpc/codes"
 | |
| 	rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
 | |
| 	"google.golang.org/grpc/status"
 | |
| )
 | |
| 
 | |
| type serverReflectionServer struct {
 | |
| 	s *grpc.Server
 | |
| 	// TODO add more cache if necessary
 | |
| 	serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo()
 | |
| }
 | |
| 
 | |
| // Register registers the server reflection service on the given gRPC server.
 | |
| func Register(s *grpc.Server) {
 | |
| 	rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
 | |
| 		s: s,
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // protoMessage is used for type assertion on proto messages.
 | |
| // Generated proto message implements function Descriptor(), but Descriptor()
 | |
| // is not part of interface proto.Message. This interface is needed to
 | |
| // call Descriptor().
 | |
| type protoMessage interface {
 | |
| 	Descriptor() ([]byte, []int)
 | |
| }
 | |
| 
 | |
| // fileDescForType gets the file descriptor for the given type.
 | |
| // The given type should be a proto message.
 | |
| func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
 | |
| 	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("failed to create message from type: %v", st)
 | |
| 	}
 | |
| 	enc, _ := m.Descriptor()
 | |
| 
 | |
| 	return s.decodeFileDesc(enc)
 | |
| }
 | |
| 
 | |
| // decodeFileDesc does decompression and unmarshalling on the given
 | |
| // file descriptor byte slice.
 | |
| func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
 | |
| 	raw, err := decompress(enc)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to decompress enc: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	fd := new(dpb.FileDescriptorProto)
 | |
| 	if err := proto.Unmarshal(raw, fd); err != nil {
 | |
| 		return nil, fmt.Errorf("bad descriptor: %v", err)
 | |
| 	}
 | |
| 	return fd, nil
 | |
| }
 | |
| 
 | |
| // decompress does gzip decompression.
 | |
| func decompress(b []byte) ([]byte, error) {
 | |
| 	r, err := gzip.NewReader(bytes.NewReader(b))
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
 | |
| 	}
 | |
| 	out, err := ioutil.ReadAll(r)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
 | |
| 	}
 | |
| 	return out, nil
 | |
| }
 | |
| 
 | |
| func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) {
 | |
| 	pt := proto.MessageType(name)
 | |
| 	if pt == nil {
 | |
| 		return nil, fmt.Errorf("unknown type: %q", name)
 | |
| 	}
 | |
| 	st := pt.Elem()
 | |
| 
 | |
| 	return st, nil
 | |
| }
 | |
| 
 | |
| func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
 | |
| 	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("failed to create message from type: %v", st)
 | |
| 	}
 | |
| 
 | |
| 	var extDesc *proto.ExtensionDesc
 | |
| 	for id, desc := range proto.RegisteredExtensions(m) {
 | |
| 		if id == ext {
 | |
| 			extDesc = desc
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if extDesc == nil {
 | |
| 		return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
 | |
| 	}
 | |
| 
 | |
| 	return s.decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
 | |
| }
 | |
| 
 | |
| func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
 | |
| 	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("failed to create message from type: %v", st)
 | |
| 	}
 | |
| 
 | |
| 	exts := proto.RegisteredExtensions(m)
 | |
| 	out := make([]int32, 0, len(exts))
 | |
| 	for id := range exts {
 | |
| 		out = append(out, id)
 | |
| 	}
 | |
| 	return out, nil
 | |
| }
 | |
| 
 | |
| // fileDescEncodingByFilename finds the file descriptor for given filename,
 | |
| // does marshalling on it and returns the marshalled result.
 | |
| func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
 | |
| 	enc := proto.FileDescriptor(name)
 | |
| 	if enc == nil {
 | |
| 		return nil, fmt.Errorf("unknown file: %v", name)
 | |
| 	}
 | |
| 	fd, err := s.decodeFileDesc(enc)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return proto.Marshal(fd)
 | |
| }
 | |
| 
 | |
| // serviceMetadataForSymbol finds the metadata for name in s.serviceInfo.
 | |
| // name should be a service name or a method name.
 | |
| func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) {
 | |
| 	if s.serviceInfo == nil {
 | |
| 		s.serviceInfo = s.s.GetServiceInfo()
 | |
| 	}
 | |
| 
 | |
| 	// Check if it's a service name.
 | |
| 	if info, ok := s.serviceInfo[name]; ok {
 | |
| 		return info.Metadata, nil
 | |
| 	}
 | |
| 
 | |
| 	// Check if it's a method name.
 | |
| 	pos := strings.LastIndex(name, ".")
 | |
| 	// Not a valid method name.
 | |
| 	if pos == -1 {
 | |
| 		return nil, fmt.Errorf("unknown symbol: %v", name)
 | |
| 	}
 | |
| 
 | |
| 	info, ok := s.serviceInfo[name[:pos]]
 | |
| 	// Substring before last "." is not a service name.
 | |
| 	if !ok {
 | |
| 		return nil, fmt.Errorf("unknown symbol: %v", name)
 | |
| 	}
 | |
| 
 | |
| 	// Search the method name in info.Methods.
 | |
| 	var found bool
 | |
| 	for _, m := range info.Methods {
 | |
| 		if m.Name == name[pos+1:] {
 | |
| 			found = true
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 	if found {
 | |
| 		return info.Metadata, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, fmt.Errorf("unknown symbol: %v", name)
 | |
| }
 | |
| 
 | |
| // parseMetadata finds the file descriptor bytes specified meta.
 | |
| // For SupportPackageIsVersion4, m is the name of the proto file, we
 | |
| // call proto.FileDescriptor to get the byte slice.
 | |
| // For SupportPackageIsVersion3, m is a byte slice itself.
 | |
| func parseMetadata(meta interface{}) ([]byte, bool) {
 | |
| 	// Check if meta is the file name.
 | |
| 	if fileNameForMeta, ok := meta.(string); ok {
 | |
| 		return proto.FileDescriptor(fileNameForMeta), true
 | |
| 	}
 | |
| 
 | |
| 	// Check if meta is the byte slice.
 | |
| 	if enc, ok := meta.([]byte); ok {
 | |
| 		return enc, true
 | |
| 	}
 | |
| 
 | |
| 	return nil, false
 | |
| }
 | |
| 
 | |
| // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
 | |
| // does marshalling on it and returns the marshalled result.
 | |
| // The given symbol can be a type, a service or a method.
 | |
| func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
 | |
| 	var (
 | |
| 		fd *dpb.FileDescriptorProto
 | |
| 	)
 | |
| 	// Check if it's a type name.
 | |
| 	if st, err := s.typeForName(name); err == nil {
 | |
| 		fd, err = s.fileDescForType(st)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	} else { // Check if it's a service name or a method name.
 | |
| 		meta, err := s.serviceMetadataForSymbol(name)
 | |
| 
 | |
| 		// Metadata not found.
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		// Metadata not valid.
 | |
| 		enc, ok := parseMetadata(meta)
 | |
| 		if !ok {
 | |
| 			return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name)
 | |
| 		}
 | |
| 
 | |
| 		fd, err = s.decodeFileDesc(enc)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return proto.Marshal(fd)
 | |
| }
 | |
| 
 | |
| // fileDescEncodingContainingExtension finds the file descriptor containing given extension,
 | |
| // does marshalling on it and returns the marshalled result.
 | |
| func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
 | |
| 	st, err := s.typeForName(typeName)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	fd, err := s.fileDescContainingExtension(st, extNum)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return proto.Marshal(fd)
 | |
| }
 | |
| 
 | |
| // allExtensionNumbersForTypeName returns all extension numbers for the given type.
 | |
| func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
 | |
| 	st, err := s.typeForName(name)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	extNums, err := s.allExtensionNumbersForType(st)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return extNums, nil
 | |
| }
 | |
| 
 | |
| // ServerReflectionInfo is the reflection service handler.
 | |
| func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
 | |
| 	for {
 | |
| 		in, err := stream.Recv()
 | |
| 		if err == io.EOF {
 | |
| 			return nil
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		out := &rpb.ServerReflectionResponse{
 | |
| 			ValidHost:       in.Host,
 | |
| 			OriginalRequest: in,
 | |
| 		}
 | |
| 		switch req := in.MessageRequest.(type) {
 | |
| 		case *rpb.ServerReflectionRequest_FileByFilename:
 | |
| 			b, err := s.fileDescEncodingByFilename(req.FileByFilename)
 | |
| 			if err != nil {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
 | |
| 					ErrorResponse: &rpb.ErrorResponse{
 | |
| 						ErrorCode:    int32(codes.NotFound),
 | |
| 						ErrorMessage: err.Error(),
 | |
| 					},
 | |
| 				}
 | |
| 			} else {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
 | |
| 					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
 | |
| 				}
 | |
| 			}
 | |
| 		case *rpb.ServerReflectionRequest_FileContainingSymbol:
 | |
| 			b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
 | |
| 			if err != nil {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
 | |
| 					ErrorResponse: &rpb.ErrorResponse{
 | |
| 						ErrorCode:    int32(codes.NotFound),
 | |
| 						ErrorMessage: err.Error(),
 | |
| 					},
 | |
| 				}
 | |
| 			} else {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
 | |
| 					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
 | |
| 				}
 | |
| 			}
 | |
| 		case *rpb.ServerReflectionRequest_FileContainingExtension:
 | |
| 			typeName := req.FileContainingExtension.ContainingType
 | |
| 			extNum := req.FileContainingExtension.ExtensionNumber
 | |
| 			b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
 | |
| 			if err != nil {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
 | |
| 					ErrorResponse: &rpb.ErrorResponse{
 | |
| 						ErrorCode:    int32(codes.NotFound),
 | |
| 						ErrorMessage: err.Error(),
 | |
| 					},
 | |
| 				}
 | |
| 			} else {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
 | |
| 					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
 | |
| 				}
 | |
| 			}
 | |
| 		case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
 | |
| 			extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
 | |
| 			if err != nil {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
 | |
| 					ErrorResponse: &rpb.ErrorResponse{
 | |
| 						ErrorCode:    int32(codes.NotFound),
 | |
| 						ErrorMessage: err.Error(),
 | |
| 					},
 | |
| 				}
 | |
| 			} else {
 | |
| 				out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
 | |
| 					AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
 | |
| 						BaseTypeName:    req.AllExtensionNumbersOfType,
 | |
| 						ExtensionNumber: extNums,
 | |
| 					},
 | |
| 				}
 | |
| 			}
 | |
| 		case *rpb.ServerReflectionRequest_ListServices:
 | |
| 			if s.serviceInfo == nil {
 | |
| 				s.serviceInfo = s.s.GetServiceInfo()
 | |
| 			}
 | |
| 			serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo))
 | |
| 			for n := range s.serviceInfo {
 | |
| 				serviceResponses = append(serviceResponses, &rpb.ServiceResponse{
 | |
| 					Name: n,
 | |
| 				})
 | |
| 			}
 | |
| 			out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
 | |
| 				ListServicesResponse: &rpb.ListServiceResponse{
 | |
| 					Service: serviceResponses,
 | |
| 				},
 | |
| 			}
 | |
| 		default:
 | |
| 			return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
 | |
| 		}
 | |
| 
 | |
| 		if err := stream.Send(out); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| }
 | 
