mirror of
				https://github.com/cloudreve/cloudreve.git
				synced 2025-11-01 00:57:15 +08:00 
			
		
		
		
	Feat: migration DB support custom upgrade scripts
This commit is contained in:
		| @ -2,6 +2,7 @@ package bootstrap | ||||
|  | ||||
| import ( | ||||
| 	model "github.com/cloudreve/Cloudreve/v3/models" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/models/scripts" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/aria2" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/auth" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/cache" | ||||
| @ -27,6 +28,12 @@ func Init(path string) { | ||||
| 		mode    string | ||||
| 		factory func() | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"both", | ||||
| 			func() { | ||||
| 				scripts.Init() | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"both", | ||||
| 			func() { | ||||
|  | ||||
| @ -2,14 +2,14 @@ package bootstrap | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/models/scripts" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/util" | ||||
| ) | ||||
|  | ||||
| func RunScript(name string) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
| 	if err := scripts.RunDBScript(name, ctx); err != nil { | ||||
| 	if err := invoker.RunDBScript(name, ctx); err != nil { | ||||
| 		util.Log().Error("数据库脚本执行失败: %s", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @ -1,12 +1,17 @@ | ||||
| package model | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/cache" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/conf" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/util" | ||||
| 	"github.com/fatih/color" | ||||
| 	"github.com/gofrs/uuid" | ||||
| 	"github.com/hashicorp/go-version" | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // 是否需要迁移 | ||||
| @ -54,6 +59,9 @@ func migration() { | ||||
| 	// 向设置数据表添加初始设置 | ||||
| 	addDefaultSettings() | ||||
|  | ||||
| 	// 执行数据库升级脚本 | ||||
| 	execUpgradeScripts() | ||||
|  | ||||
| 	util.Log().Info("数据库初始化结束") | ||||
|  | ||||
| } | ||||
| @ -290,3 +298,17 @@ func addDefaultNode() { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func execUpgradeScripts() { | ||||
| 	s := invoker.ListPrefix("UpgradeTo") | ||||
| 	versions := make([]*version.Version, len(s)) | ||||
| 	for i, raw := range s { | ||||
| 		v, _ := version.NewVersion(strings.TrimPrefix(raw, "UpgradeTo")) | ||||
| 		versions[i] = v | ||||
| 	} | ||||
| 	sort.Sort(version.Collection(versions)) | ||||
|  | ||||
| 	for i := 0; i < len(versions); i++ { | ||||
| 		invoker.RunDBScript("UpgradeTo"+versions[i].String(), context.Background()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										9
									
								
								models/scripts/init.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								models/scripts/init.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | ||||
| package scripts | ||||
|  | ||||
| import "github.com/cloudreve/Cloudreve/v3/models/scripts/invoker" | ||||
|  | ||||
| func Init() { | ||||
| 	invoker.Register("ResetAdminPassword", ResetAdminPassword(0)) | ||||
| 	invoker.Register("CalibrateUserStorage", UserStorageCalibration(0)) | ||||
| 	invoker.Register("UpgradeTo3.4.0", UpgradeTo340(0)) | ||||
| } | ||||
| @ -1,8 +1,9 @@ | ||||
| package scripts | ||||
| package invoker | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/util" | ||||
| ) | ||||
| 
 | ||||
| type DBScript interface { | ||||
| @ -13,6 +14,7 @@ var availableScripts = make(map[string]DBScript) | ||||
| 
 | ||||
| func RunDBScript(name string, ctx context.Context) error { | ||||
| 	if script, ok := availableScripts[name]; ok { | ||||
| 		util.Log().Info("开始执行数据库脚本 [%s]", name) | ||||
| 		script.Run(ctx) | ||||
| 		return nil | ||||
| 	} | ||||
| @ -20,6 +22,16 @@ func RunDBScript(name string, ctx context.Context) error { | ||||
| 	return fmt.Errorf("数据库脚本 [%s] 不存在", name) | ||||
| } | ||||
| 
 | ||||
| func register(name string, script DBScript) { | ||||
| func Register(name string, script DBScript) { | ||||
| 	availableScripts[name] = script | ||||
| } | ||||
| 
 | ||||
| func ListPrefix(prefix string) []string { | ||||
| 	var scripts []string | ||||
| 	for name := range availableScripts { | ||||
| 		if name[:len(prefix)] == prefix { | ||||
| 			scripts = append(scripts, name) | ||||
| 		} | ||||
| 	} | ||||
| 	return scripts | ||||
| } | ||||
| @ -1,4 +1,4 @@ | ||||
| package scripts | ||||
| package invoker | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| @ -35,7 +35,7 @@ func TestMain(m *testing.M) { | ||||
| 
 | ||||
| func TestRunDBScript(t *testing.T) { | ||||
| 	asserts := assert.New(t) | ||||
| 	register("test", TestScript(0)) | ||||
| 	Register("test", TestScript(0)) | ||||
| 
 | ||||
| 	// 不存在 | ||||
| 	{ | ||||
| @ -47,3 +47,14 @@ func TestRunDBScript(t *testing.T) { | ||||
| 		asserts.NoError(RunDBScript("test", context.Background())) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestListPrefix(t *testing.T) { | ||||
| 	asserts := assert.New(t) | ||||
| 	Register("U1", TestScript(0)) | ||||
| 	Register("U2", TestScript(0)) | ||||
| 	Register("U3", TestScript(0)) | ||||
| 	Register("P1", TestScript(0)) | ||||
| 
 | ||||
| 	res := ListPrefix("U") | ||||
| 	asserts.Len(res, 3) | ||||
| } | ||||
| @ -9,10 +9,6 @@ import ( | ||||
|  | ||||
| type ResetAdminPassword int | ||||
|  | ||||
| func init() { | ||||
| 	register("ResetAdminPassword", ResetAdminPassword(0)) | ||||
| } | ||||
|  | ||||
| // Run 运行脚本从社区版升级至 Pro 版 | ||||
| func (script ResetAdminPassword) Run(ctx context.Context) { | ||||
| 	// 查找用户 | ||||
|  | ||||
| @ -8,10 +8,6 @@ import ( | ||||
|  | ||||
| type UserStorageCalibration int | ||||
|  | ||||
| func init() { | ||||
| 	register("CalibrateUserStorage", UserStorageCalibration(0)) | ||||
| } | ||||
|  | ||||
| type storageResult struct { | ||||
| 	Total uint64 | ||||
| } | ||||
|  | ||||
| @ -2,11 +2,31 @@ package scripts | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"github.com/DATA-DOG/go-sqlmock" | ||||
| 	model "github.com/cloudreve/Cloudreve/v3/models" | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| var mock sqlmock.Sqlmock | ||||
| var mockDB *gorm.DB | ||||
|  | ||||
| // TestMain 初始化数据库Mock | ||||
| func TestMain(m *testing.M) { | ||||
| 	var db *sql.DB | ||||
| 	var err error | ||||
| 	db, mock, err = sqlmock.New() | ||||
| 	if err != nil { | ||||
| 		panic("An error was not expected when opening a stub database connection") | ||||
| 	} | ||||
| 	model.DB, _ = gorm.Open("mysql", db) | ||||
| 	mockDB = model.DB | ||||
| 	defer db.Close() | ||||
| 	m.Run() | ||||
| } | ||||
|  | ||||
| func TestUserStorageCalibration_Run(t *testing.T) { | ||||
| 	asserts := assert.New(t) | ||||
| 	script := UserStorageCalibration(0) | ||||
|  | ||||
							
								
								
									
										43
									
								
								models/scripts/upgrade.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								models/scripts/upgrade.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | ||||
| package scripts | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	model "github.com/cloudreve/Cloudreve/v3/models" | ||||
| 	"github.com/cloudreve/Cloudreve/v3/pkg/util" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| type UpgradeTo340 int | ||||
|  | ||||
| // Run upgrade from older version to 3.4.0 | ||||
| func (script UpgradeTo340) Run(ctx context.Context) { | ||||
| 	// 取回老版本 aria2 设定 | ||||
| 	old := model.GetSettingByType([]string{"aria2"}) | ||||
| 	if len(old) == 0 { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 写入到新版本的节点设定 | ||||
| 	n, err := model.GetNodeByID(1) | ||||
| 	if err != nil { | ||||
| 		util.Log().Error("找不到主机节点, %s", err) | ||||
| 	} | ||||
|  | ||||
| 	n.Aria2Enabled = old["aria2_rpcurl"] != "" | ||||
| 	n.Aria2OptionsSerialized.Options = old["aria2_options"] | ||||
| 	n.Aria2OptionsSerialized.Server = old["aria2_rpcurl"] | ||||
|  | ||||
| 	interval, err := strconv.Atoi(old["aria2_interval"]) | ||||
| 	if err != nil { | ||||
| 		interval = 10 | ||||
| 	} | ||||
| 	n.Aria2OptionsSerialized.Interval = interval | ||||
| 	n.Aria2OptionsSerialized.TempPath = old["aria2_temp_path"] | ||||
| 	n.Aria2OptionsSerialized.Token = old["aria2_token"] | ||||
| 	if err := model.DB.Save(&n).Error; err != nil { | ||||
| 		util.Log().Error("无法保存主机节点 Aria2 配置信息, %s", err) | ||||
| 	} else { | ||||
| 		model.DB.Where("type = ?", "aria2").Delete(model.Setting{}) | ||||
| 		util.Log().Info("Aria2 配置信息已成功迁移至 3.4.0+ 版本的模式") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										66
									
								
								models/scripts/upgrade_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								models/scripts/upgrade_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | ||||
| package scripts | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"github.com/DATA-DOG/go-sqlmock" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestUpgradeTo340_Run(t *testing.T) { | ||||
| 	a := assert.New(t) | ||||
| 	script := UpgradeTo340(0) | ||||
|  | ||||
| 	// skip | ||||
| 	{ | ||||
| 		mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"})) | ||||
| 		script.Run(context.Background()) | ||||
| 		a.NoError(mock.ExpectationsWereMet()) | ||||
| 	} | ||||
|  | ||||
| 	// node not found | ||||
| 	{ | ||||
| 		mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow("1")) | ||||
| 		mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"})) | ||||
| 		script.Run(context.Background()) | ||||
| 		a.NoError(mock.ExpectationsWereMet()) | ||||
| 	} | ||||
|  | ||||
| 	// success | ||||
| 	{ | ||||
| 		mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). | ||||
| 			AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). | ||||
| 			AddRow("aria2_interval", "expected_aria2_interval"). | ||||
| 			AddRow("aria2_temp_path", "expected_aria2_temp_path"). | ||||
| 			AddRow("aria2_token", "expected_aria2_token"). | ||||
| 			AddRow("aria2_options", "{}")) | ||||
|  | ||||
| 		mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) | ||||
| 		mock.ExpectBegin() | ||||
| 		mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) | ||||
| 		mock.ExpectCommit() | ||||
| 		mock.ExpectBegin() | ||||
| 		mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) | ||||
| 		mock.ExpectCommit() | ||||
| 		script.Run(context.Background()) | ||||
| 		a.NoError(mock.ExpectationsWereMet()) | ||||
| 	} | ||||
|  | ||||
| 	// failed | ||||
| 	{ | ||||
| 		mock.ExpectQuery("SELECT(.+)settings").WillReturnRows(sqlmock.NewRows([]string{"name", "value"}). | ||||
| 			AddRow("aria2_rpcurl", "expected_aria2_rpcurl"). | ||||
| 			AddRow("aria2_interval", "expected_aria2_interval"). | ||||
| 			AddRow("aria2_temp_path", "expected_aria2_temp_path"). | ||||
| 			AddRow("aria2_token", "expected_aria2_token"). | ||||
| 			AddRow("aria2_options", "{}")) | ||||
|  | ||||
| 		mock.ExpectQuery("SELECT(.+)nodes").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) | ||||
| 		mock.ExpectBegin() | ||||
| 		mock.ExpectExec("UPDATE(.+)").WillReturnError(errors.New("error")) | ||||
| 		mock.ExpectRollback() | ||||
| 		script.Run(context.Background()) | ||||
| 		a.NoError(mock.ExpectationsWereMet()) | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 HFO4
					HFO4