[MISC] Add missing CLI suite test (#1607)

* [MISC] Add missing CLI suite test

* Add missing test for `authelia version` command in CLI suite.
* Standardise logger calls and swap CSP switch order
This commit is contained in:
Amir Zarrinkafsh 2021-01-17 10:23:35 +11:00 committed by GitHub
parent 8bab8d47ef
commit 296efe2b32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 85 additions and 59 deletions

View File

@ -25,11 +25,12 @@ var configPathFlag string
//nolint:gocyclo // TODO: Consider refactoring/simplifying, time permitting. //nolint:gocyclo // TODO: Consider refactoring/simplifying, time permitting.
func startServer() { func startServer() {
logger := logging.Logger()
config, errs := configuration.Read(configPathFlag) config, errs := configuration.Read(configPathFlag)
if len(errs) > 0 { if len(errs) > 0 {
for _, err := range errs { for _, err := range errs {
logging.Logger().Error(err) logger.Error(err)
} }
os.Exit(1) os.Exit(1)
@ -38,7 +39,7 @@ func startServer() {
autheliaCertPool, errs, nonFatalErrs := utils.NewX509CertPool(config.CertificatesDirectory, config) autheliaCertPool, errs, nonFatalErrs := utils.NewX509CertPool(config.CertificatesDirectory, config)
if len(errs) > 0 { if len(errs) > 0 {
for _, err := range errs { for _, err := range errs {
logging.Logger().Error(err) logger.Error(err)
} }
os.Exit(2) os.Exit(2)
@ -46,28 +47,28 @@ func startServer() {
if len(nonFatalErrs) > 0 { if len(nonFatalErrs) > 0 {
for _, err := range nonFatalErrs { for _, err := range nonFatalErrs {
logging.Logger().Warn(err) logger.Warn(err)
} }
} }
if err := logging.InitializeLogger(config.LogFormat, config.LogFilePath); err != nil { if err := logging.InitializeLogger(config.LogFormat, config.LogFilePath); err != nil {
logging.Logger().Fatalf("Cannot initialize logger: %v", err) logger.Fatalf("Cannot initialize logger: %v", err)
} }
switch config.LogLevel { switch config.LogLevel {
case "info": case "info":
logging.Logger().Info("Logging severity set to info") logger.Info("Logging severity set to info")
logging.SetLevel(logrus.InfoLevel) logging.SetLevel(logrus.InfoLevel)
case "debug": case "debug":
logging.Logger().Info("Logging severity set to debug") logger.Info("Logging severity set to debug")
logging.SetLevel(logrus.DebugLevel) logging.SetLevel(logrus.DebugLevel)
case "trace": case "trace":
logging.Logger().Info("Logging severity set to trace") logger.Info("Logging severity set to trace")
logging.SetLevel(logrus.TraceLevel) logging.SetLevel(logrus.TraceLevel)
} }
if os.Getenv("ENVIRONMENT") == "dev" { if os.Getenv("ENVIRONMENT") == "dev" {
logging.Logger().Info("===> Authelia is running in development mode. <===") logger.Info("===> Authelia is running in development mode. <===")
} }
var storageProvider storage.Provider var storageProvider storage.Provider
@ -80,7 +81,7 @@ func startServer() {
case config.Storage.Local != nil: case config.Storage.Local != nil:
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path) storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
default: default:
logging.Logger().Fatalf("Unrecognized storage backend") logger.Fatalf("Unrecognized storage backend")
} }
var userProvider authentication.UserProvider var userProvider authentication.UserProvider
@ -91,7 +92,7 @@ func startServer() {
case config.AuthenticationBackend.Ldap != nil: case config.AuthenticationBackend.Ldap != nil:
userProvider = authentication.NewLDAPUserProvider(*config.AuthenticationBackend.Ldap, autheliaCertPool) userProvider = authentication.NewLDAPUserProvider(*config.AuthenticationBackend.Ldap, autheliaCertPool)
default: default:
logging.Logger().Fatalf("Unrecognized authentication backend") logger.Fatalf("Unrecognized authentication backend")
} }
var notifier notification.Notifier var notifier notification.Notifier
@ -102,13 +103,13 @@ func startServer() {
case config.Notifier.FileSystem != nil: case config.Notifier.FileSystem != nil:
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
default: default:
logging.Logger().Fatalf("Unrecognized notifier") logger.Fatalf("Unrecognized notifier")
} }
if !config.Notifier.DisableStartupCheck { if !config.Notifier.DisableStartupCheck {
_, err := notifier.StartupCheck() _, err := notifier.StartupCheck()
if err != nil { if err != nil {
logging.Logger().Fatalf("Error during notifier startup check: %s", err) logger.Fatalf("Error during notifier startup check: %s", err)
} }
} }
@ -129,6 +130,7 @@ func startServer() {
} }
func main() { func main() {
logger := logging.Logger()
rootCmd := &cobra.Command{ rootCmd := &cobra.Command{
Use: "authelia", Use: "authelia",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
@ -150,6 +152,6 @@ func main() {
commands.ValidateConfigCmd, commands.CertificatesCmd) commands.ValidateConfigCmd, commands.CertificatesCmd)
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
logging.Logger().Fatal(err) logger.Fatal(err)
} }
} }

View File

@ -36,10 +36,12 @@ type DatabaseModel struct {
// NewFileUserProvider creates a new instance of FileUserProvider. // NewFileUserProvider creates a new instance of FileUserProvider.
func NewFileUserProvider(configuration *schema.FileAuthenticationBackendConfiguration) *FileUserProvider { func NewFileUserProvider(configuration *schema.FileAuthenticationBackendConfiguration) *FileUserProvider {
logger := logging.Logger()
errs := checkDatabase(configuration.Path) errs := checkDatabase(configuration.Path)
if errs != nil { if errs != nil {
for _, err := range errs { for _, err := range errs {
logging.Logger().Error(err) logger.Error(err)
} }
os.Exit(1) os.Exit(1)

View File

@ -166,8 +166,9 @@ func (p *LDAPUserProvider) resolveUsersFilter(userFilter string, inputUsername s
} }
func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername string) (*ldapUserProfile, error) { func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername string) (*ldapUserProfile, error) {
logger := logging.Logger()
userFilter := p.resolveUsersFilter(p.configuration.UsersFilter, inputUsername) userFilter := p.resolveUsersFilter(p.configuration.UsersFilter, inputUsername)
logging.Logger().Tracef("Computed user filter is %s", userFilter) logger.Tracef("Computed user filter is %s", userFilter)
attributes := []string{"dn", attributes := []string{"dn",
p.configuration.DisplayNameAttribute, p.configuration.DisplayNameAttribute,
@ -239,6 +240,8 @@ func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ld
// GetDetails retrieve the groups a user belongs to. // GetDetails retrieve the groups a user belongs to.
func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error) { func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error) {
logger := logging.Logger()
conn, err := p.connect(p.configuration.User, p.configuration.Password) conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil { if err != nil {
return nil, err return nil, err
@ -255,7 +258,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error
return nil, fmt.Errorf("Unable to create group filter for user %s. Cause: %s", inputUsername, err) return nil, fmt.Errorf("Unable to create group filter for user %s. Cause: %s", inputUsername, err)
} }
logging.Logger().Tracef("Computed groups filter is %s", groupsFilter) logger.Tracef("Computed groups filter is %s", groupsFilter)
// Search for the given username. // Search for the given username.
searchGroupRequest := ldap.NewSearchRequest( searchGroupRequest := ldap.NewSearchRequest(
@ -273,7 +276,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error
for _, res := range sr.Entries { for _, res := range sr.Entries {
if len(res.Attributes) == 0 { if len(res.Attributes) == 0 {
logging.Logger().Warningf("No groups retrieved from LDAP for user %s", inputUsername) logger.Warningf("No groups retrieved from LDAP for user %s", inputUsername)
break break
} }
// Append all values of the document. Normally there should be only one per document. // Append all values of the document. Normally there should be only one per document.

View File

@ -114,8 +114,8 @@ func (p *Authorizer) IsSecondFactorEnabled() bool {
// GetRequiredLevel retrieve the required level of authorization to access the object. // GetRequiredLevel retrieve the required level of authorization to access the object.
func (p *Authorizer) GetRequiredLevel(subject Subject, requestURL url.URL) Level { func (p *Authorizer) GetRequiredLevel(subject Subject, requestURL url.URL) Level {
logging.Logger().Tracef("Check authorization of subject %s and url %s.", logger := logging.Logger()
subject.String(), requestURL.String()) logger.Tracef("Check authorization of subject %s and url %s.", subject.String(), requestURL.String())
matchingRules := selectMatchingRules(p.configuration.Rules, p.configuration.Networks, subject, Object{ matchingRules := selectMatchingRules(p.configuration.Rules, p.configuration.Networks, subject, Object{
Domain: requestURL.Hostname(), Domain: requestURL.Hostname(),
@ -126,8 +126,7 @@ func (p *Authorizer) GetRequiredLevel(subject Subject, requestURL url.URL) Level
return PolicyToLevel(matchingRules[0].Policy) return PolicyToLevel(matchingRules[0].Policy)
} }
logging.Logger().Tracef("No matching rule for subject %s and url %s... Applying default policy.", logger.Tracef("No matching rule for subject %s and url %s... Applying default policy.", subject.String(), requestURL.String())
subject.String(), requestURL.String())
return PolicyToLevel(p.configuration.DefaultPolicy) return PolicyToLevel(p.configuration.DefaultPolicy)
} }

View File

@ -18,6 +18,8 @@ import (
// Read a YAML configuration and create a Configuration object out of it. // Read a YAML configuration and create a Configuration object out of it.
//go:generate broccoli -src ../../config.template.yml -var=cfg -o configuration //go:generate broccoli -src ../../config.template.yml -var=cfg -o configuration
func Read(configPath string) (*schema.Configuration, []error) { func Read(configPath string) (*schema.Configuration, []error) {
logger := logging.Logger()
if configPath == "" { if configPath == "" {
return nil, []error{errors.New("No config file path provided")} return nil, []error{errors.New("No config file path provided")}
} }
@ -81,7 +83,7 @@ func Read(configPath string) (*schema.Configuration, []error) {
if val.HasWarnings() { if val.HasWarnings() {
for _, warn := range val.Warnings() { for _, warn := range val.Warnings() {
logging.Logger().Warnf(warn.Error()) logger.Warnf(warn.Error())
} }
} }

View File

@ -51,25 +51,26 @@ func NewSMTPNotifier(configuration schema.SMTPNotifierConfiguration, certPool *x
// Do startTLS if available (some servers only provide the auth extension after, and encryption is preferred). // Do startTLS if available (some servers only provide the auth extension after, and encryption is preferred).
func (n *SMTPNotifier) startTLS() error { func (n *SMTPNotifier) startTLS() error {
logger := logging.Logger()
// Only start if not already encrypted // Only start if not already encrypted
if _, ok := n.client.TLSConnectionState(); ok { if _, ok := n.client.TLSConnectionState(); ok {
logging.Logger().Debugf("Notifier SMTP connection is already encrypted, skipping STARTTLS") logger.Debugf("Notifier SMTP connection is already encrypted, skipping STARTTLS")
return nil return nil
} }
switch ok, _ := n.client.Extension("STARTTLS"); ok { switch ok, _ := n.client.Extension("STARTTLS"); ok {
case true: case true:
logging.Logger().Debugf("Notifier SMTP server supports STARTTLS (disableVerifyCert: %t, ServerName: %s), attempting", n.tlsConfig.InsecureSkipVerify, n.tlsConfig.ServerName) logger.Debugf("Notifier SMTP server supports STARTTLS (disableVerifyCert: %t, ServerName: %s), attempting", n.tlsConfig.InsecureSkipVerify, n.tlsConfig.ServerName)
if err := n.client.StartTLS(n.tlsConfig); err != nil { if err := n.client.StartTLS(n.tlsConfig); err != nil {
return err return err
} }
logging.Logger().Debug("Notifier SMTP STARTTLS completed without error") logger.Debug("Notifier SMTP STARTTLS completed without error")
default: default:
switch n.disableRequireTLS { switch n.disableRequireTLS {
case true: case true:
logging.Logger().Warn("Notifier SMTP server does not support STARTTLS and SMTP configuration is set to disable the TLS requirement (only useful for unauthenticated emails over plain text)") logger.Warn("Notifier SMTP server does not support STARTTLS and SMTP configuration is set to disable the TLS requirement (only useful for unauthenticated emails over plain text)")
default: default:
return errors.New("Notifier SMTP server does not support TLS and it is required by default (see documentation if you want to disable this highly recommended requirement)") return errors.New("Notifier SMTP server does not support TLS and it is required by default (see documentation if you want to disable this highly recommended requirement)")
} }
@ -80,6 +81,7 @@ func (n *SMTPNotifier) startTLS() error {
// Attempt Authentication. // Attempt Authentication.
func (n *SMTPNotifier) auth() error { func (n *SMTPNotifier) auth() error {
logger := logging.Logger()
// Attempt AUTH if password is specified only. // Attempt AUTH if password is specified only.
if n.password != "" { if n.password != "" {
_, ok := n.client.TLSConnectionState() _, ok := n.client.TLSConnectionState()
@ -92,18 +94,18 @@ func (n *SMTPNotifier) auth() error {
if ok { if ok {
var auth smtp.Auth var auth smtp.Auth
logging.Logger().Debugf("Notifier SMTP server supports authentication with the following mechanisms: %s", m) logger.Debugf("Notifier SMTP server supports authentication with the following mechanisms: %s", m)
mechanisms := strings.Split(m, " ") mechanisms := strings.Split(m, " ")
// Adaptively select the AUTH mechanism to use based on what the server advertised. // Adaptively select the AUTH mechanism to use based on what the server advertised.
if utils.IsStringInSlice("PLAIN", mechanisms) { if utils.IsStringInSlice("PLAIN", mechanisms) {
auth = smtp.PlainAuth("", n.username, n.password, n.host) auth = smtp.PlainAuth("", n.username, n.password, n.host)
logging.Logger().Debug("Notifier SMTP client attempting AUTH PLAIN with server") logger.Debug("Notifier SMTP client attempting AUTH PLAIN with server")
} else if utils.IsStringInSlice("LOGIN", mechanisms) { } else if utils.IsStringInSlice("LOGIN", mechanisms) {
auth = newLoginAuth(n.username, n.password, n.host) auth = newLoginAuth(n.username, n.password, n.host)
logging.Logger().Debug("Notifier SMTP client attempting AUTH LOGIN with server") logger.Debug("Notifier SMTP client attempting AUTH LOGIN with server")
} }
// Throw error since AUTH extension is not supported. // Throw error since AUTH extension is not supported.
@ -116,7 +118,7 @@ func (n *SMTPNotifier) auth() error {
return err return err
} }
logging.Logger().Debug("Notifier SMTP client authenticated successfully with the server") logger.Debug("Notifier SMTP client authenticated successfully with the server")
return nil return nil
} }
@ -124,13 +126,14 @@ func (n *SMTPNotifier) auth() error {
return errors.New("Notifier SMTP server does not advertise the AUTH extension but config requires AUTH (password specified), either disable AUTH, or use an SMTP host that supports AUTH PLAIN or AUTH LOGIN") return errors.New("Notifier SMTP server does not advertise the AUTH extension but config requires AUTH (password specified), either disable AUTH, or use an SMTP host that supports AUTH PLAIN or AUTH LOGIN")
} }
logging.Logger().Debug("Notifier SMTP config has no password specified so authentication is being skipped") logger.Debug("Notifier SMTP config has no password specified so authentication is being skipped")
return nil return nil
} }
func (n *SMTPNotifier) compose(recipient, subject, body, htmlBody string) error { func (n *SMTPNotifier) compose(recipient, subject, body, htmlBody string) error {
logging.Logger().Debugf("Notifier SMTP client attempting to send email body to %s", recipient) logger := logging.Logger()
logger.Debugf("Notifier SMTP client attempting to send email body to %s", recipient)
if !n.disableRequireTLS { if !n.disableRequireTLS {
_, ok := n.client.TLSConnectionState() _, ok := n.client.TLSConnectionState()
@ -141,7 +144,7 @@ func (n *SMTPNotifier) compose(recipient, subject, body, htmlBody string) error
wc, err := n.client.Data() wc, err := n.client.Data()
if err != nil { if err != nil {
logging.Logger().Debugf("Notifier SMTP client error while obtaining WriteCloser: %s", err) logger.Debugf("Notifier SMTP client error while obtaining WriteCloser: %s", err)
return err return err
} }
@ -171,13 +174,13 @@ func (n *SMTPNotifier) compose(recipient, subject, body, htmlBody string) error
_, err = fmt.Fprint(wc, msg) _, err = fmt.Fprint(wc, msg)
if err != nil { if err != nil {
logging.Logger().Debugf("Notifier SMTP client error while sending email body over WriteCloser: %s", err) logger.Debugf("Notifier SMTP client error while sending email body over WriteCloser: %s", err)
return err return err
} }
err = wc.Close() err = wc.Close()
if err != nil { if err != nil {
logging.Logger().Debugf("Notifier SMTP client error while closing the WriteCloser: %s", err) logger.Debugf("Notifier SMTP client error while closing the WriteCloser: %s", err)
return err return err
} }
@ -186,10 +189,11 @@ func (n *SMTPNotifier) compose(recipient, subject, body, htmlBody string) error
// Dial the SMTP server with the SMTPNotifier config. // Dial the SMTP server with the SMTPNotifier config.
func (n *SMTPNotifier) dial() error { func (n *SMTPNotifier) dial() error {
logging.Logger().Debugf("Notifier SMTP client attempting connection to %s", n.address) logger := logging.Logger()
logger.Debugf("Notifier SMTP client attempting connection to %s", n.address)
if n.port == 465 { if n.port == 465 {
logging.Logger().Warnf("Notifier SMTP client configured to connect to a SMTPS server. It's highly recommended you use a non SMTPS port and STARTTLS instead of SMTPS, as the protocol is long deprecated.") logger.Warnf("Notifier SMTP client configured to connect to a SMTPS server. It's highly recommended you use a non SMTPS port and STARTTLS instead of SMTPS, as the protocol is long deprecated.")
conn, err := tls.Dial("tcp", n.address, n.tlsConfig) conn, err := tls.Dial("tcp", n.address, n.tlsConfig)
if err != nil { if err != nil {
@ -211,16 +215,18 @@ func (n *SMTPNotifier) dial() error {
n.client = client n.client = client
} }
logging.Logger().Debug("Notifier SMTP client connected successfully") logger.Debug("Notifier SMTP client connected successfully")
return nil return nil
} }
// Closes the connection properly. // Closes the connection properly.
func (n *SMTPNotifier) cleanup() { func (n *SMTPNotifier) cleanup() {
logger := logging.Logger()
err := n.client.Quit() err := n.client.Quit()
if err != nil { if err != nil {
logging.Logger().Warnf("Notifier SMTP client encountered error during cleanup: %s", err) logger.Warnf("Notifier SMTP client encountered error during cleanup: %s", err)
} }
} }
@ -261,6 +267,7 @@ func (n *SMTPNotifier) StartupCheck() (bool, error) {
// Send is used to send an email to a recipient. // Send is used to send an email to a recipient.
func (n *SMTPNotifier) Send(recipient, title, body, htmlBody string) error { func (n *SMTPNotifier) Send(recipient, title, body, htmlBody string) error {
logger := logging.Logger()
subject := strings.ReplaceAll(n.subject, "{title}", title) subject := strings.ReplaceAll(n.subject, "{title}", title)
if err := n.dial(); err != nil { if err := n.dial(); err != nil {
@ -285,12 +292,12 @@ func (n *SMTPNotifier) Send(recipient, title, body, htmlBody string) error {
// Set the sender and recipient first. // Set the sender and recipient first.
if err := n.client.Mail(n.sender); err != nil { if err := n.client.Mail(n.sender); err != nil {
logging.Logger().Debugf("Notifier SMTP failed while sending MAIL FROM (using sender) with error: %s", err) logger.Debugf("Notifier SMTP failed while sending MAIL FROM (using sender) with error: %s", err)
return err return err
} }
if err := n.client.Rcpt(recipient); err != nil { if err := n.client.Rcpt(recipient); err != nil {
logging.Logger().Debugf("Notifier SMTP failed while sending RCPT TO (using recipient) with error: %s", err) logger.Debugf("Notifier SMTP failed while sending RCPT TO (using recipient) with error: %s", err)
return err return err
} }
@ -299,7 +306,7 @@ func (n *SMTPNotifier) Send(recipient, title, body, htmlBody string) error {
return err return err
} }
logging.Logger().Debug("Notifier SMTP client successfully sent email") logger.Debug("Notifier SMTP client successfully sent email")
return nil return nil
} }

View File

@ -10,17 +10,19 @@ import (
// Replacement for the default error handler in fasthttp. // Replacement for the default error handler in fasthttp.
func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) { func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) {
logger := logging.Logger()
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok { if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
// Note: Getting X-Forwarded-For or Request URI is impossible for ths error. // Note: Getting X-Forwarded-For or Request URI is impossible for ths error.
logging.Logger().Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge) logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
ctx.Error("Request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge) ctx.Error("Request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
// TODO: Add X-Forwarded-For Check here. // TODO: Add X-Forwarded-For Check here.
logging.Logger().Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout) logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
ctx.Error("Request timeout", fasthttp.StatusRequestTimeout) ctx.Error("Request timeout", fasthttp.StatusRequestTimeout)
} else { } else {
// TODO: Add X-Forwarded-For Check here. // TODO: Add X-Forwarded-For Check here.
logging.Logger().Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest) logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
ctx.Error("Error when parsing request", fasthttp.StatusBadRequest) ctx.Error("Error when parsing request", fasthttp.StatusBadRequest)
} }
} }

View File

@ -24,6 +24,7 @@ import (
// StartServer start Authelia server with the given configuration and providers. // StartServer start Authelia server with the given configuration and providers.
func StartServer(configuration schema.Configuration, providers middlewares.Providers) { func StartServer(configuration schema.Configuration, providers middlewares.Providers) {
logger := logging.Logger()
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers) autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
embeddedAssets := "/public_html/" embeddedAssets := "/public_html/"
swaggerAssets := embeddedAssets + "api/" swaggerAssets := embeddedAssets + "api/"
@ -147,30 +148,30 @@ func StartServer(configuration schema.Configuration, providers middlewares.Provi
listener, err := net.Listen("tcp", addrPattern) listener, err := net.Listen("tcp", addrPattern)
if err != nil { if err != nil {
logging.Logger().Fatalf("Error initializing listener: %s", err) logger.Fatalf("Error initializing listener: %s", err)
} }
if configuration.AuthenticationBackend.File != nil && configuration.AuthenticationBackend.File.Password.Algorithm == "argon2id" && runtime.GOOS == "linux" { if configuration.AuthenticationBackend.File != nil && configuration.AuthenticationBackend.File.Password.Algorithm == "argon2id" && runtime.GOOS == "linux" {
f, err := ioutil.ReadFile("/sys/fs/cgroup/memory/memory.limit_in_bytes") f, err := ioutil.ReadFile("/sys/fs/cgroup/memory/memory.limit_in_bytes")
if err != nil { if err != nil {
logging.Logger().Warnf("Error reading hosts memory limit: %s", err) logger.Warnf("Error reading hosts memory limit: %s", err)
} else { } else {
m, _ := strconv.Atoi(strings.TrimSuffix(string(f), "\n")) m, _ := strconv.Atoi(strings.TrimSuffix(string(f), "\n"))
hostMem := float64(m) / 1024 / 1024 / 1024 hostMem := float64(m) / 1024 / 1024 / 1024
argonMem := float64(configuration.AuthenticationBackend.File.Password.Memory) / 1024 argonMem := float64(configuration.AuthenticationBackend.File.Password.Memory) / 1024
if hostMem/argonMem <= 2 { if hostMem/argonMem <= 2 {
logging.Logger().Warnf("Authelia's password hashing memory parameter is set to: %gGB this is %g%% of the available memory: %gGB", argonMem, argonMem/hostMem*100, hostMem) logger.Warnf("Authelia's password hashing memory parameter is set to: %gGB this is %g%% of the available memory: %gGB", argonMem, argonMem/hostMem*100, hostMem)
logging.Logger().Warn("Please read https://www.authelia.com/docs/configuration/authentication/file.html#memory and tune your deployment") logger.Warn("Please read https://www.authelia.com/docs/configuration/authentication/file.html#memory and tune your deployment")
} }
} }
} }
if configuration.TLSCert != "" && configuration.TLSKey != "" { if configuration.TLSCert != "" && configuration.TLSKey != "" {
logging.Logger().Infof("Authelia is listening for TLS connections on %s%s", addrPattern, configuration.Server.Path) logger.Infof("Authelia is listening for TLS connections on %s%s", addrPattern, configuration.Server.Path)
logging.Logger().Fatal(server.ServeTLS(listener, configuration.TLSCert, configuration.TLSKey)) logger.Fatal(server.ServeTLS(listener, configuration.TLSCert, configuration.TLSKey))
} else { } else {
logging.Logger().Infof("Authelia is listening for non-TLS connections on %s%s", addrPattern, configuration.Server.Path) logger.Infof("Authelia is listening for non-TLS connections on %s%s", addrPattern, configuration.Server.Path)
logging.Logger().Fatal(server.Serve(listener)) logger.Fatal(server.Serve(listener))
} }
} }

View File

@ -20,19 +20,21 @@ var alphaNumericRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUV
// and generate a nonce to support a restrictive CSP while using material-ui. // and generate a nonce to support a restrictive CSP while using material-ui.
//go:generate broccoli -src ../../public_html -o public_html //go:generate broccoli -src ../../public_html -o public_html
func ServeTemplatedFile(publicDir, file, base, session, rememberMe, resetPassword string) fasthttp.RequestHandler { func ServeTemplatedFile(publicDir, file, base, session, rememberMe, resetPassword string) fasthttp.RequestHandler {
logger := logging.Logger()
f, err := br.Open(publicDir + file) f, err := br.Open(publicDir + file)
if err != nil { if err != nil {
logging.Logger().Fatalf("Unable to open %s: %s", file, err) logger.Fatalf("Unable to open %s: %s", file, err)
} }
b, err := ioutil.ReadAll(f) b, err := ioutil.ReadAll(f)
if err != nil { if err != nil {
logging.Logger().Fatalf("Unable to read %s: %s", file, err) logger.Fatalf("Unable to read %s: %s", file, err)
} }
tmpl, err := template.New("file").Parse(string(b)) tmpl, err := template.New("file").Parse(string(b))
if err != nil { if err != nil {
logging.Logger().Fatalf("Unable to parse %s template: %s", file, err) logger.Fatalf("Unable to parse %s template: %s", file, err)
} }
return func(ctx *fasthttp.RequestCtx) { return func(ctx *fasthttp.RequestCtx) {
@ -46,10 +48,10 @@ func ServeTemplatedFile(publicDir, file, base, session, rememberMe, resetPasswor
} }
switch { switch {
case os.Getenv("ENVIRONMENT") == dev:
ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("default-src 'self' 'unsafe-eval'; object-src 'none'; style-src 'self' 'nonce-%s'", nonce))
case publicDir == "/public_html/api/": case publicDir == "/public_html/api/":
ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("base-uri 'self' ; default-src 'self' ; img-src 'self' https://validator.swagger.io data: ; object-src 'none' ; script-src 'self' 'unsafe-inline' 'nonce-%s' ; style-src 'self' 'nonce-%s'", nonce, nonce)) ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("base-uri 'self' ; default-src 'self' ; img-src 'self' https://validator.swagger.io data: ; object-src 'none' ; script-src 'self' 'unsafe-inline' 'nonce-%s' ; style-src 'self' 'nonce-%s'", nonce, nonce))
case os.Getenv("ENVIRONMENT") == dev:
ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("default-src 'self' 'unsafe-eval'; object-src 'none'; style-src 'self' 'nonce-%s'", nonce))
default: default:
ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("default-src 'self' ; object-src 'none'; style-src 'self' 'nonce-%s'", nonce)) ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("default-src 'self' ; object-src 'none'; style-src 'self' 'nonce-%s'", nonce))
} }
@ -57,7 +59,7 @@ func ServeTemplatedFile(publicDir, file, base, session, rememberMe, resetPasswor
err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, CSPNonce, Session, RememberMe, ResetPassword string }{Base: base, CSPNonce: nonce, Session: session, RememberMe: rememberMe, ResetPassword: resetPassword}) err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, CSPNonce, Session, RememberMe, ResetPassword string }{Base: base, CSPNonce: nonce, Session: session, RememberMe: rememberMe, ResetPassword: resetPassword})
if err != nil { if err != nil {
ctx.Error("An error occurred", 503) ctx.Error("An error occurred", 503)
logging.Logger().Errorf("Unable to execute template: %v", err) logger.Errorf("Unable to execute template: %v", err)
return return
} }

View File

@ -37,6 +37,12 @@ func (s *CLISuite) SetupTest() {
s.coverageArg = coverageArg s.coverageArg = coverageArg
} }
func (s *CLISuite) TestShouldPrintVersion() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "version"})
s.Assert().Nil(err)
s.Assert().Contains(output, "Authelia version")
}
func (s *CLISuite) TestShouldValidateConfig() { func (s *CLISuite) TestShouldValidateConfig() {
output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "validate-config", "/config/configuration.yml"}) output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "validate-config", "/config/configuration.yml"})
s.Assert().Nil(err) s.Assert().Nil(err)