server: improve chained interceptors performance (#4524)

This commit is contained in:
Aliaksandr Mianzhynski
2021-06-25 08:11:47 +03:00
committed by GitHub
parent e24ede5936
commit 9b2fa9f8d3
2 changed files with 83 additions and 22 deletions

View File

@ -1115,22 +1115,24 @@ func chainUnaryServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler))
}
chainedInt = chainUnaryInterceptors(interceptors)
}
s.opts.unaryInt = chainedInt
}
// getChainUnaryHandler recursively generate the chained UnaryHandler
func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(ctx context.Context, req interface{}) (interface{}, error) {
return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler))
func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) {
var i int
var next UnaryHandler
next = func(ctx context.Context, req interface{}) (interface{}, error) {
if i == len(interceptors)-1 {
return interceptors[i](ctx, req, info, handler)
}
i++
return interceptors[i-1](ctx, req, info, next)
}
return next(ctx, req)
}
}
@ -1398,22 +1400,24 @@ func chainStreamServerInterceptors(s *Server) {
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler))
}
chainedInt = chainStreamInterceptors(interceptors)
}
s.opts.streamInt = chainedInt
}
// getChainStreamHandler recursively generate the chained StreamHandler
func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler {
if curr == len(interceptors)-1 {
return finalHandler
}
return func(srv interface{}, ss ServerStream) error {
return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler))
func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor {
return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error {
var i int
var next StreamHandler
next = func(srv interface{}, ss ServerStream) error {
if i == len(interceptors)-1 {
return interceptors[i](srv, ss, info, handler)
}
i++
return interceptors[i-1](srv, ss, info, next)
}
return next(srv, ss)
}
}