mirror of
https://github.com/cloudreve/cloudreve.git
synced 2025-11-03 04:15:59 +08:00
Init V4 community edition (#2265)
* Init V4 community edition * Init V4 community edition
This commit is contained in:
401
inventory/share.go
Normal file
401
inventory/share.go
Normal file
@ -0,0 +1,401 @@
|
||||
package inventory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"entgo.io/ent/dialect/sql"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/file"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/predicate"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/share"
|
||||
"github.com/cloudreve/Cloudreve/v4/ent/user"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v4/pkg/hashid"
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
type (
|
||||
// Ctx keys for eager loading options.
|
||||
LoadShareFile struct{}
|
||||
LoadShareUser struct{}
|
||||
)
|
||||
|
||||
var (
|
||||
ErrShareLinkExpired = fmt.Errorf("share link expired")
|
||||
ErrOwnerInactive = fmt.Errorf("owner is inactive")
|
||||
ErrSourceFileInvalid = fmt.Errorf("source file is deleted")
|
||||
)
|
||||
|
||||
type (
|
||||
ShareClient interface {
|
||||
TxOperator
|
||||
// GetByID returns the share with given id.
|
||||
GetByID(ctx context.Context, id int) (*ent.Share, error)
|
||||
// GetByIDUser returns the share with given id and user id.
|
||||
GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error)
|
||||
// GetByHashID returns the share with given hash id.
|
||||
GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error)
|
||||
// Upsert creates or update a new share record.
|
||||
Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error)
|
||||
// Viewed increase the view count of the share.
|
||||
Viewed(ctx context.Context, share *ent.Share) error
|
||||
// Downloaded increase the download count of the share.
|
||||
Downloaded(ctx context.Context, share *ent.Share) error
|
||||
// Delete deletes the share.
|
||||
Delete(ctx context.Context, shareId int) error
|
||||
// List returns a list of shares with the given args.
|
||||
List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error)
|
||||
// CountByTimeRange counts the number of shares created in the given time range.
|
||||
CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error)
|
||||
// DeleteBatch deletes the shares with the given ids.
|
||||
DeleteBatch(ctx context.Context, shareIds []int) error
|
||||
}
|
||||
|
||||
CreateShareParams struct {
|
||||
Existed *ent.Share
|
||||
Password string
|
||||
RemainDownloads int
|
||||
Expires *time.Time
|
||||
OwnerID int
|
||||
FileID int
|
||||
}
|
||||
|
||||
ListShareArgs struct {
|
||||
*PaginationArgs
|
||||
UserID int
|
||||
FileID int
|
||||
PublicOnly bool
|
||||
}
|
||||
ListShareResult struct {
|
||||
*PaginationResults
|
||||
Shares []*ent.Share
|
||||
}
|
||||
)
|
||||
|
||||
func NewShareClient(client *ent.Client, dbType conf.DBType, hasher hashid.Encoder) ShareClient {
|
||||
return &shareClient{
|
||||
client: client,
|
||||
hasher: hasher,
|
||||
maxSQlParam: sqlParamLimit(dbType),
|
||||
}
|
||||
}
|
||||
|
||||
type shareClient struct {
|
||||
maxSQlParam int
|
||||
client *ent.Client
|
||||
hasher hashid.Encoder
|
||||
}
|
||||
|
||||
func (c *shareClient) SetClient(newClient *ent.Client) TxOperator {
|
||||
return &shareClient{client: newClient, hasher: c.hasher, maxSQlParam: c.maxSQlParam}
|
||||
}
|
||||
|
||||
func (c *shareClient) GetClient() *ent.Client {
|
||||
return c.client
|
||||
}
|
||||
|
||||
func (c *shareClient) CountByTimeRange(ctx context.Context, start, end *time.Time) (int, error) {
|
||||
if start == nil || end == nil {
|
||||
return c.client.Share.Query().Count(ctx)
|
||||
}
|
||||
|
||||
return c.client.Share.Query().Where(share.CreatedAtGTE(*start), share.CreatedAtLT(*end)).Count(ctx)
|
||||
}
|
||||
|
||||
func (c *shareClient) Upsert(ctx context.Context, params *CreateShareParams) (*ent.Share, error) {
|
||||
if params.Existed != nil {
|
||||
createQuery := c.client.Share.
|
||||
UpdateOne(params.Existed)
|
||||
if params.RemainDownloads > 0 {
|
||||
createQuery.SetRemainDownloads(params.RemainDownloads)
|
||||
} else {
|
||||
createQuery.ClearRemainDownloads()
|
||||
}
|
||||
if params.Expires != nil {
|
||||
createQuery.SetNillableExpires(params.Expires)
|
||||
} else {
|
||||
createQuery.ClearExpires()
|
||||
}
|
||||
|
||||
return createQuery.Save(ctx)
|
||||
}
|
||||
|
||||
query := c.client.Share.
|
||||
Create().
|
||||
SetUserID(params.OwnerID).
|
||||
SetFileID(params.FileID)
|
||||
if params.Password != "" {
|
||||
query.SetPassword(params.Password)
|
||||
}
|
||||
if params.RemainDownloads > 0 {
|
||||
query.SetRemainDownloads(params.RemainDownloads)
|
||||
}
|
||||
if params.Expires != nil {
|
||||
query.SetNillableExpires(params.Expires)
|
||||
}
|
||||
|
||||
return query.Save(ctx)
|
||||
}
|
||||
|
||||
func (c *shareClient) GetByHashID(ctx context.Context, idRaw string) (*ent.Share, error) {
|
||||
id, err := c.hasher.Decode(idRaw, hashid.ShareID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode hash id %q: %w", idRaw, err)
|
||||
}
|
||||
|
||||
return c.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (c *shareClient) GetByID(ctx context.Context, id int) (*ent.Share, error) {
|
||||
s, err := withShareEagerLoading(ctx, c.client.Share.Query().Where(share.ID(id))).First(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) GetByIDUser(ctx context.Context, id, uid int) (*ent.Share, error) {
|
||||
s, err := withShareEagerLoading(ctx, c.client.Share.Query().
|
||||
Where(share.ID(id))).
|
||||
Where(share.HasUserWith(user.ID(uid))).First(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query share %d: %w", id, err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) DeleteBatch(ctx context.Context, shareIds []int) error {
|
||||
_, err := c.client.Share.Delete().Where(share.IDIn(shareIds...)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *shareClient) Delete(ctx context.Context, shareId int) error {
|
||||
return c.client.Share.DeleteOneID(shareId).Exec(ctx)
|
||||
}
|
||||
|
||||
// Viewed increments the view count of the share.
|
||||
func (c *shareClient) Viewed(ctx context.Context, share *ent.Share) error {
|
||||
_, err := c.client.Share.UpdateOneID(share.ID).AddViews(1).Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Downloaded increments the download count of the share.
|
||||
func (c *shareClient) Downloaded(ctx context.Context, share *ent.Share) error {
|
||||
stm := c.client.Share.
|
||||
UpdateOneID(share.ID).
|
||||
AddDownloads(1)
|
||||
if share.RemainDownloads != nil && *share.RemainDownloads >= 0 {
|
||||
stm.AddRemainDownloads(-1)
|
||||
}
|
||||
_, err := stm.Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func IsValidShare(share *ent.Share) error {
|
||||
// Check if share is expired
|
||||
if err := IsShareExpired(share); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check owner status
|
||||
owner, err := share.Edges.UserOrErr()
|
||||
if err != nil || owner.Status != user.StatusActive {
|
||||
// Owner already deleted, or not active.
|
||||
return ErrOwnerInactive
|
||||
}
|
||||
|
||||
// Check source file status
|
||||
file, err := share.Edges.FileOrErr()
|
||||
if err != nil || file.FileChildren == 0 || file.OwnerID != owner.ID {
|
||||
// Source file already deleted
|
||||
return ErrSourceFileInvalid
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsShareExpired(share *ent.Share) error {
|
||||
// Check if share is expired
|
||||
if (share.Expires != nil && share.Expires.Before(time.Now())) ||
|
||||
(share.RemainDownloads != nil && *share.RemainDownloads <= 0) {
|
||||
return ErrShareLinkExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *shareClient) List(ctx context.Context, args *ListShareArgs) (*ListShareResult, error) {
|
||||
rawQuery := c.listQuery(args)
|
||||
query := withShareEagerLoading(ctx, rawQuery)
|
||||
|
||||
var (
|
||||
shares []*ent.Share
|
||||
err error
|
||||
paginationRes *PaginationResults
|
||||
)
|
||||
if args.UseCursorPagination {
|
||||
shares, paginationRes, err = c.cursorPagination(ctx, query, args, 10)
|
||||
} else {
|
||||
shares, paginationRes, err = c.offsetPagination(ctx, query, args, 10)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed with paginiation: %w", err)
|
||||
}
|
||||
|
||||
return &ListShareResult{
|
||||
Shares: shares,
|
||||
PaginationResults: paginationRes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) cursorPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getShareOrderOption(args)...)
|
||||
|
||||
var (
|
||||
pageToken *PageToken
|
||||
err error
|
||||
)
|
||||
if args.PageToken != "" {
|
||||
pageToken, err = pageTokenFromString(args.PageToken, c.hasher, hashid.ShareID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid page token %q: %w", args.PageToken, err)
|
||||
}
|
||||
}
|
||||
queryPaged := getShareCursorQuery(args, pageToken, query)
|
||||
|
||||
// Use page size + 1 to determine if there are more items to come
|
||||
queryPaged.Limit(pageSize + 1)
|
||||
|
||||
logs, err := queryPaged.
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// More items to come
|
||||
nextTokenStr := ""
|
||||
if len(logs) > pageSize {
|
||||
lastItem := logs[len(logs)-2]
|
||||
nextToken, err := getShareNextPageToken(c.hasher, lastItem, args)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to generate next page token: %w", err)
|
||||
}
|
||||
|
||||
nextTokenStr = nextToken
|
||||
}
|
||||
|
||||
return lo.Subset(logs, 0, uint(pageSize)), &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
NextPageToken: nextTokenStr,
|
||||
IsCursor: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) offsetPagination(ctx context.Context, query *ent.ShareQuery, args *ListShareArgs, paramMargin int) ([]*ent.Share, *PaginationResults, error) {
|
||||
pageSize := capPageSize(c.maxSQlParam, args.PageSize, paramMargin)
|
||||
query.Order(getShareOrderOption(args)...)
|
||||
|
||||
total, err := query.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
logs, err := query.Limit(pageSize).Offset(args.Page * args.PageSize).All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return logs, &PaginationResults{
|
||||
PageSize: pageSize,
|
||||
TotalItems: total,
|
||||
Page: args.Page,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *shareClient) listQuery(args *ListShareArgs) *ent.ShareQuery {
|
||||
query := c.client.Share.Query()
|
||||
if args.UserID > 0 {
|
||||
query.Where(share.HasUserWith(user.ID(args.UserID)))
|
||||
}
|
||||
|
||||
if args.PublicOnly {
|
||||
query.Where(share.PasswordIsNil())
|
||||
}
|
||||
|
||||
if args.FileID > 0 {
|
||||
query.Where(share.HasFileWith(file.ID(args.FileID)))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// getShareNextPageToken returns the next page token for the given last share.
|
||||
func getShareNextPageToken(hasher hashid.Encoder, last *ent.Share, args *ListShareArgs) (string, error) {
|
||||
token := &PageToken{
|
||||
ID: last.ID,
|
||||
}
|
||||
|
||||
return token.Encode(hasher, hashid.EncodeShareID)
|
||||
}
|
||||
|
||||
func getShareCursorQuery(args *ListShareArgs, token *PageToken, query *ent.ShareQuery) *ent.ShareQuery {
|
||||
o := &sql.OrderTermOptions{}
|
||||
getOrderTerm(args.Order)(o)
|
||||
|
||||
predicates, ok := shareCursorQuery[args.OrderBy]
|
||||
if !ok {
|
||||
predicates = shareCursorQuery[share.FieldID]
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
query.Where(predicates[o.Desc](token))
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
var shareCursorQuery = map[string]map[bool]func(token *PageToken) predicate.Share{
|
||||
share.FieldID: {
|
||||
true: func(token *PageToken) predicate.Share {
|
||||
return share.IDLT(token.ID)
|
||||
},
|
||||
false: func(token *PageToken) predicate.Share {
|
||||
return share.IDGT(token.ID)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func getShareOrderOption(args *ListShareArgs) []share.OrderOption {
|
||||
orderTerm := getOrderTerm(args.Order)
|
||||
switch args.OrderBy {
|
||||
case share.FieldViews:
|
||||
return []share.OrderOption{share.ByViews(orderTerm), share.ByID(orderTerm)}
|
||||
case share.FieldDownloads:
|
||||
return []share.OrderOption{share.ByDownloads(orderTerm), share.ByID(orderTerm)}
|
||||
case share.FieldRemainDownloads:
|
||||
return []share.OrderOption{share.ByRemainDownloads(orderTerm), share.ByID(orderTerm)}
|
||||
default:
|
||||
return []share.OrderOption{share.ByID(orderTerm)}
|
||||
}
|
||||
}
|
||||
|
||||
func withShareEagerLoading(ctx context.Context, q *ent.ShareQuery) *ent.ShareQuery {
|
||||
if v, ok := ctx.Value(LoadShareFile{}).(bool); ok && v {
|
||||
q.WithFile(func(q *ent.FileQuery) {
|
||||
withFileEagerLoading(ctx, q)
|
||||
})
|
||||
}
|
||||
if v, ok := ctx.Value(LoadShareUser{}).(bool); ok && v {
|
||||
q.WithUser(func(q *ent.UserQuery) {
|
||||
withUserEagerLoading(ctx, q)
|
||||
})
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
Reference in New Issue
Block a user