From ad971cc9beb78251ea7edda17c24b1de49d6099a Mon Sep 17 00:00:00 2001 From: Emil Tullstedt Date: Tue, 14 Sep 2021 10:49:37 +0200 Subject: [PATCH] LDAP: Search all DNs for users (#38891) --- pkg/services/ldap/ldap.go | 47 ++- pkg/services/ldap/ldap_helpers_test.go | 268 ++++++------- pkg/services/ldap/ldap_login_test.go | 432 +++++++++++---------- pkg/services/ldap/ldap_private_test.go | 477 +++++++++++------------ pkg/services/ldap/ldap_test.go | 513 +++++++++++++++---------- pkg/services/ldap/testing.go | 23 +- 6 files changed, 896 insertions(+), 864 deletions(-) diff --git a/pkg/services/ldap/ldap.go b/pkg/services/ldap/ldap.go index 6caf5494205..1ad3fa15905 100644 --- a/pkg/services/ldap/ldap.go +++ b/pkg/services/ldap/ldap.go @@ -12,9 +12,10 @@ import ( "strings" "github.com/davecgh/go-spew/spew" + "gopkg.in/ldap.v3" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" - "gopkg.in/ldap.v3" ) // IConnection is interface for LDAP connection manipulation @@ -252,16 +253,11 @@ func (server *Server) Users(logins []string) ( []*models.ExternalUserInfo, error, ) { - var users []*ldap.Entry + var users [][]*ldap.Entry err := getUsersIteration(logins, func(previous, current int) error { - entries, err := server.users(logins[previous:current]) - if err != nil { - return err - } - - users = append(users, entries...) - - return nil + var err error + users, err = server.users(logins[previous:current]) + return err }) if err != nil { return nil, err @@ -308,13 +304,15 @@ func getUsersIteration(logins []string, fn func(int, int) error) error { // users is helper method for the Users() func (server *Server) users(logins []string) ( - []*ldap.Entry, + [][]*ldap.Entry, error, ) { var result *ldap.SearchResult var Config = server.Config var err error + var entries = make([][]*ldap.Entry, 0, len(Config.SearchBaseDNs)) + for _, base := range Config.SearchBaseDNs { result, err = server.Connection.Search( server.getSearchRequest(base, logins), @@ -324,11 +322,11 @@ func (server *Server) users(logins []string) ( } if len(result.Entries) > 0 { - break + entries = append(entries, result.Entries) } } - return result.Entries, nil + return entries, nil } // validateGrafanaUser validates user access. @@ -557,17 +555,26 @@ func (server *Server) requestMemberOf(entry *ldap.Entry) ([]string, error) { // serializeUsers serializes the users // from LDAP result to ExternalInfo struct func (server *Server) serializeUsers( - entries []*ldap.Entry, + entries [][]*ldap.Entry, ) ([]*models.ExternalUserInfo, error) { var serialized []*models.ExternalUserInfo + var users = map[string]struct{}{} - for _, user := range entries { - extUser, err := server.buildGrafanaUser(user) - if err != nil { - return nil, err + for _, dn := range entries { + for _, user := range dn { + extUser, err := server.buildGrafanaUser(user) + if err != nil { + return nil, err + } + + if _, exists := users[extUser.Login]; exists { + // ignore duplicates + continue + } + users[extUser.Login] = struct{}{} + + serialized = append(serialized, extUser) } - - serialized = append(serialized, extUser) } return serialized, nil diff --git a/pkg/services/ldap/ldap_helpers_test.go b/pkg/services/ldap/ldap_helpers_test.go index 5062623d546..e276917c973 100644 --- a/pkg/services/ldap/ldap_helpers_test.go +++ b/pkg/services/ldap/ldap_helpers_test.go @@ -1,191 +1,141 @@ package ldap import ( + "fmt" "testing" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" "gopkg.in/ldap.v3" ) -func TestLDAPHelpers(t *testing.T) { - Convey("isMemberOf()", t, func() { - Convey("Wildcard", func() { - result := isMemberOf([]string{}, "*") - So(result, ShouldBeTrue) - }) +func TestIsMemberOf(t *testing.T) { + tests := []struct { + memberOf []string + group string + expected bool + }{ + {memberOf: []string{}, group: "*", expected: true}, + {memberOf: []string{"one", "Two", "three"}, group: "two", expected: true}, + {memberOf: []string{"one", "Two", "three"}, group: "twos", expected: false}, + } - Convey("Should find one", func() { - result := isMemberOf([]string{"one", "Two", "three"}, "two") - So(result, ShouldBeTrue) + for _, tc := range tests { + t.Run(fmt.Sprintf("isMemberOf(%v, \"%s\") = %v", tc.memberOf, tc.group, tc.expected), func(t *testing.T) { + assert.Equal(t, tc.expected, isMemberOf(tc.memberOf, tc.group)) }) + } +} - Convey("Should not find one", func() { - result := isMemberOf([]string{"one", "Two", "three"}, "twos") - So(result, ShouldBeFalse) - }) - }) +func TestGetUsersIteration(t *testing.T) { + const pageSize = UsersMaxRequest + iterations := map[int]int{ + 0: 0, + 400: 1, + 600: 2, + 1500: 3, + } + + for userCount, expectedIterations := range iterations { + t.Run(fmt.Sprintf("getUserIteration iterates %d times for %d users", expectedIterations, userCount), func(t *testing.T) { + logins := make([]string, userCount) - Convey("getUsersIteration()", t, func() { - Convey("it should execute twice for 600 users", func() { - logins := make([]string, 600) i := 0 + _ = getUsersIteration(logins, func(first int, last int) error { + assert.Equal(t, pageSize*i, first) - result := getUsersIteration(logins, func(previous, current int) error { - i++ - - if i == 1 { - So(previous, ShouldEqual, 0) - So(current, ShouldEqual, 500) - } else { - So(previous, ShouldEqual, 500) - So(current, ShouldEqual, 600) + expectedLast := pageSize*i + pageSize + if expectedLast > userCount { + expectedLast = userCount } - return nil - }) + assert.Equal(t, expectedLast, last) - So(i, ShouldEqual, 2) - So(result, ShouldBeNil) - }) - - Convey("it should execute three times for 1500 users", func() { - logins := make([]string, 1500) - i := 0 - - result := getUsersIteration(logins, func(previous, current int) error { - i++ - switch i { - case 1: - So(previous, ShouldEqual, 0) - So(current, ShouldEqual, 500) - case 2: - So(previous, ShouldEqual, 500) - So(current, ShouldEqual, 1000) - default: - So(previous, ShouldEqual, 1000) - So(current, ShouldEqual, 1500) - } - - return nil - }) - - So(i, ShouldEqual, 3) - So(result, ShouldBeNil) - }) - - Convey("it should execute once for 400 users", func() { - logins := make([]string, 400) - i := 0 - - result := getUsersIteration(logins, func(previous, current int) error { - i++ - if i == 1 { - So(previous, ShouldEqual, 0) - So(current, ShouldEqual, 400) - } - - return nil - }) - - So(i, ShouldEqual, 1) - So(result, ShouldBeNil) - }) - - Convey("it should not execute for 0 users", func() { - logins := make([]string, 0) - i := 0 - - result := getUsersIteration(logins, func(previous, current int) error { i++ return nil }) - So(i, ShouldEqual, 0) - So(result, ShouldBeNil) + assert.Equal(t, expectedIterations, i) }) + } +} + +func TestGetAttribute(t *testing.T) { + t.Run("DN", func(t *testing.T) { + entry := &ldap.Entry{ + DN: "test", + } + + result := getAttribute("dn", entry) + assert.Equal(t, "test", result) }) - Convey("getAttribute()", t, func() { - Convey("Should get DN", func() { - entry := &ldap.Entry{ - DN: "test", - } - - result := getAttribute("dn", entry) - - So(result, ShouldEqual, "test") - }) - - Convey("Should get username", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "username", Values: value, - }, + t.Run("username", func(t *testing.T) { + value := "roelgerrits" + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "username", Values: []string{value}, }, - } + }, + } - result := getAttribute("username", entry) - - So(result, ShouldEqual, value[0]) - }) - - Convey("Should not get anything", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "killa", Values: value, - }, - }, - } - - result := getAttribute("username", entry) - - So(result, ShouldEqual, "") - }) + result := getAttribute("username", entry) + assert.Equal(t, value, result) }) - Convey("getArrayAttribute()", t, func() { - Convey("Should get DN", func() { - entry := &ldap.Entry{ - DN: "test", - } - - result := getArrayAttribute("dn", entry) - - So(result, ShouldResemble, []string{"test"}) - }) - - Convey("Should get username", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "username", Values: value, - }, + t.Run("no result", func(t *testing.T) { + value := []string{"roelgerrits"} + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "killa", Values: value, }, - } + }, + } - result := getArrayAttribute("username", entry) - - So(result, ShouldResemble, value) - }) - - Convey("Should not get anything", func() { - value := []string{"roelgerrits"} - entry := &ldap.Entry{ - Attributes: []*ldap.EntryAttribute{ - { - Name: "username", Values: value, - }, - }, - } - - result := getArrayAttribute("something", entry) - - So(result, ShouldResemble, []string{}) - }) + result := getAttribute("username", entry) + assert.Empty(t, result) + }) +} + +func TestGetArrayAttribute(t *testing.T) { + t.Run("DN", func(t *testing.T) { + entry := &ldap.Entry{ + DN: "test", + } + + result := getArrayAttribute("dn", entry) + + assert.EqualValues(t, []string{"test"}, result) + }) + + t.Run("username", func(t *testing.T) { + value := []string{"roelgerrits"} + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "username", Values: value, + }, + }, + } + + result := getArrayAttribute("username", entry) + + assert.EqualValues(t, value, result) + }) + + t.Run("no result", func(t *testing.T) { + value := []string{"roelgerrits"} + entry := &ldap.Entry{ + Attributes: []*ldap.EntryAttribute{ + { + Name: "username", Values: value, + }, + }, + } + + result := getArrayAttribute("something", entry) + + assert.Empty(t, result) }) } diff --git a/pkg/services/ldap/ldap_login_test.go b/pkg/services/ldap/ldap_login_test.go index dea64fab48c..7b552a8edfa 100644 --- a/pkg/services/ldap/ldap_login_test.go +++ b/pkg/services/ldap/ldap_login_test.go @@ -4,231 +4,227 @@ import ( "errors" "testing" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/ldap.v3" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" ) -func TestLDAPLogin(t *testing.T) { - defaultLogin := &models.LoginUserQuery{ - Username: "user", - Password: "pwd", - IpAddress: "192.168.1.1:56433", +var defaultLogin = &models.LoginUserQuery{ + Username: "user", + Password: "pwd", + IpAddress: "192.168.1.1:56433", +} + +func TestServer_Login_UserBind_Fail(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{} + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + connection.BindProvider = func(username, password string) error { + return &ldap.Error{ + ResultCode: 49, + } + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), } - Convey("Login()", t, func() { - Convey("Should get invalid credentials when userBind fails", func() { - connection := &MockConnection{} - entry := ldap.Entry{} - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) + _, err := server.Login(defaultLogin) - connection.BindProvider = func(username, password string) error { - return &ldap.Error{ - ResultCode: 49, - } - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldEqual, ErrInvalidCredentials) - }) - - Convey("Returns an error when search didn't find anything", func() { - connection := &MockConnection{} - result := ldap.SearchResult{Entries: []*ldap.Entry{}} - connection.setSearchResult(&result) - - connection.BindProvider = func(username, password string) error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldEqual, ErrCouldNotFindUser) - }) - - Convey("When search returns an error", func() { - connection := &MockConnection{} - expected := errors.New("Killa-gorilla") - connection.setSearchError(expected) - - connection.BindProvider = func(username, password string) error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldEqual, expected) - }) - - Convey("When login with valid credentials", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "dn", Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"markelog"}}, - {Name: "surname", Values: []string{"Gaidarenko"}}, - {Name: "email", Values: []string{"markelog@gmail.com"}}, - {Name: "name", Values: []string{"Oleg"}}, - {Name: "memberof", Values: []string{"admins"}}, - }, - } - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - connection.BindProvider = func(username, password string) error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - resp, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - So(resp.Login, ShouldEqual, "markelog") - }) - - Convey("Should perform unauthenticated bind without admin", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "test", - } - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - connection.UnauthenticatedBindProvider = func() error { - return nil - } - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - user, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - So(user.AuthId, ShouldEqual, "test") - So(connection.UnauthenticatedBindCalled, ShouldBeTrue) - }) - - Convey("Should perform authenticated binds", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "test", - } - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - connection.setSearchResult(&result) - - adminUsername := "" - adminPassword := "" - username := "" - password := "" - - i := 0 - connection.BindProvider = func(name, pass string) error { - i++ - if i == 1 { - adminUsername = name - adminPassword = pass - } - - if i == 2 { - username = name - password = pass - } - - return nil - } - server := &Server{ - Config: &ServerConfig{ - BindDN: "killa", - BindPassword: "gorilla", - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - user, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - - So(user.AuthId, ShouldEqual, "test") - So(connection.BindCalled, ShouldBeTrue) - - So(adminUsername, ShouldEqual, "killa") - So(adminPassword, ShouldEqual, "gorilla") - - So(username, ShouldEqual, "test") - So(password, ShouldEqual, "pwd") - }) - Convey("Should bind with user if %s exists in the bind_dn", func() { - connection := &MockConnection{} - entry := ldap.Entry{ - DN: "test", - } - connection.setSearchResult(&ldap.SearchResult{Entries: []*ldap.Entry{&entry}}) - - authBindUser := "" - authBindPassword := "" - - connection.BindProvider = func(name, pass string) error { - authBindUser = name - authBindPassword = pass - return nil - } - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=%s,ou=users,dc=grafana,dc=org", - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - log: log.New("test-logger"), - } - - _, err := server.Login(defaultLogin) - - So(err, ShouldBeNil) - - So(authBindUser, ShouldEqual, "cn=user,ou=users,dc=grafana,dc=org") - So(authBindPassword, ShouldEqual, "pwd") - So(connection.BindCalled, ShouldBeTrue) - }) - }) + assert.ErrorIs(t, err, ErrInvalidCredentials) +} + +func TestServer_Login_Search_NoResult(t *testing.T) { + connection := &MockConnection{} + result := ldap.SearchResult{Entries: []*ldap.Entry{}} + connection.setSearchResult(&result) + + connection.BindProvider = func(username, password string) error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + assert.ErrorIs(t, err, ErrCouldNotFindUser) +} + +func TestServer_Login_Search_Error(t *testing.T) { + connection := &MockConnection{} + expected := errors.New("Killa-gorilla") + connection.setSearchError(expected) + + connection.BindProvider = func(username, password string) error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + assert.ErrorIs(t, err, expected) +} + +func TestServer_Login_ValidCredentials(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"markelog"}}, + {Name: "surname", Values: []string{"Gaidarenko"}}, + {Name: "email", Values: []string{"markelog@gmail.com"}}, + {Name: "name", Values: []string{"Oleg"}}, + {Name: "memberof", Values: []string{"admins"}}, + }, + } + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + connection.BindProvider = func(username, password string) error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + }, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + resp, err := server.Login(defaultLogin) + require.NoError(t, err) + assert.Equal(t, "markelog", resp.Login) +} + +// TestServer_Login_UnauthenticatedBind tests that unauthenticated bind +// is called when there is no admin password or user wildcard in the +// bind_dn. +func TestServer_Login_UnauthenticatedBind(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "test", + } + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + connection.UnauthenticatedBindProvider = func() error { + return nil + } + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + user, err := server.Login(defaultLogin) + require.NoError(t, err) + assert.Equal(t, "test", user.AuthId) + assert.True(t, connection.UnauthenticatedBindCalled) +} + +func TestServer_Login_AuthenticatedBind(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "test", + } + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + connection.setSearchResult(&result) + + adminUsername := "" + adminPassword := "" + username := "" + password := "" + + i := 0 + connection.BindProvider = func(name, pass string) error { + i++ + if i == 1 { + adminUsername = name + adminPassword = pass + } + + if i == 2 { + username = name + password = pass + } + + return nil + } + server := &Server{ + Config: &ServerConfig{ + BindDN: "killa", + BindPassword: "gorilla", + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + user, err := server.Login(defaultLogin) + require.NoError(t, err) + + assert.Equal(t, "test", user.AuthId) + assert.True(t, connection.BindCalled) + + assert.Equal(t, "killa", adminUsername) + assert.Equal(t, "gorilla", adminPassword) + + assert.Equal(t, "test", username) + assert.Equal(t, "pwd", password) +} + +func TestServer_Login_UserWildcardBind(t *testing.T) { + connection := &MockConnection{} + entry := ldap.Entry{ + DN: "test", + } + connection.setSearchResult(&ldap.SearchResult{Entries: []*ldap.Entry{&entry}}) + + authBindUser := "" + authBindPassword := "" + + connection.BindProvider = func(name, pass string) error { + authBindUser = name + authBindPassword = pass + return nil + } + server := &Server{ + Config: &ServerConfig{ + BindDN: "cn=%s,ou=users,dc=grafana,dc=org", + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: connection, + log: log.New("test-logger"), + } + + _, err := server.Login(defaultLogin) + require.NoError(t, err) + + assert.Equal(t, "cn=user,ou=users,dc=grafana,dc=org", authBindUser) + assert.Equal(t, "pwd", authBindPassword) + assert.True(t, connection.BindCalled) } diff --git a/pkg/services/ldap/ldap_private_test.go b/pkg/services/ldap/ldap_private_test.go index 431f94f0d94..d4d0f1c238d 100644 --- a/pkg/services/ldap/ldap_private_test.go +++ b/pkg/services/ldap/ldap_private_test.go @@ -3,271 +3,252 @@ package ldap import ( "testing" + "github.com/stretchr/testify/require" + + "github.com/stretchr/testify/assert" + + "gopkg.in/ldap.v3" + "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" - . "github.com/smartystreets/goconvey/convey" - "gopkg.in/ldap.v3" ) -func TestLDAPPrivateMethods(t *testing.T) { - Convey("getSearchRequest()", t, func() { - Convey("with enabled GroupSearchFilterUserAttribute setting", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - Email: "email", - }, - GroupSearchFilterUserAttribute: "gansta", - SearchBaseDNs: []string{"BaseDNHere"}, - }, - log: log.New("test-logger"), - } +func TestServer_getSearchRequest(t *testing.T) { + expected := &ldap.SearchRequest{ + BaseDN: "killa", + Scope: 2, + DerefAliases: 0, + SizeLimit: 0, + TimeLimit: 0, + TypesOnly: false, + Filter: "(|)", + Attributes: []string{ + "username", + "email", + "name", + "memberof", + "gansta", + }, + Controls: nil, + } - result := server.getSearchRequest("killa", []string{"gorilla"}) + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + Email: "email", + }, + GroupSearchFilterUserAttribute: "gansta", + SearchBaseDNs: []string{"BaseDNHere"}, + }, + log: log.New("test-logger"), + } - So(result, ShouldResemble, &ldap.SearchRequest{ - BaseDN: "killa", - Scope: 2, - DerefAliases: 0, - SizeLimit: 0, - TimeLimit: 0, - TypesOnly: false, - Filter: "(|)", - Attributes: []string{ - "username", - "email", - "name", - "memberof", - "gansta", + result := server.getSearchRequest("killa", []string{"gorilla"}) + + assert.EqualValues(t, expected, result) +} + +func TestSerializeUsers(t *testing.T) { + t.Run("simple case", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + Email: "email", }, - Controls: nil, - }) - }) + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: &MockConnection{}, + log: log.New("test-logger"), + } + + entry := ldap.Entry{ + DN: "dn", + Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"roelgerrits"}}, + {Name: "surname", Values: []string{"Gerrits"}}, + {Name: "email", Values: []string{"roel@test.com"}}, + {Name: "name", Values: []string{"Roel"}}, + {Name: "memberof", Values: []string{"admins"}}, + }, + } + users := [][]*ldap.Entry{{&entry}} + + result, err := server.serializeUsers(users) + require.NoError(t, err) + + assert.Equal(t, "roelgerrits", result[0].Login) + assert.Equal(t, "roel@test.com", result[0].Email) + assert.Contains(t, result[0].Groups, "admins") }) - Convey("serializeUsers()", t, func() { - Convey("simple case", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - Email: "email", - }, - SearchBaseDNs: []string{"BaseDNHere"}, + t.Run("without lastname", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + Email: "email", }, - Connection: &MockConnection{}, - log: log.New("test-logger"), - } + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: &MockConnection{}, + log: log.New("test-logger"), + } - entry := ldap.Entry{ - DN: "dn", - Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"roelgerrits"}}, - {Name: "surname", Values: []string{"Gerrits"}}, - {Name: "email", Values: []string{"roel@test.com"}}, - {Name: "name", Values: []string{"Roel"}}, - {Name: "memberof", Values: []string{"admins"}}, - }, - } - users := []*ldap.Entry{&entry} + entry := ldap.Entry{ + DN: "dn", + Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"roelgerrits"}}, + {Name: "email", Values: []string{"roel@test.com"}}, + {Name: "name", Values: []string{"Roel"}}, + {Name: "memberof", Values: []string{"admins"}}, + }, + } + users := [][]*ldap.Entry{{&entry}} - result, err := server.serializeUsers(users) + result, err := server.serializeUsers(users) + require.NoError(t, err) - So(err, ShouldBeNil) - So(result[0].Login, ShouldEqual, "roelgerrits") - So(result[0].Email, ShouldEqual, "roel@test.com") - So(result[0].Groups, ShouldContain, "admins") - }) - - Convey("without lastname", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - Email: "email", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: &MockConnection{}, - log: log.New("test-logger"), - } - - entry := ldap.Entry{ - DN: "dn", - Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"roelgerrits"}}, - {Name: "email", Values: []string{"roel@test.com"}}, - {Name: "name", Values: []string{"Roel"}}, - {Name: "memberof", Values: []string{"admins"}}, - }, - } - users := []*ldap.Entry{&entry} - - result, err := server.serializeUsers(users) - - So(err, ShouldBeNil) - So(result[0].IsDisabled, ShouldBeFalse) - So(result[0].Name, ShouldEqual, "Roel") - }) - - Convey("a user without matching groups should be marked as disabled", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{{ - GroupDN: "foo", - OrgId: 1, - OrgRole: models.ROLE_EDITOR, - }}, - }, - Connection: &MockConnection{}, - log: log.New("test-logger"), - } - - entry := ldap.Entry{ - DN: "dn", - Attributes: []*ldap.EntryAttribute{ - {Name: "memberof", Values: []string{"admins"}}, - }, - } - users := []*ldap.Entry{&entry} - - result, err := server.serializeUsers(users) - - So(err, ShouldBeNil) - So(len(result), ShouldEqual, 1) - So(result[0].IsDisabled, ShouldBeTrue) - }) + assert.False(t, result[0].IsDisabled) + assert.Equal(t, "Roel", result[0].Name) }) - Convey("validateGrafanaUser()", t, func() { - Convey("Returns error when user does not belong in any of the specified LDAP groups", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{ - { - OrgId: 1, - }, - }, - }, - log: logger.New("test"), - } + t.Run("mark user without matching group as disabled", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{{ + GroupDN: "foo", + OrgId: 1, + OrgRole: models.ROLE_EDITOR, + }}, + }, + Connection: &MockConnection{}, + log: log.New("test-logger"), + } - user := &models.ExternalUserInfo{ - Login: "markelog", - } + entry := ldap.Entry{ + DN: "dn", + Attributes: []*ldap.EntryAttribute{ + {Name: "memberof", Values: []string{"admins"}}, + }, + } + users := [][]*ldap.Entry{{&entry}} - result := server.validateGrafanaUser(user) + result, err := server.serializeUsers(users) + require.NoError(t, err) - So(result, ShouldEqual, ErrInvalidCredentials) - }) - - Convey("Does not return error when group config is empty", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{}, - }, - log: logger.New("test"), - } - - user := &models.ExternalUserInfo{ - Login: "markelog", - } - - result := server.validateGrafanaUser(user) - - So(result, ShouldBeNil) - }) - - Convey("Does not return error when groups are there", func() { - server := &Server{ - Config: &ServerConfig{ - Groups: []*GroupToOrgRole{ - { - OrgId: 1, - }, - }, - }, - log: logger.New("test"), - } - - user := &models.ExternalUserInfo{ - Login: "markelog", - OrgRoles: map[int64]models.RoleType{ - 1: "test", - }, - } - - result := server.validateGrafanaUser(user) - - So(result, ShouldBeNil) - }) - }) - - Convey("shouldAdminBind()", t, func() { - Convey("it should require admin userBind", func() { - server := &Server{ - Config: &ServerConfig{ - BindPassword: "test", - }, - } - - result := server.shouldAdminBind() - So(result, ShouldBeTrue) - }) - - Convey("it should not require admin userBind", func() { - server := &Server{ - Config: &ServerConfig{ - BindPassword: "", - }, - } - - result := server.shouldAdminBind() - So(result, ShouldBeFalse) - }) - }) - - Convey("shouldSingleBind()", t, func() { - Convey("it should allow single bind", func() { - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=%s,dc=grafana,dc=org", - }, - } - - result := server.shouldSingleBind() - So(result, ShouldBeTrue) - }) - - Convey("it should not allow single bind", func() { - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=admin,dc=grafana,dc=org", - }, - } - - result := server.shouldSingleBind() - So(result, ShouldBeFalse) - }) - }) - - Convey("singleBindDN()", t, func() { - Convey("it should allow single bind", func() { - server := &Server{ - Config: &ServerConfig{ - BindDN: "cn=%s,dc=grafana,dc=org", - }, - } - - result := server.singleBindDN("test") - So(result, ShouldEqual, "cn=test,dc=grafana,dc=org") - }) + assert.Len(t, result, 1) + assert.True(t, result[0].IsDisabled) + }) +} + +func TestServer_validateGrafanaUser(t *testing.T) { + t.Run("no group config", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{}, + }, + log: logger.New("test"), + } + + user := &models.ExternalUserInfo{ + Login: "markelog", + } + + err := server.validateGrafanaUser(user) + require.NoError(t, err) + }) + + t.Run("user in group", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{ + { + OrgId: 1, + }, + }, + }, + log: logger.New("test"), + } + + user := &models.ExternalUserInfo{ + Login: "markelog", + OrgRoles: map[int64]models.RoleType{ + 1: "test", + }, + } + + err := server.validateGrafanaUser(user) + require.NoError(t, err) + }) + + t.Run("user not in group", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Groups: []*GroupToOrgRole{ + { + OrgId: 1, + }, + }, + }, + log: logger.New("test"), + } + + user := &models.ExternalUserInfo{ + Login: "markelog", + } + + err := server.validateGrafanaUser(user) + require.ErrorIs(t, err, ErrInvalidCredentials) + }) +} + +func TestServer_binds(t *testing.T) { + t.Run("single bind with cn wildcard", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindDN: "cn=%s,dc=grafana,dc=org", + }, + } + + assert.True(t, server.shouldSingleBind()) + assert.Equal(t, "cn=test,dc=grafana,dc=org", server.singleBindDN("test")) + }) + + t.Run("don't single bind", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindDN: "cn=admin,dc=grafana,dc=org", + }, + } + + assert.False(t, server.shouldSingleBind()) + }) + + t.Run("admin user bind", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindPassword: "test", + }, + } + + assert.True(t, server.shouldAdminBind()) + }) + + t.Run("don't admin user bind", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + BindPassword: "", + }, + } + + assert.False(t, server.shouldAdminBind()) }) } diff --git a/pkg/services/ldap/ldap_test.go b/pkg/services/ldap/ldap_test.go index ea1fd049bf3..042ac045506 100644 --- a/pkg/services/ldap/ldap_test.go +++ b/pkg/services/ldap/ldap_test.go @@ -2,226 +2,319 @@ package ldap import ( "errors" + "fmt" "testing" - "github.com/grafana/grafana/pkg/infra/log" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/ldap.v3" + + "github.com/grafana/grafana/pkg/infra/log" ) -func TestPublicAPI(t *testing.T) { - Convey("New()", t, func() { - Convey("Should return ", func() { - result := New(&ServerConfig{ +func TestNew(t *testing.T) { + result := New(&ServerConfig{ + Attr: AttributeMap{}, + SearchBaseDNs: []string{"BaseDNHere"}, + }) + + assert.Implements(t, (*IServer)(nil), result) +} + +func TestServer_Close(t *testing.T) { + t.Run("close the connection", func(t *testing.T) { + connection := &MockConnection{} + + server := &Server{ + Config: &ServerConfig{ Attr: AttributeMap{}, SearchBaseDNs: []string{"BaseDNHere"}, - }) + }, + Connection: connection, + } - So(result, ShouldImplement, (*IServer)(nil)) - }) + assert.NotPanics(t, server.Close) + assert.True(t, connection.CloseCalled) }) - Convey("Close()", t, func() { - Convey("Should close the connection", func() { - connection := &MockConnection{} + t.Run("panic if no connection", func(t *testing.T) { + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{}, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: nil, + } - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{}, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: connection, - } - - So(server.Close, ShouldNotPanic) - So(connection.CloseCalled, ShouldBeTrue) - }) - - Convey("Should panic if no connection is established", func() { - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{}, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: nil, - } - - So(server.Close, ShouldPanic) - }) - }) - Convey("Users()", t, func() { - Convey("Finds one user", func() { - MockConnection := &MockConnection{} - entry := ldap.Entry{ - DN: "dn", Attributes: []*ldap.EntryAttribute{ - {Name: "username", Values: []string{"roelgerrits"}}, - {Name: "surname", Values: []string{"Gerrits"}}, - {Name: "email", Values: []string{"roel@test.com"}}, - {Name: "name", Values: []string{"Roel"}}, - {Name: "memberof", Values: []string{"admins"}}, - }} - result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} - MockConnection.setSearchResult(&result) - - // Set up attribute map without surname and email - server := &Server{ - Config: &ServerConfig{ - Attr: AttributeMap{ - Username: "username", - Name: "name", - MemberOf: "memberof", - }, - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: MockConnection, - log: log.New("test-logger"), - } - - searchResult, err := server.Users([]string{"roelgerrits"}) - - So(err, ShouldBeNil) - So(searchResult, ShouldNotBeNil) - - // User should be searched in ldap - So(MockConnection.SearchCalled, ShouldBeTrue) - - // No empty attributes should be added to the search request - So(len(MockConnection.SearchAttributes), ShouldEqual, 3) - }) - - Convey("Handles a error", func() { - expected := errors.New("Killa-gorilla") - MockConnection := &MockConnection{} - MockConnection.setSearchError(expected) - - // Set up attribute map without surname and email - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: MockConnection, - log: log.New("test-logger"), - } - - _, err := server.Users([]string{"roelgerrits"}) - - So(err, ShouldEqual, expected) - }) - - Convey("Should return empty slice if none were found", func() { - MockConnection := &MockConnection{} - result := ldap.SearchResult{Entries: []*ldap.Entry{}} - MockConnection.setSearchResult(&result) - - // Set up attribute map without surname and email - server := &Server{ - Config: &ServerConfig{ - SearchBaseDNs: []string{"BaseDNHere"}, - }, - Connection: MockConnection, - log: log.New("test-logger"), - } - - searchResult, err := server.Users([]string{"roelgerrits"}) - - So(err, ShouldBeNil) - So(searchResult, ShouldBeEmpty) - }) - }) - - Convey("UserBind()", t, func() { - Convey("Should use provided DN and password", func() { - connection := &MockConnection{} - var actualUsername, actualPassword string - connection.BindProvider = func(username, password string) error { - actualUsername = username - actualPassword = password - return nil - } - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindDN: "cn=admin,dc=grafana,dc=org", - }, - } - - dn := "cn=user,ou=users,dc=grafana,dc=org" - err := server.UserBind(dn, "pwd") - - So(err, ShouldBeNil) - So(actualUsername, ShouldEqual, dn) - So(actualPassword, ShouldEqual, "pwd") - }) - - Convey("Should handle an error", func() { - connection := &MockConnection{} - expected := &ldap.Error{ - ResultCode: uint16(25), - } - connection.BindProvider = func(username, password string) error { - return expected - } - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindDN: "cn=%s,ou=users,dc=grafana,dc=org", - }, - log: log.New("test-logger"), - } - err := server.UserBind("user", "pwd") - So(err, ShouldEqual, expected) - }) - }) - - Convey("AdminBind()", t, func() { - Convey("Should use admin DN and password", func() { - connection := &MockConnection{} - var actualUsername, actualPassword string - connection.BindProvider = func(username, password string) error { - actualUsername = username - actualPassword = password - return nil - } - - dn := "cn=admin,dc=grafana,dc=org" - - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindPassword: "pwd", - BindDN: dn, - }, - } - - err := server.AdminBind() - - So(err, ShouldBeNil) - So(actualUsername, ShouldEqual, dn) - So(actualPassword, ShouldEqual, "pwd") - }) - - Convey("Should handle an error", func() { - connection := &MockConnection{} - expected := &ldap.Error{ - ResultCode: uint16(25), - } - connection.BindProvider = func(username, password string) error { - return expected - } - - dn := "cn=admin,dc=grafana,dc=org" - - server := &Server{ - Connection: connection, - Config: &ServerConfig{ - BindPassword: "pwd", - BindDN: dn, - }, - log: log.New("test-logger"), - } - - err := server.AdminBind() - So(err, ShouldEqual, expected) - }) + assert.Panics(t, server.Close) + }) +} + +func TestServer_Users(t *testing.T) { + t.Run("one user", func(t *testing.T) { + conn := &MockConnection{} + entry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"roelgerrits"}}, + {Name: "surname", Values: []string{"Gerrits"}}, + {Name: "email", Values: []string{"roel@test.com"}}, + {Name: "name", Values: []string{"Roel"}}, + {Name: "memberof", Values: []string{"admins"}}, + }} + result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} + conn.setSearchResult(&result) + + // Set up attribute map without surname and email + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + MemberOf: "memberof", + }, + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: conn, + log: log.New("test-logger"), + } + + searchResult, err := server.Users([]string{"roelgerrits"}) + + require.NoError(t, err) + assert.NotNil(t, searchResult) + + // User should be searched in ldap + assert.True(t, conn.SearchCalled) + // No empty attributes should be added to the search request + assert.Len(t, conn.SearchAttributes, 3) + }) + + t.Run("error", func(t *testing.T) { + expected := errors.New("Killa-gorilla") + conn := &MockConnection{} + conn.setSearchError(expected) + + // Set up attribute map without surname and email + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: conn, + log: log.New("test-logger"), + } + + _, err := server.Users([]string{"roelgerrits"}) + + assert.ErrorIs(t, err, expected) + }) + + t.Run("no user", func(t *testing.T) { + conn := &MockConnection{} + result := ldap.SearchResult{Entries: []*ldap.Entry{}} + conn.setSearchResult(&result) + + // Set up attribute map without surname and email + server := &Server{ + Config: &ServerConfig{ + SearchBaseDNs: []string{"BaseDNHere"}, + }, + Connection: conn, + log: log.New("test-logger"), + } + + searchResult, err := server.Users([]string{"roelgerrits"}) + + require.NoError(t, err) + assert.Empty(t, searchResult) + }) + + t.Run("multiple DNs", func(t *testing.T) { + conn := &MockConnection{} + serviceDN := "dc=svc,dc=example,dc=org" + serviceEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"imgrenderer"}}, + {Name: "name", Values: []string{"Image renderer"}}, + }} + services := ldap.SearchResult{Entries: []*ldap.Entry{&serviceEntry}} + + userDN := "dc=users,dc=example,dc=org" + userEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"grot"}}, + {Name: "name", Values: []string{"Grot"}}, + }} + users := ldap.SearchResult{Entries: []*ldap.Entry{&userEntry}} + + conn.setSearchFunc(func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + switch request.BaseDN { + case userDN: + return &users, nil + case serviceDN: + return &services, nil + default: + return nil, fmt.Errorf("test case not defined for baseDN: '%s'", request.BaseDN) + } + }) + + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + }, + SearchBaseDNs: []string{serviceDN, userDN}, + }, + Connection: conn, + log: log.New("test-logger"), + } + + searchResult, err := server.Users([]string{"imgrenderer", "grot"}) + require.NoError(t, err) + + assert.Len(t, searchResult, 2) + }) + + t.Run("same user in multiple DNs", func(t *testing.T) { + conn := &MockConnection{} + firstDN := "dc=users1,dc=example,dc=org" + firstEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"grot"}}, + {Name: "name", Values: []string{"Grot the First"}}, + }} + firsts := ldap.SearchResult{Entries: []*ldap.Entry{&firstEntry}} + + secondDN := "dc=users2,dc=example,dc=org" + secondEntry := ldap.Entry{ + DN: "dn", Attributes: []*ldap.EntryAttribute{ + {Name: "username", Values: []string{"grot"}}, + {Name: "name", Values: []string{"Grot the Second"}}, + }} + seconds := ldap.SearchResult{Entries: []*ldap.Entry{&secondEntry}} + + conn.setSearchFunc(func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + switch request.BaseDN { + case secondDN: + return &seconds, nil + case firstDN: + return &firsts, nil + default: + return nil, fmt.Errorf("test case not defined for baseDN: '%s'", request.BaseDN) + } + }) + + server := &Server{ + Config: &ServerConfig{ + Attr: AttributeMap{ + Username: "username", + Name: "name", + }, + SearchBaseDNs: []string{firstDN, secondDN}, + }, + Connection: conn, + log: log.New("test-logger"), + } + + res, err := server.Users([]string{"grot"}) + require.NoError(t, err) + require.Len(t, res, 1) + assert.Equal(t, "Grot the First", res[0].Name) + }) +} + +func TestServer_UserBind(t *testing.T) { + t.Run("use provided DN and password", func(t *testing.T) { + connection := &MockConnection{} + var actualUsername, actualPassword string + connection.BindProvider = func(username, password string) error { + actualUsername = username + actualPassword = password + return nil + } + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindDN: "cn=admin,dc=grafana,dc=org", + }, + } + + dn := "cn=user,ou=users,dc=grafana,dc=org" + err := server.UserBind(dn, "pwd") + + require.NoError(t, err) + assert.Equal(t, dn, actualUsername) + assert.Equal(t, "pwd", actualPassword) + }) + + t.Run("error", func(t *testing.T) { + connection := &MockConnection{} + expected := &ldap.Error{ + ResultCode: uint16(25), + } + connection.BindProvider = func(username, password string) error { + return expected + } + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindDN: "cn=%s,ou=users,dc=grafana,dc=org", + }, + log: log.New("test-logger"), + } + err := server.UserBind("user", "pwd") + assert.ErrorIs(t, err, expected) + }) +} + +func TestServer_AdminBind(t *testing.T) { + t.Run("use admin DN and password", func(t *testing.T) { + connection := &MockConnection{} + var actualUsername, actualPassword string + connection.BindProvider = func(username, password string) error { + actualUsername = username + actualPassword = password + return nil + } + + dn := "cn=admin,dc=grafana,dc=org" + + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindPassword: "pwd", + BindDN: dn, + }, + } + + err := server.AdminBind() + require.NoError(t, err) + + assert.Equal(t, dn, actualUsername) + assert.Equal(t, "pwd", actualPassword) + }) + + t.Run("error", func(t *testing.T) { + connection := &MockConnection{} + expected := &ldap.Error{ + ResultCode: uint16(25), + } + connection.BindProvider = func(username, password string) error { + return expected + } + + dn := "cn=admin,dc=grafana,dc=org" + + server := &Server{ + Connection: connection, + Config: &ServerConfig{ + BindPassword: "pwd", + BindDN: dn, + }, + log: log.New("test-logger"), + } + + err := server.AdminBind() + assert.ErrorIs(t, err, expected) }) } diff --git a/pkg/services/ldap/testing.go b/pkg/services/ldap/testing.go index 8bad83a2d92..cd9ff9184f4 100644 --- a/pkg/services/ldap/testing.go +++ b/pkg/services/ldap/testing.go @@ -6,10 +6,11 @@ import ( "gopkg.in/ldap.v3" ) +type searchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) + // MockConnection struct for testing type MockConnection struct { - SearchResult *ldap.SearchResult - SearchError error + SearchFunc searchFunc SearchCalled bool SearchAttributes []string @@ -56,11 +57,19 @@ func (c *MockConnection) Close() { } func (c *MockConnection) setSearchResult(result *ldap.SearchResult) { - c.SearchResult = result + c.SearchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + return result, nil + } } func (c *MockConnection) setSearchError(err error) { - c.SearchError = err + c.SearchFunc = func(request *ldap.SearchRequest) (*ldap.SearchResult, error) { + return nil, err + } +} + +func (c *MockConnection) setSearchFunc(fn searchFunc) { + c.SearchFunc = fn } // Search mocks Search connection function @@ -68,11 +77,7 @@ func (c *MockConnection) Search(sr *ldap.SearchRequest) (*ldap.SearchResult, err c.SearchCalled = true c.SearchAttributes = sr.Attributes - if c.SearchError != nil { - return nil, c.SearchError - } - - return c.SearchResult, nil + return c.SearchFunc(sr) } // Add mocks Add connection function