mirror of
				https://github.com/cloudreve/cloudreve.git
				synced 2025-11-04 04:47:24 +08:00 
			
		
		
		
	Modify: add general ReaderCloserSeeker interface for handler GET method to return
This commit is contained in:
		@ -4,6 +4,7 @@ import (
 | 
				
			|||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	model "github.com/HFO4/cloudreve/models"
 | 
						model "github.com/HFO4/cloudreve/models"
 | 
				
			||||||
	"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
 | 
						"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
 | 
				
			||||||
 | 
						"github.com/HFO4/cloudreve/pkg/filesystem/response"
 | 
				
			||||||
	"github.com/HFO4/cloudreve/pkg/serializer"
 | 
						"github.com/HFO4/cloudreve/pkg/serializer"
 | 
				
			||||||
	"github.com/HFO4/cloudreve/pkg/util"
 | 
						"github.com/HFO4/cloudreve/pkg/util"
 | 
				
			||||||
	"github.com/juju/ratelimit"
 | 
						"github.com/juju/ratelimit"
 | 
				
			||||||
@ -18,7 +19,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// 限速后的ReaderSeeker
 | 
					// 限速后的ReaderSeeker
 | 
				
			||||||
type lrs struct {
 | 
					type lrs struct {
 | 
				
			||||||
	io.ReadSeeker
 | 
						response.RSCloser
 | 
				
			||||||
	r io.Reader
 | 
						r io.Reader
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,7 +28,7 @@ func (r lrs) Read(p []byte) (int, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// withSpeedLimit 给原有的ReadSeeker加上限速
 | 
					// withSpeedLimit 给原有的ReadSeeker加上限速
 | 
				
			||||||
func (fs *FileSystem) withSpeedLimit(rs io.ReadSeeker) io.ReadSeeker {
 | 
					func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
 | 
				
			||||||
	// 如果用户组有速度限制,就返回限制流速的ReaderSeeker
 | 
						// 如果用户组有速度限制,就返回限制流速的ReaderSeeker
 | 
				
			||||||
	if fs.User.Group.SpeedLimit != 0 {
 | 
						if fs.User.Group.SpeedLimit != 0 {
 | 
				
			||||||
		speed := fs.User.Group.SpeedLimit
 | 
							speed := fs.User.Group.SpeedLimit
 | 
				
			||||||
@ -63,7 +64,7 @@ func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder) (*model
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetPhysicalFileContent 根据文件物理路径获取文件流
 | 
					// GetPhysicalFileContent 根据文件物理路径获取文件流
 | 
				
			||||||
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (io.ReadSeeker, error) {
 | 
					func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
 | 
				
			||||||
	// 重设上传策略
 | 
						// 重设上传策略
 | 
				
			||||||
	fs.Policy = &model.Policy{Type: "local"}
 | 
						fs.Policy = &model.Policy{Type: "local"}
 | 
				
			||||||
	_ = fs.dispatchHandler()
 | 
						_ = fs.dispatchHandler()
 | 
				
			||||||
@ -78,7 +79,7 @@ func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetDownloadContent 获取用于下载的文件流
 | 
					// GetDownloadContent 获取用于下载的文件流
 | 
				
			||||||
func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.ReadSeeker, error) {
 | 
					func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (response.RSCloser, error) {
 | 
				
			||||||
	// 获取原始文件流
 | 
						// 获取原始文件流
 | 
				
			||||||
	rs, err := fs.GetContent(ctx, path)
 | 
						rs, err := fs.GetContent(ctx, path)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -91,7 +92,7 @@ func (fs *FileSystem) GetDownloadContent(ctx context.Context, path string) (io.R
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetContent 获取文件内容,path为虚拟路径
 | 
					// GetContent 获取文件内容,path为虚拟路径
 | 
				
			||||||
func (fs *FileSystem) GetContent(ctx context.Context, path string) (io.ReadSeeker, error) {
 | 
					func (fs *FileSystem) GetContent(ctx context.Context, path string) (response.RSCloser, error) {
 | 
				
			||||||
	// 触发`下载前`钩子
 | 
						// 触发`下载前`钩子
 | 
				
			||||||
	err := fs.Trigger(ctx, fs.BeforeFileDownload)
 | 
						err := fs.Trigger(ctx, fs.BeforeFileDownload)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -27,7 +27,7 @@ type Handler interface {
 | 
				
			|||||||
	// 删除一个或多个文件
 | 
						// 删除一个或多个文件
 | 
				
			||||||
	Delete(ctx context.Context, files []string) ([]string, error)
 | 
						Delete(ctx context.Context, files []string) ([]string, error)
 | 
				
			||||||
	// 获取文件
 | 
						// 获取文件
 | 
				
			||||||
	Get(ctx context.Context, path string) (io.ReadSeeker, error)
 | 
						Get(ctx context.Context, path string) (response.RSCloser, error)
 | 
				
			||||||
	// 获取缩略图
 | 
						// 获取缩略图
 | 
				
			||||||
	Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
 | 
						Thumb(ctx context.Context, path string) (*response.ContentResponse, error)
 | 
				
			||||||
	// 获取外链地址,url
 | 
						// 获取外链地址,url
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,7 @@ type Handler struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get 获取文件内容
 | 
					// Get 获取文件内容
 | 
				
			||||||
func (handler Handler) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
 | 
					func (handler Handler) Get(ctx context.Context, path string) (response.RSCloser, error) {
 | 
				
			||||||
	// 打开文件
 | 
						// 打开文件
 | 
				
			||||||
	file, err := os.Open(path)
 | 
						file, err := os.Open(path)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -10,3 +10,9 @@ type ContentResponse struct {
 | 
				
			|||||||
	Content  io.ReadSeeker
 | 
						Content  io.ReadSeeker
 | 
				
			||||||
	URL      string
 | 
						URL      string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 存储策略适配器返回的文件流,有些策略需要带有Closer
 | 
				
			||||||
 | 
					type RSCloser interface {
 | 
				
			||||||
 | 
						io.ReadSeeker
 | 
				
			||||||
 | 
						io.Closer
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -22,9 +22,9 @@ type FileHeaderMock struct {
 | 
				
			|||||||
	testMock.Mock
 | 
						testMock.Mock
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m FileHeaderMock) Get(ctx context.Context, path string) (io.ReadSeeker, error) {
 | 
					func (m FileHeaderMock) Get(ctx context.Context, path string) (response.RSCloser, error) {
 | 
				
			||||||
	args := m.Called(ctx, path)
 | 
						args := m.Called(ctx, path)
 | 
				
			||||||
	return args.Get(0).(io.ReadSeeker), args.Error(1)
 | 
						return args.Get(0).(response.RSCloser), args.Error(1)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
 | 
					func (m FileHeaderMock) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error {
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,6 @@ import (
 | 
				
			|||||||
	"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
 | 
						"github.com/HFO4/cloudreve/pkg/filesystem/fsctx"
 | 
				
			||||||
	"github.com/HFO4/cloudreve/pkg/serializer"
 | 
						"github.com/HFO4/cloudreve/pkg/serializer"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -45,6 +44,7 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// 获取文件流
 | 
						// 获取文件流
 | 
				
			||||||
	rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
 | 
						rs, err := fs.GetPhysicalFileContent(ctx, zipPath.(string))
 | 
				
			||||||
 | 
						defer rs.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return serializer.Err(serializer.CodeNotSet, err.Error(), err)
 | 
							return serializer.Err(serializer.CodeNotSet, err.Error(), err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -58,11 +58,6 @@ func (service *DownloadService) DownloadArchived(ctx context.Context, c *gin.Con
 | 
				
			|||||||
	c.Header("Content-Type", "application/zip")
 | 
						c.Header("Content-Type", "application/zip")
 | 
				
			||||||
	http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
 | 
						http.ServeContent(c.Writer, c.Request, "", time.Now(), rs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 检查是否需要关闭文件
 | 
					 | 
				
			||||||
	if fc, ok := rs.(io.Closer); ok {
 | 
					 | 
				
			||||||
		err = fc.Close()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return serializer.Response{
 | 
						return serializer.Response{
 | 
				
			||||||
		Code: 0,
 | 
							Code: 0,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -84,6 +79,7 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// 获取文件流
 | 
						// 获取文件流
 | 
				
			||||||
	rs, err := fs.GetDownloadContent(ctx, "")
 | 
						rs, err := fs.GetDownloadContent(ctx, "")
 | 
				
			||||||
 | 
						defer rs.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return serializer.Err(serializer.CodeNotSet, err.Error(), err)
 | 
							return serializer.Err(serializer.CodeNotSet, err.Error(), err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -91,11 +87,6 @@ func (service *FileAnonymousGetService) Download(ctx context.Context, c *gin.Con
 | 
				
			|||||||
	// 发送文件
 | 
						// 发送文件
 | 
				
			||||||
	http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
 | 
						http.ServeContent(c.Writer, c.Request, service.Name, fs.FileTarget[0].UpdatedAt, rs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 检查是否需要关闭文件
 | 
					 | 
				
			||||||
	if fc, ok := rs.(io.Closer); ok {
 | 
					 | 
				
			||||||
		defer fc.Close()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return serializer.Response{
 | 
						return serializer.Response{
 | 
				
			||||||
		Code: 0,
 | 
							Code: 0,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -139,6 +130,7 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
 | 
				
			|||||||
	// 开始处理下载
 | 
						// 开始处理下载
 | 
				
			||||||
	ctx = context.WithValue(ctx, fsctx.GinCtx, c)
 | 
						ctx = context.WithValue(ctx, fsctx.GinCtx, c)
 | 
				
			||||||
	rs, err := fs.GetDownloadContent(ctx, "")
 | 
						rs, err := fs.GetDownloadContent(ctx, "")
 | 
				
			||||||
 | 
						defer rs.Close()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return serializer.Err(serializer.CodeNotSet, err.Error(), err)
 | 
							return serializer.Err(serializer.CodeNotSet, err.Error(), err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -154,11 +146,6 @@ func (service *DownloadService) Download(ctx context.Context, c *gin.Context) se
 | 
				
			|||||||
	// 发送文件
 | 
						// 发送文件
 | 
				
			||||||
	http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
 | 
						http.ServeContent(c.Writer, c.Request, "", fs.FileTarget[0].UpdatedAt, rs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 检查是否需要关闭文件
 | 
					 | 
				
			||||||
	if fc, ok := rs.(io.Closer); ok {
 | 
					 | 
				
			||||||
		defer fc.Close()
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return serializer.Response{
 | 
						return serializer.Response{
 | 
				
			||||||
		Code: 0,
 | 
							Code: 0,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user