diff --git a/transport/go16.go b/transport/go16.go index 7cffee11..5babcf9b 100644 --- a/transport/go16.go +++ b/transport/go16.go @@ -22,6 +22,7 @@ package transport import ( "net" + "net/http" "google.golang.org/grpc/codes" @@ -43,3 +44,8 @@ func ContextErr(err error) StreamError { } return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) } + +// contextFromRequest returns a background context. +func contextFromRequest(r *http.Request) context.Context { + return context.Background() +} diff --git a/transport/go17.go b/transport/go17.go index 2464e69f..b7fa6bdb 100644 --- a/transport/go17.go +++ b/transport/go17.go @@ -23,6 +23,7 @@ package transport import ( "context" "net" + "net/http" "google.golang.org/grpc/codes" @@ -44,3 +45,8 @@ func ContextErr(err error) StreamError { } return streamErrorf(codes.Internal, "Unexpected error from context packet: %v", err) } + +// contextFromRequest returns a context from the HTTP Request. +func contextFromRequest(r *http.Request) context.Context { + return r.Context() +} diff --git a/transport/handler_server.go b/transport/handler_server.go index 7e0fdb35..27c4ebb5 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -284,12 +284,12 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { // With this transport type there will be exactly 1 stream: this HTTP request. - var ctx context.Context + ctx := contextFromRequest(ht.req) var cancel context.CancelFunc if ht.timeoutSet { - ctx, cancel = context.WithTimeout(context.Background(), ht.timeout) + ctx, cancel = context.WithTimeout(ctx, ht.timeout) } else { - ctx, cancel = context.WithCancel(context.Background()) + ctx, cancel = context.WithCancel(ctx) } // requestOver is closed when either the request's context is done