Avoid reliance on fs.ErrClosed in SparseWriter users

Neither of the SparseWriter users actually _wants_ the underlying
WriteSeeker to be closed; so, don't.

That makes it clear where the responsibility for closing the file
lies, and allows us to remove the reliance on the destinations
reliably returning ErrClosed.

Signed-off-by: Miloslav Trmač <mitr@redhat.com>
This commit is contained in:
Miloslav Trmač
2024-03-07 14:23:29 +01:00
parent 4c6505be5f
commit 5e0b7e54c0
10 changed files with 18 additions and 24 deletions

View File

@ -1,7 +1,6 @@
package compression package compression
import ( import (
"errors"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@ -25,7 +24,7 @@ type decompressor interface {
compressedFileSize() int64 compressedFileSize() int64
compressedFileMode() os.FileMode compressedFileMode() os.FileMode
compressedFileReader() (io.ReadCloser, error) compressedFileReader() (io.ReadCloser, error)
decompress(w WriteSeekCloser, r io.Reader) error decompress(w io.WriteSeeker, r io.Reader) error
close() close()
} }
@ -99,7 +98,7 @@ func runDecompression(d decompressor, decompressedFilePath string) error {
return err return err
} }
defer func() { defer func() {
if err := decompressedFileWriter.Close(); err != nil && !errors.Is(err, os.ErrClosed) { if err := decompressedFileWriter.Close(); err != nil {
logrus.Warnf("Unable to to close destination file %s: %q", decompressedFilePath, err) logrus.Warnf("Unable to to close destination file %s: %q", decompressedFilePath, err)
} }
}() }()

View File

@ -43,7 +43,7 @@ func (d *genericDecompressor) compressedFileReader() (io.ReadCloser, error) {
return compressedFile, nil return compressedFile, nil
} }
func (d *genericDecompressor) decompress(w WriteSeekCloser, r io.Reader) error { func (d *genericDecompressor) decompress(w io.WriteSeeker, r io.Reader) error {
decompressedFileReader, _, err := compression.AutoDecompress(r) decompressedFileReader, _, err := compression.AutoDecompress(r)
if err != nil { if err != nil {
return err return err
@ -64,7 +64,7 @@ func (d *genericDecompressor) close() {
} }
} }
func (d *genericDecompressor) sparseOptimizedCopy(w WriteSeekCloser, r io.Reader) error { func (d *genericDecompressor) sparseOptimizedCopy(w io.WriteSeeker, r io.Reader) error {
var err error var err error
sparseWriter := NewSparseWriter(w) sparseWriter := NewSparseWriter(w)
defer func() { defer func() {

View File

@ -16,7 +16,7 @@ func newGzipDecompressor(compressedFilePath string) (*gzipDecompressor, error) {
return &gzipDecompressor{*d}, err return &gzipDecompressor{*d}, err
} }
func (d *gzipDecompressor) decompress(w WriteSeekCloser, r io.Reader) error { func (d *gzipDecompressor) decompress(w io.WriteSeeker, r io.Reader) error {
gzReader, err := image.GzipDecompressor(r) gzReader, err := image.GzipDecompressor(r)
if err != nil { if err != nil {
return err return err

View File

@ -14,14 +14,18 @@ type WriteSeekCloser interface {
} }
type sparseWriter struct { type sparseWriter struct {
file WriteSeekCloser file io.WriteSeeker
// Invariant between method calls: // Invariant between method calls:
// The contents of the file match the contents passed to Write, except that pendingZeroes trailing zeroes have not been written. // The contents of the file match the contents passed to Write, except that pendingZeroes trailing zeroes have not been written.
// Also, the data that _has_ been written does not end with a zero byte (i.e. pendingZeroes is the largest possible value. // Also, the data that _has_ been written does not end with a zero byte (i.e. pendingZeroes is the largest possible value.
pendingZeroes int64 pendingZeroes int64
} }
func NewSparseWriter(file WriteSeekCloser) *sparseWriter { // NewSparseWriter returns a WriteCloser for underlying file which creates
// holes where appropriate.
// NOTE: The caller must .Close() both the returned sparseWriter AND the underlying file,
// in that order.
func NewSparseWriter(file io.WriteSeeker) *sparseWriter {
return &sparseWriter{ return &sparseWriter{
file: file, file: file,
pendingZeroes: 0, pendingZeroes: 0,
@ -121,18 +125,15 @@ func (sw *sparseWriter) Close() error {
if sw.pendingZeroes != 0 { if sw.pendingZeroes != 0 {
if holeSize := sw.pendingZeroes - 1; holeSize >= zerosThreshold { if holeSize := sw.pendingZeroes - 1; holeSize >= zerosThreshold {
if err := sw.createHole(holeSize); err != nil { if err := sw.createHole(holeSize); err != nil {
sw.file.Close()
return err return err
} }
sw.pendingZeroes -= holeSize sw.pendingZeroes -= holeSize
} }
var zeroArray [zerosThreshold]byte var zeroArray [zerosThreshold]byte
if _, err := sw.file.Write(zeroArray[:sw.pendingZeroes]); err != nil { if _, err := sw.file.Write(zeroArray[:sw.pendingZeroes]); err != nil {
sw.file.Close()
return err return err
} }
} }
err := sw.file.Close()
sw.file = nil sw.file = nil
return err return nil
} }

View File

@ -58,10 +58,6 @@ func (m *memorySparseFile) Write(b []byte) (n int, err error) {
return n, err return n, err
} }
func (m *memorySparseFile) Close() error {
return nil
}
func testInputWithWriteLen(t *testing.T, input []byte, minSparse int64, chunkSize int) { func testInputWithWriteLen(t *testing.T, input []byte, minSparse int64, chunkSize int) {
m := &memorySparseFile{} m := &memorySparseFile{}
sparseWriter := NewSparseWriter(m) sparseWriter := NewSparseWriter(m)

View File

@ -13,6 +13,6 @@ func newUncompressedDecompressor(compressedFilePath string) (*uncompressedDecomp
return &uncompressedDecompressor{*d}, err return &uncompressedDecompressor{*d}, err
} }
func (d *uncompressedDecompressor) decompress(w WriteSeekCloser, r io.Reader) error { func (d *uncompressedDecompressor) decompress(w io.WriteSeeker, r io.Reader) error {
return d.sparseOptimizedCopy(w, r) return d.sparseOptimizedCopy(w, r)
} }

View File

@ -22,7 +22,7 @@ func newXzDecompressor(compressedFilePath string) (*xzDecompressor, error) {
// Will error out if file without .Xz already exists // Will error out if file without .Xz already exists
// Maybe extracting then renaming is a good idea here.. // Maybe extracting then renaming is a good idea here..
// depends on Xz: not pre-installed on mac, so it becomes a brew dependency // depends on Xz: not pre-installed on mac, so it becomes a brew dependency
func (*xzDecompressor) decompress(w WriteSeekCloser, r io.Reader) error { func (*xzDecompressor) decompress(w io.WriteSeeker, r io.Reader) error {
var cmd *exec.Cmd var cmd *exec.Cmd
var read io.Reader var read io.Reader

View File

@ -40,7 +40,7 @@ func (d *zipDecompressor) compressedFileReader() (io.ReadCloser, error) {
return z, nil return z, nil
} }
func (*zipDecompressor) decompress(w WriteSeekCloser, r io.Reader) error { func (*zipDecompressor) decompress(w io.WriteSeeker, r io.Reader) error {
_, err := io.Copy(w, r) _, err := io.Copy(w, r)
return err return err
} }

View File

@ -15,7 +15,7 @@ func newZstdDecompressor(compressedFilePath string) (*zstdDecompressor, error) {
return &zstdDecompressor{*d}, err return &zstdDecompressor{*d}, err
} }
func (d *zstdDecompressor) decompress(w WriteSeekCloser, r io.Reader) error { func (d *zstdDecompressor) decompress(w io.WriteSeeker, r io.Reader) error {
zstdReader, err := zstd.NewReader(r) zstdReader, err := zstd.NewReader(r)
if err != nil { if err != nil {
return err return err

View File

@ -1,10 +1,8 @@
package e2e_test package e2e_test
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@ -121,7 +119,7 @@ func setup() (string, *machineTestBuilder) {
Fail(fmt.Sprintf("failed to create file %s: %q", mb.imagePath, err)) Fail(fmt.Sprintf("failed to create file %s: %q", mb.imagePath, err))
} }
defer func() { defer func() {
if err := dest.Close(); err != nil && !errors.Is(err, fs.ErrClosed) { if err := dest.Close(); err != nil {
fmt.Printf("failed to close destination file %q: %q\n", dest.Name(), err) fmt.Printf("failed to close destination file %q: %q\n", dest.Name(), err)
} }
}() }()
@ -161,7 +159,7 @@ func teardown(origHomeDir string, testDir string, mb *machineTestBuilder) {
} }
// copySparse is a helper method for tests only; caller is responsible for closures // copySparse is a helper method for tests only; caller is responsible for closures
func copySparse(dst compression.WriteSeekCloser, src io.Reader) error { func copySparse(dst io.WriteSeeker, src io.Reader) error {
spWriter := compression.NewSparseWriter(dst) spWriter := compression.NewSparseWriter(dst)
defer spWriter.Close() defer spWriter.Close()
_, err := io.Copy(spWriter, src) _, err := io.Copy(spWriter, src)