diff --git a/pkg/machine/compression/decompress.go b/pkg/machine/compression/decompress.go index 4d362c3a5d..56816d5b95 100644 --- a/pkg/machine/compression/decompress.go +++ b/pkg/machine/compression/decompress.go @@ -18,6 +18,7 @@ const ( macOs = "darwin" progressBarPrefix = "Extracting compressed file" zipExt = ".zip" + magicNumberMaxBytes = 10 ) type decompressor interface { @@ -30,24 +31,21 @@ type decompressor interface { func Decompress(compressedVMFile *define.VMFile, decompressedFilePath string) error { compressedFilePath := compressedVMFile.GetPath() - // Are we reading full image file? - // Only few bytes are read to detect - // the compression type - compressedFileContent, err := compressedVMFile.Read() + compressedFileMagicNum, err := compressedVMFile.ReadMagicNumber(magicNumberMaxBytes) if err != nil { return err } var d decompressor - if d, err = newDecompressor(compressedFilePath, compressedFileContent); err != nil { + if d, err = newDecompressor(compressedFilePath, compressedFileMagicNum); err != nil { return err } return runDecompression(d, decompressedFilePath) } -func newDecompressor(compressedFilePath string, compressedFileContent []byte) (decompressor, error) { - compressionType := archive.DetectCompression(compressedFileContent) +func newDecompressor(compressedFilePath string, compressedFileMagicNum []byte) (decompressor, error) { + compressionType := archive.DetectCompression(compressedFileMagicNum) os := runtime.GOOS hasZipSuffix := strings.HasSuffix(compressedFilePath, zipExt) diff --git a/pkg/machine/define/vmfile.go b/pkg/machine/define/vmfile.go index 1795a4dc5a..261cc78226 100644 --- a/pkg/machine/define/vmfile.go +++ b/pkg/machine/define/vmfile.go @@ -2,6 +2,7 @@ package define import ( "errors" + "io" "os" "path/filepath" "strconv" @@ -48,6 +49,22 @@ func (m *VMFile) Read() ([]byte, error) { return os.ReadFile(m.GetPath()) } +// Read the first n bytes of a given file and return in []bytes +func (m *VMFile) ReadMagicNumber(n int) ([]byte, error) { + f, err := os.Open(m.GetPath()) + if err != nil { + return nil, err + } + defer f.Close() + b := make([]byte, n) + n, err = io.ReadFull(f, b) + if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { + return b[:n], err + } else { + return b[:n], nil + } +} + // ReadPIDFrom a file and return as int. -1 means the pid file could not // be read or had something that could not be converted to an int in it func (m *VMFile) ReadPIDFrom() (int, error) {