mirror of
https://github.com/grafana/grafana.git
synced 2025-08-02 00:01:48 +08:00
Add oauth pass-thru option for datasources
This commit is contained in:
@ -165,6 +165,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) {
|
||||
|
||||
extUser := &m.ExternalUserInfo{
|
||||
AuthModule: "oauth_" + name,
|
||||
OAuthToken: token,
|
||||
AuthId: userInfo.Id,
|
||||
Name: userInfo.Name,
|
||||
Login: userInfo.Login,
|
||||
|
@ -14,11 +14,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/log"
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/social"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
@ -215,6 +218,44 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) {
|
||||
if proxy.route != nil {
|
||||
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
|
||||
}
|
||||
|
||||
if proxy.ds.JsonData != nil && proxy.ds.JsonData.Get("oauthPassThru").MustBool() {
|
||||
provider := proxy.ds.JsonData.Get("oauthPassThruProvider").MustString()
|
||||
connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
|
||||
if !ok {
|
||||
logger.Error("Failed to find oauth provider with given name", "provider", provider)
|
||||
}
|
||||
cmd := &m.GetAuthInfoQuery{UserId: proxy.ctx.UserId, AuthModule: provider}
|
||||
if err := bus.Dispatch(cmd); err != nil {
|
||||
logger.Error("Error feching oauth information for user", "error", err)
|
||||
}
|
||||
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(proxy.ctx.Req.Context(), &oauth2.Token{
|
||||
AccessToken: cmd.Result.OAuthAccessToken,
|
||||
Expiry: cmd.Result.OAuthExpiry,
|
||||
RefreshToken: cmd.Result.OAuthRefreshToken,
|
||||
TokenType: cmd.Result.OAuthTokenType,
|
||||
}).Token()
|
||||
if err != nil {
|
||||
logger.Error("Failed to retrieve access token from oauth provider", "provider", cmd.Result.AuthModule)
|
||||
}
|
||||
|
||||
// If the tokens are not the same, update the entry in the DB
|
||||
if token.AccessToken != cmd.Result.OAuthAccessToken {
|
||||
cmd2 := &m.UpdateAuthInfoCommand{
|
||||
UserId: cmd.Result.Id,
|
||||
AuthModule: cmd.Result.AuthModule,
|
||||
AuthId: cmd.Result.AuthId,
|
||||
OAuthToken: token,
|
||||
}
|
||||
if err := bus.Dispatch(cmd2); err != nil {
|
||||
logger.Error("Failed to update access token during token refresh", "error", err)
|
||||
}
|
||||
}
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,13 +9,16 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
macaron "gopkg.in/macaron.v1"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/log"
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/social"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
@ -388,6 +391,55 @@ func TestDSRouteRule(t *testing.T) {
|
||||
So(req.Header.Get("X-Canary"), ShouldEqual, "stillthere")
|
||||
})
|
||||
})
|
||||
|
||||
Convey("When proxying a datasource that has oauth token pass-thru enabled", func() {
|
||||
social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{
|
||||
SocialBase: &social.SocialBase{
|
||||
Config: &oauth2.Config{},
|
||||
},
|
||||
}
|
||||
|
||||
bus.AddHandler("test", func(query *m.GetAuthInfoQuery) error {
|
||||
query.Result = &m.UserAuth{
|
||||
Id: 1,
|
||||
UserId: 1,
|
||||
AuthModule: "generic_oauth",
|
||||
OAuthAccessToken: "testtoken",
|
||||
OAuthRefreshToken: "testrefreshtoken",
|
||||
OAuthTokenType: "Bearer",
|
||||
OAuthExpiry: time.Now().AddDate(0, 0, 1),
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
plugin := &plugins.DataSourcePlugin{}
|
||||
ds := &m.DataSource{
|
||||
Type: "custom-datasource",
|
||||
Url: "http://host/root/",
|
||||
JsonData: simplejson.NewFromAny(map[string]interface{}{
|
||||
"oauthPassThru": true,
|
||||
"oauthPassThruProvider": "oauth_generic_oauth",
|
||||
}),
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://localhost/asd", nil)
|
||||
ctx := &m.ReqContext{
|
||||
SignedInUser: &m.SignedInUser{UserId: 1},
|
||||
Context: &macaron.Context{
|
||||
Req: macaron.Request{Request: req},
|
||||
},
|
||||
}
|
||||
proxy := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/")
|
||||
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
proxy.getDirector()(req)
|
||||
|
||||
Convey("Should have access token in header", func() {
|
||||
So(req.Header.Get("Authorization"), ShouldEqual, fmt.Sprintf("%s %s", "Bearer", "testtoken"))
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -51,11 +51,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if extUser.AuthModule != "" && extUser.AuthId != "" {
|
||||
if extUser.AuthModule != "" {
|
||||
cmd2 := &m.SetAuthInfoCommand{
|
||||
UserId: cmd.Result.Id,
|
||||
AuthModule: extUser.AuthModule,
|
||||
AuthId: extUser.AuthId,
|
||||
OAuthToken: extUser.OAuthToken,
|
||||
}
|
||||
if err := bus.Dispatch(cmd2); err != nil {
|
||||
return err
|
||||
@ -69,6 +70,14 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Always persist the latest token at log-in
|
||||
if extUser.AuthModule != "" && extUser.OAuthToken != nil {
|
||||
err = updateUserAuth(cmd.Result, extUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = syncOrgRoles(cmd.Result, extUser)
|
||||
@ -143,6 +152,18 @@ func updateUser(user *m.User, extUser *m.ExternalUserInfo) error {
|
||||
return bus.Dispatch(updateCmd)
|
||||
}
|
||||
|
||||
func updateUserAuth(user *m.User, extUser *m.ExternalUserInfo) error {
|
||||
updateCmd := &m.UpdateAuthInfoCommand{
|
||||
AuthModule: extUser.AuthModule,
|
||||
AuthId: extUser.AuthId,
|
||||
UserId: user.Id,
|
||||
OAuthToken: extUser.OAuthToken,
|
||||
}
|
||||
|
||||
log.Debug("Updating user_auth info for user_id %d", user.Id)
|
||||
return bus.Dispatch(updateCmd)
|
||||
}
|
||||
|
||||
func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error {
|
||||
// don't sync org roles if none are specified
|
||||
if len(extUser.OrgRoles) == 0 {
|
||||
|
@ -2,17 +2,24 @@ package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type UserAuth struct {
|
||||
Id int64
|
||||
UserId int64
|
||||
AuthModule string
|
||||
AuthId string
|
||||
Created time.Time
|
||||
Id int64
|
||||
UserId int64
|
||||
AuthModule string
|
||||
AuthId string
|
||||
Created time.Time
|
||||
OAuthAccessToken string
|
||||
OAuthRefreshToken string
|
||||
OAuthTokenType string
|
||||
OAuthExpiry time.Time
|
||||
}
|
||||
|
||||
type ExternalUserInfo struct {
|
||||
OAuthToken *oauth2.Token
|
||||
AuthModule string
|
||||
AuthId string
|
||||
UserId int64
|
||||
@ -39,6 +46,14 @@ type SetAuthInfoCommand struct {
|
||||
AuthModule string
|
||||
AuthId string
|
||||
UserId int64
|
||||
OAuthToken *oauth2.Token
|
||||
}
|
||||
|
||||
type UpdateAuthInfoCommand struct {
|
||||
AuthModule string
|
||||
AuthId string
|
||||
UserId int64
|
||||
OAuthToken *oauth2.Token
|
||||
}
|
||||
|
||||
type DeleteAuthInfoCommand struct {
|
||||
@ -67,6 +82,7 @@ type GetUserByAuthInfoQuery struct {
|
||||
}
|
||||
|
||||
type GetAuthInfoQuery struct {
|
||||
UserId int64
|
||||
AuthModule string
|
||||
AuthId string
|
||||
|
||||
|
@ -33,6 +33,7 @@ func AddMigrations(mg *Migrator) {
|
||||
addUserAuthMigrations(mg)
|
||||
addServerlockMigrations(mg)
|
||||
addUserAuthTokenMigrations(mg)
|
||||
addUserAuthOAuthMigrations(mg)
|
||||
}
|
||||
|
||||
func addMigrationLogMigrations(mg *Migrator) {
|
||||
|
25
pkg/services/sqlstore/migrations/user_auth_oauth_mig.go
Normal file
25
pkg/services/sqlstore/migrations/user_auth_oauth_mig.go
Normal file
@ -0,0 +1,25 @@
|
||||
package migrations
|
||||
|
||||
import . "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
|
||||
func addUserAuthOAuthMigrations(mg *Migrator) {
|
||||
userAuthV2 := Table{Name: "user_auth"}
|
||||
|
||||
mg.AddMigration("Add OAuth access token to user_auth", NewAddColumnMigration(userAuthV2, &Column{
|
||||
Name: "o_auth_access_token", Type: DB_Text, Nullable: true, Length: 255,
|
||||
}))
|
||||
mg.AddMigration("Add OAuth refresh token to user_auth", NewAddColumnMigration(userAuthV2, &Column{
|
||||
Name: "o_auth_refresh_token", Type: DB_Text, Nullable: true, Length: 255,
|
||||
}))
|
||||
mg.AddMigration("Add OAuth token type to user_auth", NewAddColumnMigration(userAuthV2, &Column{
|
||||
Name: "o_auth_token_type", Type: DB_Text, Nullable: true, Length: 255,
|
||||
}))
|
||||
mg.AddMigration("Add OAuth expiry to user_auth", NewAddColumnMigration(userAuthV2, &Column{
|
||||
Name: "o_auth_expiry", Type: DB_DateTime, Nullable: true,
|
||||
}))
|
||||
|
||||
mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV2, &Index{
|
||||
Cols: []string{"user_id"},
|
||||
}))
|
||||
|
||||
}
|
@ -5,12 +5,15 @@ import (
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
func init() {
|
||||
bus.AddHandler("sql", GetUserByAuthInfo)
|
||||
bus.AddHandler("sql", GetAuthInfo)
|
||||
bus.AddHandler("sql", SetAuthInfo)
|
||||
bus.AddHandler("sql", UpdateAuthInfo)
|
||||
bus.AddHandler("sql", DeleteAuthInfo)
|
||||
}
|
||||
|
||||
@ -94,7 +97,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
|
||||
}
|
||||
|
||||
// create authInfo record to link accounts
|
||||
if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" {
|
||||
if authQuery.Result == nil && query.AuthModule != "" {
|
||||
cmd2 := &m.SetAuthInfoCommand{
|
||||
UserId: user.Id,
|
||||
AuthModule: query.AuthModule,
|
||||
@ -111,6 +114,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
|
||||
|
||||
func GetAuthInfo(query *m.GetAuthInfoQuery) error {
|
||||
userAuth := &m.UserAuth{
|
||||
UserId: query.UserId, // TODO this doesn't have an index in the db
|
||||
AuthModule: query.AuthModule,
|
||||
AuthId: query.AuthId,
|
||||
}
|
||||
@ -122,6 +126,28 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error {
|
||||
return m.ErrUserNotFound
|
||||
}
|
||||
|
||||
if userAuth.OAuthAccessToken != "" {
|
||||
accessToken, err := util.Decrypt([]byte(userAuth.OAuthAccessToken), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.OAuthAccessToken = string(accessToken)
|
||||
}
|
||||
if userAuth.OAuthRefreshToken != "" {
|
||||
refreshToken, err := util.Decrypt([]byte(userAuth.OAuthRefreshToken), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.OAuthRefreshToken = string(refreshToken)
|
||||
}
|
||||
if userAuth.OAuthTokenType != "" {
|
||||
tokenType, err := util.Decrypt([]byte(userAuth.OAuthTokenType), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.OAuthTokenType = string(tokenType)
|
||||
}
|
||||
|
||||
query.Result = userAuth
|
||||
return nil
|
||||
}
|
||||
@ -135,11 +161,69 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
|
||||
Created: time.Now(),
|
||||
}
|
||||
|
||||
if cmd.OAuthToken != nil {
|
||||
secretAccessToken, err := util.Encrypt([]byte(cmd.OAuthToken.AccessToken), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretRefreshToken, err := util.Encrypt([]byte(cmd.OAuthToken.RefreshToken), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretTokenType, err := util.Encrypt([]byte(cmd.OAuthToken.TokenType), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authUser.OAuthAccessToken = string(secretAccessToken)
|
||||
authUser.OAuthRefreshToken = string(secretRefreshToken)
|
||||
authUser.OAuthTokenType = string(secretTokenType)
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
_, err := sess.Insert(authUser)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
authUser := &m.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
AuthId: cmd.AuthId,
|
||||
Created: time.Now(),
|
||||
}
|
||||
|
||||
if cmd.OAuthToken != nil {
|
||||
secretAccessToken, err := util.Encrypt([]byte(cmd.OAuthToken.AccessToken), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretRefreshToken, err := util.Encrypt([]byte(cmd.OAuthToken.RefreshToken), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secretTokenType, err := util.Encrypt([]byte(cmd.OAuthToken.TokenType), setting.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
authUser.OAuthAccessToken = string(secretAccessToken)
|
||||
authUser.OAuthRefreshToken = string(secretRefreshToken)
|
||||
authUser.OAuthTokenType = string(secretTokenType)
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
cond := &m.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
}
|
||||
|
||||
_, err := sess.Update(authUser, cond)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
|
||||
return inTransaction(func(sess *DBSession) error {
|
||||
_, err := sess.Delete(cmd.UserAuth)
|
||||
|
@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
m "github.com/grafana/grafana/pkg/models"
|
||||
)
|
||||
@ -126,5 +128,46 @@ func TestUserAuth(t *testing.T) {
|
||||
So(err, ShouldEqual, m.ErrUserNotFound)
|
||||
So(query.Result, ShouldBeNil)
|
||||
})
|
||||
|
||||
Convey("Can set & retrieve oauth token information", func() {
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now(),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
// Find a user to set tokens on
|
||||
login := "loginuser0"
|
||||
|
||||
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
|
||||
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
|
||||
err = GetUserByAuthInfo(query)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.Login, ShouldEqual, login)
|
||||
|
||||
cmd := &m.UpdateAuthInfoCommand{
|
||||
UserId: query.Result.Id,
|
||||
AuthId: query.AuthId,
|
||||
AuthModule: query.AuthModule,
|
||||
OAuthToken: token,
|
||||
}
|
||||
err = UpdateAuthInfo(cmd)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
getAuthQuery := &m.GetAuthInfoQuery{
|
||||
UserId: query.Result.Id,
|
||||
}
|
||||
|
||||
err = GetAuthInfo(getAuthQuery)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken)
|
||||
So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken)
|
||||
So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType)
|
||||
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ type SocialConnector interface {
|
||||
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
|
||||
Exchange(ctx context.Context, code string) (*oauth2.Token, error)
|
||||
Client(ctx context.Context, t *oauth2.Token) *http.Client
|
||||
TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource
|
||||
}
|
||||
|
||||
type SocialBase struct {
|
||||
|
@ -87,6 +87,19 @@
|
||||
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="Skip TLS Verify" label-class="width-10"
|
||||
checked="current.jsonData.tlsSkipVerify" switch-class="max-width-6"></gf-form-checkbox>
|
||||
</div>
|
||||
<div class="gf-form-inline">
|
||||
<gf-form-switch class="gf-form" ng-if="current.access=='proxy'" label="Forward OAuth Identity" label-class="width-13" tooltip="Forward the user's upstream OAuth identity to the datasource (Their access token gets passed along)." label-class="width-10" checked="current.jsonData.oauthPassThru" switch-class="max-width-6"></gf-form-switch>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="gf-form-group" ng-if="current.jsonData.oauthPassThru">
|
||||
<h6>OAuth Identity Forwarding Details</h6>
|
||||
<div class="gf-form max-width-30">
|
||||
<span class="gf-form-label width-10">OAuth Source</span>
|
||||
<div class="gf-form-select-wrapper max-width-24">
|
||||
<select class="gf-form-input" ng-model="current.jsonData.oauthPassThruProvider" ng-options="f.key as f.value for f in oauthProviders"></select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="gf-form-group" ng-if="current.basicAuth">
|
||||
|
@ -20,6 +20,13 @@ coreModule.directive('datasourceHttpSettings', () => {
|
||||
$scope.getSuggestUrls = () => {
|
||||
return [$scope.suggestUrl];
|
||||
};
|
||||
$scope.oauthProviders = [
|
||||
{ key: 'oauth_google', value: 'Google OAuth' },
|
||||
{ key: 'oauth_gitlab', value: 'GitLab OAuth' },
|
||||
{ key: 'oauth_generic_oauth', value: 'Generic OAuth' },
|
||||
{ key: 'oauth_grafana_com', value: 'Grafana OAuth' },
|
||||
{ key: 'oauth_github', value: 'GitHub OAuth' },
|
||||
];
|
||||
},
|
||||
},
|
||||
};
|
||||
|
Reference in New Issue
Block a user