mirror of
https://github.com/cloudreve/cloudreve.git
synced 2025-11-04 04:47:24 +08:00
Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
314
inventory/task.go
Normal file
314
inventory/task.go
Normal file
@ -0,0 +1,314 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/task"
|
||||
"github.com/cloudreve/Cloudreve/v4/inventory/types"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ctx keys for eager loading options.
|
||||
LoadTaskUser struct{}
|
||||
|
||||
TaskArgs struct {
|
||||
Status task.Status
|
||||
Type string
|
||||
PublicState *types.TaskPublicState
|
||||
PrivateState string
|
||||
OwnerID int
|
||||
CorrelationID uuid.UUID
|
||||
}
|
||||
)
|
||||
|
||||
type TaskClient interface {
|
||||
TxOperator
|
||||
// New creates a new task with the given args.
|
||||
New(ctx context.Context, task *TaskArgs) (*ent.Task, error)
|
||||
// Update updates the task with the given args.
|
||||
Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error)
|
||||
// GetPendingTasks returns all pending tasks of given type.
|
||||
GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error)
|
||||
// GetTaskByID returns the task with the given ID.
|
||||
GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error)
|
||||
// SetCompleteByID sets the task with the given ID to complete.
|
||||
SetCompleteByID(ctx context.Context, taskID int) error
|
||||
// List returns a list of tasks with the given args.
|
||||
List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error)
|
||||
// DeleteByIDs deletes the tasks with the given IDs.
|
||||
DeleteByIDs(ctx context.Context, ids ...int) error
|
||||
}
|
||||
|
||||
type (
|
||||
ListTaskArgs struct {
|
||||
*PaginationArgs
|
||||
Types []string
|
||||
Status []task.Status
|
||||
UserID int
|
||||
CorrelationID *uuid.UUID
|
||||
}
|
||||
|
||||
ListTaskResult struct {
|
||||
*PaginationResults
|
||||
Tasks []*ent.Task
|
||||
}
|
||||
)
|
||||
|
||||
func NewTaskClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) TaskClient {
|
||||
return &taskClient{client: client, maxSQlParam: sqlParamLimit(dbType), hasher: hasher}
|
||||
}
|
||||
|
||||
type taskClient struct {
|
||||
maxSQlParam int
|
||||
hasher hashid.Encoder
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func (c *taskClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &taskClient{client: newClient, maxSQlParam: c.maxSQlParam, hasher: c.hasher}
|
||||
}
|
||||
|
||||
func (c *taskClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *taskClient) New(ctx context.Context, task *TaskArgs) (*ent.Task, error) {
|
||||
stm := c.client.Task.
|
||||
Create().
|
||||
SetType(task.Type).
|
||||
SetPublicState(task.PublicState)
|
||||
if task.PrivateState != "" {
|
||||
stm.SetPrivateState(task.PrivateState)
|
||||
}
|
||||
|
||||
if task.OwnerID != 0 {
|
||||
stm.SetUserID(task.OwnerID)
|
||||
}
|
||||
|
||||
if task.Status != "" {
|
||||
stm.SetStatus(task.Status)
|
||||
}
|
||||
|
||||
if task.CorrelationID.String() != uuid.Nil.String() {
|
||||
stm.SetCorrelationID(task.CorrelationID)
|
||||
}
|
||||
|
||||
newTask, err := stm.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create task: %w", err)
|
||||
}
|
||||
|
||||
return newTask, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) DeleteByIDs(ctx context.Context, ids ...int) error {
|
||||
_, err := c.client.Task.Delete().Where(task.IDIn(ids...)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *taskClient) Update(ctx context.Context, task *ent.Task, args *TaskArgs) (*ent.Task, error) {
|
||||
stm := c.client.Task.UpdateOne(task).
|
||||
SetPublicState(args.PublicState)
|
||||
|
||||
task.PublicState = args.PublicState
|
||||
|
||||
if task.PrivateState != "" {
|
||||
stm.SetPrivateState(task.PrivateState)
|
||||
task.PrivateState = args.PrivateState
|
||||
}
|
||||
|
||||
if task.Status != "" {
|
||||
stm.SetStatus(args.Status)
|
||||
task.Status = args.Status
|
||||
}
|
||||
|
||||
if err := stm.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to create task: %w", err)
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) GetPendingTasks(ctx context.Context, taskType ...string) ([]*ent.Task, error) {
|
||||
tasks, err := withTaskEagerLoading(ctx, c.client.Task.Query()).
|
||||
Where(task.StatusIn(task.StatusProcessing, task.StatusQueued, task.StatusSuspending)).
|
||||
Where(task.TypeIn(taskType...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Anonymous user is not loaded by default, so we need to load it manually.
|
||||
userClient := NewUserClient(c.client)
|
||||
anonymous, err := userClient.AnonymousUser(ctx)
|
||||
for _, t := range tasks {
|
||||
if t.UserTasks == 0 {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.SetUser(anonymous)
|
||||
}
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) GetTaskByID(ctx context.Context, taskID int) (*ent.Task, error) {
|
||||
return withTaskEagerLoading(ctx, c.client.Task.Query()).
|
||||
Where(task.ID(taskID)).
|
||||
First(ctx)
|
||||
}
|
||||
|
||||
func (c *taskClient) SetCompleteByID(ctx context.Context, taskID int) error {
|
||||
_, err := c.client.Task.UpdateOneID(taskID).
|
||||
SetStatus(task.StatusCompleted).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *taskClient) List(ctx context.Context, args *ListTaskArgs) (*ListTaskResult, error) {
|
||||
q := c.client.Task.Query()
|
||||
if args.UserID != 0 {
|
||||
q.Where(task.UserTasks(args.UserID))
|
||||
}
|
||||
|
||||
if args.Types != nil {
|
||||
q.Where(task.TypeIn(args.Types...))
|
||||
}
|
||||
|
||||
if args.Status != nil {
|
||||
q.Where(task.StatusIn(args.Status...))
|
||||
}
|
||||
|
||||
if args.CorrelationID != nil {
|
||||
q.Where(task.CorrelationID(*args.CorrelationID))
|
||||
}
|
||||
|
||||
q = withTaskEagerLoading(ctx, q)
|
||||
var (
|
||||
tasks []*ent.Task
|
||||
err error
|
||||
paginationRes *PaginationResults
|
||||
)
|
||||
|
||||
if args.UseCursorPagination {
|
||||
tasks, paginationRes, err = c.cursorPagination(ctx, q, args, 1)
|
||||
} else {
|
||||
tasks, paginationRes, err = c.offsetPagination(ctx, q, args, 1)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed with paginiation: %w", err)
|
||||
}
|
||||
|
||||
return &ListTaskResult{
|
||||
Tasks: tasks,
|
||||
PaginationResults: paginationRes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) cursorPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(task.ByID(sql.OrderDesc()))
|
||||
|
||||
var (
|
||||
pageToken *PageToken
|
||||
err error
|
||||
queryPaged = query
|
||||
)
|
||||
if args.PageToken != "" {
|
||||
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.TaskID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
||||
}
|
||||
|
||||
queryPaged = query.Where(task.IDLT(pageToken.ID))
|
||||
}
|
||||
|
||||
// Use page size + 1 to determine if there are more items to come
|
||||
queryPaged.Limit(pageSize + 1)
|
||||
|
||||
tasks, err := queryPaged.
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// More items to come
|
||||
nextTokenStr := ""
|
||||
if len(tasks) > pageSize {
|
||||
lastItem := tasks[len(tasks)-2]
|
||||
nextToken, err := getTaskNextPageToken(c.hasher, lastItem)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
||||
}
|
||||
|
||||
nextTokenStr = nextToken
|
||||
}
|
||||
|
||||
return lo.Subset(tasks, 0, uint(pageSize)), &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
NextPageToken: nextTokenStr,
|
||||
IsCursor: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *taskClient) offsetPagination(ctx context.Context, query *ent.TaskQuery, args *ListTaskArgs, paramMargin int) ([]*ent.Task, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getTaskOrderOption(args)...)
|
||||
|
||||
// Count total items
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
logs, err := query.
|
||||
Limit(pageSize).
|
||||
Offset(args.Page * args.PageSize).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return logs, &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func getTaskOrderOption(args *ListTaskArgs) []task.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
default:
|
||||
return []task.OrderOption{task.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
// getTaskNextPageToken returns the next page token for the given last task.
|
||||
func getTaskNextPageToken(hasher hashid.Encoder, last *ent.Task) (string, error) {
|
||||
token := &PageToken{
|
||||
ID: last.ID,
|
||||
}
|
||||
|
||||
return token.Encode(hasher, hashid.EncodeTaskID)
|
||||
}
|
||||
|
||||
func withTaskEagerLoading(ctx context.Context, q *ent.TaskQuery) *ent.TaskQuery {
|
||||
if v, ok := ctx.Value(LoadTaskUser{}).(bool); ok && v {
|
||||
q.WithUser(func(q *ent.UserQuery) {
|
||||
withUserEagerLoading(ctx, q)
|
||||
})
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
Reference in New Issue
Block a user