mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-04 10:12:29 +08:00 
			
		
		
		
	caddyhttp: Refactor header matching
This allows response matchers to benefit from the same matching logic as the request header matchers (mainly prefix/suffix wildcards).
This commit is contained in:
		@ -17,7 +17,6 @@ package caddyhttp
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"log"
 | 
					 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/textproto"
 | 
						"net/textproto"
 | 
				
			||||||
@ -28,6 +27,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	"github.com/caddyserver/caddy/v2"
 | 
						"github.com/caddyserver/caddy/v2"
 | 
				
			||||||
	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
 | 
						"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
 | 
				
			||||||
 | 
						"go.uber.org/zap"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type (
 | 
					type (
 | 
				
			||||||
@ -105,7 +105,8 @@ type (
 | 
				
			|||||||
	MatchRemoteIP struct {
 | 
						MatchRemoteIP struct {
 | 
				
			||||||
		Ranges []string `json:"ranges,omitempty"`
 | 
							Ranges []string `json:"ranges,omitempty"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		cidrs []*net.IPNet
 | 
							cidrs  []*net.IPNet
 | 
				
			||||||
 | 
							logger *zap.Logger
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// MatchNot matches requests by negating the results of its matcher
 | 
						// MatchNot matches requests by negating the results of its matcher
 | 
				
			||||||
@ -410,23 +411,28 @@ func (m *MatchHeader) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Like req.Header.Get(), but that works with Host header.
 | 
					 | 
				
			||||||
// go's http module swallows "Host" header.
 | 
					 | 
				
			||||||
func getHeader(r *http.Request, field string) []string {
 | 
					 | 
				
			||||||
	field = textproto.CanonicalMIMEHeaderKey(field)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if field == "Host" {
 | 
					 | 
				
			||||||
		return []string{r.Host}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return r.Header[field]
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Match returns true if r matches m.
 | 
					// Match returns true if r matches m.
 | 
				
			||||||
func (m MatchHeader) Match(r *http.Request) bool {
 | 
					func (m MatchHeader) Match(r *http.Request) bool {
 | 
				
			||||||
	for field, allowedFieldVals := range m {
 | 
						return matchHeaders(r.Header, http.Header(m), r.Host)
 | 
				
			||||||
		actualFieldVals := getHeader(r, field)
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// getHeaderFieldVals returns the field values for the given fieldName from input.
 | 
				
			||||||
 | 
					// The host parameter should be obtained from the http.Request.Host field since
 | 
				
			||||||
 | 
					// net/http removes it from the header map.
 | 
				
			||||||
 | 
					func getHeaderFieldVals(input http.Header, fieldName, host string) []string {
 | 
				
			||||||
 | 
						fieldName = textproto.CanonicalMIMEHeaderKey(fieldName)
 | 
				
			||||||
 | 
						if fieldName == "Host" && host != "" {
 | 
				
			||||||
 | 
							return []string{host}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return input[fieldName]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// matchHeaders returns true if input matches the criteria in against without regex.
 | 
				
			||||||
 | 
					// The host parameter should be obtained from the http.Request.Host field since
 | 
				
			||||||
 | 
					// net/http removes it from the header map.
 | 
				
			||||||
 | 
					func matchHeaders(input, against http.Header, host string) bool {
 | 
				
			||||||
 | 
						for field, allowedFieldVals := range against {
 | 
				
			||||||
 | 
							actualFieldVals := getHeaderFieldVals(input, field, host)
 | 
				
			||||||
		if allowedFieldVals != nil && len(allowedFieldVals) == 0 && actualFieldVals != nil {
 | 
							if allowedFieldVals != nil && len(allowedFieldVals) == 0 && actualFieldVals != nil {
 | 
				
			||||||
			// a non-nil but empty list of allowed values means
 | 
								// a non-nil but empty list of allowed values means
 | 
				
			||||||
			// match if the header field exists at all
 | 
								// match if the header field exists at all
 | 
				
			||||||
@ -501,8 +507,7 @@ func (m *MatchHeaderRE) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 | 
				
			|||||||
// Match returns true if r matches m.
 | 
					// Match returns true if r matches m.
 | 
				
			||||||
func (m MatchHeaderRE) Match(r *http.Request) bool {
 | 
					func (m MatchHeaderRE) Match(r *http.Request) bool {
 | 
				
			||||||
	for field, rm := range m {
 | 
						for field, rm := range m {
 | 
				
			||||||
		actualFieldVals := getHeader(r, field)
 | 
							actualFieldVals := getHeaderFieldVals(r.Header, field, r.Host)
 | 
				
			||||||
 | 
					 | 
				
			||||||
		match := false
 | 
							match := false
 | 
				
			||||||
	fieldVal:
 | 
						fieldVal:
 | 
				
			||||||
		for _, actualFieldVal := range actualFieldVals {
 | 
							for _, actualFieldVal := range actualFieldVals {
 | 
				
			||||||
@ -700,6 +705,7 @@ func (m *MatchRemoteIP) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Provision parses m's IP ranges, either from IP or CIDR expressions.
 | 
					// Provision parses m's IP ranges, either from IP or CIDR expressions.
 | 
				
			||||||
func (m *MatchRemoteIP) Provision(ctx caddy.Context) error {
 | 
					func (m *MatchRemoteIP) Provision(ctx caddy.Context) error {
 | 
				
			||||||
 | 
						m.logger = ctx.Logger(m)
 | 
				
			||||||
	for _, str := range m.Ranges {
 | 
						for _, str := range m.Ranges {
 | 
				
			||||||
		if strings.Contains(str, "/") {
 | 
							if strings.Contains(str, "/") {
 | 
				
			||||||
			_, ipNet, err := net.ParseCIDR(str)
 | 
								_, ipNet, err := net.ParseCIDR(str)
 | 
				
			||||||
@ -748,7 +754,7 @@ func (m MatchRemoteIP) getClientIP(r *http.Request) (net.IP, error) {
 | 
				
			|||||||
func (m MatchRemoteIP) Match(r *http.Request) bool {
 | 
					func (m MatchRemoteIP) Match(r *http.Request) bool {
 | 
				
			||||||
	clientIP, err := m.getClientIP(r)
 | 
						clientIP, err := m.getClientIP(r)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Printf("[ERROR] remote_ip matcher: %v", err)
 | 
							m.logger.Error("getting client IP", zap.Error(err))
 | 
				
			||||||
		return false
 | 
							return false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	for _, ipRange := range m.cidrs {
 | 
						for _, ipRange := range m.cidrs {
 | 
				
			||||||
@ -859,7 +865,9 @@ type ResponseMatcher struct {
 | 
				
			|||||||
	// in that class (e.g. 3 for all 3xx codes).
 | 
						// in that class (e.g. 3 for all 3xx codes).
 | 
				
			||||||
	StatusCode []int `json:"status_code,omitempty"`
 | 
						StatusCode []int `json:"status_code,omitempty"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If set, each header specified must be one of the specified values.
 | 
						// If set, each header specified must be one of the
 | 
				
			||||||
 | 
						// specified values, with the same logic used by the
 | 
				
			||||||
 | 
						// request header matcher.
 | 
				
			||||||
	Headers http.Header `json:"headers,omitempty"`
 | 
						Headers http.Header `json:"headers,omitempty"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -868,7 +876,7 @@ func (rm ResponseMatcher) Match(statusCode int, hdr http.Header) bool {
 | 
				
			|||||||
	if !rm.matchStatusCode(statusCode) {
 | 
						if !rm.matchStatusCode(statusCode) {
 | 
				
			||||||
		return false
 | 
							return false
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return rm.matchHeaders(hdr)
 | 
						return matchHeaders(hdr, rm.Headers, "")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
 | 
					func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
 | 
				
			||||||
@ -883,31 +891,6 @@ func (rm ResponseMatcher) matchStatusCode(statusCode int) bool {
 | 
				
			|||||||
	return false
 | 
						return false
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (rm ResponseMatcher) matchHeaders(hdr http.Header) bool {
 | 
					 | 
				
			||||||
	for field, allowedFieldVals := range rm.Headers {
 | 
					 | 
				
			||||||
		actualFieldVals, fieldExists := hdr[textproto.CanonicalMIMEHeaderKey(field)]
 | 
					 | 
				
			||||||
		if allowedFieldVals != nil && len(allowedFieldVals) == 0 && fieldExists {
 | 
					 | 
				
			||||||
			// a non-nil but empty list of allowed values means
 | 
					 | 
				
			||||||
			// match if the header field exists at all
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		var match bool
 | 
					 | 
				
			||||||
	fieldVals:
 | 
					 | 
				
			||||||
		for _, actualFieldVal := range actualFieldVals {
 | 
					 | 
				
			||||||
			for _, allowedFieldVal := range allowedFieldVals {
 | 
					 | 
				
			||||||
				if actualFieldVal == allowedFieldVal {
 | 
					 | 
				
			||||||
					match = true
 | 
					 | 
				
			||||||
					break fieldVals
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if !match {
 | 
					 | 
				
			||||||
			return false
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
var wordRE = regexp.MustCompile(`\w+`)
 | 
					var wordRE = regexp.MustCompile(`\w+`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const regexpPlaceholderPrefix = "http.regexp"
 | 
					const regexpPlaceholderPrefix = "http.regexp"
 | 
				
			||||||
 | 
				
			|||||||
@ -448,6 +448,21 @@ func TestHeaderMatcher(t *testing.T) {
 | 
				
			|||||||
			input:  http.Header{"Field2": []string{"foo"}},
 | 
								input:  http.Header{"Field2": []string{"foo"}},
 | 
				
			||||||
			expect: false,
 | 
								expect: false,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								match:  MatchHeader{"Field1": []string{"foo*"}},
 | 
				
			||||||
 | 
								input:  http.Header{"Field1": []string{"foo"}},
 | 
				
			||||||
 | 
								expect: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								match:  MatchHeader{"Field1": []string{"foo*"}},
 | 
				
			||||||
 | 
								input:  http.Header{"Field1": []string{"asdf", "foobar"}},
 | 
				
			||||||
 | 
								expect: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								match:  MatchHeader{"Field1": []string{"*bar"}},
 | 
				
			||||||
 | 
								input:  http.Header{"Field1": []string{"asdf", "foobar"}},
 | 
				
			||||||
 | 
								expect: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			match:  MatchHeader{"host": []string{"localhost"}},
 | 
								match:  MatchHeader{"host": []string{"localhost"}},
 | 
				
			||||||
			input:  http.Header{},
 | 
								input:  http.Header{},
 | 
				
			||||||
@ -814,6 +829,24 @@ func TestResponseMatcher(t *testing.T) {
 | 
				
			|||||||
			hdr:    http.Header{"Foo": []string{"bar"}, "Foo2": []string{"baz"}},
 | 
								hdr:    http.Header{"Foo": []string{"bar"}, "Foo2": []string{"baz"}},
 | 
				
			||||||
			expect: true,
 | 
								expect: true,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								require: ResponseMatcher{
 | 
				
			||||||
 | 
									Headers: http.Header{
 | 
				
			||||||
 | 
										"Foo": []string{"foo*"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								hdr:    http.Header{"Foo": []string{"foobar"}},
 | 
				
			||||||
 | 
								expect: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								require: ResponseMatcher{
 | 
				
			||||||
 | 
									Headers: http.Header{
 | 
				
			||||||
 | 
										"Foo": []string{"foo*"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								hdr:    http.Header{"Foo": []string{"foobar"}},
 | 
				
			||||||
 | 
								expect: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	} {
 | 
						} {
 | 
				
			||||||
		actual := tc.require.Match(tc.status, tc.hdr)
 | 
							actual := tc.require.Match(tc.status, tc.hdr)
 | 
				
			||||||
		if actual != tc.expect {
 | 
							if actual != tc.expect {
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user