Add server.GetServiceInfo().

To replace server.ServiceMetadata() and server.AllServiceNames().
This commit is contained in:
Menghan Li
2016-06-23 16:37:55 -07:00
parent 439f11e63d
commit 26d2db5487
4 changed files with 254 additions and 154 deletions

View File

@ -69,7 +69,8 @@ import (
type serverReflectionServer struct { type serverReflectionServer struct {
s *grpc.Server s *grpc.Server
// TODO add cache if necessary // TODO add more cache if necessary
serviceInfo map[string]*grpc.ServiceInfo // cache for s.GetServiceInfo()
} }
// Register registers the server reflection service on the given gRPC server. // Register registers the server reflection service on the given gRPC server.
@ -188,6 +189,46 @@ func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte
return proto.Marshal(fd) return proto.Marshal(fd)
} }
// serviceMetadataForSymbol finds the metadata for name in s.serviceInfo.
// name should be a service name or a method name.
func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) {
if s.serviceInfo == nil {
s.serviceInfo = s.s.GetServiceInfo()
}
// Check if it's a service name.
if info, ok := s.serviceInfo[name]; ok {
return info.Metadata, nil
}
// Check if it's a method name.
pos := strings.LastIndex(name, ".")
// Not a valid method name.
if pos == -1 {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
info, ok := s.serviceInfo[name[:pos]]
// Substring before last "." is not a service name.
if !ok {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
// Search for method in info.
var found bool
for _, m := range info.Methods {
if m == name[pos+1:] {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
return info.Metadata, nil
}
// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
// does marshalling on it and returns the marshalled result. // does marshalling on it and returns the marshalled result.
// The given symbol can be a type, a service or a method. // The given symbol can be a type, a service or a method.
@ -201,28 +242,26 @@ func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) (
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else { // Check if it's a service name or a method name.
// Check if it's a service name. meta, err := s.serviceMetadataForSymbol(name)
meta := s.s.ServiceMetadata(name, "")
// Check if it's a method name. // Metadata not found.
if meta == nil { if err != nil {
if pos := strings.LastIndex(name, "."); pos != -1 { return nil, err
meta = s.s.ServiceMetadata(name[:pos], name[pos+1:])
}
} }
if meta != nil {
if enc, ok := meta.([]byte); ok { // Metadata not valid.
fd, err = s.decodeFileDesc(enc) enc, ok := meta.([]byte)
if err != nil { if !ok {
return nil, err return nil, fmt.Errorf("invalid file descriptor for symbol: %v")
} }
}
fd, err = s.decodeFileDesc(enc)
if err != nil {
return nil, err
} }
} }
if fd == nil {
return nil, fmt.Errorf("unknown symbol: %v", name)
}
return proto.Marshal(fd) return proto.Marshal(fd)
} }
@ -331,12 +370,14 @@ func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflectio
} }
} }
case *rpb.ServerReflectionRequest_ListServices: case *rpb.ServerReflectionRequest_ListServices:
services := s.s.AllServiceNames() if s.serviceInfo == nil {
serviceResponses := make([]*rpb.ServiceResponse, len(services)) s.serviceInfo = s.s.GetServiceInfo()
for i, s := range services { }
serviceResponses[i] = &rpb.ServiceResponse{ serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo))
Name: s, for n := range s.serviceInfo {
} serviceResponses = append(serviceResponses, &rpb.ServiceResponse{
Name: n,
})
} }
out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &rpb.ListServiceResponse{ ListServicesResponse: &rpb.ListServiceResponse{

View File

@ -92,7 +92,7 @@ func TestFileDescForType(t *testing.T) {
} { } {
fd, err := s.fileDescForType(test.st) fd, err := s.fileDescForType(test.st)
if err != nil || !reflect.DeepEqual(fd, test.wantFd) { if err != nil || !reflect.DeepEqual(fd, test.wantFd) {
t.Fatalf("fileDescForType(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.wantFd) t.Errorf("fileDescForType(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.wantFd)
} }
} }
} }
@ -106,7 +106,7 @@ func TestTypeForName(t *testing.T) {
} { } {
r, err := s.typeForName(test.name) r, err := s.typeForName(test.name)
if err != nil || r != test.want { if err != nil || r != test.want {
t.Fatalf("typeForName(%q) = %q, %v, want %q, <nil>", test.name, r, err, test.want) t.Errorf("typeForName(%q) = %q, %v, want %q, <nil>", test.name, r, err, test.want)
} }
} }
} }
@ -117,7 +117,7 @@ func TestTypeForNameNotFound(t *testing.T) {
} { } {
_, err := s.typeForName(test) _, err := s.typeForName(test)
if err == nil { if err == nil {
t.Fatalf("typeForName(%q) = _, %v, want _, <non-nil>", test, err) t.Errorf("typeForName(%q) = _, %v, want _, <non-nil>", test, err)
} }
} }
} }
@ -132,7 +132,7 @@ func TestFileDescContainingExtension(t *testing.T) {
} { } {
fd, err := s.fileDescContainingExtension(test.st, test.extNum) fd, err := s.fileDescContainingExtension(test.st, test.extNum)
if err != nil || !reflect.DeepEqual(fd, test.want) { if err != nil || !reflect.DeepEqual(fd, test.want) {
t.Fatalf("fileDescContainingExtension(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.want) t.Errorf("fileDescContainingExtension(%q) = %q, %v, want %q, <nil>", test.st, fd, err, test.want)
} }
} }
} }
@ -154,7 +154,7 @@ func TestAllExtensionNumbersForType(t *testing.T) {
r, err := s.allExtensionNumbersForType(test.st) r, err := s.allExtensionNumbersForType(test.st)
sort.Sort(intArray(r)) sort.Sort(intArray(r))
if err != nil || !reflect.DeepEqual(r, test.want) { if err != nil || !reflect.DeepEqual(r, test.want) {
t.Fatalf("allExtensionNumbersForType(%q) = %v, %v, want %v, <nil>", test.st, r, err, test.want) t.Errorf("allExtensionNumbersForType(%q) = %v, %v, want %v, <nil>", test.st, r, err, test.want)
} }
} }
} }
@ -194,9 +194,13 @@ func TestReflectionEnd2end(t *testing.T) {
stream, err := c.ServerReflectionInfo(context.Background()) stream, err := c.ServerReflectionInfo(context.Background())
testFileByFilename(t, stream) testFileByFilename(t, stream)
testFileByFilenameError(t, stream)
testFileContainingSymbol(t, stream) testFileContainingSymbol(t, stream)
testFileContainingSymbolError(t, stream)
testFileContainingExtension(t, stream) testFileContainingExtension(t, stream)
testFileContainingExtensionError(t, stream)
testAllExtensionNumbersOfType(t, stream) testAllExtensionNumbersOfType(t, stream)
testAllExtensionNumbersOfTypeError(t, stream)
testListServices(t, stream) testListServices(t, stream)
s.Stop() s.Stop()
@ -227,10 +231,37 @@ func testFileByFilename(t *testing.T, stream rpb.ServerReflection_ServerReflecti
switch r.MessageResponse.(type) { switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse: case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) { if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) {
t.Fatalf("FileByFilename\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) t.Errorf("FileByFilename(%v)\nreceived: %q,\nwant: %q", test.filename, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
} }
default: default:
t.Fatalf("FileByFilename = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", r.MessageResponse) t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.filename, r.MessageResponse)
}
}
}
func testFileByFilenameError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []string{
"test.poto",
"proo2.proto",
"proto2_et.proto",
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
FileByFilename: test,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("FileByFilename(%v) = %v, want type <ServerReflectionResponse_ErrorResponse>", test, r.MessageResponse)
} }
} }
} }
@ -261,10 +292,38 @@ func testFileContainingSymbol(t *testing.T, stream rpb.ServerReflection_ServerRe
switch r.MessageResponse.(type) { switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse: case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) { if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) {
t.Fatalf("FileContainingSymbol\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) t.Errorf("FileContainingSymbol(%v)\nreceived: %q,\nwant: %q", test.symbol, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
} }
default: default:
t.Fatalf("FileContainingSymbol = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", r.MessageResponse) t.Errorf("FileContainingSymbol(%v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.symbol, r.MessageResponse)
}
}
}
func testFileContainingSymbolError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []string{
"grpc.testing.SerchService",
"grpc.testing.SearchService.SearchE",
"grpc.tesing.SearchResponse",
"gpc.testing.ToBeExtened",
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{
FileContainingSymbol: test,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("FileContainingSymbol(%v) = %v, want type <ServerReflectionResponse_ErrorResponse>", test, r.MessageResponse)
} }
} }
} }
@ -296,10 +355,42 @@ func testFileContainingExtension(t *testing.T, stream rpb.ServerReflection_Serve
switch r.MessageResponse.(type) { switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_FileDescriptorResponse: case *rpb.ServerReflectionResponse_FileDescriptorResponse:
if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) { if !reflect.DeepEqual(r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) {
t.Fatalf("FileContainingExtension\nreceived: %q,\nwant: %q", r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want) t.Errorf("FileContainingExtension(%v, %v)\nreceived: %q,\nwant: %q", test.typeName, test.extNum, r.GetFileDescriptorResponse().FileDescriptorProto[0], test.want)
} }
default: default:
t.Fatalf("FileContainingExtension = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", r.MessageResponse) t.Errorf("FileContainingExtension(%v, %v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.typeName, test.extNum, r.MessageResponse)
}
}
}
func testFileContainingExtensionError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []struct {
typeName string
extNum int32
}{
{"grpc.testing.ToBExtened", 17},
{"grpc.testing.ToBeExtened", 15},
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{
FileContainingExtension: &rpb.ExtensionRequest{
ContainingType: test.typeName,
ExtensionNumber: test.extNum,
},
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("FileContainingExtension(%v, %v) = %v, want type <ServerReflectionResponse_FileDescriptorResponse>", test.typeName, test.extNum, r.MessageResponse)
} }
} }
} }
@ -330,10 +421,35 @@ func testAllExtensionNumbersOfType(t *testing.T, stream rpb.ServerReflection_Ser
sort.Sort(intArray(extNum)) sort.Sort(intArray(extNum))
if r.GetAllExtensionNumbersResponse().BaseTypeName != test.typeName || if r.GetAllExtensionNumbersResponse().BaseTypeName != test.typeName ||
!reflect.DeepEqual(extNum, test.want) { !reflect.DeepEqual(extNum, test.want) {
t.Fatalf("AllExtensionNumbersOfType\nreceived: %v,\nwant: {%q %v}", r.GetAllExtensionNumbersResponse(), test.typeName, test.want) t.Errorf("AllExtensionNumbersOfType(%v)\nreceived: %v,\nwant: {%q %v}", r.GetAllExtensionNumbersResponse(), test.typeName, test.typeName, test.want)
} }
default: default:
t.Fatalf("AllExtensionNumbersOfType = %v, want type <ServerReflectionResponse_AllExtensionNumbersResponse>", r.MessageResponse) t.Errorf("AllExtensionNumbersOfType(%v) = %v, want type <ServerReflectionResponse_AllExtensionNumbersResponse>", test.typeName, r.MessageResponse)
}
}
}
func testAllExtensionNumbersOfTypeError(t *testing.T, stream rpb.ServerReflection_ServerReflectionInfoClient) {
for _, test := range []string{
"grpc.testing.ToBeExtenedE",
} {
if err := stream.Send(&rpb.ServerReflectionRequest{
MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{
AllExtensionNumbersOfType: test,
},
}); err != nil {
t.Fatalf("failed to send request: %v", err)
}
r, err := stream.Recv()
if err != nil {
// io.EOF is not ok.
t.Fatalf("failed to recv response: %v", err)
}
switch r.MessageResponse.(type) {
case *rpb.ServerReflectionResponse_ErrorResponse:
default:
t.Errorf("AllExtensionNumbersOfType(%v) = %v, want type <ServerReflectionResponse_ErrorResponse>", test, r.MessageResponse)
} }
} }
} }
@ -356,7 +472,7 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection
want := []string{"grpc.testing.SearchService", "grpc.reflection.v1alpha.ServerReflection"} want := []string{"grpc.testing.SearchService", "grpc.reflection.v1alpha.ServerReflection"}
// Compare service names in response with want. // Compare service names in response with want.
if len(services) != len(want) { if len(services) != len(want) {
t.Fatalf("= %v, want service names: %v", services, want) t.Errorf("= %v, want service names: %v", services, want)
} }
m := make(map[string]int) m := make(map[string]int)
for _, e := range services { for _, e := range services {
@ -367,9 +483,9 @@ func testListServices(t *testing.T, stream rpb.ServerReflection_ServerReflection
m[e]-- m[e]--
continue continue
} }
t.Fatalf("ListService\nreceived: %v,\nwant: %q", services, want) t.Errorf("ListService\nreceived: %v,\nwant: %q", services, want)
} }
default: default:
t.Fatalf("ListServices = %v, want type <ServerReflectionResponse_ListServicesResponse>", r.MessageResponse) t.Errorf("ListServices = %v, want type <ServerReflectionResponse_ListServicesResponse>", r.MessageResponse)
} }
} }

View File

@ -245,32 +245,29 @@ func (s *Server) register(sd *ServiceDesc, ss interface{}) {
s.m[sd.ServiceName] = srv s.m[sd.ServiceName] = srv
} }
// ServiceMetadata returns the metadata for a service or method. // ServiceInfo contains method names and metadata for a service.
// service should be the full service name with package, in the form of <package>.<service>. type ServiceInfo struct {
// method should be the method name only. Methods []string
// If only service is important, method should be an empty string. Metadata interface{}
func (s *Server) ServiceMetadata(service, method string) interface{} {
// Check if service is registered.
if srv, ok := s.m[service]; ok {
if method == "" {
return srv.meta
}
// Check if method is part of service.
if _, ok := srv.md[method]; ok {
return srv.meta
}
if _, ok := srv.sd[method]; ok {
return srv.meta
}
}
return nil
} }
// AllServiceNames returns all the registered service names. // GetServiceInfo returns a map from service name to ServiceInfo.
func (s *Server) AllServiceNames() []string { // Service name includes the package name, in the form of <package>.<service>.
ret := make([]string, 0, len(s.m)) func (s *Server) GetServiceInfo() map[string]*ServiceInfo {
for k := range s.m { ret := make(map[string]*ServiceInfo)
ret = append(ret, k) for n, srv := range s.m {
methods := make([]string, 0, len(srv.md)+len(srv.sd))
for m := range srv.md {
methods = append(methods, m)
}
for m := range srv.sd {
methods = append(methods, m)
}
ret[n] = &ServiceInfo{
Methods: methods,
Metadata: srv.meta,
}
} }
return ret return ret
} }

View File

@ -44,29 +44,6 @@ type emptyServiceServer interface{}
type testServer struct{} type testServer struct{}
var (
testSd = ServiceDesc{
ServiceName: "grpc.testing.EmptyService",
HandlerType: (*emptyServiceServer)(nil),
Methods: []MethodDesc{
{
MethodName: "EmptyCall",
Handler: nil,
},
},
Streams: []StreamDesc{
{
StreamName: "EmptyStream",
Handler: nil,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: testFd,
}
testFd = []byte{0, 1, 2, 3}
)
func TestStopBeforeServe(t *testing.T) { func TestStopBeforeServe(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
if err != nil { if err != nil {
@ -88,73 +65,42 @@ func TestStopBeforeServe(t *testing.T) {
} }
} }
func TestServiceMetadata(t *testing.T) { func TestGetServiceInfo(t *testing.T) {
server := NewServer() testSd := ServiceDesc{
server.RegisterService(&testSd, &testServer{}) ServiceName: "grpc.testing.EmptyService",
for _, test := range []struct {
service string
method string
want []byte
}{
{"grpc.testing.EmptyService", "", testFd},
{"grpc.testing.EmptyService", "EmptyCall", testFd},
{"grpc.testing.EmptyService", "EmptyStream", testFd},
} {
meta := server.ServiceMetadata(test.service, test.method)
var (
fd []byte
ok bool
)
if fd, ok = meta.([]byte); !ok {
t.Errorf("ServiceMetadata(%q, %q) = %v, want %v", test.service, test.method, meta, test.want)
}
if !reflect.DeepEqual(fd, test.want) {
t.Errorf("ServiceMetadata(%q, %q) = %v, want %v", test.service, test.method, fd, test.want)
}
}
}
func TestServiceMetadataNotFound(t *testing.T) {
server := NewServer()
server.RegisterService(&testSd, &testServer{})
for _, test := range []struct {
service string
method string
}{
{"", "EmptyCall"},
{"grpc.EmptyService", ""},
{"grpc.EmptyService", "EmptyCall"},
{"grpc.testing.EmptyService", "EmptyCallWrong"},
{"grpc.testing.EmptyService", "EmptyStreamWrong"},
} {
meta := server.ServiceMetadata(test.service, test.method)
if meta != nil {
t.Errorf("ServiceMetadata(%q, %q) = %v, want <nil>", test.service, test.method, meta)
}
}
}
func TestAllServiceNames(t *testing.T) {
server := NewServer()
server.RegisterService(&testSd, &testServer{})
server.RegisterService(&ServiceDesc{
ServiceName: "another.EmptyService",
HandlerType: (*emptyServiceServer)(nil), HandlerType: (*emptyServiceServer)(nil),
}, &testServer{}) Methods: []MethodDesc{
services := server.AllServiceNames() {
want := []string{"grpc.testing.EmptyService", "another.EmptyService"} MethodName: "EmptyCall",
// Compare string slices. Handler: nil,
m := make(map[string]int) },
for _, s := range services { },
m[s]++ Streams: []StreamDesc{
{
StreamName: "EmptyStream",
Handler: nil,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: []int{0, 2, 1, 3},
} }
for _, s := range want {
if m[s] > 0 { server := NewServer()
m[s]-- server.RegisterService(&testSd, &testServer{})
continue
} info := server.GetServiceInfo()
t.Fatalf("AllServiceNames() = %q, want: %q", services, want) want := map[string]*ServiceInfo{
"grpc.testing.EmptyService": &ServiceInfo{
Methods: []string{
"EmptyCall",
"EmptyStream",
},
Metadata: []int{0, 2, 1, 3},
},
}
if !reflect.DeepEqual(info, want) {
t.Errorf("GetServiceInfo() = %q, want %q", info, want)
} }
} }