mirror of
https://github.com/cloudreve/cloudreve.git
synced 2025-10-30 08:07:01 +08:00
* Feat: retrieve nodes from data table * Feat: master node ping slave node in REST API * Feat: master send scheduled ping request * Feat: inactive nodes recover loop * Modify: remove database operations from aria2 RPC caller implementation * Feat: init aria2 client in master node * Feat: Round Robin load balancer * Feat: create and monitor aria2 task in master node * Feat: salve receive and handle heartbeat * Fix: Node ID will be 0 in download record generated in older version * Feat: sign request headers with all `X-` prefix * Feat: API call to slave node will carry meta data in headers * Feat: call slave aria2 rpc method from master * Feat: get slave aria2 task status Feat: encode slave response data using gob * Feat: aria2 callback to master node / cancel or select task to slave node * Fix: use dummy aria2 client when caller initialize failed in master node * Feat: slave aria2 status event callback / salve RPC auth * Feat: prototype for slave driven filesystem * Feat: retry for init aria2 client in master node * Feat: init request client with global options * Feat: slave receive async task from master * Fix: competition write in request header * Refactor: dependency initialize order * Feat: generic message queue implementation * Feat: message queue implementation * Feat: master waiting slave transfer result * Feat: slave transfer file in stateless policy * Feat: slave transfer file in slave policy * Feat: slave transfer file in local policy * Feat: slave transfer file in OneDrive policy * Fix: failed to initialize update checker http client * Feat: list slave nodes for dashboard * Feat: test aria2 rpc connection in slave * Feat: add and save node * Feat: add and delete node in node pool * Fix: temp file cannot be removed when aria2 task fails * Fix: delete node in admin panel * Feat: edit node and get node info * Modify: delete unused settings
190 lines
5.0 KiB
Go
190 lines
5.0 KiB
Go
package onedrive
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"io/ioutil"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||
"github.com/cloudreve/Cloudreve/v3/pkg/slave"
|
||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||
)
|
||
|
||
// Error 实现error接口
|
||
func (err OAuthError) Error() string {
|
||
return err.ErrorDescription
|
||
}
|
||
|
||
// OAuthURL 获取OAuth认证页面URL
|
||
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
|
||
query := url.Values{
|
||
"client_id": {client.ClientID},
|
||
"scope": {strings.Join(scope, " ")},
|
||
"response_type": {"code"},
|
||
"redirect_uri": {client.Redirect},
|
||
}
|
||
client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode()
|
||
return client.Endpoints.OAuthEndpoints.authorize.String()
|
||
}
|
||
|
||
// getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址
|
||
func (client *Client) getOAuthEndpoint() *oauthEndpoint {
|
||
base, err := url.Parse(client.Endpoints.OAuthURL)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
var (
|
||
token *url.URL
|
||
authorize *url.URL
|
||
)
|
||
switch base.Host {
|
||
case "login.live.com":
|
||
token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
|
||
authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
|
||
case "login.chinacloudapi.cn":
|
||
client.Endpoints.isInChina = true
|
||
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
|
||
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
|
||
default:
|
||
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
|
||
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
|
||
}
|
||
|
||
return &oauthEndpoint{
|
||
token: *token,
|
||
authorize: *authorize,
|
||
}
|
||
}
|
||
|
||
// ObtainToken 通过code或refresh_token兑换token
|
||
func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
|
||
options := newDefaultOption()
|
||
for _, o := range opts {
|
||
o.apply(options)
|
||
}
|
||
|
||
body := url.Values{
|
||
"client_id": {client.ClientID},
|
||
"redirect_uri": {client.Redirect},
|
||
"client_secret": {client.ClientSecret},
|
||
}
|
||
if options.code != "" {
|
||
body.Add("grant_type", "authorization_code")
|
||
body.Add("code", options.code)
|
||
} else {
|
||
body.Add("grant_type", "refresh_token")
|
||
body.Add("refresh_token", options.refreshToken)
|
||
}
|
||
strBody := body.Encode()
|
||
|
||
res := client.Request.Request(
|
||
"POST",
|
||
client.Endpoints.OAuthEndpoints.token.String(),
|
||
ioutil.NopCloser(strings.NewReader(strBody)),
|
||
request.WithHeader(http.Header{
|
||
"Content-Type": {"application/x-www-form-urlencoded"}},
|
||
),
|
||
request.WithContentLength(int64(len(strBody))),
|
||
)
|
||
if res.Err != nil {
|
||
return nil, res.Err
|
||
}
|
||
|
||
respBody, err := res.GetResponse()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var (
|
||
errResp OAuthError
|
||
credential Credential
|
||
decodeErr error
|
||
)
|
||
|
||
if res.Response.StatusCode != 200 {
|
||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||
} else {
|
||
decodeErr = json.Unmarshal([]byte(respBody), &credential)
|
||
}
|
||
if decodeErr != nil {
|
||
return nil, decodeErr
|
||
}
|
||
|
||
if errResp.ErrorType != "" {
|
||
return nil, errResp
|
||
}
|
||
|
||
return &credential, nil
|
||
|
||
}
|
||
|
||
// UpdateCredential 更新凭证,并检查有效期
|
||
func (client *Client) UpdateCredential(ctx context.Context) error {
|
||
if conf.SystemConfig.Mode == "slave" {
|
||
return client.fetchCredentialFromMaster(ctx)
|
||
}
|
||
|
||
GlobalMutex.Lock(client.Policy.ID)
|
||
defer GlobalMutex.Unlock(client.Policy.ID)
|
||
|
||
// 如果已存在凭证
|
||
if client.Credential != nil && client.Credential.AccessToken != "" {
|
||
// 检查已有凭证是否过期
|
||
if client.Credential.ExpiresIn > time.Now().Unix() {
|
||
// 未过期,不要更新
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 尝试从缓存中获取凭证
|
||
if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok {
|
||
credential := cacheCredential.(Credential)
|
||
if credential.ExpiresIn > time.Now().Unix() {
|
||
client.Credential = &credential
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 获取新的凭证
|
||
if client.Credential == nil || client.Credential.RefreshToken == "" {
|
||
// 无有效的RefreshToken
|
||
util.Log().Error("上传策略[%s]凭证刷新失败,请重新授权OneDrive账号", client.Policy.Name)
|
||
return ErrInvalidRefreshToken
|
||
}
|
||
|
||
credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 更新有效期为绝对时间戳
|
||
expires := credential.ExpiresIn - 60
|
||
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
|
||
client.Credential = credential
|
||
|
||
// 更新存储策略的 RefreshToken
|
||
client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
|
||
|
||
// 更新缓存
|
||
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateCredential 更新凭证,并检查有效期
|
||
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
|
||
res, err := slave.DefaultController.GetOneDriveToken(client.Policy.MasterID, client.Policy.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
client.Credential = &Credential{AccessToken: res}
|
||
return nil
|
||
}
|