authelia/internal/authentication/ldap_user_provider.go
James Elliott cb71df5d9b
feat(authentiation): check ldap support for extended operations on startup (#2133)
* feat(authentiation): check ldap server on startup

This PR adds a startup check to the LDAP authentication backend. It additionally adds support for checking supportedExtension OIDs, currently only checking passwdModifyOID (1.3.6.1.4.1.4203.1.11.3). This can relatively easily be enhanced to add detection for other rootDSE capabilities like supportedControl and supportedCapabilities as necessary.

* test(authentication): add unit tests for new feature

* refactor(authentication): factorize ldap user provider newup

* refactor: minor adjustments
2021-07-02 09:16:16 +10:00

365 lines
10 KiB
Go

package authentication
import (
"crypto/tls"
"crypto/x509"
"fmt"
"strings"
"github.com/go-ldap/ldap/v3"
"github.com/sirupsen/logrus"
"golang.org/x/text/encoding/unicode"
"github.com/authelia/authelia/internal/configuration/schema"
"github.com/authelia/authelia/internal/logging"
"github.com/authelia/authelia/internal/utils"
)
// LDAPUserProvider is a provider using a LDAP or AD as a user database.
type LDAPUserProvider struct {
configuration schema.LDAPAuthenticationBackendConfiguration
tlsConfig *tls.Config
dialOpts ldap.DialOpt
logger *logrus.Logger
connectionFactory LDAPConnectionFactory
usersBaseDN string
groupsBaseDN string
supportExtensionPasswdModify bool
}
// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
func NewLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider, err error) {
provider = newLDAPUserProvider(configuration, certPool, nil)
err = provider.checkServer()
if err != nil {
return provider, err
}
if provider.supportExtensionPasswdModify {
provider.logger.Trace("LDAP Server does support passwdModifyOID Extension")
} else {
provider.logger.Trace("LDAP Server does not support passwdModifyOID Extension")
}
return provider, nil
}
func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
if configuration.TLS == nil {
configuration.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS
}
tlsConfig := utils.NewTLSConfig(configuration.TLS, tls.VersionTLS12, certPool)
var dialOpts ldap.DialOpt
if tlsConfig != nil {
dialOpts = ldap.DialWithTLSConfig(tlsConfig)
}
if factory == nil {
factory = NewLDAPConnectionFactoryImpl()
}
provider = &LDAPUserProvider{
configuration: configuration,
tlsConfig: tlsConfig,
dialOpts: dialOpts,
logger: logging.Logger(),
connectionFactory: factory,
}
provider.parseDynamicConfiguration()
return provider
}
func (p *LDAPUserProvider) parseDynamicConfiguration() {
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{username_attribute}", p.configuration.UsernameAttribute)
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{mail_attribute}", p.configuration.MailAttribute)
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{display_name_attribute}", p.configuration.DisplayNameAttribute)
p.logger.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter)
if p.configuration.AdditionalUsersDN != "" {
p.usersBaseDN = p.configuration.AdditionalUsersDN + "," + p.configuration.BaseDN
} else {
p.usersBaseDN = p.configuration.BaseDN
}
p.logger.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN)
if p.configuration.AdditionalGroupsDN != "" {
p.groupsBaseDN = ldap.EscapeFilter(p.configuration.AdditionalGroupsDN + "," + p.configuration.BaseDN)
} else {
p.groupsBaseDN = p.configuration.BaseDN
}
p.logger.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN)
}
func (p *LDAPUserProvider) checkServer() (err error) {
conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return err
}
searchRequest := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases,
1, 0, false, "(objectClass=*)", []string{ldapSupportedExtensionAttribute}, nil)
sr, err := conn.Search(searchRequest)
if err != nil {
return err
}
if len(sr.Entries) != 1 {
return nil
}
// Iterate the attribute values to see what the server supports.
for _, attr := range sr.Entries[0].Attributes {
if attr.Name == ldapSupportedExtensionAttribute {
p.logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
for _, oid := range attr.Values {
if oid == ldapOIDPasswdModifyExtension {
p.supportExtensionPasswdModify = true
break
}
}
break
}
}
return nil
}
func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnection, error) {
conn, err := p.connectionFactory.DialURL(p.configuration.URL, p.dialOpts)
if err != nil {
return nil, err
}
if p.configuration.StartTLS {
if err := conn.StartTLS(p.tlsConfig); err != nil {
return nil, err
}
}
if err := conn.Bind(userDN, password); err != nil {
return nil, err
}
return conn, nil
}
// CheckUserPassword checks if provided password matches for the given user.
func (p *LDAPUserProvider) CheckUserPassword(inputUsername string, password string) (bool, error) {
conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return false, err
}
defer conn.Close()
profile, err := p.getUserProfile(conn, inputUsername)
if err != nil {
return false, err
}
userConn, err := p.connect(profile.DN, password)
if err != nil {
return false, fmt.Errorf("Authentication of user %s failed. Cause: %s", inputUsername, err)
}
defer userConn.Close()
return true, nil
}
func (p *LDAPUserProvider) ldapEscape(inputUsername string) string {
inputUsername = ldap.EscapeFilter(inputUsername)
for _, c := range specialLDAPRunes {
inputUsername = strings.ReplaceAll(inputUsername, string(c), fmt.Sprintf("\\%c", c))
}
return inputUsername
}
type ldapUserProfile struct {
DN string
Emails []string
DisplayName string
Username string
}
func (p *LDAPUserProvider) resolveUsersFilter(userFilter string, inputUsername string) string {
inputUsername = p.ldapEscape(inputUsername)
// The {input} placeholder is replaced by the users username input.
userFilter = strings.ReplaceAll(userFilter, "{input}", inputUsername)
p.logger.Tracef("Computed user filter is %s", userFilter)
return userFilter
}
func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername string) (*ldapUserProfile, error) {
userFilter := p.resolveUsersFilter(p.configuration.UsersFilter, inputUsername)
attributes := []string{"dn",
p.configuration.DisplayNameAttribute,
p.configuration.MailAttribute,
p.configuration.UsernameAttribute}
// Search for the given username.
searchRequest := ldap.NewSearchRequest(
p.usersBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
1, 0, false, userFilter, attributes, nil,
)
sr, err := conn.Search(searchRequest)
if err != nil {
return nil, fmt.Errorf("Cannot find user DN of user %s. Cause: %s", inputUsername, err)
}
if len(sr.Entries) == 0 {
return nil, ErrUserNotFound
}
if len(sr.Entries) > 1 {
return nil, fmt.Errorf("Multiple users %s found", inputUsername)
}
userProfile := ldapUserProfile{
DN: sr.Entries[0].DN,
}
for _, attr := range sr.Entries[0].Attributes {
if attr.Name == p.configuration.DisplayNameAttribute {
userProfile.DisplayName = attr.Values[0]
}
if attr.Name == p.configuration.MailAttribute {
userProfile.Emails = attr.Values
}
if attr.Name == p.configuration.UsernameAttribute {
if len(attr.Values) != 1 {
return nil, fmt.Errorf("User %s cannot have multiple value for attribute %s",
inputUsername, p.configuration.UsernameAttribute)
}
userProfile.Username = attr.Values[0]
}
}
if userProfile.DN == "" {
return nil, fmt.Errorf("No DN has been found for user %s", inputUsername)
}
return &userProfile, nil
}
func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ldapUserProfile) (string, error) { //nolint:unparam
inputUsername = p.ldapEscape(inputUsername)
// The {input} placeholder is replaced by the users username input.
groupFilter := strings.ReplaceAll(p.configuration.GroupsFilter, "{input}", inputUsername)
if profile != nil {
groupFilter = strings.ReplaceAll(groupFilter, "{username}", ldap.EscapeFilter(profile.Username))
groupFilter = strings.ReplaceAll(groupFilter, "{dn}", ldap.EscapeFilter(profile.DN))
}
p.logger.Tracef("Computed groups filter is %s", groupFilter)
return groupFilter, nil
}
// GetDetails retrieve the groups a user belongs to.
func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error) {
conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return nil, err
}
defer conn.Close()
profile, err := p.getUserProfile(conn, inputUsername)
if err != nil {
return nil, err
}
groupsFilter, err := p.resolveGroupsFilter(inputUsername, profile)
if err != nil {
return nil, fmt.Errorf("Unable to create group filter for user %s. Cause: %s", inputUsername, err)
}
// Search for the given username.
searchGroupRequest := ldap.NewSearchRequest(
p.groupsBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases,
0, 0, false, groupsFilter, []string{p.configuration.GroupNameAttribute}, nil,
)
sr, err := conn.Search(searchGroupRequest)
if err != nil {
return nil, fmt.Errorf("Unable to retrieve groups of user %s. Cause: %s", inputUsername, err)
}
groups := make([]string, 0)
for _, res := range sr.Entries {
if len(res.Attributes) == 0 {
p.logger.Warningf("No groups retrieved from LDAP for user %s", inputUsername)
break
}
// Append all values of the document. Normally there should be only one per document.
groups = append(groups, res.Attributes[0].Values...)
}
return &UserDetails{
Username: profile.Username,
DisplayName: profile.DisplayName,
Emails: profile.Emails,
Groups: groups,
}, nil
}
// UpdatePassword update the password of the given user.
func (p *LDAPUserProvider) UpdatePassword(inputUsername string, newPassword string) error {
conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return fmt.Errorf("Unable to update password. Cause: %s", err)
}
defer conn.Close()
profile, err := p.getUserProfile(conn, inputUsername)
if err != nil {
return fmt.Errorf("Unable to update password. Cause: %s", err)
}
modifyRequest := ldap.NewModifyRequest(profile.DN, nil)
switch p.configuration.Implementation {
case schema.LDAPImplementationActiveDirectory:
utf16 := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)
// The password needs to be enclosed in quotes
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-adts/6e803168-f140-4d23-b2d3-c3a8ab5917d2
pwdEncoded, _ := utf16.NewEncoder().String(fmt.Sprintf("\"%s\"", newPassword))
modifyRequest.Replace("unicodePwd", []string{pwdEncoded})
default:
modifyRequest.Replace("userPassword", []string{newPassword})
}
err = conn.Modify(modifyRequest)
if err != nil {
return fmt.Errorf("Unable to update password. Cause: %s", err)
}
return nil
}