diff --git a/reflection/serverreflection.go b/reflection/serverreflection.go index e59bfb95..78ee6069 100644 --- a/reflection/serverreflection.go +++ b/reflection/serverreflection.go @@ -69,7 +69,8 @@ import ( type serverReflectionServer struct { s *grpc.Server - // TODO add cache if necessary + // TODO add more cache if necessary + serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo() } // Register registers the server reflection service on the given gRPC server. @@ -188,6 +189,46 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte return proto.Marshal(fd) } +// serviceMetadataForSymbol finds the metadata for name in s.serviceInfo. +// name should be a service name or a method name. +func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) { + if s.serviceInfo == nil { + s.serviceInfo = s.s.GetServiceInfo() + } + + // Check if it's a service name. + if info, ok := s.serviceInfo[name]; ok { + return info.Metadata, nil + } + + // Check if it's a method name. + pos := strings.LastIndex(name, ".") + // Not a valid method name. + if pos == -1 { + return nil, fmt.Errorf("unknown symbol: %v", name) + } + + info, ok := s.serviceInfo[name[:pos]] + // Substring before last "." is not a service name. + if !ok { + return nil, fmt.Errorf("unknown symbol: %v", name) + } + + // Search for method in info. + var found bool + for _, m := range info.Methods { + if m == name[pos+1:] { + found = true + break + } + } + if !found { + return nil, fmt.Errorf("unknown symbol: %v", name) + } + + return info.Metadata, nil +} + // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, // does marshalling on it and returns the marshalled result. // The given symbol can be a type, a service or a method. @@ -201,28 +242,26 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ( if err != nil { return nil, err } - } else { - // Check if it's a service name. - meta := s.s.ServiceMetadata(name, "") - // Check if it's a method name. - if meta == nil { - if pos := strings.LastIndex(name, "."); pos != -1 { - meta = s.s.ServiceMetadata(name[:pos], name[pos+1:]) - } + } else { // Check if it's a service name or a method name. + meta, err := s.serviceMetadataForSymbol(name) + + // Metadata not found. + if err != nil { + return nil, err } - if meta != nil { - if enc, ok := meta.([]byte); ok { - fd, err = s.decodeFileDesc(enc) - if err != nil { - return nil, err - } - } + + // Metadata not valid. + enc, ok := meta.([]byte) + if !ok { + return nil, fmt.Errorf("invalid file descriptor for symbol: %v") + } + + fd, err = s.decodeFileDesc(enc) + if err != nil { + return nil, err } } - if fd == nil { - return nil, fmt.Errorf("unknown symbol: %v", name) - } return proto.Marshal(fd) } @@ -331,12 +370,14 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio } } case *rpb.ServerReflectionRequest_ListServices: - services := s.s.AllServiceNames() - serviceResponses := make([]*rpb.ServiceResponse, len(services)) - for i, s := range services { - serviceResponses[i] = &rpb.ServiceResponse{ - Name: s, - } + if s.serviceInfo == nil { + s.serviceInfo = s.s.GetServiceInfo() + } + serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo)) + for n := range s.serviceInfo { + serviceResponses = append(serviceResponses, &rpb.ServiceResponse{ + Name: n, + }) } out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ ListServicesResponse: &rpb.ListServiceResponse{ diff --git a/reflection/serverreflection_test.go b/reflection/serverreflection_test.go index 3bdb2936..aeb31e14 100644 --- a/reflection/serverreflection_test.go +++ b/reflection/serverreflection_test.go @@ -92,7 +92,7 @@ func TestFileDescForType(t *testing.T) { } { fd, err := s.fileDescForType(test.st) if err != nil || !reflect.DeepEqual(fd, test.wantFd) { - t.Fatalf("fileDescForType(%q) = %q, %v, want %q, ", test.st, fd, err, test.wantFd) + t.Errorf("fileDescForType(%q) = %q, %v, want %q, ", test.st, fd, err, test.wantFd) } } } @@ -106,7 +106,7 @@ func TestTypeForName(t *testing.T) { } { r, err := s.typeForName(test.name) if err != nil || r != test.want { - t.Fatalf("typeForName(%q) = %q, %v, want %q, ", test.name, r, err, test.want) + t.Errorf("typeForName(%q) = %q, %v, want %q, ", test.name, r, err, test.want) } } } @@ -117,7 +117,7 @@ func TestTypeForNameNotFound(t *testing.T) { } { _, err := s.typeForName(test) if err == nil { - t.Fatalf("typeForName(%q) = _, %v, want _, ", test, err) + t.Errorf("typeForName(%q) = _, %v, want _, ", test, err) } } } @@ -132,7 +132,7 @@ func TestFileDescContainingExtension(t *testing.T) { } { fd, err := s.fileDescContainingExtension(test.st, test.extNum) if err != nil || !reflect.DeepEqual(fd, test.want) { - t.Fatalf("fileDescContainingExtension(%q) = %q, %v, want %q, ", test.st, fd, err, test.want) + t.Errorf("fileDescContainingExtension(%q) = %q, %v, want %q, ", test.st, fd, err, test.want) } } } @@ -154,7 +154,7 @@ func TestAllExtensionNumbersForType(t *testing.T) { r, err := s.allExtensionNumbersForType(test.st) sort.Sort(intArray(r)) if err != nil || !reflect.DeepEqual(r, test.want) { - t.Fatalf("allExtensionNumbersForType(%q) = %v, %v, want %v, ", test.st, r, err, test.want) + t.Errorf("allExtensionNumbersForType(%q) = %v, %v, want %v, ", test.st, r, err, test.want) } } } @@ -194,9 +194,13 @@ func TestReflectionEnd2end(t *testing.T) { stream, err := c.ServerReflectionInfo(context.Background()) testFileByFilename(t, stream) + testFileByFilenameError(t, stream) testFileContainingSymbol(t, stream) + testFileContainingSymbolError(t, stream) testFileContainingExtension(t, stream) + testFileContainingExtensionError(t, stream) testAllExtensionNumbersOfType(t, stream) + testAllExtensionNumbersOfTypeError(t, stream) testListServices(t, stream) s.Stop() @@ -227,10 +231,37 @@ func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflecti switch r.MessageResponse.(type) { case *rpb.ServerReflectionResponse_FileDescriptorResponse: if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) { - t.Fatalf("FileByFilename\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) + t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", test.filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) } default: - t.Fatalf("FileByFilename = %v, want type ", r.MessageResponse) + t.Errorf("FileByFilename(%v) = %v, want type ", test.filename, r.MessageResponse) + } + } +} + +func testFileByFilenameError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) { + for _, test := range []string{ + "test.poto", + "proo2.proto", + "proto2_et.proto", + } { + if err := stream.Send(&rpb.ServerReflectionRequest{ + MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{ + FileByFilename: test, + }, + }); err != nil { + t.Fatalf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + + switch r.MessageResponse.(type) { + case *rpb.ServerReflectionResponse_ErrorResponse: + default: + t.Errorf("FileByFilename(%v) = %v, want type ", test, r.MessageResponse) } } } @@ -261,10 +292,38 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe switch r.MessageResponse.(type) { case *rpb.ServerReflectionResponse_FileDescriptorResponse: if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) { - t.Fatalf("FileContainingSymbol\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) + t.Errorf("FileContainingSymbol(%v)\nreceived: %q,\nwant: %q", test.symbol, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) } default: - t.Fatalf("FileContainingSymbol = %v, want type ", r.MessageResponse) + t.Errorf("FileContainingSymbol(%v) = %v, want type ", test.symbol, r.MessageResponse) + } + } +} + +func testFileContainingSymbolError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) { + for _, test := range []string{ + "grpc.testing.SerchService", + "grpc.testing.SearchService.SearchE", + "grpc.tesing.SearchResponse", + "gpc.testing.ToBeExtened", + } { + if err := stream.Send(&rpb.ServerReflectionRequest{ + MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{ + FileContainingSymbol: test, + }, + }); err != nil { + t.Fatalf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + + switch r.MessageResponse.(type) { + case *rpb.ServerReflectionResponse_ErrorResponse: + default: + t.Errorf("FileContainingSymbol(%v) = %v, want type ", test, r.MessageResponse) } } } @@ -296,10 +355,42 @@ func testFileContainingExtension(t *testing.T, stream rpb.ServerReflection_Serve switch r.MessageResponse.(type) { case *rpb.ServerReflectionResponse_FileDescriptorResponse: if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) { - t.Fatalf("FileContainingExtension\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) + t.Errorf("FileContainingExtension(%v, %v)\nreceived: %q,\nwant: %q", test.typeName, test.extNum, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) } default: - t.Fatalf("FileContainingExtension = %v, want type ", r.MessageResponse) + t.Errorf("FileContainingExtension(%v, %v) = %v, want type ", test.typeName, test.extNum, r.MessageResponse) + } + } +} + +func testFileContainingExtensionError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) { + for _, test := range []struct { + typeName string + extNum int32 + }{ + {"grpc.testing.ToBExtened", 17}, + {"grpc.testing.ToBeExtened", 15}, + } { + if err := stream.Send(&rpb.ServerReflectionRequest{ + MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{ + FileContainingExtension: &rpb.ExtensionRequest{ + ContainingType: test.typeName, + ExtensionNumber: test.extNum, + }, + }, + }); err != nil { + t.Fatalf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + + switch r.MessageResponse.(type) { + case *rpb.ServerReflectionResponse_ErrorResponse: + default: + t.Errorf("FileContainingExtension(%v, %v) = %v, want type ", test.typeName, test.extNum, r.MessageResponse) } } } @@ -330,10 +421,35 @@ func testAllExtensionNumbersOfType(t *testing.T, stream rpb.ServerReflection_Ser sort.Sort(intArray(extNum)) if r.GetAllExtensionNumbersResponse().BaseTypeName != test.typeName || !reflect.DeepEqual(extNum, test.want) { - t.Fatalf("AllExtensionNumbersOfType\nreceived: %v,\nwant: {%q %v}", r.GetAllExtensionNumbersResponse(), test.typeName, test.want) + t.Errorf("AllExtensionNumbersOfType(%v)\nreceived: %v,\nwant: {%q %v}", r.GetAllExtensionNumbersResponse(), test.typeName, test.typeName, test.want) } default: - t.Fatalf("AllExtensionNumbersOfType = %v, want type ", r.MessageResponse) + t.Errorf("AllExtensionNumbersOfType(%v) = %v, want type ", test.typeName, r.MessageResponse) + } + } +} + +func testAllExtensionNumbersOfTypeError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) { + for _, test := range []string{ + "grpc.testing.ToBeExtenedE", + } { + if err := stream.Send(&rpb.ServerReflectionRequest{ + MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{ + AllExtensionNumbersOfType: test, + }, + }); err != nil { + t.Fatalf("failed to send request: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + + switch r.MessageResponse.(type) { + case *rpb.ServerReflectionResponse_ErrorResponse: + default: + t.Errorf("AllExtensionNumbersOfType(%v) = %v, want type ", test, r.MessageResponse) } } } @@ -356,7 +472,7 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection want := []string{"grpc.testing.SearchService", "grpc.reflection.v1alpha.ServerReflection"} // Compare service names in response with want. if len(services) != len(want) { - t.Fatalf("= %v, want service names: %v", services, want) + t.Errorf("= %v, want service names: %v", services, want) } m := make(map[string]int) for _, e := range services { @@ -367,9 +483,9 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection m[e]-- continue } - t.Fatalf("ListService\nreceived: %v,\nwant: %q", services, want) + t.Errorf("ListService\nreceived: %v,\nwant: %q", services, want) } default: - t.Fatalf("ListServices = %v, want type ", r.MessageResponse) + t.Errorf("ListServices = %v, want type ", r.MessageResponse) } } diff --git a/server.go b/server.go index 20b3fac5..4bd8c7db 100644 --- a/server.go +++ b/server.go @@ -245,32 +245,29 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) { s.m[sd.ServiceName] = srv } -// ServiceMetadata returns the metadata for a service or method. -// service should be the full service name with package, in the form of .. -// method should be the method name only. -// If only service is important, method should be an empty string. -func (s *Server) ServiceMetadata(service, method string) interface{} { - // Check if service is registered. - if srv, ok := s.m[service]; ok { - if method == "" { - return srv.meta - } - // Check if method is part of service. - if _, ok := srv.md[method]; ok { - return srv.meta - } - if _, ok := srv.sd[method]; ok { - return srv.meta - } - } - return nil +// ServiceInfo contains method names and metadata for a service. +type ServiceInfo struct { + Methods []string + Metadata interface{} } -// AllServiceNames returns all the registered service names. -func (s *Server) AllServiceNames() []string { - ret := make([]string, 0, len(s.m)) - for k := range s.m { - ret = append(ret, k) +// GetServiceInfo returns a map from service name to ServiceInfo. +// Service name includes the package name, in the form of .. +func (s *Server) GetServiceInfo() map[string]*ServiceInfo { + ret := make(map[string]*ServiceInfo) + for n, srv := range s.m { + methods := make([]string, 0, len(srv.md)+len(srv.sd)) + for m := range srv.md { + methods = append(methods, m) + } + for m := range srv.sd { + methods = append(methods, m) + } + + ret[n] = &ServiceInfo{ + Methods: methods, + Metadata: srv.meta, + } } return ret } diff --git a/server_test.go b/server_test.go index 8bf08aa5..7c1e54dd 100644 --- a/server_test.go +++ b/server_test.go @@ -44,29 +44,6 @@ type emptyServiceServer interface{} type testServer struct{} -var ( - testSd = ServiceDesc{ - ServiceName: "grpc.testing.EmptyService", - HandlerType: (*emptyServiceServer)(nil), - Methods: []MethodDesc{ - { - MethodName: "EmptyCall", - Handler: nil, - }, - }, - Streams: []StreamDesc{ - { - StreamName: "EmptyStream", - Handler: nil, - ServerStreams: true, - ClientStreams: true, - }, - }, - Metadata: testFd, - } - testFd = []byte{0, 1, 2, 3} -) - func TestStopBeforeServe(t *testing.T) { lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -88,73 +65,42 @@ func TestStopBeforeServe(t *testing.T) { } } -func TestServiceMetadata(t *testing.T) { - server := NewServer() - server.RegisterService(&testSd, &testServer{}) - - for _, test := range []struct { - service string - method string - want []byte - }{ - {"grpc.testing.EmptyService", "", testFd}, - {"grpc.testing.EmptyService", "EmptyCall", testFd}, - {"grpc.testing.EmptyService", "EmptyStream", testFd}, - } { - meta := server.ServiceMetadata(test.service, test.method) - var ( - fd []byte - ok bool - ) - if fd, ok = meta.([]byte); !ok { - t.Errorf("ServiceMetadata(%q, %q) = %v, want %v", test.service, test.method, meta, test.want) - } - if !reflect.DeepEqual(fd, test.want) { - t.Errorf("ServiceMetadata(%q, %q) = %v, want %v", test.service, test.method, fd, test.want) - } - } -} - -func TestServiceMetadataNotFound(t *testing.T) { - server := NewServer() - server.RegisterService(&testSd, &testServer{}) - - for _, test := range []struct { - service string - method string - }{ - {"", "EmptyCall"}, - {"grpc.EmptyService", ""}, - {"grpc.EmptyService", "EmptyCall"}, - {"grpc.testing.EmptyService", "EmptyCallWrong"}, - {"grpc.testing.EmptyService", "EmptyStreamWrong"}, - } { - meta := server.ServiceMetadata(test.service, test.method) - if meta != nil { - t.Errorf("ServiceMetadata(%q, %q) = %v, want ", test.service, test.method, meta) - } - } -} - -func TestAllServiceNames(t *testing.T) { - server := NewServer() - server.RegisterService(&testSd, &testServer{}) - server.RegisterService(&ServiceDesc{ - ServiceName: "another.EmptyService", +func TestGetServiceInfo(t *testing.T) { + testSd := ServiceDesc{ + ServiceName: "grpc.testing.EmptyService", HandlerType: (*emptyServiceServer)(nil), - }, &testServer{}) - services := server.AllServiceNames() - want := []string{"grpc.testing.EmptyService", "another.EmptyService"} - // Compare string slices. - m := make(map[string]int) - for _, s := range services { - m[s]++ + Methods: []MethodDesc{ + { + MethodName: "EmptyCall", + Handler: nil, + }, + }, + Streams: []StreamDesc{ + { + StreamName: "EmptyStream", + Handler: nil, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: []int{0, 2, 1, 3}, } - for _, s := range want { - if m[s] > 0 { - m[s]-- - continue - } - t.Fatalf("AllServiceNames() = %q, want: %q", services, want) + + server := NewServer() + server.RegisterService(&testSd, &testServer{}) + + info := server.GetServiceInfo() + want := map[string]*ServiceInfo{ + "grpc.testing.EmptyService": &ServiceInfo{ + Methods: []string{ + "EmptyCall", + "EmptyStream", + }, + Metadata: []int{0, 2, 1, 3}, + }, + } + + if !reflect.DeepEqual(info, want) { + t.Errorf("GetServiceInfo() = %q, want %q", info, want) } }