[CI] Add wsl linter (#980)

* [CI] Add wsl linter

* Implement wsl recommendations

Co-authored-by: Clément Michaud <clement.michaud34@gmail.com>
This commit is contained in:
Amir Zarrinkafsh 2020-05-06 05:35:32 +10:00 committed by GitHub
parent c13196a86e
commit 1600e0f7da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
107 changed files with 441 additions and 19 deletions

View File

@ -30,6 +30,7 @@ linters:
- unconvert - unconvert
- unparam - unparam
- whitespace - whitespace
- wsl
issues: issues:
exclude: exclude:

View File

@ -99,6 +99,7 @@ func prepareHostsFile() {
for _, entry := range hostEntries { for _, entry := range hostEntries {
domainInHostFile := false domainInHostFile := false
for i, line := range lines { for i, line := range lines {
domainFound := strings.Contains(line, entry.Domain) domainFound := strings.Contains(line, entry.Domain)
ipFound := strings.Contains(line, entry.IP) ipFound := strings.Contains(line, entry.IP)
@ -154,6 +155,7 @@ func readHostsFile() ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return bs, nil return bs, nil
} }
@ -188,6 +190,7 @@ func Bootstrap(cobraCmd *cobra.Command, args []string) {
bootstrapPrintln("Checking if GOPATH is set") bootstrapPrintln("Checking if GOPATH is set")
goPathFound := false goPathFound := false
for _, v := range os.Environ() { for _, v := range os.Environ() {
if strings.HasPrefix(v, "GOPATH=") { if strings.HasPrefix(v, "GOPATH=") {
goPathFound = true goPathFound = true

View File

@ -12,6 +12,7 @@ import (
func buildAutheliaBinary() { func buildAutheliaBinary() {
cmd := utils.CommandWithStdout("go", "build", "-o", "../../"+OutputDir+"/authelia") cmd := utils.CommandWithStdout("go", "build", "-o", "../../"+OutputDir+"/authelia")
cmd.Dir = "cmd/authelia" cmd.Dir = "cmd/authelia"
cmd.Env = append(os.Environ(), cmd.Env = append(os.Environ(),
"GOOS=linux", "GOARCH=amd64", "CGO_ENABLED=1") "GOOS=linux", "GOARCH=amd64", "CGO_ENABLED=1")
@ -34,6 +35,7 @@ func buildFrontend() {
// Then build the frontend. // Then build the frontend.
cmd = utils.CommandWithStdout("yarn", "build") cmd = utils.CommandWithStdout("yarn", "build")
cmd.Dir = webDirectory cmd.Dir = webDirectory
cmd.Env = append(os.Environ(), "INLINE_RUNTIME_CHUNK=false") cmd.Env = append(os.Environ(), "INLINE_RUNTIME_CHUNK=false")
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {

View File

@ -10,11 +10,13 @@ import (
// RunCI run the CI scripts. // RunCI run the CI scripts.
func RunCI(cmd *cobra.Command, args []string) { func RunCI(cmd *cobra.Command, args []string) {
log.Info("=====> Build stage <=====") log.Info("=====> Build stage <=====")
if err := utils.CommandWithStdout("authelia-scripts", "--log-level", "debug", "build").Run(); err != nil { if err := utils.CommandWithStdout("authelia-scripts", "--log-level", "debug", "build").Run(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
log.Info("=====> Unit testing stage <=====") log.Info("=====> Unit testing stage <=====")
if err := utils.CommandWithStdout("authelia-scripts", "--log-level", "debug", "unittest").Run(); err != nil { if err := utils.CommandWithStdout("authelia-scripts", "--log-level", "debug", "unittest").Run(); err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -37,6 +37,7 @@ func checkArchIsSupported(arch string) {
return return
} }
} }
log.Fatal("Architecture is not supported. Please select one of " + strings.Join(supportedArch, ", ") + ".") log.Fatal("Architecture is not supported. Please select one of " + strings.Join(supportedArch, ", ") + ".")
} }
@ -90,9 +91,11 @@ func dockerBuildOfficialImage(arch string) error {
cmd.Stdout = nil cmd.Stdout = nil
cmd.Stderr = nil cmd.Stderr = nil
commitBytes, err := cmd.Output() commitBytes, err := cmd.Output()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
commitHash := strings.Trim(string(commitBytes), "\n") commitHash := strings.Trim(string(commitBytes), "\n")
return docker.Build(IntermediateDockerImageName, dockerfile, ".", gitTag, commitHash) return docker.Build(IntermediateDockerImageName, dockerfile, ".", gitTag, commitHash)
@ -202,9 +205,9 @@ func publishDockerImage(arch string) {
if ciTag != "" { if ciTag != "" {
if len(tags) == 4 { if len(tags) == 4 {
log.Infof("Detected tags: '%s' | '%s' | '%s'", tags[1], tags[2], tags[3]) log.Infof("Detected tags: '%s' | '%s' | '%s'", tags[1], tags[2], tags[3])
login(docker) login(docker)
deploy(docker, tags[1]+"-"+arch) deploy(docker, tags[1]+"-"+arch)
if !ignoredSuffixes.MatchString(ciTag) { if !ignoredSuffixes.MatchString(ciTag) {
deploy(docker, tags[2]+"-"+arch) deploy(docker, tags[2]+"-"+arch)
deploy(docker, tags[3]+"-"+arch) deploy(docker, tags[3]+"-"+arch)
@ -233,7 +236,6 @@ func publishDockerManifest() {
if ciTag != "" { if ciTag != "" {
if len(tags) == 4 { if len(tags) == 4 {
log.Infof("Detected tags: '%s' | '%s' | '%s'", tags[1], tags[2], tags[3]) log.Infof("Detected tags: '%s' | '%s' | '%s'", tags[1], tags[2], tags[3])
login(docker) login(docker)
deployManifest(docker, tags[1], tags[1]+"-amd64", tags[1]+"-arm32v7", tags[1]+"-arm64v8") deployManifest(docker, tags[1], tags[1]+"-amd64", tags[1]+"-arm32v7", tags[1]+"-arm64v8")
publishDockerReadme(docker) publishDockerReadme(docker)

View File

@ -108,6 +108,7 @@ func listSuites() []string {
suiteNames := make([]string, 0) suiteNames := make([]string, 0)
suiteNames = append(suiteNames, suites.GlobalRegistry.Suites()...) suiteNames = append(suiteNames, suites.GlobalRegistry.Suites()...)
sort.Strings(suiteNames) sort.Strings(suiteNames)
return suiteNames return suiteNames
} }
@ -119,6 +120,7 @@ func checkSuiteAvailable(suite string) error {
return nil return nil
} }
} }
return ErrNotAvailableSuite return ErrNotAvailableSuite
} }
@ -130,6 +132,7 @@ func runSuiteSetupTeardown(command string, suite string) error {
if err == ErrNotAvailableSuite { if err == ErrNotAvailableSuite {
log.Fatal(errors.New("Suite named " + selectedSuite + " does not exist")) log.Fatal(errors.New("Suite named " + selectedSuite + " does not exist"))
} }
log.Fatal(err) log.Fatal(err)
} }
@ -139,6 +142,7 @@ func runSuiteSetupTeardown(command string, suite string) error {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
cmd.Env = os.Environ() cmd.Env = os.Environ()
return utils.RunCommandWithTimeout(cmd, s.SetUpTimeout) return utils.RunCommandWithTimeout(cmd, s.SetUpTimeout)
} }
@ -147,6 +151,7 @@ func runOnSetupTimeout(suite string) error {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
cmd.Env = os.Environ() cmd.Env = os.Environ()
return utils.RunCommandWithTimeout(cmd, 15*time.Second) return utils.RunCommandWithTimeout(cmd, 15*time.Second)
} }
@ -155,11 +160,13 @@ func runOnError(suite string) error {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
cmd.Env = os.Environ() cmd.Env = os.Environ()
return utils.RunCommandWithTimeout(cmd, 15*time.Second) return utils.RunCommandWithTimeout(cmd, 15*time.Second)
} }
func setupSuite(suiteName string) error { func setupSuite(suiteName string) error {
log.Infof("Setup environment for suite %s...", suiteName) log.Infof("Setup environment for suite %s...", suiteName)
signalChannel := make(chan os.Signal) signalChannel := make(chan os.Signal)
signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM) signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM)
@ -167,6 +174,7 @@ func setupSuite(suiteName string) error {
go func() { go func() {
<-signalChannel <-signalChannel
interrupted = true interrupted = true
}() }()
@ -174,7 +182,9 @@ func setupSuite(suiteName string) error {
if errSetup == utils.ErrTimeoutReached { if errSetup == utils.ErrTimeoutReached {
runOnSetupTimeout(suiteName) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. runOnSetupTimeout(suiteName) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
} }
teardownSuite(suiteName) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. teardownSuite(suiteName) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
return errSetup return errSetup
} }
@ -230,6 +240,7 @@ func getRunningSuite() (string, error) {
} }
b, err := ioutil.ReadFile(runningSuiteFile) b, err := ioutil.ReadFile(runningSuiteFile)
return string(b), err return string(b), err
} }
@ -247,6 +258,7 @@ func runSuiteTests(suiteName string, withEnv bool) error {
if suite.TestTimeout > 0 { if suite.TestTimeout > 0 {
timeout = fmt.Sprintf("%ds", int64(suite.TestTimeout/time.Second)) timeout = fmt.Sprintf("%ds", int64(suite.TestTimeout/time.Second))
} }
testCmdLine := fmt.Sprintf("go test -count=1 -v ./internal/suites -timeout %s ", timeout) testCmdLine := fmt.Sprintf("go test -count=1 -v ./internal/suites -timeout %s ", timeout)
if testPattern != "" { if testPattern != "" {
@ -262,6 +274,7 @@ func runSuiteTests(suiteName string, withEnv bool) error {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
cmd.Env = os.Environ() cmd.Env = os.Environ()
if headless { if headless {
cmd.Env = append(cmd.Env, "HEADLESS=y") cmd.Env = append(cmd.Env, "HEADLESS=y")
} }
@ -293,16 +306,20 @@ func runMultipleSuitesTests(suiteNames []string, withEnv bool) error {
return err return err
} }
} }
return nil return nil
} }
func runAllSuites() error { func runAllSuites() error {
log.Info("Start running all suites") log.Info("Start running all suites")
for _, s := range listSuites() { for _, s := range listSuites() {
if err := runSuiteTests(s, true); err != nil { if err := runSuiteTests(s, true); err != nil {
return err return err
} }
} }
log.Info("All suites passed successfully") log.Info("All suites passed successfully")
return nil return nil
} }

View File

@ -12,13 +12,16 @@ import (
// RunUnitTest run the unit tests. // RunUnitTest run the unit tests.
func RunUnitTest(cobraCmd *cobra.Command, args []string) { func RunUnitTest(cobraCmd *cobra.Command, args []string) {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
if err := utils.Shell("go test $(go list ./... | grep -v suites)").Run(); err != nil { if err := utils.Shell("go test $(go list ./... | grep -v suites)").Run(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
cmd := utils.Shell("yarn test") cmd := utils.Shell("yarn test")
cmd.Dir = webDirectory cmd.Dir = webDirectory
cmd.Env = append(os.Environ(), "CI=true") cmd.Env = append(os.Environ(), "CI=true")
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -85,11 +85,13 @@ func levelStringToLevel(level string) log.Level {
} else if level == "warning" { } else if level == "warning" {
return log.WarnLevel return log.WarnLevel
} }
return log.InfoLevel return log.InfoLevel
} }
func main() { func main() {
var rootCmd = &cobra.Command{Use: "authelia-scripts"} var rootCmd = &cobra.Command{Use: "authelia-scripts"}
cobraCommands := make([]*cobra.Command, 0) cobraCommands := make([]*cobra.Command, 0)
for _, autheliaCommand := range Commands { for _, autheliaCommand := range Commands {
@ -99,6 +101,7 @@ func main() {
cmdline := autheliaCommand.CommandLine cmdline := autheliaCommand.CommandLine
fn = func(cobraCmd *cobra.Command, args []string) { fn = func(cobraCmd *cobra.Command, args []string) {
cmd := utils.CommandWithStdout(cmdline, args...) cmd := utils.CommandWithStdout(cmdline, args...)
err := cmd.Run() err := cmd.Run()
if err != nil { if err != nil {
panic(err) panic(err)
@ -131,6 +134,7 @@ func main() {
cobraCommands = append(cobraCommands, command) cobraCommands = append(cobraCommands, command)
} }
cobraCommands = append(cobraCommands, commands.HashPasswordCmd) cobraCommands = append(cobraCommands, commands.HashPasswordCmd)
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "Set the log level for the command") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "Set the log level for the command")

View File

@ -55,6 +55,7 @@ func main() {
rootCmd.AddCommand(setupTimeoutCmd) rootCmd.AddCommand(setupTimeoutCmd)
rootCmd.AddCommand(errorCmd) rootCmd.AddCommand(errorCmd)
rootCmd.AddCommand(stopCmd) rootCmd.AddCommand(stopCmd)
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -125,6 +126,7 @@ func setupTimeoutSuite(cmd *cobra.Command, args []string) {
if s.OnSetupTimeout == nil { if s.OnSetupTimeout == nil {
return return
} }
if err := s.OnSetupTimeout(); err != nil { if err := s.OnSetupTimeout(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -137,6 +139,7 @@ func runErrorCallback(cmd *cobra.Command, args []string) {
if s.OnError == nil { if s.OnError == nil {
return return
} }
if err := s.OnError(); err != nil { if err := s.OnError(); err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -37,6 +37,7 @@ func startServer() {
for _, err := range errs { for _, err := range errs {
logging.Logger().Error(err) logging.Logger().Error(err)
} }
panic(errors.New("Some errors have been reported")) panic(errors.New("Some errors have been reported"))
} }
@ -89,6 +90,7 @@ func startServer() {
} else { } else {
log.Fatalf("Unrecognized notifier") log.Fatalf("Unrecognized notifier")
} }
if !config.Notifier.DisableStartupCheck { if !config.Notifier.DisableStartupCheck {
_, err := notifier.StartupCheck() _, err := notifier.StartupCheck()
if err != nil { if err != nil {

View File

@ -57,6 +57,7 @@ func NewFileUserProvider(configuration *schema.FileAuthenticationBackendConfigur
if configuration.Password.Algorithm == sha512 { if configuration.Password.Algorithm == sha512 {
cryptAlgo = HashingAlgorithmSHA512 cryptAlgo = HashingAlgorithmSHA512
} }
settings := getCryptSettings(utils.RandomString(configuration.Password.SaltLength, HashingPossibleSaltCharacters), settings := getCryptSettings(utils.RandomString(configuration.Password.SaltLength, HashingPossibleSaltCharacters),
cryptAlgo, configuration.Password.Iterations, configuration.Password.Memory*1024, configuration.Password.Parallelism, cryptAlgo, configuration.Password.Iterations, configuration.Password.Memory*1024, configuration.Password.Parallelism,
configuration.Password.KeyLength) configuration.Password.KeyLength)
@ -78,6 +79,7 @@ func checkPasswordHashes(database *DatabaseModel) error {
return fmt.Errorf("Unable to parse hash of user %s: %s", u, err) return fmt.Errorf("Unable to parse hash of user %s: %s", u, err)
} }
} }
return nil return nil
} }
@ -86,7 +88,9 @@ func readDatabase(path string) (*DatabaseModel, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to read database from file %s: %s", path, err) return nil, fmt.Errorf("Unable to read database from file %s: %s", path, err)
} }
db := DatabaseModel{} db := DatabaseModel{}
err = yaml.Unmarshal(content, &db) err = yaml.Unmarshal(content, &db)
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to parse database: %s", err) return nil, fmt.Errorf("Unable to parse database: %s", err)
@ -100,6 +104,7 @@ func readDatabase(path string) (*DatabaseModel, error) {
if !ok { if !ok {
return nil, fmt.Errorf("The database format is invalid: %s", err) return nil, fmt.Errorf("The database format is invalid: %s", err)
} }
return &db, nil return &db, nil
} }
@ -107,10 +112,12 @@ func readDatabase(path string) (*DatabaseModel, error) {
func (p *FileUserProvider) CheckUserPassword(username string, password string) (bool, error) { func (p *FileUserProvider) CheckUserPassword(username string, password string) (bool, error) {
if details, ok := p.database.Users[username]; ok { if details, ok := p.database.Users[username]; ok {
hashedPassword := strings.ReplaceAll(details.HashedPassword, "{CRYPT}", "") hashedPassword := strings.ReplaceAll(details.HashedPassword, "{CRYPT}", "")
ok, err := CheckPassword(password, hashedPassword) ok, err := CheckPassword(password, hashedPassword)
if err != nil { if err != nil {
return false, err return false, err
} }
return ok, nil return ok, nil
} }
@ -130,6 +137,7 @@ func (p *FileUserProvider) GetDetails(username string) (*UserDetails, error) {
Emails: []string{details.Email}, Emails: []string{details.Email},
}, nil }, nil
} }
return nil, fmt.Errorf("User '%s' does not exist in database", username) return nil, fmt.Errorf("User '%s' does not exist in database", username)
} }
@ -153,11 +161,12 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
newPassword, "", algorithm, p.configuration.Password.Iterations, newPassword, "", algorithm, p.configuration.Password.Iterations,
p.configuration.Password.Memory*1024, p.configuration.Password.Parallelism, p.configuration.Password.Memory*1024, p.configuration.Password.Parallelism,
p.configuration.Password.KeyLength, p.configuration.Password.SaltLength) p.configuration.Password.KeyLength, p.configuration.Password.SaltLength)
if err != nil { if err != nil {
return err return err
} }
details.HashedPassword = hash details.HashedPassword = hash
p.lock.Lock() p.lock.Lock()
p.database.Users[username] = details p.database.Users[username] = details
@ -166,7 +175,9 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
p.lock.Unlock() p.lock.Unlock()
return err return err
} }
err = ioutil.WriteFile(p.configuration.Path, b, 0644) //nolint:gosec // Fixed in future PR. err = ioutil.WriteFile(p.configuration.Path, b, 0644) //nolint:gosec // Fixed in future PR.
p.lock.Unlock() p.lock.Unlock()
return err return err
} }

View File

@ -69,6 +69,7 @@ func (lcf *LDAPConnectionFactoryImpl) DialTLS(network, addr string, config *tls.
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewLDAPConnectionImpl(conn), nil return NewLDAPConnectionImpl(conn), nil
} }
@ -78,5 +79,6 @@ func (lcf *LDAPConnectionFactoryImpl) Dial(network, addr string) (LDAPConnection
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewLDAPConnectionImpl(conn), nil return NewLDAPConnectionImpl(conn), nil
} }

View File

@ -47,12 +47,14 @@ func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnecti
if url.Scheme == "ldaps" { if url.Scheme == "ldaps" {
logging.Logger().Trace("LDAP client starts a TLS session") logging.Logger().Trace("LDAP client starts a TLS session")
conn, err := p.connectionFactory.DialTLS("tcp", url.Host, &tls.Config{ conn, err := p.connectionFactory.DialTLS("tcp", url.Host, &tls.Config{
InsecureSkipVerify: p.configuration.SkipVerify, //nolint:gosec // This is a configurable option, is desirable in some situations and is off by default InsecureSkipVerify: p.configuration.SkipVerify, //nolint:gosec // This is a configurable option, is desirable in some situations and is off by default
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
newConnection = conn newConnection = conn
} else { } else {
logging.Logger().Trace("LDAP client starts a session over raw TCP") logging.Logger().Trace("LDAP client starts a session over raw TCP")
@ -66,6 +68,7 @@ func (p *LDAPUserProvider) connect(userDN string, password string) (LDAPConnecti
if err := newConnection.Bind(userDN, password); err != nil { if err := newConnection.Bind(userDN, password); err != nil {
return nil, err return nil, err
} }
return newConnection, nil return newConnection, nil
} }
@ -100,6 +103,7 @@ func (p *LDAPUserProvider) ldapEscape(inputUsername string) string {
for _, c := range specialLDAPRunes { for _, c := range specialLDAPRunes {
inputUsername = strings.ReplaceAll(inputUsername, string(c), fmt.Sprintf("\\%c", c)) inputUsername = strings.ReplaceAll(inputUsername, string(c), fmt.Sprintf("\\%c", c))
} }
return inputUsername return inputUsername
} }
@ -122,6 +126,7 @@ func (p *LDAPUserProvider) resolveUsersFilter(userFilter string, inputUsername s
// in configuration. // in configuration.
userFilter = strings.ReplaceAll(userFilter, "{username_attribute}", p.configuration.UsernameAttribute) userFilter = strings.ReplaceAll(userFilter, "{username_attribute}", p.configuration.UsernameAttribute)
userFilter = strings.ReplaceAll(userFilter, "{mail_attribute}", p.configuration.MailAttribute) userFilter = strings.ReplaceAll(userFilter, "{mail_attribute}", p.configuration.MailAttribute)
return userFilter return userFilter
} }
@ -160,15 +165,18 @@ func (p *LDAPUserProvider) getUserProfile(conn LDAPConnection, inputUsername str
userProfile := ldapUserProfile{ userProfile := ldapUserProfile{
DN: sr.Entries[0].DN, DN: sr.Entries[0].DN,
} }
for _, attr := range sr.Entries[0].Attributes { for _, attr := range sr.Entries[0].Attributes {
if attr.Name == p.configuration.MailAttribute { if attr.Name == p.configuration.MailAttribute {
userProfile.Emails = attr.Values userProfile.Emails = attr.Values
} }
if attr.Name == p.configuration.UsernameAttribute { if attr.Name == p.configuration.UsernameAttribute {
if len(attr.Values) != 1 { if len(attr.Values) != 1 {
return nil, fmt.Errorf("User %s cannot have multiple value for attribute %s", return nil, fmt.Errorf("User %s cannot have multiple value for attribute %s",
inputUsername, p.configuration.UsernameAttribute) inputUsername, p.configuration.UsernameAttribute)
} }
userProfile.Username = attr.Values[0] userProfile.Username = attr.Values[0]
} }
} }
@ -186,6 +194,7 @@ func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ld
// We temporarily keep placeholder {0} for backward compatibility. // We temporarily keep placeholder {0} for backward compatibility.
groupFilter := strings.ReplaceAll(p.configuration.GroupsFilter, "{0}", inputUsername) groupFilter := strings.ReplaceAll(p.configuration.GroupsFilter, "{0}", inputUsername)
groupFilter = strings.ReplaceAll(groupFilter, "{input}", inputUsername) groupFilter = strings.ReplaceAll(groupFilter, "{input}", inputUsername)
if profile != nil { if profile != nil {
// We temporarily keep placeholder {1} for backward compatibility. // We temporarily keep placeholder {1} for backward compatibility.
groupFilter = strings.ReplaceAll(groupFilter, "{1}", ldap.EscapeFilter(profile.Username)) groupFilter = strings.ReplaceAll(groupFilter, "{1}", ldap.EscapeFilter(profile.Username))
@ -213,6 +222,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to create group filter for user %s. Cause: %s", inputUsername, err) return nil, fmt.Errorf("Unable to create group filter for user %s. Cause: %s", inputUsername, err)
} }
logging.Logger().Tracef("Computed groups filter is %s", groupsFilter) logging.Logger().Tracef("Computed groups filter is %s", groupsFilter)
groupBaseDN := p.configuration.BaseDN groupBaseDN := p.configuration.BaseDN
@ -233,6 +243,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error
} }
groups := make([]string, 0) groups := make([]string, 0)
for _, res := range sr.Entries { for _, res := range sr.Entries {
if len(res.Attributes) == 0 { if len(res.Attributes) == 0 {
logging.Logger().Warningf("No groups retrieved from LDAP for user %s", inputUsername) logging.Logger().Warningf("No groups retrieved from LDAP for user %s", inputUsername)

View File

@ -38,6 +38,7 @@ func ParseHash(hash string) (passwordHash *PasswordHash, err error) {
if h.Key != parts[len(parts)-1] { if h.Key != parts[len(parts)-1] {
return nil, fmt.Errorf("Hash key is not the last parameter, the hash is likely malformed (%s)", hash) return nil, fmt.Errorf("Hash key is not the last parameter, the hash is likely malformed (%s)", hash)
} }
if h.Key == "" { if h.Key == "" {
return nil, fmt.Errorf("Hash key contains no characters or the field length is invalid (%s)", hash) return nil, fmt.Errorf("Hash key contains no characters or the field length is invalid (%s)", hash)
} }
@ -50,6 +51,7 @@ func ParseHash(hash string) (passwordHash *PasswordHash, err error) {
if code == HashingAlgorithmSHA512 { if code == HashingAlgorithmSHA512 {
h.Iterations = parameters.GetInt("rounds", HashingDefaultSHA512Iterations) h.Iterations = parameters.GetInt("rounds", HashingDefaultSHA512Iterations)
h.Algorithm = HashingAlgorithmSHA512 h.Algorithm = HashingAlgorithmSHA512
if parameters["rounds"] != "" && parameters["rounds"] != strconv.Itoa(h.Iterations) { if parameters["rounds"] != "" && parameters["rounds"] != strconv.Itoa(h.Iterations) {
return nil, fmt.Errorf("SHA512 iterations is not numeric (%s)", parameters["rounds"]) return nil, fmt.Errorf("SHA512 iterations is not numeric (%s)", parameters["rounds"])
} }
@ -79,6 +81,7 @@ func ParseHash(hash string) (passwordHash *PasswordHash, err error) {
} else { } else {
return nil, fmt.Errorf("Authelia only supports salted SHA512 hashing ($6$) and salted argon2id ($argon2id$), not $%s$", code) return nil, fmt.Errorf("Authelia only supports salted SHA512 hashing ($6$) and salted argon2id ($argon2id$), not $%s$", code)
} }
return h, nil return h, nil
} }
@ -110,28 +113,33 @@ func HashPassword(password, salt string, algorithm CryptAlgo, iterations, memory
if memory < 8 { if memory < 8 {
return "", fmt.Errorf("Memory (argon2id) input of %d is invalid, it must be 8 or higher", memory) return "", fmt.Errorf("Memory (argon2id) input of %d is invalid, it must be 8 or higher", memory)
} }
if parallelism < 1 { if parallelism < 1 {
return "", fmt.Errorf("Parallelism (argon2id) input of %d is invalid, it must be 1 or higher", parallelism) return "", fmt.Errorf("Parallelism (argon2id) input of %d is invalid, it must be 1 or higher", parallelism)
} }
if memory < parallelism*8 { if memory < parallelism*8 {
return "", fmt.Errorf("Memory (argon2id) input of %d is invalid with a parallelism input of %d, it must be %d (parallelism * 8) or higher", memory, parallelism, parallelism*8) return "", fmt.Errorf("Memory (argon2id) input of %d is invalid with a parallelism input of %d, it must be %d (parallelism * 8) or higher", memory, parallelism, parallelism*8)
} }
if keyLength < 16 { if keyLength < 16 {
return "", fmt.Errorf("Key length (argon2id) input of %d is invalid, it must be 16 or higher", keyLength) return "", fmt.Errorf("Key length (argon2id) input of %d is invalid, it must be 16 or higher", keyLength)
} }
if iterations < 1 { if iterations < 1 {
return "", fmt.Errorf("Iterations (argon2id) input of %d is invalid, it must be 1 or more", iterations) return "", fmt.Errorf("Iterations (argon2id) input of %d is invalid, it must be 1 or more", iterations)
} }
// Caution: Increasing any of the values in the above block has a high chance in old passwords that cannot be verified.
} }
if salt == "" { if salt == "" {
salt = utils.RandomString(saltLength, HashingPossibleSaltCharacters) salt = utils.RandomString(saltLength, HashingPossibleSaltCharacters)
} }
settings = getCryptSettings(salt, algorithm, iterations, memory, parallelism, keyLength) settings = getCryptSettings(salt, algorithm, iterations, memory, parallelism, keyLength)
// This error can be ignored because we check for it before a user gets here. // This error can be ignored because we check for it before a user gets here.
hash, _ = crypt.Crypt(password, settings) hash, _ = crypt.Crypt(password, settings)
return hash, nil return hash, nil
} }
@ -141,10 +149,12 @@ func CheckPassword(password, hash string) (ok bool, err error) {
if err != nil { if err != nil {
return false, err return false, err
} }
expectedHash, err := HashPassword(password, passwordHash.Salt, passwordHash.Algorithm, passwordHash.Iterations, passwordHash.Memory, passwordHash.Parallelism, passwordHash.KeyLength, len(passwordHash.Salt)) expectedHash, err := HashPassword(password, passwordHash.Salt, passwordHash.Algorithm, passwordHash.Iterations, passwordHash.Memory, passwordHash.Parallelism, passwordHash.KeyLength, len(passwordHash.Salt))
if err != nil { if err != nil {
return false, err return false, err
} }
return hash == expectedHash, nil return hash == expectedHash, nil
} }
@ -156,5 +166,6 @@ func getCryptSettings(salt string, algorithm CryptAlgo, iterations, memory, para
} else { } else {
panic("invalid password hashing algorithm provided") panic("invalid password hashing algorithm provided")
} }
return settings return settings
} }

View File

@ -47,10 +47,13 @@ func TestShouldHashArgon2idPassword(t *testing.T) {
// This checks the method of hashing (for argon2id) supports all the characters we allow in Authelia's hash function. // This checks the method of hashing (for argon2id) supports all the characters we allow in Authelia's hash function.
func TestArgon2idHashSaltValidValues(t *testing.T) { func TestArgon2idHashSaltValidValues(t *testing.T) {
var err error
var hash string
data := string(HashingPossibleSaltCharacters) data := string(HashingPossibleSaltCharacters)
datas := utils.SliceString(data, 16) datas := utils.SliceString(data, 16)
var hash string
var err error
for _, salt := range datas { for _, salt := range datas {
hash, err = HashPassword("password", salt, HashingAlgorithmArgon2id, 1, 8, 1, 32, 16) hash, err = HashPassword("password", salt, HashingAlgorithmArgon2id, 1, 8, 1, 32, 16)
assert.NoError(t, err) assert.NoError(t, err)
@ -60,10 +63,13 @@ func TestArgon2idHashSaltValidValues(t *testing.T) {
// This checks the method of hashing (for sha512) supports all the characters we allow in Authelia's hash function. // This checks the method of hashing (for sha512) supports all the characters we allow in Authelia's hash function.
func TestSHA512HashSaltValidValues(t *testing.T) { func TestSHA512HashSaltValidValues(t *testing.T) {
var err error
var hash string
data := string(HashingPossibleSaltCharacters) data := string(HashingPossibleSaltCharacters)
datas := utils.SliceString(data, 16) datas := utils.SliceString(data, 16)
var hash string
var err error
for _, salt := range datas { for _, salt := range datas {
hash, err = HashPassword("password", salt, HashingAlgorithmSHA512, 1000, 0, 0, 0, 16) hash, err = HashPassword("password", salt, HashingAlgorithmSHA512, 1000, 0, 0, 0, 16)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -71,6 +71,7 @@ func selectMatchingObjectRules(rules []schema.ACLRule, object Object) []schema.A
selectedRules = append(selectedRules, rule) selectedRules = append(selectedRules, rule)
} }
} }
return selectedRules return selectedRules
} }
@ -123,6 +124,7 @@ func (p *Authorizer) GetRequiredLevel(subject Subject, requestURL url.URL) Level
if len(matchingRules) > 0 { if len(matchingRules) > 0 {
return PolicyToLevel(matchingRules[0].Policy) return PolicyToLevel(matchingRules[0].Policy)
} }
logging.Logger().Tracef("No matching rule for subject %s and url %s... Applying default policy.", logging.Logger().Tracef("No matching rule for subject %s and url %s... Applying default policy.",
subject.String(), requestURL.String()) subject.String(), requestURL.String())
@ -141,5 +143,6 @@ func (p *Authorizer) IsURLMatchingRuleWithGroupSubjects(requestURL url.URL) (has
} }
} }
} }
return false return false
} }

View File

@ -10,5 +10,6 @@ func isDomainMatching(domain string, domainRules []string) bool {
return true return true
} }
} }
return false return false
} }

View File

@ -17,9 +17,12 @@ func isIPMatching(ip net.IP, networks []string) bool {
if ip.String() == network { if ip.String() == network {
return true return true
} }
continue continue
} }
_, ipNet, err := net.ParseCIDR(network) _, ipNet, err := net.ParseCIDR(network)
if err != nil { if err != nil {
// TODO(c.michaud): make sure the rule is valid at startup to // TODO(c.michaud): make sure the rule is valid at startup to
// to such a case here. // to such a case here.
@ -30,5 +33,6 @@ func isIPMatching(ip net.IP, networks []string) bool {
return true return true
} }
} }
return false return false
} }

View File

@ -20,5 +20,6 @@ func isPathMatching(path string, pathRegexps []string) bool {
return true return true
} }
} }
return false return false
} }

View File

@ -25,5 +25,6 @@ func isSubjectMatching(subject Subject, subjectRule string) bool {
return true return true
} }
} }
return false return false
} }

View File

@ -34,6 +34,7 @@ var (
func init() { func init() {
CertificatesGenerateCmd.PersistentFlags().StringVar(&host, "host", "", "Comma-separated hostnames and IPs to generate a certificate for") CertificatesGenerateCmd.PersistentFlags().StringVar(&host, "host", "", "Comma-separated hostnames and IPs to generate a certificate for")
err := CertificatesGenerateCmd.MarkPersistentFlagRequired("host") err := CertificatesGenerateCmd.MarkPersistentFlagRequired("host")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -66,7 +67,9 @@ func publicKey(priv interface{}) interface{} {
func generateSelfSignedCertificate(cmd *cobra.Command, args []string) { func generateSelfSignedCertificate(cmd *cobra.Command, args []string) {
// implementation retrieved from https://golang.org/src/crypto/tls/generate_cert.go // implementation retrieved from https://golang.org/src/crypto/tls/generate_cert.go
var priv interface{} var priv interface{}
var err error var err error
switch ecdsaCurve { switch ecdsaCurve {
case "": case "":
if ed25519Key { if ed25519Key {
@ -85,6 +88,7 @@ func generateSelfSignedCertificate(cmd *cobra.Command, args []string) {
default: default:
log.Fatalf("Unrecognized elliptic curve: %q", ecdsaCurve) log.Fatalf("Unrecognized elliptic curve: %q", ecdsaCurve)
} }
if err != nil { if err != nil {
log.Fatalf("Failed to generate private key: %v", err) log.Fatalf("Failed to generate private key: %v", err)
} }
@ -103,6 +107,7 @@ func generateSelfSignedCertificate(cmd *cobra.Command, args []string) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil { if err != nil {
log.Fatalf("Failed to generate serial number: %v", err) log.Fatalf("Failed to generate serial number: %v", err)
} }
@ -141,33 +146,42 @@ func generateSelfSignedCertificate(cmd *cobra.Command, args []string) {
certPath := path.Join(targetDirectory, "cert.pem") certPath := path.Join(targetDirectory, "cert.pem")
certOut, err := os.Create(certPath) certOut, err := os.Create(certPath)
if err != nil { if err != nil {
log.Fatalf("Failed to open %s for writing: %v", certPath, err) log.Fatalf("Failed to open %s for writing: %v", certPath, err)
} }
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
log.Fatalf("Failed to write data to cert.pem: %v", err) log.Fatalf("Failed to write data to cert.pem: %v", err)
} }
if err := certOut.Close(); err != nil { if err := certOut.Close(); err != nil {
log.Fatalf("Error closing %s: %v", certPath, err) log.Fatalf("Error closing %s: %v", certPath, err)
} }
log.Printf("wrote %s\n", certPath) log.Printf("wrote %s\n", certPath)
keyPath := path.Join(targetDirectory, "key.pem") keyPath := path.Join(targetDirectory, "key.pem")
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
log.Fatalf("Failed to open %s for writing: %v", keyPath, err) log.Fatalf("Failed to open %s for writing: %v", keyPath, err)
return return
} }
privBytes, err := x509.MarshalPKCS8PrivateKey(priv) privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil { if err != nil {
log.Fatalf("Unable to marshal private key: %v", err) log.Fatalf("Unable to marshal private key: %v", err)
} }
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
log.Fatalf("Failed to write data to %s: %v", keyPath, err) log.Fatalf("Failed to write data to %s: %v", keyPath, err)
} }
if err := keyOut.Close(); err != nil { if err := keyOut.Close(); err != nil {
log.Fatalf("Error closing %s: %v", keyPath, err) log.Fatalf("Error closing %s: %v", keyPath, err)
} }
log.Printf("wrote %s\n", keyPath) log.Printf("wrote %s\n", keyPath)
} }

View File

@ -43,6 +43,7 @@ func Read(configPath string) (*schema.Configuration, []error) {
} }
var configuration schema.Configuration var configuration schema.Configuration
viper.Unmarshal(&configuration) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. viper.Unmarshal(&configuration) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
val := schema.NewStructValidator() val := schema.NewStructValidator()

View File

@ -58,6 +58,7 @@ func TestShouldParseConfigFile(t *testing.T) {
func TestShouldParseAltConfigFile(t *testing.T) { func TestShouldParseAltConfigFile(t *testing.T) {
require.NoError(t, os.Setenv("AUTHELIA_STORAGE_POSTGRES_PASSWORD", "postgres_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_STORAGE_POSTGRES_PASSWORD", "postgres_secret_from_env"))
config, errors := Read("./test_resources/config_alt.yml") config, errors := Read("./test_resources/config_alt.yml")
require.Len(t, errors, 0) require.Len(t, errors, 0)
@ -98,6 +99,7 @@ func TestShouldNotParseConfigFileWithOldOrUnexpectedKeys(t *testing.T) {
func TestShouldValidateConfigurationTemplate(t *testing.T) { func TestShouldValidateConfigurationTemplate(t *testing.T) {
resetEnv() resetEnv()
_, errors := Read("../../config.template.yml") _, errors := Read("../../config.template.yml")
assert.Len(t, errors, 0) assert.Len(t, errors, 0)
} }
@ -112,6 +114,7 @@ func TestShouldOnlyAllowOneEnvType(t *testing.T) {
require.NoError(t, os.Setenv("AUTHELIA_AUTHENTICATION_BACKEND_LDAP_PASSWORD", "ldap_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_AUTHENTICATION_BACKEND_LDAP_PASSWORD", "ldap_secret_from_env"))
require.NoError(t, os.Setenv("AUTHELIA_NOTIFIER_SMTP_PASSWORD", "smtp_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_NOTIFIER_SMTP_PASSWORD", "smtp_secret_from_env"))
require.NoError(t, os.Setenv("AUTHELIA_SESSION_REDIS_PASSWORD", "redis_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_SESSION_REDIS_PASSWORD", "redis_secret_from_env"))
_, errors := Read("./test_resources/config_alt.yml") _, errors := Read("./test_resources/config_alt.yml")
require.Len(t, errors, 2) require.Len(t, errors, 2)
@ -128,6 +131,7 @@ func TestShouldOnlyAllowEnvOrConfig(t *testing.T) {
require.NoError(t, os.Setenv("AUTHELIA_AUTHENTICATION_BACKEND_LDAP_PASSWORD", "ldap_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_AUTHENTICATION_BACKEND_LDAP_PASSWORD", "ldap_secret_from_env"))
require.NoError(t, os.Setenv("AUTHELIA_NOTIFIER_SMTP_PASSWORD", "smtp_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_NOTIFIER_SMTP_PASSWORD", "smtp_secret_from_env"))
require.NoError(t, os.Setenv("AUTHELIA_SESSION_REDIS_PASSWORD", "redis_secret_from_env")) require.NoError(t, os.Setenv("AUTHELIA_SESSION_REDIS_PASSWORD", "redis_secret_from_env"))
_, errors := Read("./test_resources/config_with_secret.yml") _, errors := Read("./test_resources/config_with_secret.yml")
require.Len(t, errors, 1) require.Len(t, errors, 1)

View File

@ -23,6 +23,7 @@ type Validator struct {
func NewValidator() *Validator { func NewValidator() *Validator {
validator := new(Validator) validator := new(Validator)
validator.errors = make(map[string][]error) validator.errors = make(map[string][]error)
return validator return validator
} }
@ -39,6 +40,7 @@ func (v *Validator) validateOne(item QueueItem, q *queue.Queue) error { //nolint
} }
elem := item.value.Elem() elem := item.value.Elem()
q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
value: elem, value: elem,
path: item.path, path: item.path,
@ -64,6 +66,7 @@ func (v *Validator) validateOne(item QueueItem, q *queue.Queue) error { //nolint
}) })
} }
} }
return nil return nil
} }
@ -77,12 +80,15 @@ func (v *Validator) Validate(s interface{}) error {
if err != nil { if err != nil {
return err return err
} }
item, ok := val[0].(QueueItem) item, ok := val[0].(QueueItem)
if !ok { if !ok {
return fmt.Errorf("Cannot convert item into QueueItem") return fmt.Errorf("Cannot convert item into QueueItem")
} }
v.validateOne(item, q) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. v.validateOne(item, q) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
} }
return nil return nil
} }
@ -90,6 +96,7 @@ func (v *Validator) Validate(s interface{}) error {
func (v *Validator) PrintErrors() { func (v *Validator) PrintErrors() {
for path, errs := range v.errors { for path, errs := range v.errors {
fmt.Printf("Errors at %s:\n", path) fmt.Printf("Errors at %s:\n", path)
for _, err := range errs { for _, err := range errs {
fmt.Printf("--> %s\n", err) fmt.Printf("--> %s\n", err)
} }
@ -110,6 +117,7 @@ type StructValidator struct {
func NewStructValidator() *StructValidator { func NewStructValidator() *StructValidator {
val := new(StructValidator) val := new(StructValidator)
val.errors = make([]error, 0) val.errors = make([]error, 0)
return val return val
} }

View File

@ -45,6 +45,7 @@ func ValidateConfiguration(configuration *schema.Configuration, validator *schem
if configuration.TOTP == nil { if configuration.TOTP == nil {
configuration.TOTP = &schema.DefaultTOTPConfiguration configuration.TOTP = &schema.DefaultTOTPConfiguration
} }
ValidateTOTP(configuration.TOTP, validator) ValidateTOTP(configuration.TOTP, validator)
ValidateAuthenticationBackend(&configuration.AuthenticationBackend, validator) ValidateAuthenticationBackend(&configuration.AuthenticationBackend, validator)
@ -58,6 +59,7 @@ func ValidateConfiguration(configuration *schema.Configuration, validator *schem
if configuration.Regulation == nil { if configuration.Regulation == nil {
configuration.Regulation = &schema.DefaultRegulationConfiguration configuration.Regulation = &schema.DefaultRegulationConfiguration
} }
ValidateRegulation(configuration.Regulation, validator) ValidateRegulation(configuration.Regulation, validator)
ValidateServer(&configuration.Server, validator) ValidateServer(&configuration.Server, validator)

View File

@ -30,6 +30,7 @@ func newDefaultConfig() schema.Configuration {
Filename: "/tmp/file", Filename: "/tmp/file",
}, },
} }
return config return config
} }

View File

@ -11,6 +11,7 @@ import (
// ValidateKeys determines if a provided key is valid. // ValidateKeys determines if a provided key is valid.
func ValidateKeys(validator *schema.StructValidator, keys []string) { func ValidateKeys(validator *schema.StructValidator, keys []string) {
var errStrings []string var errStrings []string
for _, key := range keys { for _, key := range keys {
if utils.IsStringInSlice(key, validKeys) { if utils.IsStringInSlice(key, validKeys) {
continue continue
@ -24,6 +25,7 @@ func ValidateKeys(validator *schema.StructValidator, keys []string) {
validator.Push(fmt.Errorf("config key not expected: %s", key)) validator.Push(fmt.Errorf("config key not expected: %s", key))
} }
} }
for _, err := range errStrings { for _, err := range errStrings {
validator.Push(errors.New(err)) validator.Push(errors.New(err))
} }

View File

@ -34,11 +34,13 @@ func TestShouldNotValidateBadKeys(t *testing.T) {
func TestAllSpecificErrorKeys(t *testing.T) { func TestAllSpecificErrorKeys(t *testing.T) {
var configKeys []string //nolint:prealloc // This is because the test is dynamic based on the keys that exist in the map var configKeys []string //nolint:prealloc // This is because the test is dynamic based on the keys that exist in the map
var uniqueValues []string var uniqueValues []string
// Setup configKeys and uniqueValues expected. // Setup configKeys and uniqueValues expected.
for key, value := range specificErrorKeys { for key, value := range specificErrorKeys {
configKeys = append(configKeys, key) configKeys = append(configKeys, key)
if !utils.IsStringInSlice(value, uniqueValues) { if !utils.IsStringInSlice(value, uniqueValues) {
uniqueValues = append(uniqueValues, value) uniqueValues = append(uniqueValues, value)
} }

View File

@ -22,6 +22,7 @@ func ValidateNotifier(configuration *schema.NotifierConfiguration, validator *sc
if configuration.FileSystem.Filename == "" { if configuration.FileSystem.Filename == "" {
validator.Push(fmt.Errorf("Filename of filesystem notifier must not be empty")) validator.Push(fmt.Errorf("Filename of filesystem notifier must not be empty"))
} }
return return
} }
@ -29,6 +30,7 @@ func ValidateNotifier(configuration *schema.NotifierConfiguration, validator *sc
if configuration.SMTP.StartupCheckAddress == "" { if configuration.SMTP.StartupCheckAddress == "" {
configuration.SMTP.StartupCheckAddress = "test@authelia.com" configuration.SMTP.StartupCheckAddress = "test@authelia.com"
} }
if configuration.SMTP.Host == "" { if configuration.SMTP.Host == "" {
validator.Push(fmt.Errorf("Host of SMTP notifier must be provided")) validator.Push(fmt.Errorf("Host of SMTP notifier must be provided"))
} }
@ -44,6 +46,7 @@ func ValidateNotifier(configuration *schema.NotifierConfiguration, validator *sc
if configuration.SMTP.Subject == "" { if configuration.SMTP.Subject == "" {
configuration.SMTP.Subject = schema.DefaultSMTPNotifierConfiguration.Subject configuration.SMTP.Subject = schema.DefaultSMTPNotifierConfiguration.Subject
} }
return return
} }
} }

View File

@ -12,17 +12,21 @@ func ValidateRegulation(configuration *schema.RegulationConfiguration, validator
if configuration.FindTime == "" { if configuration.FindTime == "" {
configuration.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min configuration.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min
} }
if configuration.BanTime == "" { if configuration.BanTime == "" {
configuration.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min configuration.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min
} }
findTime, err := utils.ParseDurationString(configuration.FindTime) findTime, err := utils.ParseDurationString(configuration.FindTime)
if err != nil { if err != nil {
validator.Push(fmt.Errorf("Error occurred parsing regulation find_time string: %s", err)) validator.Push(fmt.Errorf("Error occurred parsing regulation find_time string: %s", err))
} }
banTime, err := utils.ParseDurationString(configuration.BanTime) banTime, err := utils.ParseDurationString(configuration.BanTime)
if err != nil { if err != nil {
validator.Push(fmt.Errorf("Error occurred parsing regulation ban_time string: %s", err)) validator.Push(fmt.Errorf("Error occurred parsing regulation ban_time string: %s", err))
} }
if findTime > banTime { if findTime > banTime {
validator.Push(fmt.Errorf("find_time cannot be greater than ban_time")) validator.Push(fmt.Errorf("find_time cannot be greater than ban_time"))
} }

View File

@ -50,6 +50,7 @@ func getSecretValue(name string, validator *schema.StructValidator, viper *viper
if envValue != "" && fileEnvValue != "" { if envValue != "" && fileEnvValue != "" {
validator.Push(fmt.Errorf("secret is defined in multiple areas: %s", name)) validator.Push(fmt.Errorf("secret is defined in multiple areas: %s", name))
} }
if (envValue != "" || fileEnvValue != "") && configValue != "" { if (envValue != "" || fileEnvValue != "") && configValue != "" {
validator.Push(fmt.Errorf("error loading secret (%s): it's already defined in the config file", name)) validator.Push(fmt.Errorf("error loading secret (%s): it's already defined in the config file", name))
} }
@ -63,9 +64,11 @@ func getSecretValue(name string, validator *schema.StructValidator, viper *viper
return strings.Replace(string(content), "\n", "", -1) return strings.Replace(string(content), "\n", "", -1)
} }
} }
if envValue != "" { if envValue != "" {
logging.Logger().Warnf("The following secret is defined as an environment variable, this is insecure and being removed in 4.18.0+, it's recommended to use the file secrets instead (https://docs.authelia.com/configuration/secrets.html): %s", name) logging.Logger().Warnf("The following secret is defined as an environment variable, this is insecure and being removed in 4.18.0+, it's recommended to use the file secrets instead (https://docs.authelia.com/configuration/secrets.html): %s", name)
return envValue return envValue
} }
return configValue return configValue
} }

View File

@ -12,6 +12,7 @@ func newDefaultSessionConfig() schema.SessionConfiguration {
config := schema.SessionConfiguration{} config := schema.SessionConfiguration{}
config.Secret = testJWTSecret config.Secret = testJWTSecret
config.Domain = "example.com" config.Domain = "example.com"
return config return config
} }

View File

@ -11,6 +11,7 @@ func ValidateTOTP(configuration *schema.TOTPConfiguration, validator *schema.Str
if configuration.Issuer == "" { if configuration.Issuer == "" {
configuration.Issuer = schema.DefaultTOTPConfiguration.Issuer configuration.Issuer = schema.DefaultTOTPConfiguration.Issuer
} }
if configuration.Period == 0 { if configuration.Period == 0 {
configuration.Period = schema.DefaultTOTPConfiguration.Period configuration.Period = schema.DefaultTOTPConfiguration.Period
} else if configuration.Period < 0 { } else if configuration.Period < 0 {

View File

@ -23,6 +23,7 @@ func TestShouldSetDefaultTOTPValues(t *testing.T) {
func TestShouldRaiseErrorWhenInvalidTOTPMinimumValues(t *testing.T) { func TestShouldRaiseErrorWhenInvalidTOTPMinimumValues(t *testing.T) {
var badSkew = -1 var badSkew = -1
validator := schema.NewStructValidator() validator := schema.NewStructValidator()
config := schema.TOTPConfiguration{ config := schema.TOTPConfiguration{
Period: -5, Period: -5,

View File

@ -13,23 +13,25 @@ import (
func NewDuoAPI(duoAPI *duoapi.DuoApi) *APIImpl { func NewDuoAPI(duoAPI *duoapi.DuoApi) *APIImpl {
api := new(APIImpl) api := new(APIImpl)
api.DuoApi = duoAPI api.DuoApi = duoAPI
return api return api
} }
// Call call to the DuoAPI. // Call call to the DuoAPI.
func (d *APIImpl) Call(values url.Values, ctx *middlewares.AutheliaCtx) (*Response, error) { func (d *APIImpl) Call(values url.Values, ctx *middlewares.AutheliaCtx) (*Response, error) {
_, responseBytes, err := d.DuoApi.SignedCall("POST", "/auth/v2/auth", values) var response Response
_, responseBytes, err := d.DuoApi.SignedCall("POST", "/auth/v2/auth", values)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx.Logger.Tracef("Duo Push Auth Response Raw Data for %s from IP %s: %s", ctx.GetSession().Username, ctx.RemoteIP().String(), string(responseBytes)) ctx.Logger.Tracef("Duo Push Auth Response Raw Data for %s from IP %s: %s", ctx.GetSession().Username, ctx.RemoteIP().String(), string(responseBytes))
var response Response
err = json.Unmarshal(responseBytes, &response) err = json.Unmarshal(responseBytes, &response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &response, nil return &response, nil
} }

View File

@ -38,6 +38,7 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldServeDefaultMethods() {
SecondFactorEnabled: false, SecondFactorEnabled: false,
TOTPPeriod: schema.DefaultTOTPConfiguration.Period, TOTPPeriod: schema.DefaultTOTPConfiguration.Period,
} }
ExtendedConfigurationGet(s.mock.Ctx) ExtendedConfigurationGet(s.mock.Ctx)
s.mock.Assert200OK(s.T(), expectedBody) s.mock.Assert200OK(s.T(), expectedBody)
} }
@ -54,6 +55,7 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldServeDefaultMethodsAndMo
SecondFactorEnabled: false, SecondFactorEnabled: false,
TOTPPeriod: schema.DefaultTOTPConfiguration.Period, TOTPPeriod: schema.DefaultTOTPConfiguration.Period,
} }
ExtendedConfigurationGet(s.mock.Ctx) ExtendedConfigurationGet(s.mock.Ctx)
s.mock.Assert200OK(s.T(), expectedBody) s.mock.Assert200OK(s.T(), expectedBody)
} }

View File

@ -28,7 +28,9 @@ func FirstFactorPost(ctx *middlewares.AutheliaCtx) {
ctx.Error(fmt.Errorf("User %s is banned until %s", bodyJSON.Username, bannedUntil), userBannedMessage) ctx.Error(fmt.Errorf("User %s is banned until %s", bodyJSON.Username, bannedUntil), userBannedMessage)
return return
} }
ctx.Error(fmt.Errorf("Unable to regulate authentication: %s", err), authenticationFailedMessage) ctx.Error(fmt.Errorf("Unable to regulate authentication: %s", err), authenticationFailedMessage)
return return
} }
@ -39,6 +41,7 @@ func FirstFactorPost(ctx *middlewares.AutheliaCtx) {
ctx.Providers.Regulator.Mark(bodyJSON.Username, false) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. ctx.Providers.Regulator.Mark(bodyJSON.Username, false) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
ctx.Error(fmt.Errorf("Error while checking password for user %s: %s", bodyJSON.Username, err.Error()), authenticationFailedMessage) ctx.Error(fmt.Errorf("Error while checking password for user %s: %s", bodyJSON.Username, err.Error()), authenticationFailedMessage)
return return
} }
@ -47,6 +50,7 @@ func FirstFactorPost(ctx *middlewares.AutheliaCtx) {
ctx.Providers.Regulator.Mark(bodyJSON.Username, false) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. ctx.Providers.Regulator.Mark(bodyJSON.Username, false) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
ctx.ReplyError(fmt.Errorf("Credentials are wrong for user %s", bodyJSON.Username), authenticationFailedMessage) ctx.ReplyError(fmt.Errorf("Credentials are wrong for user %s", bodyJSON.Username), authenticationFailedMessage)
return return
} }
@ -106,9 +110,11 @@ func FirstFactorPost(ctx *middlewares.AutheliaCtx) {
userSession.LastActivity = time.Now().Unix() userSession.LastActivity = time.Now().Unix()
userSession.KeepMeLoggedIn = keepMeLoggedIn userSession.KeepMeLoggedIn = keepMeLoggedIn
refresh, refreshInterval := getProfileRefreshSettings(ctx.Configuration.AuthenticationBackend) refresh, refreshInterval := getProfileRefreshSettings(ctx.Configuration.AuthenticationBackend)
if refresh { if refresh {
userSession.RefreshTTL = ctx.Clock.Now().Add(refreshInterval) userSession.RefreshTTL = ctx.Clock.Now().Add(refreshInterval)
} }
err = ctx.SaveSession(userSession) err = ctx.SaveSession(userSession)
if err != nil { if err != nil {

View File

@ -36,6 +36,7 @@ func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string
appID := fmt.Sprintf("%s://%s", ctx.XForwardedProto(), ctx.XForwardedHost()) appID := fmt.Sprintf("%s://%s", ctx.XForwardedProto(), ctx.XForwardedHost())
ctx.Logger.Tracef("U2F appID is %s", appID) ctx.Logger.Tracef("U2F appID is %s", appID)
var trustedFacets = []string{appID} var trustedFacets = []string{appID}
challenge, err := u2f.NewChallenge(appID, trustedFacets) challenge, err := u2f.NewChallenge(appID, trustedFacets)

View File

@ -43,6 +43,7 @@ func createToken(secret string, username string, action string, expiresAt time.T
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, _ := token.SignedString([]byte(secret)) ss, _ := token.SignedString([]byte(secret))
return ss return ss
} }

View File

@ -31,6 +31,7 @@ func SecondFactorDuoPost(duoAPI duo.API) middlewares.RequestHandler {
values.Set("ipaddr", remoteIP) values.Set("ipaddr", remoteIP)
values.Set("factor", "push") values.Set("factor", "push")
values.Set("device", "auto") values.Set("device", "auto")
if requestBody.TargetURL != "" { if requestBody.TargetURL != "" {
values.Set("pushinfo", fmt.Sprintf("target%%20url=%s", requestBody.TargetURL)) values.Set("pushinfo", fmt.Sprintf("target%%20url=%s", requestBody.TargetURL))
} }

View File

@ -19,6 +19,7 @@ func SecondFactorTOTPPost(totpVerifier TOTPVerifier) middlewares.RequestHandler
} }
userSession := ctx.GetSession() userSession := ctx.GetSession()
secret, err := ctx.Providers.StorageProvider.LoadTOTPSecret(userSession.Username) secret, err := ctx.Providers.StorageProvider.LoadTOTPSecret(userSession.Username)
if err != nil { if err != nil {
ctx.Error(fmt.Errorf("Unable to load TOTP secret: %s", err), mfaValidationFailedMessage) ctx.Error(fmt.Errorf("Unable to load TOTP secret: %s", err), mfaValidationFailedMessage)

View File

@ -24,6 +24,7 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
} }
appID := fmt.Sprintf("%s://%s", ctx.XForwardedProto(), ctx.XForwardedHost()) appID := fmt.Sprintf("%s://%s", ctx.XForwardedProto(), ctx.XForwardedHost())
var trustedFacets = []string{appID} var trustedFacets = []string{appID}
challenge, err := u2f.NewChallenge(appID, trustedFacets) challenge, err := u2f.NewChallenge(appID, trustedFacets)
@ -40,7 +41,9 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
ctx.Error(fmt.Errorf("No device handle found for user %s", userSession.Username), mfaValidationFailedMessage) ctx.Error(fmt.Errorf("No device handle found for user %s", userSession.Username), mfaValidationFailedMessage)
return return
} }
ctx.Error(fmt.Errorf("Unable to retrieve U2F device handle: %s", err), mfaValidationFailedMessage) ctx.Error(fmt.Errorf("Unable to retrieve U2F device handle: %s", err), mfaValidationFailedMessage)
return return
} }

View File

@ -15,17 +15,22 @@ import (
func loadInfo(username string, storageProvider storage.Provider, preferences *UserPreferences, logger *logrus.Entry) []error { func loadInfo(username string, storageProvider storage.Provider, preferences *UserPreferences, logger *logrus.Entry) []error {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(3) wg.Add(3)
errors := make([]error, 0) errors := make([]error, 0)
go func() { go func() {
defer wg.Done() defer wg.Done()
method, err := storageProvider.LoadPreferred2FAMethod(username) method, err := storageProvider.LoadPreferred2FAMethod(username)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
logger.Error(err) logger.Error(err)
return return
} }
if method == "" { if method == "" {
preferences.Method = authentication.PossibleMethods[0] preferences.Method = authentication.PossibleMethods[0]
} else { } else {
@ -35,33 +40,42 @@ func loadInfo(username string, storageProvider storage.Provider, preferences *Us
go func() { go func() {
defer wg.Done() defer wg.Done()
_, _, err := storageProvider.LoadU2FDeviceHandle(username) _, _, err := storageProvider.LoadU2FDeviceHandle(username)
if err != nil { if err != nil {
if err == storage.ErrNoU2FDeviceHandle { if err == storage.ErrNoU2FDeviceHandle {
return return
} }
errors = append(errors, err) errors = append(errors, err)
logger.Error(err) logger.Error(err)
return return
} }
preferences.HasU2F = true preferences.HasU2F = true
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := storageProvider.LoadTOTPSecret(username) _, err := storageProvider.LoadTOTPSecret(username)
if err != nil { if err != nil {
if err == storage.ErrNoTOTPSecret { if err == storage.ErrNoTOTPSecret {
return return
} }
errors = append(errors, err) errors = append(errors, err)
logger.Error(err) logger.Error(err)
return return
} }
preferences.HasTOTP = true preferences.HasTOTP = true
}() }()
wg.Wait() wg.Wait()
return errors return errors
} }
@ -76,6 +90,7 @@ func UserInfoGet(ctx *middlewares.AutheliaCtx) {
ctx.Error(fmt.Errorf("Unable to load user information"), operationFailedMessage) ctx.Error(fmt.Errorf("Unable to load user information"), operationFailedMessage)
return return
} }
ctx.SetJSONBody(preferences) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. ctx.SetJSONBody(preferences) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
} }
@ -87,6 +102,7 @@ type MethodBody struct {
// MethodPreferencePost update the user preferences regarding 2FA method. // MethodPreferencePost update the user preferences regarding 2FA method.
func MethodPreferencePost(ctx *middlewares.AutheliaCtx) { func MethodPreferencePost(ctx *middlewares.AutheliaCtx) {
bodyJSON := MethodBody{} bodyJSON := MethodBody{}
err := ctx.ParseBody(&bodyJSON) err := ctx.ParseBody(&bodyJSON)
if err != nil { if err != nil {
ctx.Error(err, operationFailedMessage) ctx.Error(err, operationFailedMessage)

View File

@ -38,7 +38,9 @@ func getOriginalURL(ctx *middlewares.AutheliaCtx) (*url.URL, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to parse URL extracted from X-Original-URL header: %v", err) return nil, fmt.Errorf("Unable to parse URL extracted from X-Original-URL header: %v", err)
} }
ctx.Logger.Trace("Using X-Original-URL header content as targeted site URL") ctx.Logger.Trace("Using X-Original-URL header content as targeted site URL")
return url, nil return url, nil
} }
@ -55,6 +57,7 @@ func getOriginalURL(ctx *middlewares.AutheliaCtx) (*url.URL, error) {
} }
var requestURI string var requestURI string
scheme := append(forwardedProto, protoHostSeparator...) scheme := append(forwardedProto, protoHostSeparator...)
requestURI = string(append(scheme, requestURI = string(append(scheme,
append(forwardedHost, forwardedURI...)...)) append(forwardedHost, forwardedURI...)...))
@ -63,8 +66,10 @@ func getOriginalURL(ctx *middlewares.AutheliaCtx) (*url.URL, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to parse URL %s: %v", requestURI, err) return nil, fmt.Errorf("Unable to parse URL %s: %v", requestURI, err)
} }
ctx.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " + ctx.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " +
"to construct targeted site URL") "to construct targeted site URL")
return url, nil return url, nil
} }
@ -74,15 +79,19 @@ func parseBasicAuth(auth string) (username, password string, err error) {
if !strings.HasPrefix(auth, authPrefix) { if !strings.HasPrefix(auth, authPrefix) {
return "", "", fmt.Errorf("%s prefix not found in %s header", strings.Trim(authPrefix, " "), AuthorizationHeader) return "", "", fmt.Errorf("%s prefix not found in %s header", strings.Trim(authPrefix, " "), AuthorizationHeader)
} }
c, err := base64.StdEncoding.DecodeString(auth[len(authPrefix):]) c, err := base64.StdEncoding.DecodeString(auth[len(authPrefix):])
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
cs := string(c) cs := string(c)
s := strings.IndexByte(cs, ':') s := strings.IndexByte(cs, ':')
if s < 0 { if s < 0 {
return "", "", fmt.Errorf("Format of %s header must be user:password", AuthorizationHeader) return "", "", fmt.Errorf("Format of %s header must be user:password", AuthorizationHeader)
} }
return cs[:s], cs[s+1:], nil return cs[:s], cs[s+1:], nil
} }
@ -114,6 +123,7 @@ func isTargetURLAuthorized(authorizer *authorization.Authorizer, targetURL url.U
return Authorized return Authorized
} }
} }
return NotAuthorized return NotAuthorized
} }
@ -208,8 +218,10 @@ func verifySessionCookie(ctx *middlewares.AutheliaCtx, targetURL *url.URL, userS
if err != nil { if err != nil {
ctx.Logger.Error(fmt.Errorf("Unable to destroy user session after provider refresh didn't find the user: %s", err)) ctx.Logger.Error(fmt.Errorf("Unable to destroy user session after provider refresh didn't find the user: %s", err))
} }
return userSession.Username, userSession.Groups, authentication.NotAuthenticated, err return userSession.Username, userSession.Groups, authentication.NotAuthenticated, err
} }
ctx.Logger.Warnf("Error occurred while attempting to update user details from LDAP: %s", err) ctx.Logger.Warnf("Error occurred while attempting to update user details from LDAP: %s", err)
} }
@ -226,6 +238,7 @@ func handleUnauthorized(ctx *middlewares.AutheliaCtx, targetURL fmt.Stringer, us
if strings.Contains(redirectionURL, "/%23/") { if strings.Contains(redirectionURL, "/%23/") {
ctx.Logger.Warn("Characters /%23/ have been detected in redirection URL. This is not needed anymore, please strip it") ctx.Logger.Warn("Characters /%23/ have been detected in redirection URL. This is not needed anymore, please strip it")
} }
ctx.Logger.Infof("Access to %s is not authorized to user %s, redirecting to %s", targetURL.String(), username, redirectionURL) ctx.Logger.Infof("Access to %s is not authorized to user %s, redirecting to %s", targetURL.String(), username, redirectionURL)
ctx.Redirect(redirectionURL, 302) ctx.Redirect(redirectionURL, 302)
ctx.SetBodyString(fmt.Sprintf("Found. Redirecting to %s", redirectionURL)) ctx.SetBodyString(fmt.Sprintf("Found. Redirecting to %s", redirectionURL))
@ -248,6 +261,7 @@ func updateActivityTimestamp(ctx *middlewares.AutheliaCtx, isBasicAuth bool, use
// Mark current activity. // Mark current activity.
userSession.LastActivity = ctx.Clock.Now().Unix() userSession.LastActivity = ctx.Clock.Now().Unix()
return ctx.SaveSession(userSession) return ctx.SaveSession(userSession)
} }
@ -263,9 +277,11 @@ func generateVerifySessionHasUpToDateProfileTraceLogs(ctx *middlewares.AutheliaC
if len(groupsAdded) != 0 { if len(groupsAdded) != 0 {
groupsDelta = append(groupsDelta, fmt.Sprintf("Added: %s.", strings.Join(groupsAdded, ", "))) groupsDelta = append(groupsDelta, fmt.Sprintf("Added: %s.", strings.Join(groupsAdded, ", ")))
} }
if len(groupsRemoved) != 0 { if len(groupsRemoved) != 0 {
groupsDelta = append(groupsDelta, fmt.Sprintf("Removed: %s.", strings.Join(groupsRemoved, ", "))) groupsDelta = append(groupsDelta, fmt.Sprintf("Removed: %s.", strings.Join(groupsRemoved, ", ")))
} }
if len(groupsDelta) != 0 { if len(groupsDelta) != 0 {
ctx.Logger.Tracef("Updated groups detected for %s. %s", userSession.Username, strings.Join(groupsDelta, " ")) ctx.Logger.Tracef("Updated groups detected for %s. %s", userSession.Username, strings.Join(groupsDelta, " "))
} else { } else {
@ -277,9 +293,11 @@ func generateVerifySessionHasUpToDateProfileTraceLogs(ctx *middlewares.AutheliaC
if len(emailsAdded) != 0 { if len(emailsAdded) != 0 {
emailsDelta = append(emailsDelta, fmt.Sprintf("Added: %s.", strings.Join(emailsAdded, ", "))) emailsDelta = append(emailsDelta, fmt.Sprintf("Added: %s.", strings.Join(emailsAdded, ", ")))
} }
if len(emailsRemoved) != 0 { if len(emailsRemoved) != 0 {
emailsDelta = append(emailsDelta, fmt.Sprintf("Removed: %s.", strings.Join(emailsRemoved, ", "))) emailsDelta = append(emailsDelta, fmt.Sprintf("Removed: %s.", strings.Join(emailsRemoved, ", ")))
} }
if len(emailsDelta) != 0 { if len(emailsDelta) != 0 {
ctx.Logger.Tracef("Updated emails detected for %s. %s", userSession.Username, strings.Join(emailsDelta, " ")) ctx.Logger.Tracef("Updated emails detected for %s. %s", userSession.Username, strings.Join(emailsDelta, " "))
} else { } else {
@ -291,8 +309,8 @@ func verifySessionHasUpToDateProfile(ctx *middlewares.AutheliaCtx, targetURL *ur
refreshProfile bool, refreshProfileInterval time.Duration) error { refreshProfile bool, refreshProfileInterval time.Duration) error {
// TODO: Add a check for LDAP password changes based on a time format attribute. // TODO: Add a check for LDAP password changes based on a time format attribute.
// See https://docs.authelia.com/security/threat-model.html#potential-future-guarantees // See https://docs.authelia.com/security/threat-model.html#potential-future-guarantees
ctx.Logger.Tracef("Checking if we need check the authentication backend for an updated profile for %s.", userSession.Username) ctx.Logger.Tracef("Checking if we need check the authentication backend for an updated profile for %s.", userSession.Username)
if refreshProfile && userSession.Username != "" && targetURL != nil && if refreshProfile && userSession.Username != "" && targetURL != nil &&
ctx.Providers.Authorizer.IsURLMatchingRuleWithGroupSubjects(*targetURL) && ctx.Providers.Authorizer.IsURLMatchingRuleWithGroupSubjects(*targetURL) &&
(refreshProfileInterval == schema.RefreshIntervalAlways || userSession.RefreshTTL.Before(ctx.Clock.Now())) { (refreshProfileInterval == schema.RefreshIntervalAlways || userSession.RefreshTTL.Before(ctx.Clock.Now())) {
@ -305,6 +323,7 @@ func verifySessionHasUpToDateProfile(ctx *middlewares.AutheliaCtx, targetURL *ur
groupsDiff := utils.IsStringSlicesDifferent(userSession.Groups, details.Groups) groupsDiff := utils.IsStringSlicesDifferent(userSession.Groups, details.Groups)
emailsDiff := utils.IsStringSlicesDifferent(userSession.Emails, details.Emails) emailsDiff := utils.IsStringSlicesDifferent(userSession.Emails, details.Emails)
if !groupsDiff && !emailsDiff { if !groupsDiff && !emailsDiff {
ctx.Logger.Tracef("Updated profile not detected for %s.", userSession.Username) ctx.Logger.Tracef("Updated profile not detected for %s.", userSession.Username)
} else { } else {
@ -329,6 +348,7 @@ func verifySessionHasUpToDateProfile(ctx *middlewares.AutheliaCtx, targetURL *ur
return ctx.SaveSession(*userSession) return ctx.SaveSession(*userSession)
} }
} }
return nil return nil
} }
@ -336,6 +356,7 @@ func getProfileRefreshSettings(cfg schema.AuthenticationBackendConfiguration) (r
if cfg.Ldap != nil { if cfg.Ldap != nil {
if cfg.RefreshInterval != schema.ProfileRefreshDisabled { if cfg.RefreshInterval != schema.ProfileRefreshDisabled {
refresh = true refresh = true
if cfg.RefreshInterval != schema.ProfileRefreshAlways { if cfg.RefreshInterval != schema.ProfileRefreshAlways {
// Skip Error Check since validator checks it // Skip Error Check since validator checks it
refreshInterval, _ = utils.ParseDurationString(cfg.RefreshInterval) refreshInterval, _ = utils.ParseDurationString(cfg.RefreshInterval)
@ -344,6 +365,7 @@ func getProfileRefreshSettings(cfg schema.AuthenticationBackendConfiguration) (r
} }
} }
} }
return refresh, refreshInterval return refresh, refreshInterval
} }
@ -364,6 +386,7 @@ func VerifyGet(cfg schema.AuthenticationBackendConfiguration) middlewares.Reques
ctx.Logger.Error(fmt.Errorf("Scheme of target URL %s must be secure since cookies are "+ ctx.Logger.Error(fmt.Errorf("Scheme of target URL %s must be secure since cookies are "+
"only transported over a secure connection for security reasons", targetURL.String())) "only transported over a secure connection for security reasons", targetURL.String()))
ctx.ReplyUnauthorized() ctx.ReplyUnauthorized()
return return
} }
@ -371,11 +394,14 @@ func VerifyGet(cfg schema.AuthenticationBackendConfiguration) middlewares.Reques
ctx.Logger.Error(fmt.Errorf("The target URL %s is not under the protected domain %s", ctx.Logger.Error(fmt.Errorf("The target URL %s is not under the protected domain %s",
targetURL.String(), ctx.Configuration.Session.Domain)) targetURL.String(), ctx.Configuration.Session.Domain))
ctx.ReplyUnauthorized() ctx.ReplyUnauthorized()
return return
} }
var username string var username string
var groups []string var groups []string
var authLevel authentication.Level var authLevel authentication.Level
proxyAuthorization := ctx.Request.Header.Peek(AuthorizationHeader) proxyAuthorization := ctx.Request.Header.Peek(AuthorizationHeader)
@ -391,11 +417,14 @@ func VerifyGet(cfg schema.AuthenticationBackendConfiguration) middlewares.Reques
if err != nil { if err != nil {
ctx.Logger.Error(fmt.Sprintf("Error caught when verifying user authorization: %s", err)) ctx.Logger.Error(fmt.Sprintf("Error caught when verifying user authorization: %s", err))
if err := updateActivityTimestamp(ctx, isBasicAuth, username); err != nil { if err := updateActivityTimestamp(ctx, isBasicAuth, username); err != nil {
ctx.Error(fmt.Errorf("Unable to update last activity: %s", err), operationFailedMessage) ctx.Error(fmt.Errorf("Unable to update last activity: %s", err), operationFailedMessage)
return return
} }
handleUnauthorized(ctx, targetURL, username) handleUnauthorized(ctx, targetURL, username)
return return
} }

View File

@ -153,6 +153,7 @@ func TestShouldCheckAuthorizationMatching(t *testing.T) {
AuthLevel authentication.Level AuthLevel authentication.Level
ExpectedMatching authorizationMatching ExpectedMatching authorizationMatching
} }
rules := []Rule{ rules := []Rule{
{"bypass", authentication.NotAuthenticated, Authorized}, {"bypass", authentication.NotAuthenticated, Authorized},
{"bypass", authentication.OneFactor, Authorized}, {"bypass", authentication.OneFactor, Authorized},
@ -679,6 +680,7 @@ func TestIsDomainProtected(t *testing.T) {
GetURL := func(u string) *url.URL { GetURL := func(u string) *url.URL {
x, err := url.ParseRequestURI(u) x, err := url.ParseRequestURI(u)
require.NoError(t, err) require.NoError(t, err)
return x return x
} }
@ -701,6 +703,7 @@ func TestSchemeIsHTTPS(t *testing.T) {
GetURL := func(u string) *url.URL { GetURL := func(u string) *url.URL {
x, err := url.ParseRequestURI(u) x, err := url.ParseRequestURI(u)
require.NoError(t, err) require.NoError(t, err)
return x return x
} }
@ -718,6 +721,7 @@ func TestSchemeIsWSS(t *testing.T) {
GetURL := func(u string) *url.URL { GetURL := func(u string) *url.URL {
x, err := url.ParseRequestURI(u) x, err := url.ParseRequestURI(u)
require.NoError(t, err) require.NoError(t, err)
return x return x
} }
@ -854,6 +858,7 @@ func TestShouldGetRemovedUserGroupsFromBackend(t *testing.T) {
} }
verifyGet := VerifyGet(verifyGetCfg) verifyGet := VerifyGet(verifyGetCfg)
mock.UserProviderMock.EXPECT().GetDetails("john").Return(user, nil).Times(2) mock.UserProviderMock.EXPECT().GetDetails("john").Return(user, nil).Times(2)
clock := mocks.TestingClock{} clock := mocks.TestingClock{}
@ -968,6 +973,7 @@ func TestShouldGetAddedUserGroupsFromBackend(t *testing.T) {
// Reset otherwise we get the last 403 when we check the Response. Is there a better way to do this? // Reset otherwise we get the last 403 when we check the Response. Is there a better way to do this?
mock.Close() mock.Close()
mock = mocks.NewMockAutheliaCtx(t) mock = mocks.NewMockAutheliaCtx(t)
defer mock.Close() defer mock.Close()
err = mock.Ctx.SaveSession(userSession) err = mock.Ctx.SaveSession(userSession)

View File

@ -17,6 +17,7 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI string, username
} else { } else {
ctx.ReplyOK() ctx.ReplyOK()
} }
return return
} }
@ -37,6 +38,7 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI string, username
if requiredLevel == authorization.TwoFactor { if requiredLevel == authorization.TwoFactor {
ctx.Logger.Warnf("%s requires 2FA, cannot be redirected yet", targetURI) ctx.Logger.Warnf("%s requires 2FA, cannot be redirected yet", targetURI)
ctx.ReplyOK() ctx.ReplyOK()
return return
} }
@ -48,10 +50,12 @@ func Handle1FAResponse(ctx *middlewares.AutheliaCtx, targetURI string, username
} else { } else {
ctx.ReplyOK() ctx.ReplyOK()
} }
return return
} }
ctx.Logger.Debugf("Redirection URL %s is safe", targetURI) ctx.Logger.Debugf("Redirection URL %s is safe", targetURI)
response := redirectResponse{Redirect: targetURI} response := redirectResponse{Redirect: targetURI}
ctx.SetJSONBody(response) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. ctx.SetJSONBody(response) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
} }
@ -64,6 +68,7 @@ func Handle2FAResponse(ctx *middlewares.AutheliaCtx, targetURI string) {
} else { } else {
ctx.ReplyOK() ctx.ReplyOK()
} }
return return
} }

View File

@ -26,5 +26,6 @@ func (tv *TOTPVerifierImpl) Verify(token, secret string) (bool, error) {
Digits: otp.DigitsSix, Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1, Algorithm: otp.AlgorithmSHA1,
} }
return totp.ValidateCustom(token, secret, time.Now().UTC(), opts) return totp.ValidateCustom(token, secret, time.Now().UTC(), opts)
} }

View File

@ -27,5 +27,6 @@ func (uv *U2FVerifierImpl) Verify(keyHandle []byte, publicKey []byte,
// TODO(c.michaud): store the counter to help detecting cloned U2F keys. // TODO(c.michaud): store the counter to help detecting cloned U2F keys.
_, err := registration.Authenticate( _, err := registration.Authenticate(
signResponse, challenge, 0) signResponse, challenge, 0)
return err return err
} }

View File

@ -28,7 +28,9 @@ func InitializeLogger(filename string) error {
if err != nil { if err != nil {
return err return err
} }
logrus.SetOutput(f) logrus.SetOutput(f)
} }
return nil return nil
} }

View File

@ -16,6 +16,7 @@ func TestShouldWriteLogsToFile(t *testing.T) {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
path := fmt.Sprintf("%s/authelia.log", dir) path := fmt.Sprintf("%s/authelia.log", dir)

View File

@ -32,6 +32,7 @@ func NewAutheliaCtx(ctx *fasthttp.RequestCtx, configuration schema.Configuration
autheliaCtx.Configuration = configuration autheliaCtx.Configuration = configuration
autheliaCtx.Logger = NewRequestLogger(autheliaCtx) autheliaCtx.Logger = NewRequestLogger(autheliaCtx)
autheliaCtx.Clock = utils.RealClock{} autheliaCtx.Clock = utils.RealClock{}
return autheliaCtx, nil return autheliaCtx, nil
} }
@ -44,6 +45,7 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers)
autheliaCtx.Error(err, operationFailedMessage) autheliaCtx.Error(err, operationFailedMessage)
return return
} }
next(autheliaCtx) next(autheliaCtx)
} }
} }
@ -78,7 +80,6 @@ func (c *AutheliaCtx) ReplyError(err error, message string) {
// ReplyUnauthorized response sent when user is unauthorized. // ReplyUnauthorized response sent when user is unauthorized.
func (c *AutheliaCtx) ReplyUnauthorized() { func (c *AutheliaCtx) ReplyUnauthorized() {
c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized) c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
// c.Response.Header.Set("WWW-Authenticate", "Basic realm=Restricted")
} }
// ReplyForbidden response sent when access is forbidden to user. // ReplyForbidden response sent when access is forbidden to user.
@ -113,6 +114,7 @@ func (c *AutheliaCtx) GetSession() session.UserSession {
c.Logger.Error("Unable to retrieve user session") c.Logger.Error("Unable to retrieve user session")
return session.NewDefaultUserSession() return session.NewDefaultUserSession()
} }
return userSession return userSession
} }
@ -144,6 +146,7 @@ func (c *AutheliaCtx) ParseBody(value interface{}) error {
if !valid { if !valid {
return fmt.Errorf("Body is not valid") return fmt.Errorf("Body is not valid")
} }
return nil return nil
} }
@ -156,6 +159,7 @@ func (c *AutheliaCtx) SetJSONBody(value interface{}) error {
c.SetContentType("application/json") c.SetContentType("application/json")
c.SetBody(b) c.SetBody(b)
return nil return nil
} }
@ -169,5 +173,6 @@ func (c *AutheliaCtx) RemoteIP() net.IP {
return net.ParseIP(strings.Trim(ips[0], " ")) return net.ParseIP(strings.Trim(ips[0], " "))
} }
} }
return c.RequestCtx.RemoteIP() return c.RequestCtx.RemoteIP()
} }

View File

@ -24,6 +24,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
// In that case we reply ok to avoid user enumeration. // In that case we reply ok to avoid user enumeration.
ctx.Logger.Error(err) ctx.Logger.Error(err)
ctx.ReplyOK() ctx.ReplyOK()
return return
} }
@ -78,6 +79,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
ctx.Logger.Debugf("Sending an email to user %s (%s) to confirm identity for registering a device.", ctx.Logger.Debugf("Sending an email to user %s (%s) to confirm identity for registering a device.",
identity.Username, identity.Email) identity.Username, identity.Email)
err = ctx.Providers.Notifier.Send(identity.Email, args.MailTitle, buf.String()) err = ctx.Providers.Notifier.Send(identity.Email, args.MailTitle, buf.String())
if err != nil { if err != nil {
@ -93,6 +95,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(ctx *AutheliaCtx, username string)) RequestHandler { func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(ctx *AutheliaCtx, username string)) RequestHandler {
return func(ctx *AutheliaCtx) { return func(ctx *AutheliaCtx) {
var finishBody IdentityVerificationFinishBody var finishBody IdentityVerificationFinishBody
b := ctx.PostBody() b := ctx.PostBody()
err := json.Unmarshal(b, &finishBody) err := json.Unmarshal(b, &finishBody)
@ -139,7 +142,9 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
return return
} }
} }
ctx.Error(err, operationFailedMessage) ctx.Error(err, operationFailedMessage)
return return
} }

View File

@ -174,6 +174,7 @@ func createToken(secret string, username string, action string, expiresAt time.T
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
ss, _ := token.SignedString([]byte(secret)) ss, _ := token.SignedString([]byte(secret))
return ss return ss
} }

View File

@ -9,6 +9,7 @@ import (
func TestShouldCallNextFunction(t *testing.T) { func TestShouldCallNextFunction(t *testing.T) {
var val = false var val = false
f := func(ctx *fasthttp.RequestCtx) { val = true } f := func(ctx *fasthttp.RequestCtx) { val = true }
context := &fasthttp.RequestCtx{} context := &fasthttp.RequestCtx{}

View File

@ -11,6 +11,7 @@ func RequireFirstFactor(next RequestHandler) RequestHandler {
ctx.ReplyForbidden() ctx.ReplyForbidden()
return return
} }
next(ctx) next(ctx)
} }
} }

View File

@ -123,6 +123,7 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx {
mockAuthelia.Hook = hook mockAuthelia.Hook = hook
mockAuthelia.Ctx.Logger = logrus.NewEntry(logger) mockAuthelia.Ctx.Logger = logrus.NewEntry(logger)
return mockAuthelia return mockAuthelia
} }
@ -141,6 +142,7 @@ func (m *MockAutheliaCtx) Assert200KO(t *testing.T, message string) {
// Assert200OK assert a successful response from the service. // Assert200OK assert a successful response from the service.
func (m *MockAutheliaCtx) Assert200OK(t *testing.T, data interface{}) { func (m *MockAutheliaCtx) Assert200OK(t *testing.T, data interface{}) {
assert.Equal(t, 200, m.Ctx.Response.StatusCode()) assert.Equal(t, 200, m.Ctx.Response.StatusCode())
response := middlewares.OKResponse{ response := middlewares.OKResponse{
Status: "OK", Status: "OK",
Data: data, Data: data,

View File

@ -42,9 +42,11 @@ func (n *FileNotifier) StartupCheck() (bool, error) {
return false, err return false, err
} }
} }
if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil { if err := ioutil.WriteFile(n.path, []byte(""), fileNotifierMode); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
} }
@ -57,5 +59,6 @@ func (n *FileNotifier) Send(recipient, subject, body string) error {
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }

View File

@ -21,9 +21,11 @@ func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
if !server.TLS && !(server.Name == "localhost" || server.Name == "127.0.0.1" || server.Name == "::1") { if !server.TLS && !(server.Name == "localhost" || server.Name == "127.0.0.1" || server.Name == "::1") {
return "", nil, errors.New("connection over plain-text") return "", nil, errors.New("connection over plain-text")
} }
if server.Name != a.host { if server.Name != a.host {
return "", nil, errors.New("unexpected hostname from server") return "", nil, errors.New("unexpected hostname from server")
} }
return "LOGIN", []byte{}, nil return "LOGIN", []byte{}, nil
} }
@ -31,6 +33,7 @@ func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if !more { if !more {
return nil, nil return nil, nil
} }
switch { switch {
case bytes.Equal(fromServer, []byte("Username:")): case bytes.Equal(fromServer, []byte("Username:")):
return []byte(a.username), nil return []byte(a.username), nil

View File

@ -48,6 +48,7 @@ func NewSMTPNotifier(configuration schema.SMTPNotifierConfiguration) *SMTPNotifi
startupCheckAddress: configuration.StartupCheckAddress, startupCheckAddress: configuration.StartupCheckAddress,
} }
notifier.initializeTLSConfig() notifier.initializeTLSConfig()
return notifier return notifier
} }
@ -64,6 +65,7 @@ func (n *SMTPNotifier) initializeTLSConfig() {
if n.trustedCert != "" { if n.trustedCert != "" {
log.Debugf("Notifier SMTP client attempting to load certificate from %s", n.trustedCert) log.Debugf("Notifier SMTP client attempting to load certificate from %s", n.trustedCert)
if exists, err := utils.FileExists(n.trustedCert); exists { if exists, err := utils.FileExists(n.trustedCert); exists {
pem, err := ioutil.ReadFile(n.trustedCert) pem, err := ioutil.ReadFile(n.trustedCert)
if err != nil { if err != nil {
@ -83,6 +85,7 @@ func (n *SMTPNotifier) initializeTLSConfig() {
log.Warnf("Notifier SMTP failed to load cert from file (file does not exist) with error: %s", err) log.Warnf("Notifier SMTP failed to load cert from file (file does not exist) with error: %s", err)
} }
} }
n.tlsConfig = &tls.Config{ n.tlsConfig = &tls.Config{
InsecureSkipVerify: n.disableVerifyCert, //nolint:gosec // This is an intended config, we never default true, provide alternate options, and we constantly warn the user. InsecureSkipVerify: n.disableVerifyCert, //nolint:gosec // This is an intended config, we never default true, provide alternate options, and we constantly warn the user.
ServerName: n.host, ServerName: n.host,
@ -105,12 +108,14 @@ func (n *SMTPNotifier) startTLS() error {
if err := n.client.StartTLS(n.tlsConfig); err != nil { if err := n.client.StartTLS(n.tlsConfig); err != nil {
return err return err
} }
log.Debug("Notifier SMTP STARTTLS completed without error") log.Debug("Notifier SMTP STARTTLS completed without error")
} else if n.disableRequireTLS { } else if n.disableRequireTLS {
log.Warn("Notifier SMTP server does not support STARTTLS and SMTP configuration is set to disable the TLS requirement (only useful for unauthenticated emails over plain text)") log.Warn("Notifier SMTP server does not support STARTTLS and SMTP configuration is set to disable the TLS requirement (only useful for unauthenticated emails over plain text)")
} else { } else {
return errors.New("Notifier SMTP server does not support TLS and it is required by default (see documentation if you want to disable this highly recommended requirement)") return errors.New("Notifier SMTP server does not support TLS and it is required by default (see documentation if you want to disable this highly recommended requirement)")
} }
return nil return nil
} }
@ -126,16 +131,19 @@ func (n *SMTPNotifier) auth() error {
// Check the server supports AUTH, and get the mechanisms. // Check the server supports AUTH, and get the mechanisms.
ok, m := n.client.Extension("AUTH") ok, m := n.client.Extension("AUTH")
if ok { if ok {
var auth smtp.Auth
log.Debugf("Notifier SMTP server supports authentication with the following mechanisms: %s", m) log.Debugf("Notifier SMTP server supports authentication with the following mechanisms: %s", m)
mechanisms := strings.Split(m, " ") mechanisms := strings.Split(m, " ")
var auth smtp.Auth
// Adaptively select the AUTH mechanism to use based on what the server advertised. // Adaptively select the AUTH mechanism to use based on what the server advertised.
if utils.IsStringInSlice("PLAIN", mechanisms) { if utils.IsStringInSlice("PLAIN", mechanisms) {
auth = smtp.PlainAuth("", n.username, n.password, n.host) auth = smtp.PlainAuth("", n.username, n.password, n.host)
log.Debug("Notifier SMTP client attempting AUTH PLAIN with server") log.Debug("Notifier SMTP client attempting AUTH PLAIN with server")
} else if utils.IsStringInSlice("LOGIN", mechanisms) { } else if utils.IsStringInSlice("LOGIN", mechanisms) {
auth = newLoginAuth(n.username, n.password, n.host) auth = newLoginAuth(n.username, n.password, n.host)
log.Debug("Notifier SMTP client attempting AUTH LOGIN with server") log.Debug("Notifier SMTP client attempting AUTH LOGIN with server")
} }
@ -148,23 +156,30 @@ func (n *SMTPNotifier) auth() error {
if err := n.client.Auth(auth); err != nil { if err := n.client.Auth(auth); err != nil {
return err return err
} }
log.Debug("Notifier SMTP client authenticated successfully with the server") log.Debug("Notifier SMTP client authenticated successfully with the server")
return nil return nil
} }
return errors.New("Notifier SMTP server does not advertise the AUTH extension but config requires AUTH (password specified), either disable AUTH, or use an SMTP host that supports AUTH PLAIN or AUTH LOGIN") return errors.New("Notifier SMTP server does not advertise the AUTH extension but config requires AUTH (password specified), either disable AUTH, or use an SMTP host that supports AUTH PLAIN or AUTH LOGIN")
} }
log.Debug("Notifier SMTP config has no password specified so authentication is being skipped") log.Debug("Notifier SMTP config has no password specified so authentication is being skipped")
return nil return nil
} }
func (n *SMTPNotifier) compose(recipient, subject, body string) error { func (n *SMTPNotifier) compose(recipient, subject, body string) error {
log.Debugf("Notifier SMTP client attempting to send email body to %s", recipient) log.Debugf("Notifier SMTP client attempting to send email body to %s", recipient)
if !n.disableRequireTLS { if !n.disableRequireTLS {
_, ok := n.client.TLSConnectionState() _, ok := n.client.TLSConnectionState()
if !ok { if !ok {
return errors.New("Notifier SMTP client can't send an email over plain text connection") return errors.New("Notifier SMTP client can't send an email over plain text connection")
} }
} }
wc, err := n.client.Data() wc, err := n.client.Data()
if err != nil { if err != nil {
log.Debugf("Notifier SMTP client error while obtaining WriteCloser: %s", err) log.Debugf("Notifier SMTP client error while obtaining WriteCloser: %s", err)
@ -188,31 +203,39 @@ func (n *SMTPNotifier) compose(recipient, subject, body string) error {
log.Debugf("Notifier SMTP client error while closing the WriteCloser: %s", err) log.Debugf("Notifier SMTP client error while closing the WriteCloser: %s", err)
return err return err
} }
return nil return nil
} }
// Dial the SMTP server with the SMTPNotifier config. // Dial the SMTP server with the SMTPNotifier config.
func (n *SMTPNotifier) dial() error { func (n *SMTPNotifier) dial() error {
log.Debugf("Notifier SMTP client attempting connection to %s", n.address) log.Debugf("Notifier SMTP client attempting connection to %s", n.address)
if n.port == 465 { if n.port == 465 {
log.Warnf("Notifier SMTP client configured to connect to a SMTPS server. It's highly recommended you use a non SMTPS port and STARTTLS instead of SMTPS, as the protocol is long deprecated.") log.Warnf("Notifier SMTP client configured to connect to a SMTPS server. It's highly recommended you use a non SMTPS port and STARTTLS instead of SMTPS, as the protocol is long deprecated.")
conn, err := tls.Dial("tcp", n.address, n.tlsConfig) conn, err := tls.Dial("tcp", n.address, n.tlsConfig)
if err != nil { if err != nil {
return err return err
} }
client, err := smtp.NewClient(conn, n.host) client, err := smtp.NewClient(conn, n.host)
if err != nil { if err != nil {
return err return err
} }
n.client = client n.client = client
} else { } else {
client, err := smtp.Dial(n.address) client, err := smtp.Dial(n.address)
if err != nil { if err != nil {
return err return err
} }
n.client = client n.client = client
} }
log.Debug("Notifier SMTP client connected successfully") log.Debug("Notifier SMTP client connected successfully")
return nil return nil
} }
@ -258,6 +281,7 @@ func (n *SMTPNotifier) StartupCheck() (bool, error) {
// Send is used to send an email to a recipient. // Send is used to send an email to a recipient.
func (n *SMTPNotifier) Send(recipient, title, body string) error { func (n *SMTPNotifier) Send(recipient, title, body string) error {
subject := strings.ReplaceAll(n.subject, "{title}", title) subject := strings.ReplaceAll(n.subject, "{title}", title)
if err := n.dial(); err != nil { if err := n.dial(); err != nil {
return err return err
} }
@ -269,6 +293,7 @@ func (n *SMTPNotifier) Send(recipient, title, body string) error {
if err := n.startTLS(); err != nil { if err := n.startTLS(); err != nil {
return err return err
} }
if err := n.auth(); err != nil { if err := n.auth(); err != nil {
return err return err
} }
@ -278,6 +303,7 @@ func (n *SMTPNotifier) Send(recipient, title, body string) error {
log.Debugf("Notifier SMTP failed while sending MAIL FROM (using sender) with error: %s", err) log.Debugf("Notifier SMTP failed while sending MAIL FROM (using sender) with error: %s", err)
return err return err
} }
if err := n.client.Rcpt(recipient); err != nil { if err := n.client.Rcpt(recipient); err != nil {
log.Debugf("Notifier SMTP failed while sending RCPT TO (using recipient) with error: %s", err) log.Debugf("Notifier SMTP failed while sending RCPT TO (using recipient) with error: %s", err)
return err return err
@ -289,5 +315,6 @@ func (n *SMTPNotifier) Send(recipient, title, body string) error {
} }
log.Debug("Notifier SMTP client successfully sent email") log.Debug("Notifier SMTP client successfully sent email")
return nil return nil
} }

View File

@ -14,11 +14,13 @@ import (
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.Provider, clock utils.Clock) *Regulator { func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.Provider, clock utils.Clock) *Regulator {
regulator := &Regulator{storageProvider: provider} regulator := &Regulator{storageProvider: provider}
regulator.clock = clock regulator.clock = clock
if configuration != nil { if configuration != nil {
findTime, err := utils.ParseDurationString(configuration.FindTime) findTime, err := utils.ParseDurationString(configuration.FindTime)
if err != nil { if err != nil {
panic(err) panic(err)
} }
banTime, err := utils.ParseDurationString(configuration.BanTime) banTime, err := utils.ParseDurationString(configuration.BanTime)
if err != nil { if err != nil {
panic(err) panic(err)
@ -34,6 +36,7 @@ func NewRegulator(configuration *schema.RegulationConfiguration, provider storag
regulator.findTime = findTime regulator.findTime = findTime
regulator.banTime = banTime regulator.banTime = banTime
} }
return regulator return regulator
} }
@ -55,6 +58,7 @@ func (r *Regulator) Regulate(username string) (time.Time, error) {
if !r.enabled { if !r.enabled {
return time.Time{}, nil return time.Time{}, nil
} }
now := r.clock.Now() now := r.clock.Now()
// TODO(c.michaud): make sure FindTime < BanTime. // TODO(c.michaud): make sure FindTime < BanTime.
@ -65,6 +69,7 @@ func (r *Regulator) Regulate(username string) (time.Time, error) {
} }
latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.maxRetries) latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.maxRetries)
for _, attempt := range attempts { for _, attempt := range attempts {
if attempt.Successful || len(latestFailedAttempts) >= r.maxRetries { if attempt.Successful || len(latestFailedAttempts) >= r.maxRetries {
// We stop appending failed attempts once we find the first successful attempts or we reach // We stop appending failed attempts once we find the first successful attempts or we reach
@ -90,5 +95,6 @@ func (r *Regulator) Regulate(username string) (time.Time, error) {
bannedUntil := latestFailedAttempts[0].Time.Add(r.banTime) bannedUntil := latestFailedAttempts[0].Time.Add(r.banTime)
return bannedUntil, ErrUserIsBanned return bannedUntil, ErrUserIsBanned
} }
return time.Time{}, nil return time.Time{}, nil
} }

View File

@ -34,12 +34,15 @@ func ServeIndex(publicDir string) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) { return func(ctx *fasthttp.RequestCtx) {
nonce := utils.RandomString(32, alphaNumericRunes) nonce := utils.RandomString(32, alphaNumericRunes)
ctx.SetContentType("text/html; charset=utf-8") ctx.SetContentType("text/html; charset=utf-8")
ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("default-src 'self'; style-src 'self' 'nonce-%s'", nonce)) ctx.Response.Header.Add("Content-Security-Policy", fmt.Sprintf("default-src 'self'; style-src 'self' 'nonce-%s'", nonce))
err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ CSPNonce string }{CSPNonce: nonce}) err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ CSPNonce string }{CSPNonce: nonce})
if err != nil { if err != nil {
ctx.Error("An error occurred", 503) ctx.Error("An error occurred", 503)
logging.Logger().Errorf("Unable to execute template: %v", err) logging.Logger().Errorf("Unable to execute template: %v", err)
return return
} }
} }

View File

@ -46,6 +46,7 @@ func (e *EncryptingSerializer) Decode(dst *session.Dict, src []byte) error {
} }
dst.Reset() dst.Reset()
decryptedSrc, err := utils.Decrypt(src, &e.key) decryptedSrc, err := utils.Decrypt(src, &e.key)
if err != nil { if err != nil {
// If an error is thrown while decrypting, it's probably an old unencrypted session // If an error is thrown while decrypting, it's probably an old unencrypted session
@ -56,9 +57,11 @@ func (e *EncryptingSerializer) Decode(dst *session.Dict, src []byte) error {
if uerr != nil { if uerr != nil {
return fmt.Errorf("Unable to decrypt session: %s", err) return fmt.Errorf("Unable to decrypt session: %s", err)
} }
return nil return nil
} }
_, err = dst.UnmarshalMsg(decryptedSrc) _, err = dst.UnmarshalMsg(decryptedSrc)
return err return err
} }

View File

@ -29,18 +29,21 @@ func NewProvider(configuration schema.SessionConfiguration) *Provider {
if err != nil { if err != nil {
panic(err) panic(err)
} }
provider.RememberMe = duration provider.RememberMe = duration
duration, err = utils.ParseDurationString(configuration.Inactivity) duration, err = utils.ParseDurationString(configuration.Inactivity)
if err != nil { if err != nil {
panic(err) panic(err)
} }
provider.Inactivity = duration provider.Inactivity = duration
err = provider.sessionHolder.SetProvider(providerConfig.providerName, providerConfig.providerConfig) err = provider.sessionHolder.SetProvider(providerConfig.providerName, providerConfig.providerConfig)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return provider return provider
} }
@ -59,6 +62,7 @@ func (p *Provider) GetSession(ctx *fasthttp.RequestCtx) (UserSession, error) {
if !ok { if !ok {
userSession := NewDefaultUserSession() userSession := NewDefaultUserSession()
store.Set(userSessionStorerKey, userSession) store.Set(userSessionStorerKey, userSession)
return userSession, nil return userSession, nil
} }
@ -88,6 +92,7 @@ func (p *Provider) SaveSession(ctx *fasthttp.RequestCtx, userSession UserSession
store.Set(userSessionStorerKey, userSessionJSON) store.Set(userSessionStorerKey, userSessionJSON)
p.sessionHolder.Save(ctx, store) p.sessionHolder.Save(ctx, store)
return nil return nil
} }
@ -117,6 +122,7 @@ func (p *Provider) UpdateExpiration(ctx *fasthttp.RequestCtx, expiration time.Du
} }
p.sessionHolder.Save(ctx, store) p.sessionHolder.Save(ctx, store)
return nil return nil
} }

View File

@ -32,6 +32,7 @@ func NewProviderConfig(configuration schema.SessionConfiguration) ProviderConfig
} }
var providerConfig session.ProviderConfig var providerConfig session.ProviderConfig
var providerName string var providerName string
// If redis configuration is provided, then use the redis provider. // If redis configuration is provided, then use the redis provider.
@ -54,6 +55,7 @@ func NewProviderConfig(configuration schema.SessionConfiguration) ProviderConfig
providerName = "memory" providerName = "memory"
providerConfig = &memory.Config{} providerConfig = &memory.Config{}
} }
return ProviderConfig{ return ProviderConfig{
config: config, config: config,
providerName: providerName, providerName: providerName,

View File

@ -31,8 +31,8 @@ func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProv
if configuration.Port > 0 { if configuration.Port > 0 {
address += fmt.Sprintf(":%d", configuration.Port) address += fmt.Sprintf(":%d", configuration.Port)
} }
connectionString += fmt.Sprintf("tcp(%s)", address)
connectionString += fmt.Sprintf("tcp(%s)", address)
if configuration.Database != "" { if configuration.Database != "" {
connectionString += fmt.Sprintf("/%s", configuration.Database) connectionString += fmt.Sprintf("/%s", configuration.Database)
} }
@ -71,5 +71,6 @@ func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProv
if err := provider.initialize(db); err != nil { if err := provider.initialize(db); err != nil {
logging.Logger().Fatalf("Unable to initialize SQL database: %v", err) logging.Logger().Fatalf("Unable to initialize SQL database: %v", err)
} }
return &provider return &provider
} }

View File

@ -80,5 +80,6 @@ func NewPostgreSQLProvider(configuration schema.PostgreSQLStorageConfiguration)
if err := provider.initialize(db); err != nil { if err := provider.initialize(db); err != nil {
logging.Logger().Fatalf("Unable to initialize SQL database: %v", err) logging.Logger().Fatalf("Unable to initialize SQL database: %v", err)
} }
return &provider return &provider
} }

View File

@ -75,21 +75,26 @@ func (p *SQLProvider) initialize(db *sql.DB) error {
return fmt.Errorf("Unable to create table %s: %v", authenticationLogsTableName, err) return fmt.Errorf("Unable to create table %s: %v", authenticationLogsTableName, err)
} }
} }
return nil return nil
} }
// LoadPreferred2FAMethod load the preferred method for 2FA from sqlite db. // LoadPreferred2FAMethod load the preferred method for 2FA from sqlite db.
func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) { func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) {
var method string
rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username) rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username)
if err != nil { if err != nil {
return "", err return "", err
} }
defer rows.Close() defer rows.Close()
if !rows.Next() { if !rows.Next() {
return "", nil return "", nil
} }
var method string
err = rows.Scan(&method) err = rows.Scan(&method)
return method, err return method, err
} }
@ -102,10 +107,12 @@ func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) err
// FindIdentityVerificationToken look for an identity verification token in DB. // FindIdentityVerificationToken look for an identity verification token in DB.
func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) { func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) {
var found bool var found bool
err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found) err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found)
if err != nil { if err != nil {
return false, err return false, err
} }
return found, nil return found, nil
} }
@ -134,8 +141,10 @@ func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", ErrNoTOTPSecret return "", ErrNoTOTPSecret
} }
return "", err return "", err
} }
return secret, nil return secret, nil
} }
@ -151,6 +160,7 @@ func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, pub
username, username,
base64.StdEncoding.EncodeToString(keyHandle), base64.StdEncoding.EncodeToString(keyHandle),
base64.StdEncoding.EncodeToString(publicKey)) base64.StdEncoding.EncodeToString(publicKey))
return err return err
} }
@ -161,6 +171,7 @@ func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, erro
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil, ErrNoU2FDeviceHandle return nil, nil, ErrNoU2FDeviceHandle
} }
return nil, nil, err return nil, nil, err
} }
@ -187,6 +198,8 @@ func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttem
// LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log. // LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log.
func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) { func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
var t int64
rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username) rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username)
if err != nil { if err != nil {
@ -194,18 +207,20 @@ func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate tim
} }
attempts := make([]models.AuthenticationAttempt, 0, 10) attempts := make([]models.AuthenticationAttempt, 0, 10)
for rows.Next() { for rows.Next() {
attempt := models.AuthenticationAttempt{ attempt := models.AuthenticationAttempt{
Username: username, Username: username,
} }
var t int64
err = rows.Scan(&attempt.Successful, &t) err = rows.Scan(&attempt.Successful, &t)
attempt.Time = time.Unix(t, 0) attempt.Time = time.Unix(t, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
attempts = append(attempts, attempt) attempts = append(attempts, attempt)
} }
return attempts, nil return attempts, nil
} }

View File

@ -51,5 +51,6 @@ func NewSQLiteProvider(path string) *SQLiteProvider {
if err := provider.initialize(db); err != nil { if err := provider.initialize(db); err != nil {
logging.Logger().Fatalf("Unable to initialize SQLite database %s: %s", path, err) logging.Logger().Fatalf("Unable to initialize SQLite database %s: %s", path, err)
} }
return &provider return &provider
} }

View File

@ -19,5 +19,6 @@ func doHTTPGetQuery(t *testing.T, url string) []byte {
defer resp.Body.Close() defer resp.Body.Close()
body, _ := ioutil.ReadAll(resp.Body) body, _ := ioutil.ReadAll(resp.Body)
return body return body
} }

View File

@ -51,6 +51,7 @@ func (wds *WebDriverSession) doLoginAndRegisterTOTP(ctx context.Context, t *test
secret := wds.doRegisterTOTP(ctx, t) secret := wds.doRegisterTOTP(ctx, t)
wds.doVisit(t, LoginBaseURL) wds.doVisit(t, LoginBaseURL)
wds.verifyIsSecondFactorPage(ctx, t) wds.verifyIsSecondFactorPage(ctx, t)
return secret return secret
} }
@ -59,5 +60,6 @@ func (wds *WebDriverSession) doRegisterAndLogin2FA(ctx context.Context, t *testi
// Register TOTP secret and logout. // Register TOTP secret and logout.
secret := wds.doRegisterThenLogout(ctx, t, username, password) secret := wds.doRegisterThenLogout(ctx, t, username, password)
wds.doLoginTwoFactor(ctx, t, username, password, keepMeLoggedIn, secret, targetURL) wds.doLoginTwoFactor(ctx, t, username, password, keepMeLoggedIn, secret, targetURL)
return secret return secret
} }

View File

@ -28,5 +28,6 @@ func doGetLinkFromLastMail(t *testing.T) string {
matches := re.FindStringSubmatch(string(res)) matches := re.FindStringSubmatch(string(res))
assert.Len(t, matches, 2, "Number of match for link in email is not equal to one") assert.Len(t, matches, 2, "Number of match for link in email is not equal to one")
return matches[1] return matches[1]
} }

View File

@ -8,5 +8,6 @@ import (
func (wds *WebDriverSession) doRegisterThenLogout(ctx context.Context, t *testing.T, username, password string) string { func (wds *WebDriverSession) doRegisterThenLogout(ctx context.Context, t *testing.T, username, password string) string {
secret := wds.doLoginAndRegisterTOTP(ctx, t, username, password, false) secret := wds.doLoginAndRegisterTOTP(ctx, t, username, password, false)
wds.doLogout(ctx, t) wds.doLogout(ctx, t)
return secret return secret
} }

View File

@ -18,6 +18,7 @@ func (wds *WebDriverSession) doRegisterTOTP(ctx context.Context, t *testing.T) s
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, "", secret) assert.NotEqual(t, "", secret)
assert.NotNil(t, secret) assert.NotNil(t, secret)
return secret return secret
} }

View File

@ -23,5 +23,6 @@ func (wds *WebDriverSession) doVisitLoginPage(ctx context.Context, t *testing.T,
if targetURL != "" { if targetURL != "" {
suffix = fmt.Sprintf("?rd=%s", targetURL) suffix = fmt.Sprintf("?rd=%s", targetURL)
} }
wds.doVisitAndVerifyOneFactorStep(ctx, t, fmt.Sprintf("%s/%s", LoginBaseURL, suffix)) wds.doVisitAndVerifyOneFactorStep(ctx, t, fmt.Sprintf("%s/%s", LoginBaseURL, suffix))
} }

View File

@ -27,18 +27,21 @@ func NewDockerEnvironment(files []string) *DockerEnvironment {
files[i] = strings.ReplaceAll(files[i], "{}", "dev") files[i] = strings.ReplaceAll(files[i], "{}", "dev")
} }
} }
return &DockerEnvironment{dockerComposeFiles: files} return &DockerEnvironment{dockerComposeFiles: files}
} }
func (de *DockerEnvironment) createCommandWithStdout(cmd string) *exec.Cmd { func (de *DockerEnvironment) createCommandWithStdout(cmd string) *exec.Cmd {
dockerCmdLine := fmt.Sprintf("docker-compose -p authelia -f %s %s", strings.Join(de.dockerComposeFiles, " -f "), cmd) dockerCmdLine := fmt.Sprintf("docker-compose -p authelia -f %s %s", strings.Join(de.dockerComposeFiles, " -f "), cmd)
log.Trace(dockerCmdLine) log.Trace(dockerCmdLine)
return utils.CommandWithStdout("bash", "-c", dockerCmdLine) return utils.CommandWithStdout("bash", "-c", dockerCmdLine)
} }
func (de *DockerEnvironment) createCommand(cmd string) *exec.Cmd { func (de *DockerEnvironment) createCommand(cmd string) *exec.Cmd {
dockerCmdLine := fmt.Sprintf("docker-compose -p authelia -f %s %s", strings.Join(de.dockerComposeFiles, " -f "), cmd) dockerCmdLine := fmt.Sprintf("docker-compose -p authelia -f %s %s", strings.Join(de.dockerComposeFiles, " -f "), cmd)
log.Trace(dockerCmdLine) log.Trace(dockerCmdLine)
return utils.Command("bash", "-c", dockerCmdLine) return utils.Command("bash", "-c", dockerCmdLine)
} }
@ -61,5 +64,6 @@ func (de *DockerEnvironment) Down() error {
func (de *DockerEnvironment) Logs(service string, flags []string) (string, error) { func (de *DockerEnvironment) Logs(service string, flags []string) (string, error) {
cmd := de.createCommand(fmt.Sprintf("logs %s %s", strings.Join(flags, " "), service)) cmd := de.createCommand(fmt.Sprintf("logs %s %s", strings.Join(flags, " "), service))
content, err := cmd.Output() content, err := cmd.Output()
return string(content), err return string(content), err
} }

View File

@ -19,6 +19,7 @@ func waitUntilServiceLogDetected(
service string, service string,
logPatterns []string) error { logPatterns []string) error {
log.Debug("Waiting for service " + service + " to be ready...") log.Debug("Waiting for service " + service + " to be ready...")
err := utils.CheckUntil(5*time.Second, 1*time.Minute, func() (bool, error) { err := utils.CheckUntil(5*time.Second, 1*time.Minute, func() (bool, error) {
logs, err := dockerEnvironment.Logs(service, []string{"--tail", "20"}) logs, err := dockerEnvironment.Logs(service, []string{"--tail", "20"})
fmt.Printf(".") fmt.Printf(".")
@ -35,6 +36,7 @@ func waitUntilServiceLogDetected(
}) })
fmt.Print("\n") fmt.Print("\n")
return err return err
} }
@ -68,6 +70,8 @@ func waitUntilAutheliaIsReady(dockerEnvironment *DockerEnvironment) error {
return err return err
} }
} }
log.Info("Authelia is now ready!") log.Info("Authelia is now ready!")
return nil return nil
} }

View File

@ -12,6 +12,7 @@ func NewHTTPClient() *http.Client {
InsecureSkipVerify: true, //nolint:gosec // Needs to be enabled in suites. Not used in production. InsecureSkipVerify: true, //nolint:gosec // Needs to be enabled in suites. Not used in production.
}, },
} }
return &http.Client{ return &http.Client{
Transport: tr, Transport: tr,
CheckRedirect: func(req *http.Request, via []*http.Request) error { CheckRedirect: func(req *http.Request, via []*http.Request) error {

View File

@ -39,6 +39,7 @@ func (k Kind) CreateCluster() error {
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
return err return err
} }
return nil return nil
} }
@ -92,6 +93,7 @@ func (k Kubectl) StartDashboard() error {
if err := utils.Shell("docker-compose -p authelia -f internal/suites/docker-compose.yml -f internal/suites/example/compose/kind/docker-compose.yml up -d kube-dashboard").Run(); err != nil { if err := utils.Shell("docker-compose -p authelia -f internal/suites/docker-compose.yml -f internal/suites/example/compose/kind/docker-compose.yml up -d kube-dashboard").Run(); err != nil {
return err return err
} }
return nil return nil
} }

View File

@ -49,6 +49,7 @@ func (sr *Registry) Register(name string, suite Suite) {
if _, found := sr.registry[name]; found { if _, found := sr.registry[name]; found {
log.Fatal(fmt.Sprintf("Trying to register the suite %s multiple times", name)) log.Fatal(fmt.Sprintf("Trying to register the suite %s multiple times", name))
} }
sr.registry[name] = suite sr.registry[name] = suite
} }
@ -58,6 +59,7 @@ func (sr *Registry) Get(name string) Suite {
if !found { if !found {
log.Fatal(fmt.Sprintf("The suite %s does not exist", name)) log.Fatal(fmt.Sprintf("The suite %s does not exist", name))
} }
return s return s
} }
@ -67,5 +69,6 @@ func (sr *Registry) Suites() []string {
for k := range sr.registry { for k := range sr.registry {
suites = append(suites, k) suites = append(suites, k)
} }
return suites return suites
} }

View File

@ -54,6 +54,7 @@ func IsStringInList(str string, list []string) bool {
return true return true
} }
} }
return false return false
} }
@ -73,9 +74,11 @@ func (s *AvailableMethodsScenario) TestShouldCheckAvailableMethods() {
s.Assert().Len(options, len(s.methods)) s.Assert().Len(options, len(s.methods))
optionsList := make([]string, 0) optionsList := make([]string, 0)
for _, o := range options { for _, o := range options {
txt, err := o.Text() txt, err := o.Text()
s.Assert().NoError(err) s.Assert().NoError(err)
optionsList = append(optionsList, txt) optionsList = append(optionsList, txt)
} }

View File

@ -60,8 +60,8 @@ func (s *InactivityScenario) TestShouldRequireReauthenticationAfterInactivityPer
defer cancel() defer cancel()
targetURL := fmt.Sprintf("%s/secret.html", AdminBaseURL) targetURL := fmt.Sprintf("%s/secret.html", AdminBaseURL)
s.doLoginTwoFactor(ctx, s.T(), "john", "password", false, s.secret, "")
s.doLoginTwoFactor(ctx, s.T(), "john", "password", false, s.secret, "")
s.doVisit(s.T(), HomeBaseURL) s.doVisit(s.T(), HomeBaseURL)
s.verifyIsHome(ctx, s.T()) s.verifyIsHome(ctx, s.T())
@ -76,6 +76,7 @@ func (s *InactivityScenario) TestShouldRequireReauthenticationAfterCookieExpirat
defer cancel() defer cancel()
targetURL := fmt.Sprintf("%s/secret.html", AdminBaseURL) targetURL := fmt.Sprintf("%s/secret.html", AdminBaseURL)
s.doLoginTwoFactor(ctx, s.T(), "john", "password", false, s.secret, "") s.doLoginTwoFactor(ctx, s.T(), "john", "password", false, s.secret, "")
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
@ -83,6 +84,7 @@ func (s *InactivityScenario) TestShouldRequireReauthenticationAfterCookieExpirat
s.verifyIsHome(ctx, s.T()) s.verifyIsHome(ctx, s.T())
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
s.doVisit(s.T(), targetURL) s.doVisit(s.T(), targetURL)
s.verifySecretAuthorized(ctx, s.T()) s.verifySecretAuthorized(ctx, s.T())
} }
@ -101,8 +103,8 @@ func (s *InactivityScenario) TestShouldDisableCookieExpirationAndInactivity() {
defer cancel() defer cancel()
targetURL := fmt.Sprintf("%s/secret.html", AdminBaseURL) targetURL := fmt.Sprintf("%s/secret.html", AdminBaseURL)
s.doLoginTwoFactor(ctx, s.T(), "john", "password", true, s.secret, "")
s.doLoginTwoFactor(ctx, s.T(), "john", "password", true, s.secret, "")
s.doVisit(s.T(), HomeBaseURL) s.doVisit(s.T(), HomeBaseURL)
s.verifyIsHome(ctx, s.T()) s.verifyIsHome(ctx, s.T())

View File

@ -90,10 +90,10 @@ func (s *TwoFactorSuite) TestShouldFailTwoFactor() {
s.doRegisterThenLogout(ctx, s.T(), testUsername, testPassword) s.doRegisterThenLogout(ctx, s.T(), testUsername, testPassword)
wrongPasscode := "123456" wrongPasscode := "123456"
s.doLoginOneFactor(ctx, s.T(), testUsername, testPassword, false, "") s.doLoginOneFactor(ctx, s.T(), testUsername, testPassword, false, "")
s.verifyIsSecondFactorPage(ctx, s.T()) s.verifyIsSecondFactorPage(ctx, s.T())
s.doEnterOTP(ctx, s.T(), wrongPasscode) s.doEnterOTP(ctx, s.T(), wrongPasscode)
s.verifyNotificationDisplayed(ctx, s.T(), "The one-time password might be wrong") s.verifyNotificationDisplayed(ctx, s.T(), "The one-time password might be wrong")
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -29,13 +29,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -31,13 +31,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -36,13 +36,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := haDockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := haDockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -137,6 +137,7 @@ func (s *HighAvailabilityWebDriverSuite) TestShouldVerifyAccessControl() {
verifyUserIsAuthorized := func(ctx context.Context, t *testing.T, username, targetURL string, authorized bool) { //nolint:unparam verifyUserIsAuthorized := func(ctx context.Context, t *testing.T, username, targetURL string, authorized bool) { //nolint:unparam
s.doVisit(t, targetURL) s.doVisit(t, targetURL)
s.verifyURLIs(ctx, t, targetURL) s.verifyURLIs(ctx, t, targetURL)
if authorized { if authorized {
s.verifySecretAuthorized(ctx, t) s.verifySecretAuthorized(ctx, t)
} else { } else {
@ -182,6 +183,7 @@ func DoGetWithAuth(t *testing.T, username, password string) int {
res, err := client.Do(req) res, err := client.Do(req)
assert.NoError(t, err) assert.NoError(t, err)
return res.StatusCode return res.StatusCode
} }

View File

@ -44,6 +44,7 @@ func init() {
} }
log.Debug("Building authelia:dist image or use cache if already built...") log.Debug("Building authelia:dist image or use cache if already built...")
if os.Getenv("CI") != stringTrue { if os.Getenv("CI") != stringTrue {
if err := utils.Shell("authelia-scripts docker build").Run(); err != nil { if err := utils.Shell("authelia-scripts docker build").Run(); err != nil {
return err return err
@ -51,45 +52,54 @@ func init() {
} }
log.Debug("Loading images into Kubernetes container...") log.Debug("Loading images into Kubernetes container...")
if err := loadDockerImages(); err != nil { if err := loadDockerImages(); err != nil {
return err return err
} }
log.Debug("Starting Kubernetes dashboard...") log.Debug("Starting Kubernetes dashboard...")
if err := kubectl.StartDashboard(); err != nil { if err := kubectl.StartDashboard(); err != nil {
return err return err
} }
log.Debug("Deploying thirdparties...") log.Debug("Deploying thirdparties...")
if err := kubectl.DeployThirdparties(); err != nil { if err := kubectl.DeployThirdparties(); err != nil {
return err return err
} }
log.Debug("Waiting for services to be ready...") log.Debug("Waiting for services to be ready...")
if err := waitAllPodsAreReady(5 * time.Minute); err != nil { if err := waitAllPodsAreReady(5 * time.Minute); err != nil {
return err return err
} }
log.Debug("Deploying Authelia...") log.Debug("Deploying Authelia...")
if err = kubectl.DeployAuthelia(); err != nil { if err = kubectl.DeployAuthelia(); err != nil {
return err return err
} }
log.Debug("Waiting for services to be ready...") log.Debug("Waiting for services to be ready...")
if err := waitAllPodsAreReady(2 * time.Minute); err != nil { if err := waitAllPodsAreReady(2 * time.Minute); err != nil {
return err return err
} }
log.Debug("Starting proxy...") log.Debug("Starting proxy...")
if err := kubectl.StartProxy(); err != nil { if err := kubectl.StartProxy(); err != nil {
return err return err
} }
return nil return nil
} }
teardown := func(suitePath string) error { teardown := func(suitePath string) error {
kubectl.StopDashboard() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. kubectl.StopDashboard() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
kubectl.StopProxy() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. kubectl.StopProxy() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
return kind.DeleteCluster() return kind.DeleteCluster()
} }
@ -123,9 +133,12 @@ func waitAllPodsAreReady(timeout time.Duration) error {
// Wait in case the deployment has just been done and some services do not appear in kubectl logs. // Wait in case the deployment has just been done and some services do not appear in kubectl logs.
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
fmt.Println("Check services are running") fmt.Println("Check services are running")
if err := kubectl.WaitPodsReady(timeout); err != nil { if err := kubectl.WaitPodsReady(timeout); err != nil {
return err return err
} }
fmt.Println("All pods are ready") fmt.Println("All pods are ready")
return nil return nil
} }

View File

@ -35,13 +35,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -34,13 +34,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -23,6 +23,7 @@ func (s *NetworkACLSuite) TestShouldAccessSecretUpon2FA() {
wds, err := StartWebDriver() wds, err := StartWebDriver()
s.Require().NoError(err) s.Require().NoError(err)
defer wds.Stop() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. defer wds.Stop() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
targetURL := fmt.Sprintf("%s/secret.html", SecureBaseURL) targetURL := fmt.Sprintf("%s/secret.html", SecureBaseURL)
@ -40,6 +41,7 @@ func (s *NetworkACLSuite) TestShouldAccessSecretUpon1FA() {
wds, err := StartWebDriverWithProxy("http://proxy-client1.example.com:3128", 4444) wds, err := StartWebDriverWithProxy("http://proxy-client1.example.com:3128", 4444)
s.Require().NoError(err) s.Require().NoError(err)
defer wds.Stop() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. defer wds.Stop() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
targetURL := fmt.Sprintf("%s/secret.html", SecureBaseURL) targetURL := fmt.Sprintf("%s/secret.html", SecureBaseURL)
@ -58,6 +60,7 @@ func (s *NetworkACLSuite) TestShouldAccessSecretUpon0FA() {
wds, err := StartWebDriverWithProxy("http://proxy-client2.example.com:3128", 4444) wds, err := StartWebDriverWithProxy("http://proxy-client2.example.com:3128", 4444)
s.Require().NoError(err) s.Require().NoError(err)
defer wds.Stop() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting. defer wds.Stop() //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
wds.doVisit(s.T(), fmt.Sprintf("%s/secret.html", SecureBaseURL)) wds.doVisit(s.T(), fmt.Sprintf("%s/secret.html", SecureBaseURL))

View File

@ -30,13 +30,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -31,13 +31,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

View File

@ -33,13 +33,16 @@ func init() {
if err != nil { if err != nil {
return err return err
} }
fmt.Println(backendLogs) fmt.Println(backendLogs)
frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil) frontendLogs, err := dockerEnvironment.Logs("authelia-frontend", nil)
if err != nil { if err != nil {
return err return err
} }
fmt.Println(frontendLogs) fmt.Println(frontendLogs)
return nil return nil
} }

Some files were not shown because too many files have changed in this diff Show More