refactor: factorize startup checks (#2386)

* refactor: factorize startup checks

* refactor: address linting issues
This commit is contained in:
James Elliott 2021-09-17 19:53:59 +10:00 committed by GitHub
parent 8e4dc91b81
commit aed9099ce2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 205 additions and 106 deletions

View File

@ -9,6 +9,7 @@ import (
"sync"
"github.com/asaskevich/govalidator"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"github.com/authelia/authelia/v4/internal/configuration/schema"
@ -205,3 +206,8 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
return err
}
// StartupCheck implements the startup check provider interface.
func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) {
return nil
}

View File

@ -24,6 +24,8 @@ type LDAPUserProvider struct {
logger *logrus.Logger
connectionFactory LDAPConnectionFactory
disableResetPassword bool
// Automatically detected ldap features.
supportExtensionPasswdModify bool
@ -41,25 +43,13 @@ type LDAPUserProvider struct {
}
// NewLDAPUserProvider creates a new instance of LDAPUserProvider.
func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider, err error) {
provider = newLDAPUserProvider(*configuration.LDAP, certPool, nil)
func NewLDAPUserProvider(configuration schema.AuthenticationBackendConfiguration, certPool *x509.CertPool) (provider *LDAPUserProvider) {
provider = newLDAPUserProvider(*configuration.LDAP, configuration.DisableResetPassword, certPool, nil)
err = provider.checkServer()
if err != nil {
return provider, err
return provider
}
if !provider.supportExtensionPasswdModify && !configuration.DisableResetPassword &&
provider.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
provider.logger.Warnf("Your LDAP server implementation may not support a method for password hashing " +
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
"attribute when users reset their password via Authelia.")
}
return provider, nil
}
func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfiguration, disableResetPassword bool, certPool *x509.CertPool, factory LDAPConnectionFactory) (provider *LDAPUserProvider) {
if configuration.TLS == nil {
configuration.TLS = schema.DefaultLDAPAuthenticationBackendConfiguration.TLS
}
@ -84,6 +74,7 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura
dialOpts: dialOpts,
logger: logging.Logger(),
connectionFactory: factory,
disableResetPassword: disableResetPassword,
}
provider.parseDynamicUsersConfiguration()

View File

@ -4,9 +4,13 @@ import (
"strings"
"github.com/go-ldap/ldap/v3"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
func (p *LDAPUserProvider) checkServer() (err error) {
// StartupCheck implements the startup check provider interface.
func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return err
@ -29,7 +33,7 @@ func (p *LDAPUserProvider) checkServer() (err error) {
// Iterate the attribute values to see what the server supports.
for _, attr := range sr.Entries[0].Attributes {
if attr.Name == ldapSupportedExtensionAttribute {
p.logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
for _, oid := range attr.Values {
if oid == ldapOIDPasswdModifyExtension {
@ -42,6 +46,13 @@ func (p *LDAPUserProvider) checkServer() (err error) {
}
}
if !p.supportExtensionPasswdModify && !p.disableResetPassword &&
p.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
logger.Warn("Your LDAP server implementation may not support a method for password hashing " +
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
"attribute when users reset their password via Authelia.")
}
return nil
}

View File

@ -12,6 +12,7 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
)
@ -26,6 +27,7 @@ func TestShouldCreateRawConnectionWhenSchemeIsLDAP(t *testing.T) {
schema.LDAPAuthenticationBackendConfiguration{
URL: "ldap://127.0.0.1:389",
},
false,
nil,
mockFactory)
@ -55,6 +57,7 @@ func TestShouldCreateTLSConnectionWhenSchemeIsLDAPS(t *testing.T) {
schema.LDAPAuthenticationBackendConfiguration{
URL: "ldaps://127.0.0.1:389",
},
false,
nil,
mockFactory)
@ -83,6 +86,7 @@ func TestEscapeSpecialCharsFromUserInput(t *testing.T) {
schema.LDAPAuthenticationBackendConfiguration{
URL: "ldaps://127.0.0.1:389",
},
false,
nil,
mockFactory)
@ -115,6 +119,7 @@ func TestEscapeSpecialCharsInGroupsFilter(t *testing.T) {
URL: "ldaps://127.0.0.1:389",
GroupsFilter: "(|(member={dn})(uid={username})(uid={input}))",
},
false,
nil,
mockFactory)
@ -179,6 +184,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -210,7 +216,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
assert.NoError(t, err)
assert.True(t, ldapClient.supportExtensionPasswdModify)
@ -235,6 +241,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -266,7 +273,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
assert.NoError(t, err)
assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -291,6 +298,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -298,7 +306,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()).
Return(mockConn, errors.New("could not connect"))
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
assert.EqualError(t, err, "could not connect")
assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -323,6 +331,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -342,7 +351,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
assert.EqualError(t, err, "could not perform the search")
assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -384,6 +393,7 @@ func TestShouldEscapeUserInput(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -416,6 +426,7 @@ func TestShouldCombineUsernameFilterAndUsersFilter(t *testing.T) {
MailAttribute: "mail",
DisplayNameAttribute: "displayName",
},
false,
nil,
mockFactory)
@ -463,6 +474,7 @@ func TestShouldNotCrashWhenGroupsAreNotRetrievedFromLDAP(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -532,6 +544,7 @@ func TestShouldNotCrashWhenEmailsAreNotRetrievedFromLDAP(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -594,6 +607,7 @@ func TestShouldReturnUsernameFromLDAP(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -665,6 +679,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -740,7 +755,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password")
@ -767,6 +782,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -846,7 +862,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password")
@ -873,6 +889,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -949,7 +966,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.checkServer()
err := ldapClient.StartupCheck(logging.Logger())
require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password")
@ -975,6 +992,7 @@ func TestShouldCheckValidUserPassword(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -1042,6 +1060,7 @@ func TestShouldCheckInvalidUserPassword(t *testing.T) {
AdditionalUsersDN: "ou=users",
BaseDN: "dc=example,dc=com",
},
false,
nil,
mockFactory)
@ -1110,6 +1129,7 @@ func TestShouldCallStartTLSWhenEnabled(t *testing.T) {
BaseDN: "dc=example,dc=com",
StartTLS: true,
},
false,
nil,
mockFactory)
@ -1186,6 +1206,7 @@ func TestShouldParseDynamicConfiguration(t *testing.T) {
BaseDN: "dc=example,dc=com",
StartTLS: true,
},
false,
nil,
mockFactory)
@ -1224,6 +1245,7 @@ func TestShouldCallStartTLSWithInsecureSkipVerifyWhenSkipVerifyTrue(t *testing.T
SkipVerify: true,
},
},
false,
nil,
mockFactory)
@ -1306,6 +1328,7 @@ func TestShouldReturnLDAPSAlreadySecuredWhenStartTLSAttempted(t *testing.T) {
SkipVerify: true,
},
},
false,
nil,
mockFactory)

View File

@ -1,9 +1,14 @@
package authentication
import (
"github.com/sirupsen/logrus"
)
// UserProvider is the interface for checking user password and
// gathering user details.
type UserProvider interface {
CheckUserPassword(username string, password string) (bool, error)
GetDetails(username string) (*UserDetails, error)
UpdatePassword(username string, newPassword string) error
CheckUserPassword(username string, password string) (valid bool, err error)
GetDetails(username string) (details *UserDetails, err error)
UpdatePassword(username string, newPassword string) (err error)
StartupCheck(logger *logrus.Logger) (err error)
}

View File

@ -3,7 +3,9 @@ package commands
import (
"fmt"
"os"
"strings"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/authelia/authelia/v4/internal/authentication"
@ -78,13 +80,13 @@ func cmdRootRun(_ *cobra.Command, _ []string) {
logger.Fatalf("Errors occurred provisioning providers.")
}
doStartupChecks(config, &providers)
server.Start(*config, providers)
}
//nolint:gocyclo // TODO: Consider refactoring time permitting.
func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) {
logger := logging.Logger()
// TODO: Adjust this so the CertPool can be used like a provider.
autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory)
if len(warnings) != 0 || len(errors) != 0 {
return providers, warnings, errors
@ -100,6 +102,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
case config.Storage.Local != nil:
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
default:
// TODO: Add storage provider startup check and remove this.
errors = append(errors, fmt.Errorf("unrecognized storage provider"))
}
@ -112,12 +115,7 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
case config.AuthenticationBackend.File != nil:
userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
case config.AuthenticationBackend.LDAP != nil:
userProvider, err = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
if err != nil {
errors = append(errors, fmt.Errorf("failed to check LDAP authentication backend: %w", err))
}
default:
errors = append(errors, fmt.Errorf("unrecognized user provider"))
userProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
}
var notifier notification.Notifier
@ -127,14 +125,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, autheliaCertPool)
case config.Notifier.FileSystem != nil:
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
default:
errors = append(errors, fmt.Errorf("unrecognized notifier provider"))
}
if notifier != nil && !config.Notifier.DisableStartupCheck {
if _, err := notifier.StartupCheck(); err != nil {
errors = append(errors, fmt.Errorf("failed to check notification provider: %w", err))
}
}
var ntpProvider *ntp.Provider
@ -152,25 +142,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
errors = append(errors, err)
}
var failed bool
if !config.NTP.DisableStartupCheck && authorizer.IsSecondFactorEnabled() {
failed, err = ntpProvider.StartupCheck()
if err != nil {
logger.Errorf("Failed to check time against the NTP server: %+v", err)
}
if failed {
if config.NTP.DisableFailure {
logger.Error("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server, this may cause issues in validating TOTP secrets")
} else {
logger.Fatal("The system time is outside the maximum desynchronization when compared to the time reported by the NTP server")
}
} else {
logger.Debug("The system time is within the maximum desynchronization when compared to the time reported by the NTP server")
}
}
return middlewares.Providers{
Authorizer: authorizer,
UserProvider: userProvider,
@ -182,3 +153,53 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
SessionProvider: sessionProvider,
}, warnings, errors
}
func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) {
logger := logging.Logger()
var (
failures []string
err error
)
if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil {
logger.Errorf("Failure running the user provider startup check: %+v", err)
failures = append(failures, "user")
}
if err = doStartupCheck(logger, "notification", providers.Notifier, config.Notifier.DisableStartupCheck); err != nil {
logger.Errorf("Failure running the notification provider startup check: %+v", err)
failures = append(failures, "notification")
}
if !config.NTP.DisableStartupCheck && !providers.Authorizer.IsSecondFactorEnabled() {
logger.Debug("The NTP startup check was skipped due to there being no configured 2FA access control rules")
} else if err = doStartupCheck(logger, "ntp", providers.NTP, config.NTP.DisableStartupCheck); err != nil {
logger.Errorf("Failure running the user provider startup check: %+v", err)
failures = append(failures, "ntp")
}
if len(failures) != 0 {
logger.Fatalf("The following providers had fatal failures during startup: %s", strings.Join(failures, ", "))
}
}
func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) {
if disabled {
logger.Debugf("%s provider: startup check skipped as it is disabled", name)
return nil
}
if provider == nil {
return fmt.Errorf("unrecognized provider or it is not configured properly")
}
if err = provider.StartupCheck(logger); err != nil {
return err
}
return nil
}

View File

@ -28,6 +28,11 @@ type AutheliaCtx struct {
Clock utils.Clock
}
// ProviderWithStartupCheck represents a provider that has a startup check.
type ProviderWithStartupCheck interface {
StartupCheck(logger *logrus.Logger) (err error)
}
// Providers contain all provider provided to Authelia.
type Providers struct {
Authorizer *authorization.Authorizer

View File

@ -8,6 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
)
// MockNotifier is a mock of Notifier interface.
@ -48,16 +49,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go
}
// StartupCheck mocks base method.
func (m *MockNotifier) StartupCheck() (bool, error) {
func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
ret := m.ctrl.Call(m, "StartupCheck", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call {
func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck))
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0)
}

View File

@ -8,6 +8,7 @@ import (
"reflect"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication"
)
@ -78,3 +79,17 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) *
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1)
}
// StartupCheck mocks base method.
func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0)
}

View File

@ -7,6 +7,8 @@ import (
"path/filepath"
"time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
@ -22,28 +24,28 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File
}
}
// StartupCheck checks the file provider can write to the specified file.
func (n *FileNotifier) StartupCheck() (bool, error) {
// StartupCheck implements the startup check provider interface.
func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) {
dir := filepath.Dir(n.path)
if _, err := os.Stat(dir); err != nil {
if os.IsNotExist(err) {
if err = os.MkdirAll(dir, fileNotifierMode); err != nil {
return false, err
return err
}
} else {
return false, err
return err
}
} else if _, err = os.Stat(n.path); err != nil {
if !os.IsNotExist(err) {
return false, err
return err
}
}
if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil {
return false, err
return err
}
return true, nil
return nil
}
// Send send a identity verification link to a user.

View File

@ -1,7 +1,11 @@
package notification
import (
"github.com/sirupsen/logrus"
)
// Notifier interface for sending the identity verification link.
type Notifier interface {
Send(recipient, subject, body, htmlBody string) error
StartupCheck() (bool, error)
Send(recipient, subject, body, htmlBody string) (err error)
StartupCheck(logger *logrus.Logger) (err error)
}

View File

@ -10,6 +10,8 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
@ -220,39 +222,39 @@ func (n *SMTPNotifier) cleanup() {
}
}
// StartupCheck checks the server is functioning correctly and the configuration is correct.
func (n *SMTPNotifier) StartupCheck() (bool, error) {
// StartupCheck implements the startup check provider interface.
func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) {
if err := n.dial(); err != nil {
return false, err
return err
}
defer n.cleanup()
if err := n.client.Hello(n.configuration.Identifier); err != nil {
return false, err
return err
}
if err := n.startTLS(); err != nil {
return false, err
return err
}
if err := n.auth(); err != nil {
return false, err
return err
}
if err := n.client.Mail(n.configuration.Sender); err != nil {
return false, err
return err
}
if err := n.client.Rcpt(n.configuration.StartupCheckAddress); err != nil {
return false, err
return err
}
if err := n.client.Reset(); err != nil {
return false, err
return err
}
return true, nil
return nil
}
// Send is used to send an email to a recipient.

View File

@ -2,10 +2,12 @@ package ntp
import (
"encoding/binary"
"fmt"
"errors"
"net"
"time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils"
)
@ -15,17 +17,21 @@ func NewProvider(config *schema.NTPConfiguration) *Provider {
return &Provider{config}
}
// StartupCheck checks if the system clock is not out of sync.
func (p *Provider) StartupCheck() (failed bool, err error) {
// StartupCheck implements the startup check provider interface.
func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
conn, err := net.Dial("udp", p.config.Address)
if err != nil {
return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err)
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
return nil
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return false, fmt.Errorf("could not connect to NTP server to validate the time desync: %w", err)
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
return nil
}
version := ntpV4
@ -36,7 +42,9 @@ func (p *Provider) StartupCheck() (failed bool, err error) {
req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)}
if err := binary.Write(conn, binary.BigEndian, req); err != nil {
return false, fmt.Errorf("could not write to the NTP server socket to validate the time desync: %w", err)
logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
return nil
}
now := time.Now()
@ -44,12 +52,18 @@ func (p *Provider) StartupCheck() (failed bool, err error) {
resp := &ntpPacket{}
if err := binary.Read(conn, binary.BigEndian, resp); err != nil {
return false, fmt.Errorf("could not read from the NTP server socket to validate the time desync: %w", err)
logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
return nil
}
maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync)
ntpTime := ntpPacketToTime(resp)
return ntpIsOffsetTooLarge(maxOffset, now, ntpTime), nil
if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result {
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
}
return nil
}

View File

@ -7,6 +7,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator"
"github.com/authelia/authelia/v4/internal/logging"
)
func TestShouldCheckNTP(t *testing.T) {
@ -19,8 +20,7 @@ func TestShouldCheckNTP(t *testing.T) {
sv := schema.NewStructValidator()
validator.ValidateNTP(&config, sv)
NTP := NewProvider(&config)
ntp := NewProvider(&config)
checkfailed, _ := NTP.StartupCheck()
assert.Equal(t, false, checkfailed)
assert.NoError(t, ntp.StartupCheck(logging.Logger()))
}