mirror of
				https://github.com/cloudreve/cloudreve.git
				synced 2025-11-04 13:16:02 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			438 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			438 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package onedrive
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"github.com/cloudreve/Cloudreve/v3/pkg/auth"
 | 
						|
	"github.com/cloudreve/Cloudreve/v3/pkg/mq"
 | 
						|
	"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
 | 
						|
	"github.com/jinzhu/gorm"
 | 
						|
	"io"
 | 
						|
	"io/ioutil"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"strings"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	model "github.com/cloudreve/Cloudreve/v3/models"
 | 
						|
	"github.com/cloudreve/Cloudreve/v3/pkg/cache"
 | 
						|
	"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
 | 
						|
	"github.com/cloudreve/Cloudreve/v3/pkg/request"
 | 
						|
	"github.com/stretchr/testify/assert"
 | 
						|
	testMock "github.com/stretchr/testify/mock"
 | 
						|
)
 | 
						|
 | 
						|
func TestDriver_Token(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	h, _ := NewDriver(&model.Policy{
 | 
						|
		AccessKey:  "ak",
 | 
						|
		SecretKey:  "sk",
 | 
						|
		BucketName: "test",
 | 
						|
		Server:     "test.com",
 | 
						|
	})
 | 
						|
	handler := h.(Driver)
 | 
						|
 | 
						|
	// 分片上传 失败
 | 
						|
	{
 | 
						|
		cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
 | 
						|
		handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
		handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
		clientMock := ClientMock{}
 | 
						|
		clientMock.On(
 | 
						|
			"Request",
 | 
						|
			"POST",
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
		).Return(&request.Response{
 | 
						|
			Err: nil,
 | 
						|
			Response: &http.Response{
 | 
						|
				StatusCode: 400,
 | 
						|
				Body:       ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
 | 
						|
			},
 | 
						|
		})
 | 
						|
		handler.Client.Request = clientMock
 | 
						|
		res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{}, &fsctx.FileStream{})
 | 
						|
		asserts.Error(err)
 | 
						|
		asserts.Nil(res)
 | 
						|
	}
 | 
						|
 | 
						|
	// 分片上传 成功
 | 
						|
	{
 | 
						|
		cache.Set("setting_siteURL", "http://test.cloudreve.org", 0)
 | 
						|
		cache.Set("setting_onedrive_monitor_timeout", "600", 0)
 | 
						|
		cache.Set("setting_onedrive_callback_check", "20", 0)
 | 
						|
		handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
		handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
		handler.Client.Credential.AccessToken = "1"
 | 
						|
		clientMock := ClientMock{}
 | 
						|
		clientMock.On(
 | 
						|
			"Request",
 | 
						|
			"POST",
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
		).Return(&request.Response{
 | 
						|
			Err: nil,
 | 
						|
			Response: &http.Response{
 | 
						|
				StatusCode: 200,
 | 
						|
				Body:       ioutil.NopCloser(strings.NewReader(`{"uploadUrl":"123321"}`)),
 | 
						|
			},
 | 
						|
		})
 | 
						|
		handler.Client.Request = clientMock
 | 
						|
		go func() {
 | 
						|
			time.Sleep(time.Duration(1) * time.Second)
 | 
						|
			mq.GlobalMQ.Publish("TestDriver_Token", mq.Message{})
 | 
						|
		}()
 | 
						|
		res, err := handler.Token(context.Background(), 10, &serializer.UploadSession{Key: "TestDriver_Token"}, &fsctx.FileStream{})
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Equal("123321", res.UploadURLs[0])
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_Source(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
	cache.Set("setting_onedrive_source_timeout", "1800", 0)
 | 
						|
 | 
						|
	// 失败
 | 
						|
	{
 | 
						|
		res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0)
 | 
						|
		asserts.Error(err)
 | 
						|
		asserts.Empty(res)
 | 
						|
	}
 | 
						|
 | 
						|
	// 命中缓存 成功
 | 
						|
	{
 | 
						|
		handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
		handler.Client.Credential.AccessToken = "1"
 | 
						|
		cache.Set("onedrive_source_0_123.jpg", "res", 1)
 | 
						|
		res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 0, true, 0)
 | 
						|
		cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Equal("res", res)
 | 
						|
	}
 | 
						|
 | 
						|
	// 命中缓存 上下文存在文件 成功
 | 
						|
	{
 | 
						|
		file := model.File{}
 | 
						|
		file.ID = 1
 | 
						|
		file.UpdatedAt = time.Now()
 | 
						|
		ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
 | 
						|
		handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
		handler.Client.Credential.AccessToken = "1"
 | 
						|
		cache.Set(fmt.Sprintf("onedrive_source_file_%d_1", file.UpdatedAt.Unix()), "res", 0)
 | 
						|
		res, err := handler.Source(ctx, "123.jpg", url.URL{}, 1, true, 0)
 | 
						|
		cache.Deletes([]string{"0_123.jpg"}, "onedrive_source_")
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Equal("res", res)
 | 
						|
	}
 | 
						|
 | 
						|
	// 成功
 | 
						|
	{
 | 
						|
		handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
		clientMock := ClientMock{}
 | 
						|
		clientMock.On(
 | 
						|
			"Request",
 | 
						|
			"GET",
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
		).Return(&request.Response{
 | 
						|
			Err: nil,
 | 
						|
			Response: &http.Response{
 | 
						|
				StatusCode: 200,
 | 
						|
				Body:       ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
 | 
						|
			},
 | 
						|
		})
 | 
						|
		handler.Client.Request = clientMock
 | 
						|
		handler.Client.Credential.AccessToken = "1"
 | 
						|
		res, err := handler.Source(context.Background(), "123.jpg", url.URL{}, 1, true, 0)
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Equal("123321", res)
 | 
						|
	}
 | 
						|
 | 
						|
	// 成功 永久直链
 | 
						|
	{
 | 
						|
		file := model.File{}
 | 
						|
		file.ID = 1
 | 
						|
		file.Name = "123.jpg"
 | 
						|
		file.UpdatedAt = time.Now()
 | 
						|
		ctx := context.WithValue(context.Background(), fsctx.FileModelCtx, file)
 | 
						|
		handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
		auth.General = auth.HMACAuth{}
 | 
						|
		handler.Client.Credential.AccessToken = "1"
 | 
						|
		res, err := handler.Source(ctx, "123.jpg", url.URL{}, 0, true, 0)
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Contains(res, "/api/v3/file/source/1/123.jpg?sign")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_List(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.AccessToken = "AccessToken"
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
 | 
						|
	// 非递归
 | 
						|
	{
 | 
						|
		clientMock := ClientMock{}
 | 
						|
		clientMock.On(
 | 
						|
			"Request",
 | 
						|
			"GET",
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
		).Return(&request.Response{
 | 
						|
			Err: nil,
 | 
						|
			Response: &http.Response{
 | 
						|
				StatusCode: 200,
 | 
						|
				Body:       ioutil.NopCloser(strings.NewReader(`{"value":[{}]}`)),
 | 
						|
			},
 | 
						|
		})
 | 
						|
		handler.Client.Request = clientMock
 | 
						|
		res, err := handler.List(context.Background(), "/", false)
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Len(res, 1)
 | 
						|
	}
 | 
						|
 | 
						|
	// 递归一次
 | 
						|
	{
 | 
						|
		clientMock := ClientMock{}
 | 
						|
		clientMock.On(
 | 
						|
			"Request",
 | 
						|
			"GET",
 | 
						|
			"me/drive/root/children?$top=999999999",
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
		).Return(&request.Response{
 | 
						|
			Err: nil,
 | 
						|
			Response: &http.Response{
 | 
						|
				StatusCode: 200,
 | 
						|
				Body:       ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"1","folder":{}}]}`)),
 | 
						|
			},
 | 
						|
		})
 | 
						|
		clientMock.On(
 | 
						|
			"Request",
 | 
						|
			"GET",
 | 
						|
			"me/drive/root:/1:/children?$top=999999999",
 | 
						|
			testMock.Anything,
 | 
						|
			testMock.Anything,
 | 
						|
		).Return(&request.Response{
 | 
						|
			Err: nil,
 | 
						|
			Response: &http.Response{
 | 
						|
				StatusCode: 200,
 | 
						|
				Body:       ioutil.NopCloser(strings.NewReader(`{"value":[{"name":"2"}]}`)),
 | 
						|
			},
 | 
						|
		})
 | 
						|
		handler.Client.Request = clientMock
 | 
						|
		res, err := handler.List(context.Background(), "/", true)
 | 
						|
		asserts.NoError(err)
 | 
						|
		asserts.Len(res, 2)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_Thumb(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
 | 
						|
	// 失败
 | 
						|
	{
 | 
						|
		ctx := context.WithValue(context.Background(), fsctx.ThumbSizeCtx, [2]uint{10, 20})
 | 
						|
		ctx = context.WithValue(ctx, fsctx.FileModelCtx, model.File{PicInfo: "1,1", Model: gorm.Model{ID: 1}})
 | 
						|
		res, err := handler.Thumb(ctx, "123.jpg")
 | 
						|
		asserts.Error(err)
 | 
						|
		asserts.Empty(res.URL)
 | 
						|
	}
 | 
						|
 | 
						|
	// 上下文错误
 | 
						|
	{
 | 
						|
		_, err := handler.Thumb(context.Background(), "123.jpg")
 | 
						|
		asserts.Error(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_Delete(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
 | 
						|
	// 失败
 | 
						|
	{
 | 
						|
		_, err := handler.Delete(context.Background(), []string{"1"})
 | 
						|
		asserts.Error(err)
 | 
						|
	}
 | 
						|
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_Put(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
 | 
						|
	// 失败
 | 
						|
	{
 | 
						|
		err := handler.Put(context.Background(), &fsctx.FileStream{})
 | 
						|
		asserts.Error(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_Get(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
 | 
						|
	// 无法获取source
 | 
						|
	{
 | 
						|
		res, err := handler.Get(context.Background(), "123.txt")
 | 
						|
		asserts.Error(err)
 | 
						|
		asserts.Nil(res)
 | 
						|
	}
 | 
						|
 | 
						|
	// 成功
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
	clientMock := ClientMock{}
 | 
						|
	clientMock.On(
 | 
						|
		"Request",
 | 
						|
		"GET",
 | 
						|
		testMock.Anything,
 | 
						|
		testMock.Anything,
 | 
						|
		testMock.Anything,
 | 
						|
	).Return(&request.Response{
 | 
						|
		Err: nil,
 | 
						|
		Response: &http.Response{
 | 
						|
			StatusCode: 200,
 | 
						|
			Body:       ioutil.NopCloser(strings.NewReader(`{"@microsoft.graph.downloadUrl":"123321"}`)),
 | 
						|
		},
 | 
						|
	})
 | 
						|
	handler.Client.Request = clientMock
 | 
						|
	handler.Client.Credential.AccessToken = "1"
 | 
						|
 | 
						|
	driverClientMock := ClientMock{}
 | 
						|
	driverClientMock.On(
 | 
						|
		"Request",
 | 
						|
		"GET",
 | 
						|
		testMock.Anything,
 | 
						|
		testMock.Anything,
 | 
						|
		testMock.Anything,
 | 
						|
	).Return(&request.Response{
 | 
						|
		Err: nil,
 | 
						|
		Response: &http.Response{
 | 
						|
			StatusCode: 200,
 | 
						|
			Body:       ioutil.NopCloser(strings.NewReader(`123`)),
 | 
						|
		},
 | 
						|
	})
 | 
						|
	handler.HTTPClient = driverClientMock
 | 
						|
	res, err := handler.Get(context.Background(), "123.txt")
 | 
						|
	clientMock.AssertExpectations(t)
 | 
						|
	asserts.NoError(err)
 | 
						|
	_, err = res.Seek(0, io.SeekEnd)
 | 
						|
	asserts.NoError(err)
 | 
						|
	content, err := ioutil.ReadAll(res)
 | 
						|
	asserts.NoError(err)
 | 
						|
	asserts.Equal("123", string(content))
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_replaceSourceHost(t *testing.T) {
 | 
						|
	tests := []struct {
 | 
						|
		name    string
 | 
						|
		origin  string
 | 
						|
		cdn     string
 | 
						|
		want    string
 | 
						|
		wantErr bool
 | 
						|
	}{
 | 
						|
		{"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false},
 | 
						|
		{"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false},
 | 
						|
		{"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true},
 | 
						|
		{"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true},
 | 
						|
	}
 | 
						|
	for _, tt := range tests {
 | 
						|
		t.Run(tt.name, func(t *testing.T) {
 | 
						|
			policy := &model.Policy{}
 | 
						|
			policy.OptionsSerialized.OdProxy = tt.cdn
 | 
						|
			handler := Driver{
 | 
						|
				Policy: policy,
 | 
						|
			}
 | 
						|
			got, err := handler.replaceSourceHost(tt.origin)
 | 
						|
			if (err != nil) != tt.wantErr {
 | 
						|
				t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr)
 | 
						|
				return
 | 
						|
			}
 | 
						|
			if got != tt.want {
 | 
						|
				t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestDriver_CancelToken(t *testing.T) {
 | 
						|
	asserts := assert.New(t)
 | 
						|
	handler := Driver{
 | 
						|
		Policy: &model.Policy{
 | 
						|
			AccessKey:  "ak",
 | 
						|
			SecretKey:  "sk",
 | 
						|
			BucketName: "test",
 | 
						|
			Server:     "test.com",
 | 
						|
		},
 | 
						|
	}
 | 
						|
	handler.Client, _ = NewClient(&model.Policy{})
 | 
						|
	handler.Client.Credential.ExpiresIn = time.Now().Add(time.Duration(100) * time.Hour).Unix()
 | 
						|
 | 
						|
	// 失败
 | 
						|
	{
 | 
						|
		err := handler.CancelToken(context.Background(), &serializer.UploadSession{})
 | 
						|
		asserts.Error(err)
 | 
						|
	}
 | 
						|
}
 |