mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-11-04 10:12:29 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			174 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			174 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package caddyhttp
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"net/http"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
)
 | 
						|
 | 
						|
type responseWriterSpy interface {
 | 
						|
	http.ResponseWriter
 | 
						|
	Written() string
 | 
						|
	CalledReadFrom() bool
 | 
						|
}
 | 
						|
 | 
						|
var (
 | 
						|
	_ responseWriterSpy = (*baseRespWriter)(nil)
 | 
						|
	_ responseWriterSpy = (*readFromRespWriter)(nil)
 | 
						|
)
 | 
						|
 | 
						|
// a barebones http.ResponseWriter mock
 | 
						|
type baseRespWriter []byte
 | 
						|
 | 
						|
func (brw *baseRespWriter) Write(d []byte) (int, error) {
 | 
						|
	*brw = append(*brw, d...)
 | 
						|
	return len(d), nil
 | 
						|
}
 | 
						|
func (brw *baseRespWriter) Header() http.Header        { return nil }
 | 
						|
func (brw *baseRespWriter) WriteHeader(statusCode int) {}
 | 
						|
func (brw *baseRespWriter) Written() string            { return string(*brw) }
 | 
						|
func (brw *baseRespWriter) CalledReadFrom() bool       { return false }
 | 
						|
 | 
						|
// an http.ResponseWriter mock that supports ReadFrom
 | 
						|
type readFromRespWriter struct {
 | 
						|
	baseRespWriter
 | 
						|
	called bool
 | 
						|
}
 | 
						|
 | 
						|
func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
 | 
						|
	rf.called = true
 | 
						|
	return io.Copy(&rf.baseRespWriter, r)
 | 
						|
}
 | 
						|
 | 
						|
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
 | 
						|
 | 
						|
func TestResponseWriterWrapperReadFrom(t *testing.T) {
 | 
						|
	tests := map[string]struct {
 | 
						|
		responseWriter responseWriterSpy
 | 
						|
		wantReadFrom   bool
 | 
						|
	}{
 | 
						|
		"no ReadFrom": {
 | 
						|
			responseWriter: &baseRespWriter{},
 | 
						|
			wantReadFrom:   false,
 | 
						|
		},
 | 
						|
		"has ReadFrom": {
 | 
						|
			responseWriter: &readFromRespWriter{},
 | 
						|
			wantReadFrom:   true,
 | 
						|
		},
 | 
						|
	}
 | 
						|
	for name, tt := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			// what we expect middlewares to do:
 | 
						|
			type myWrapper struct {
 | 
						|
				*ResponseWriterWrapper
 | 
						|
			}
 | 
						|
 | 
						|
			wrapped := myWrapper{
 | 
						|
				ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter},
 | 
						|
			}
 | 
						|
 | 
						|
			const srcData = "boo!"
 | 
						|
			// hides everything but Read, since strings.Reader implements WriteTo it would
 | 
						|
			// take precedence over our ReadFrom.
 | 
						|
			src := struct{ io.Reader }{strings.NewReader(srcData)}
 | 
						|
 | 
						|
			fmt.Println(name)
 | 
						|
			if _, err := io.Copy(wrapped, src); err != nil {
 | 
						|
				t.Errorf("Copy() err = %v", err)
 | 
						|
			}
 | 
						|
 | 
						|
			if got := tt.responseWriter.Written(); got != srcData {
 | 
						|
				t.Errorf("data = %q, want %q", got, srcData)
 | 
						|
			}
 | 
						|
 | 
						|
			if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
 | 
						|
				if tt.wantReadFrom {
 | 
						|
					t.Errorf("ReadFrom() should have been called")
 | 
						|
				} else {
 | 
						|
					t.Errorf("ReadFrom() should not have been called")
 | 
						|
				}
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestResponseWriterWrapperUnwrap(t *testing.T) {
 | 
						|
	w := &ResponseWriterWrapper{&baseRespWriter{}}
 | 
						|
 | 
						|
	if _, ok := w.Unwrap().(*baseRespWriter); !ok {
 | 
						|
		t.Errorf("Unwrap() doesn't return the underlying ResponseWriter")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestResponseRecorderReadFrom(t *testing.T) {
 | 
						|
	tests := map[string]struct {
 | 
						|
		responseWriter responseWriterSpy
 | 
						|
		shouldBuffer   bool
 | 
						|
		wantReadFrom   bool
 | 
						|
	}{
 | 
						|
		"buffered plain": {
 | 
						|
			responseWriter: &baseRespWriter{},
 | 
						|
			shouldBuffer:   true,
 | 
						|
			wantReadFrom:   false,
 | 
						|
		},
 | 
						|
		"streamed plain": {
 | 
						|
			responseWriter: &baseRespWriter{},
 | 
						|
			shouldBuffer:   false,
 | 
						|
			wantReadFrom:   false,
 | 
						|
		},
 | 
						|
		"buffered ReadFrom": {
 | 
						|
			responseWriter: &readFromRespWriter{},
 | 
						|
			shouldBuffer:   true,
 | 
						|
			wantReadFrom:   false,
 | 
						|
		},
 | 
						|
		"streamed ReadFrom": {
 | 
						|
			responseWriter: &readFromRespWriter{},
 | 
						|
			shouldBuffer:   false,
 | 
						|
			wantReadFrom:   true,
 | 
						|
		},
 | 
						|
	}
 | 
						|
	for name, tt := range tests {
 | 
						|
		t.Run(name, func(t *testing.T) {
 | 
						|
			var buf bytes.Buffer
 | 
						|
 | 
						|
			rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
 | 
						|
				return tt.shouldBuffer
 | 
						|
			})
 | 
						|
 | 
						|
			const srcData = "boo!"
 | 
						|
			// hides everything but Read, since strings.Reader implements WriteTo it would
 | 
						|
			// take precedence over our ReadFrom.
 | 
						|
			src := struct{ io.Reader }{strings.NewReader(srcData)}
 | 
						|
 | 
						|
			if _, err := io.Copy(rr, src); err != nil {
 | 
						|
				t.Errorf("Copy() err = %v", err)
 | 
						|
			}
 | 
						|
 | 
						|
			wantStreamed := srcData
 | 
						|
			wantBuffered := ""
 | 
						|
			if tt.shouldBuffer {
 | 
						|
				wantStreamed = ""
 | 
						|
				wantBuffered = srcData
 | 
						|
			}
 | 
						|
 | 
						|
			if got := tt.responseWriter.Written(); got != wantStreamed {
 | 
						|
				t.Errorf("streamed data = %q, want %q", got, wantStreamed)
 | 
						|
			}
 | 
						|
			if got := buf.String(); got != wantBuffered {
 | 
						|
				t.Errorf("buffered data = %q, want %q", got, wantBuffered)
 | 
						|
			}
 | 
						|
 | 
						|
			if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
 | 
						|
				if tt.wantReadFrom {
 | 
						|
					t.Errorf("ReadFrom() should have been called")
 | 
						|
				} else {
 | 
						|
					t.Errorf("ReadFrom() should not have been called")
 | 
						|
				}
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 |