mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
aed9099ce2
* refactor: factorize startup checks * refactor: address linting issues
206 lines
6.0 KiB
Go
206 lines
6.0 KiB
Go
package commands
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/spf13/cobra"
|
|
|
|
"github.com/authelia/authelia/v4/internal/authentication"
|
|
"github.com/authelia/authelia/v4/internal/authorization"
|
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
|
"github.com/authelia/authelia/v4/internal/logging"
|
|
"github.com/authelia/authelia/v4/internal/middlewares"
|
|
"github.com/authelia/authelia/v4/internal/notification"
|
|
"github.com/authelia/authelia/v4/internal/ntp"
|
|
"github.com/authelia/authelia/v4/internal/oidc"
|
|
"github.com/authelia/authelia/v4/internal/regulation"
|
|
"github.com/authelia/authelia/v4/internal/server"
|
|
"github.com/authelia/authelia/v4/internal/session"
|
|
"github.com/authelia/authelia/v4/internal/storage"
|
|
"github.com/authelia/authelia/v4/internal/utils"
|
|
)
|
|
|
|
// NewRootCmd returns a new Root Cmd.
|
|
func NewRootCmd() (cmd *cobra.Command) {
|
|
version := utils.Version()
|
|
|
|
cmd = &cobra.Command{
|
|
Use: "authelia",
|
|
Example: cmdAutheliaExample,
|
|
Short: fmt.Sprintf("authelia %s", version),
|
|
Long: fmt.Sprintf(fmtAutheliaLong, version),
|
|
Version: version,
|
|
Args: cobra.NoArgs,
|
|
PreRun: newCmdWithConfigPreRun(true, true, true),
|
|
Run: cmdRootRun,
|
|
}
|
|
|
|
cmdWithConfigFlags(cmd)
|
|
|
|
cmd.AddCommand(
|
|
newBuildInfoCmd(),
|
|
NewCertificatesCmd(),
|
|
newCompletionCmd(),
|
|
NewHashPasswordCmd(),
|
|
NewRSACmd(),
|
|
newValidateConfigCmd(),
|
|
)
|
|
|
|
return cmd
|
|
}
|
|
|
|
func cmdRootRun(_ *cobra.Command, _ []string) {
|
|
logger := logging.Logger()
|
|
|
|
logger.Infof("Authelia %s is starting", utils.Version())
|
|
|
|
if os.Getenv("ENVIRONMENT") == "dev" {
|
|
logger.Info("===> Authelia is running in development mode. <===")
|
|
}
|
|
|
|
if err := logging.InitializeLogger(config.Log, true); err != nil {
|
|
logger.Fatalf("Cannot initialize logger: %v", err)
|
|
}
|
|
|
|
providers, warnings, errors := getProviders(config)
|
|
if len(warnings) != 0 {
|
|
for _, err := range warnings {
|
|
logger.Warn(err)
|
|
}
|
|
}
|
|
|
|
if len(errors) != 0 {
|
|
for _, err := range errors {
|
|
logger.Error(err)
|
|
}
|
|
|
|
logger.Fatalf("Errors occurred provisioning providers.")
|
|
}
|
|
|
|
doStartupChecks(config, &providers)
|
|
|
|
server.Start(*config, providers)
|
|
}
|
|
|
|
func getProviders(config *schema.Configuration) (providers middlewares.Providers, warnings []error, errors []error) {
|
|
// TODO: Adjust this so the CertPool can be used like a provider.
|
|
autheliaCertPool, warnings, errors := utils.NewX509CertPool(config.CertificatesDirectory)
|
|
if len(warnings) != 0 || len(errors) != 0 {
|
|
return providers, warnings, errors
|
|
}
|
|
|
|
var storageProvider storage.Provider
|
|
|
|
switch {
|
|
case config.Storage.PostgreSQL != nil:
|
|
storageProvider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL)
|
|
case config.Storage.MySQL != nil:
|
|
storageProvider = storage.NewMySQLProvider(*config.Storage.MySQL)
|
|
case config.Storage.Local != nil:
|
|
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
|
|
default:
|
|
// TODO: Add storage provider startup check and remove this.
|
|
errors = append(errors, fmt.Errorf("unrecognized storage provider"))
|
|
}
|
|
|
|
var (
|
|
userProvider authentication.UserProvider
|
|
err error
|
|
)
|
|
|
|
switch {
|
|
case config.AuthenticationBackend.File != nil:
|
|
userProvider = authentication.NewFileUserProvider(config.AuthenticationBackend.File)
|
|
case config.AuthenticationBackend.LDAP != nil:
|
|
userProvider = authentication.NewLDAPUserProvider(config.AuthenticationBackend, autheliaCertPool)
|
|
}
|
|
|
|
var notifier notification.Notifier
|
|
|
|
switch {
|
|
case config.Notifier.SMTP != nil:
|
|
notifier = notification.NewSMTPNotifier(config.Notifier.SMTP, autheliaCertPool)
|
|
case config.Notifier.FileSystem != nil:
|
|
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
|
|
}
|
|
|
|
var ntpProvider *ntp.Provider
|
|
if config.NTP != nil {
|
|
ntpProvider = ntp.NewProvider(config.NTP)
|
|
}
|
|
|
|
clock := utils.RealClock{}
|
|
authorizer := authorization.NewAuthorizer(config)
|
|
sessionProvider := session.NewProvider(config.Session, autheliaCertPool)
|
|
regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock)
|
|
|
|
oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
|
|
return middlewares.Providers{
|
|
Authorizer: authorizer,
|
|
UserProvider: userProvider,
|
|
Regulator: regulator,
|
|
OpenIDConnect: oidcProvider,
|
|
StorageProvider: storageProvider,
|
|
NTP: ntpProvider,
|
|
Notifier: notifier,
|
|
SessionProvider: sessionProvider,
|
|
}, warnings, errors
|
|
}
|
|
|
|
func doStartupChecks(config *schema.Configuration, providers *middlewares.Providers) {
|
|
logger := logging.Logger()
|
|
|
|
var (
|
|
failures []string
|
|
err error
|
|
)
|
|
|
|
if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil {
|
|
logger.Errorf("Failure running the user provider startup check: %+v", err)
|
|
|
|
failures = append(failures, "user")
|
|
}
|
|
|
|
if err = doStartupCheck(logger, "notification", providers.Notifier, config.Notifier.DisableStartupCheck); err != nil {
|
|
logger.Errorf("Failure running the notification provider startup check: %+v", err)
|
|
|
|
failures = append(failures, "notification")
|
|
}
|
|
|
|
if !config.NTP.DisableStartupCheck && !providers.Authorizer.IsSecondFactorEnabled() {
|
|
logger.Debug("The NTP startup check was skipped due to there being no configured 2FA access control rules")
|
|
} else if err = doStartupCheck(logger, "ntp", providers.NTP, config.NTP.DisableStartupCheck); err != nil {
|
|
logger.Errorf("Failure running the user provider startup check: %+v", err)
|
|
|
|
failures = append(failures, "ntp")
|
|
}
|
|
|
|
if len(failures) != 0 {
|
|
logger.Fatalf("The following providers had fatal failures during startup: %s", strings.Join(failures, ", "))
|
|
}
|
|
}
|
|
|
|
func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) {
|
|
if disabled {
|
|
logger.Debugf("%s provider: startup check skipped as it is disabled", name)
|
|
return nil
|
|
}
|
|
|
|
if provider == nil {
|
|
return fmt.Errorf("unrecognized provider or it is not configured properly")
|
|
}
|
|
|
|
if err = provider.StartupCheck(logger); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|