diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index 6f6b4186..69869dfc 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -221,6 +221,40 @@ func (s *serverReflectionServer) fileDescWireFormatByFilename(name string) ([]by return b, nil } +func (s *serverReflectionServer) fileDescWireFormatContainingSymbol(name string) ([]byte, error) { + var ( + fd *dpb.FileDescriptorProto + ) + // Check if it's a type name. + if st, err := s.typeForName(name); err == nil { + fd, _, err = s.fileDescForType(st) + if err != nil { + return nil, err + } + } else { + // Check if it's a service name or method name. + meta := s.s.Metadata(name) + if meta != nil { + if enc, ok := meta.([]byte); ok { + fd, err = s.decodeFileDesc(enc) + if err != nil { + return nil, err + } + } + } + } + + // Marshal to wire format. + if fd != nil { + b, err := proto.Marshal(fd) + if err != nil { + return nil, err + } + return b, nil + } + return nil, fmt.Errorf("unknown symbol: %v", name) +} + func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) if !ok { @@ -254,6 +288,12 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } response = &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}} case *rpb.ServerReflectionRequest_FileContainingSymbol: + b, err := s.fileDescWireFormatContainingSymbol(req.FileContainingSymbol) + if err != nil { + // TODO grpc error or send message back + return err + } + response = &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}} case *rpb.ServerReflectionRequest_FileContainingExtension: case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: case *rpb.ServerReflectionRequest_ListServices: