From 60ff16b518e54bcc2cbb7a3a3441c47b1d382be6 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Tue, 2 Feb 2021 12:01:46 +1100 Subject: [PATCH] fix(handlers): refresh user details on all domains (#1642) * fix(handlers): refresh user details on all domains * previously sessions only got checked for updated details if the domain had group subjects attached * this meant disabled or deleted accounts did not get detected until the session expired or the user visited a domain protected by a group subject * this patch fixes this issue and simplifies some logic surrounding the check * add tests simplify IsStringSlicesDifferent so it only iterates once * add another test for IsStringSlicesDifferent --- internal/authorization/authorizer.go | 18 --- internal/handlers/handler_verify.go | 81 +++++++------ internal/handlers/handler_verify_test.go | 146 +++++++++++++++++++---- internal/utils/strings.go | 10 +- internal/utils/strings_test.go | 7 ++ 5 files changed, 178 insertions(+), 84 deletions(-) diff --git a/internal/authorization/authorizer.go b/internal/authorization/authorizer.go index af0ee0dc..8d602250 100644 --- a/internal/authorization/authorizer.go +++ b/internal/authorization/authorizer.go @@ -130,21 +130,3 @@ func (p *Authorizer) GetRequiredLevel(subject Subject, requestURL url.URL) Level return PolicyToLevel(p.configuration.DefaultPolicy) } - -// IsURLMatchingRuleWithGroupSubjects returns true if the request has at least one -// matching ACL with a subject of type group attached to it, otherwise false. -func (p *Authorizer) IsURLMatchingRuleWithGroupSubjects(requestURL url.URL) (hasGroupSubjects bool) { - for _, rule := range p.configuration.Rules { - if isDomainMatching(requestURL.Hostname(), rule.Domains) && isPathMatching(requestURL.Path, rule.Resources) { - for _, subjectRule := range rule.Subjects { - for _, subject := range subjectRule { - if strings.HasPrefix(subject, groupPrefix) { - return true - } - } - } - } - } - - return false -} diff --git a/internal/handlers/handler_verify.go b/internal/handlers/handler_verify.go index 65c064ce..bc0d0101 100644 --- a/internal/handlers/handler_verify.go +++ b/internal/handlers/handler_verify.go @@ -323,47 +323,51 @@ func verifySessionHasUpToDateProfile(ctx *middlewares.AutheliaCtx, targetURL *ur // See https://docs.authelia.com/security/threat-model.html#potential-future-guarantees ctx.Logger.Tracef("Checking if we need check the authentication backend for an updated profile for %s.", userSession.Username) - if refreshProfile && userSession.Username != "" && targetURL != nil && - ctx.Providers.Authorizer.IsURLMatchingRuleWithGroupSubjects(*targetURL) && - (refreshProfileInterval == schema.RefreshIntervalAlways || userSession.RefreshTTL.Before(ctx.Clock.Now())) { - ctx.Logger.Debugf("Checking the authentication backend for an updated profile for user %s", userSession.Username) - details, err := ctx.Providers.UserProvider.GetDetails(userSession.Username) - // Only update the session if we could get the new details. - if err != nil { - return err - } + if !refreshProfile || userSession.Username == "" || targetURL == nil { + return nil + } - emailsDiff := utils.IsStringSlicesDifferent(userSession.Emails, details.Emails) - groupsDiff := utils.IsStringSlicesDifferent(userSession.Groups, details.Groups) - nameDiff := userSession.DisplayName != details.DisplayName + if refreshProfileInterval != schema.RefreshIntervalAlways && userSession.RefreshTTL.After(ctx.Clock.Now()) { + return nil + } - if !groupsDiff && !emailsDiff && !nameDiff { - ctx.Logger.Tracef("Updated profile not detected for %s.", userSession.Username) - // Only update TTL if the user has a interval set. - // We get to this check when there were no changes. - // Also make sure to update the session even if no difference was found. - // This is so that we don't check every subsequent request after this one. - if refreshProfileInterval != schema.RefreshIntervalAlways { - // Update RefreshTTL and save session if refresh is not set to always. - userSession.RefreshTTL = ctx.Clock.Now().Add(refreshProfileInterval) - return ctx.SaveSession(*userSession) - } - } else { - ctx.Logger.Debugf("Updated profile detected for %s.", userSession.Username) - if ctx.Configuration.LogLevel == "trace" { - generateVerifySessionHasUpToDateProfileTraceLogs(ctx, userSession, details) - } - userSession.Emails = details.Emails - userSession.Groups = details.Groups - userSession.DisplayName = details.DisplayName + ctx.Logger.Debugf("Checking the authentication backend for an updated profile for user %s", userSession.Username) + details, err := ctx.Providers.UserProvider.GetDetails(userSession.Username) + // Only update the session if we could get the new details. + if err != nil { + return err + } - // Only update TTL if the user has a interval set. - if refreshProfileInterval != schema.RefreshIntervalAlways { - userSession.RefreshTTL = ctx.Clock.Now().Add(refreshProfileInterval) - } - // Return the result of save session if there were changes. + emailsDiff := utils.IsStringSlicesDifferent(userSession.Emails, details.Emails) + groupsDiff := utils.IsStringSlicesDifferent(userSession.Groups, details.Groups) + nameDiff := userSession.DisplayName != details.DisplayName + + if !groupsDiff && !emailsDiff && !nameDiff { + ctx.Logger.Tracef("Updated profile not detected for %s.", userSession.Username) + // Only update TTL if the user has a interval set. + // We get to this check when there were no changes. + // Also make sure to update the session even if no difference was found. + // This is so that we don't check every subsequent request after this one. + if refreshProfileInterval != schema.RefreshIntervalAlways { + // Update RefreshTTL and save session if refresh is not set to always. + userSession.RefreshTTL = ctx.Clock.Now().Add(refreshProfileInterval) return ctx.SaveSession(*userSession) } + } else { + ctx.Logger.Debugf("Updated profile detected for %s.", userSession.Username) + if ctx.Configuration.LogLevel == "trace" { + generateVerifySessionHasUpToDateProfileTraceLogs(ctx, userSession, details) + } + userSession.Emails = details.Emails + userSession.Groups = details.Groups + userSession.DisplayName = details.DisplayName + + // Only update TTL if the user has a interval set. + if refreshProfileInterval != schema.RefreshIntervalAlways { + userSession.RefreshTTL = ctx.Clock.Now().Add(refreshProfileInterval) + } + // Return the result of save session if there were changes. + return ctx.SaveSession(*userSession) } // Return nil if disabled or if no changes and refresh interval set to always. @@ -372,7 +376,10 @@ func verifySessionHasUpToDateProfile(ctx *middlewares.AutheliaCtx, targetURL *ur func getProfileRefreshSettings(cfg schema.AuthenticationBackendConfiguration) (refresh bool, refreshInterval time.Duration) { if cfg.Ldap != nil { - if cfg.RefreshInterval != schema.ProfileRefreshDisabled { + if cfg.RefreshInterval == schema.ProfileRefreshDisabled { + refresh = false + refreshInterval = 0 + } else { refresh = true if cfg.RefreshInterval != schema.ProfileRefreshAlways { diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go index e570fb57..872d2d76 100644 --- a/internal/handlers/handler_verify_test.go +++ b/internal/handlers/handler_verify_test.go @@ -417,11 +417,15 @@ func TestShouldNotCrashOnEmptyEmail(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() + mock.Clock.Set(time.Now()) + userSession := mock.Ctx.GetSession() userSession.Username = testUsername userSession.Emails = nil userSession.AuthenticationLevel = authentication.OneFactor + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) + fmt.Printf("Time is %v\n", userSession.RefreshTTL) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -475,10 +479,13 @@ func TestShouldVerifyAuthorizationsUsingSessionCookie(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() + mock.Clock.Set(time.Now()) + userSession := mock.Ctx.GetSession() userSession.Username = testCase.Username userSession.Emails = testCase.Emails userSession.AuthenticationLevel = testCase.AuthenticationLevel + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -569,8 +576,7 @@ func TestShouldKeepSessionWhenUserCheckedRememberMeAndIsInactiveForTooLong(t *te mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() - clock := mocks.TestingClock{} - clock.Set(time.Now()) + mock.Clock.Set(time.Now()) mock.Ctx.Configuration.Session.Inactivity = testInactivity @@ -580,6 +586,7 @@ func TestShouldKeepSessionWhenUserCheckedRememberMeAndIsInactiveForTooLong(t *te userSession.AuthenticationLevel = authentication.TwoFactor userSession.LastActivity = 0 userSession.KeepMeLoggedIn = true + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -601,18 +608,18 @@ func TestShouldKeepSessionWhenInactivityTimeoutHasNotBeenExceeded(t *testing.T) mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() - clock := mocks.TestingClock{} - clock.Set(time.Now()) + mock.Clock.Set(time.Now()) mock.Ctx.Configuration.Session.Inactivity = testInactivity - past := clock.Now().Add(-1 * time.Hour) + past := mock.Clock.Now().Add(-1 * time.Hour) userSession := mock.Ctx.GetSession() userSession.Username = testUsername userSession.Emails = []string{"john.doe@example.com"} userSession.AuthenticationLevel = authentication.TwoFactor userSession.LastActivity = past.Unix() + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -627,7 +634,7 @@ func TestShouldKeepSessionWhenInactivityTimeoutHasNotBeenExceeded(t *testing.T) assert.Equal(t, authentication.TwoFactor, newUserSession.AuthenticationLevel) // Check the inactivity timestamp has been updated to current time in the new session. - assert.Equal(t, clock.Now().Unix(), newUserSession.LastActivity) + assert.Equal(t, mock.Clock.Now().Unix(), newUserSession.LastActivity) } // In the case of Traefik and Nginx ingress controller in Kube, the response to an inactive @@ -672,17 +679,17 @@ func TestShouldUpdateInactivityTimestampEvenWhenHittingForbiddenResources(t *tes mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() - clock := mocks.TestingClock{} - clock.Set(time.Now()) + mock.Clock.Set(time.Now()) mock.Ctx.Configuration.Session.Inactivity = testInactivity - past := clock.Now().Add(-1 * time.Hour) + past := mock.Clock.Now().Add(-1 * time.Hour) userSession := mock.Ctx.GetSession() userSession.Username = testUsername userSession.AuthenticationLevel = authentication.TwoFactor userSession.LastActivity = past.Unix() + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -696,16 +703,19 @@ func TestShouldUpdateInactivityTimestampEvenWhenHittingForbiddenResources(t *tes // Check the inactivity timestamp has been updated to current time in the new session. newUserSession := mock.Ctx.GetSession() - assert.Equal(t, clock.Now().Unix(), newUserSession.LastActivity) + assert.Equal(t, mock.Clock.Now().Unix(), newUserSession.LastActivity) } func TestShouldURLEncodeRedirectionURLParameter(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() + mock.Clock.Set(time.Now()) + userSession := mock.Ctx.GetSession() userSession.Username = testUsername userSession.AuthenticationLevel = authentication.NotAuthenticated + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -843,7 +853,7 @@ func TestShouldNotRefreshUserGroupsFromBackend(t *testing.T) { assert.Equal(t, "users", userSession.Groups[1]) } -func TestShouldNotRefreshUserGroupsFromBackendWhenNoGroupSubject(t *testing.T) { +func TestShouldNotRefreshUserGroupsFromBackendWhenDisabled(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() @@ -877,7 +887,11 @@ func TestShouldNotRefreshUserGroupsFromBackendWhenNoGroupSubject(t *testing.T) { require.NoError(t, err) mock.Ctx.Request.Header.Set("X-Original-URL", "https://two-factor.example.com") - VerifyGet(verifyGetCfg)(mock.Ctx) + + config := verifyGetCfg + config.RefreshInterval = schema.ProfileRefreshDisabled + + VerifyGet(config)(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) // Session time should NOT have been updated, it should still have a refresh TTL 1 minute in the past. @@ -885,6 +899,65 @@ func TestShouldNotRefreshUserGroupsFromBackendWhenNoGroupSubject(t *testing.T) { assert.Equal(t, clock.Now().Add(-1*time.Minute).Unix(), userSession.RefreshTTL.Unix()) } +func TestShouldDestroySessionWhenUserNotExist(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + // Setup user john. + user := &authentication.UserDetails{ + Username: "john", + Groups: []string{ + "admin", + "users", + }, + Emails: []string{ + "john@example.com", + }, + } + + mock.UserProviderMock.EXPECT().GetDetails("john").Return(user, nil).Times(1) + + clock := mocks.TestingClock{} + clock.Set(time.Now()) + + userSession := mock.Ctx.GetSession() + userSession.Username = user.Username + userSession.AuthenticationLevel = authentication.TwoFactor + userSession.LastActivity = clock.Now().Unix() + userSession.RefreshTTL = clock.Now().Add(-1 * time.Minute) + userSession.Groups = user.Groups + userSession.Emails = user.Emails + userSession.KeepMeLoggedIn = true + err := mock.Ctx.SaveSession(userSession) + + require.NoError(t, err) + + mock.Ctx.Request.Header.Set("X-Original-URL", "https://two-factor.example.com") + + VerifyGet(verifyGetCfg)(mock.Ctx) + assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) + + // Session time should NOT have been updated, it should still have a refresh TTL 1 minute in the past. + userSession = mock.Ctx.GetSession() + assert.Equal(t, clock.Now().Add(5*time.Minute).Unix(), userSession.RefreshTTL.Unix()) + + // Simulate a Deleted User + userSession.RefreshTTL = clock.Now().Add(-1 * time.Minute) + err = mock.Ctx.SaveSession(userSession) + + require.NoError(t, err) + + mock.UserProviderMock.EXPECT().GetDetails("john").Return(nil, authentication.ErrUserNotFound).Times(1) + + VerifyGet(verifyGetCfg)(mock.Ctx) + + assert.Equal(t, 401, mock.Ctx.Response.StatusCode()) + + userSession = mock.Ctx.GetSession() + assert.Equal(t, "", userSession.Username) + assert.Equal(t, authentication.NotAuthenticated, userSession.AuthenticationLevel) +} + func TestShouldGetRemovedUserGroupsFromBackend(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() @@ -970,18 +1043,17 @@ func TestShouldGetAddedUserGroupsFromBackend(t *testing.T) { }, } - mock.UserProviderMock.EXPECT().GetDetails("john").Times(0) + mock.UserProviderMock.EXPECT().GetDetails("john").Return(user, nil).Times(1) verifyGet := VerifyGet(verifyGetCfg) - clock := mocks.TestingClock{} - clock.Set(time.Now()) + mock.Clock.Set(time.Now()) userSession := mock.Ctx.GetSession() userSession.Username = user.Username userSession.AuthenticationLevel = authentication.TwoFactor - userSession.LastActivity = clock.Now().Unix() - userSession.RefreshTTL = clock.Now().Add(-1 * time.Minute) + userSession.LastActivity = mock.Clock.Now().Unix() + userSession.RefreshTTL = mock.Clock.Now().Add(-1 * time.Minute) userSession.Groups = user.Groups userSession.Emails = user.Emails userSession.KeepMeLoggedIn = true @@ -992,9 +1064,6 @@ func TestShouldGetAddedUserGroupsFromBackend(t *testing.T) { verifyGet(mock.Ctx) assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) - // Request should get refresh user profile. - mock.UserProviderMock.EXPECT().GetDetails("john").Return(user, nil).Times(1) - mock.Ctx.Request.Header.Set("X-Original-URL", "https://grafana.example.com") verifyGet(mock.Ctx) assert.Equal(t, 403, mock.Ctx.Response.StatusCode()) @@ -1004,13 +1073,13 @@ func TestShouldGetAddedUserGroupsFromBackend(t *testing.T) { // Check user groups are correct. require.Len(t, userSession.Groups, len(user.Groups)) - assert.Equal(t, clock.Now().Add(5*time.Minute).Unix(), userSession.RefreshTTL.Unix()) + assert.Equal(t, mock.Clock.Now().Add(5*time.Minute).Unix(), userSession.RefreshTTL.Unix()) assert.Equal(t, "admin", userSession.Groups[0]) assert.Equal(t, "users", userSession.Groups[1]) // Add the grafana group, and force the next request to refresh. user.Groups = append(user.Groups, "grafana") - userSession.RefreshTTL = clock.Now().Add(-1 * time.Second) + userSession.RefreshTTL = mock.Clock.Now().Add(-1 * time.Second) err = mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -1022,6 +1091,8 @@ func TestShouldGetAddedUserGroupsFromBackend(t *testing.T) { err = mock.Ctx.SaveSession(userSession) assert.NoError(t, err) + mock.Clock.Set(time.Now()) + gomock.InOrder( mock.UserProviderMock.EXPECT().GetDetails("john").Return(user, nil).Times(1), ) @@ -1034,7 +1105,7 @@ func TestShouldGetAddedUserGroupsFromBackend(t *testing.T) { userSession = mock.Ctx.GetSession() assert.Equal(t, true, userSession.KeepMeLoggedIn) assert.Equal(t, authentication.TwoFactor, userSession.AuthenticationLevel) - assert.Equal(t, clock.Now().Add(5*time.Minute).Unix(), userSession.RefreshTTL.Unix()) + assert.Equal(t, mock.Clock.Now().Add(5*time.Minute).Unix(), userSession.RefreshTTL.Unix()) require.Len(t, userSession.Groups, 3) assert.Equal(t, "admin", userSession.Groups[0]) assert.Equal(t, "users", userSession.Groups[1]) @@ -1045,11 +1116,14 @@ func TestShouldCheckValidSessionUsernameHeaderAndReturn200(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() + mock.Clock.Set(time.Now()) + expectedStatusCode := 200 userSession := mock.Ctx.GetSession() userSession.Username = testUsername userSession.AuthenticationLevel = authentication.OneFactor + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -1066,11 +1140,14 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) { mock := mocks.NewMockAutheliaCtx(t) defer mock.Close() + mock.Clock.Set(time.Now()) + expectedStatusCode := 401 userSession := mock.Ctx.GetSession() userSession.Username = testUsername userSession.AuthenticationLevel = authentication.OneFactor + userSession.RefreshTTL = mock.Clock.Now().Add(5 * time.Minute) err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) @@ -1082,3 +1159,26 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) { assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode()) assert.Equal(t, "Unauthorized", string(mock.Ctx.Response.Body())) } + +func TestGetProfileRefreshSettings(t *testing.T) { + cfg := verifyGetCfg + + refresh, interval := getProfileRefreshSettings(cfg) + + assert.Equal(t, true, refresh) + assert.Equal(t, 5*time.Minute, interval) + + cfg.RefreshInterval = schema.ProfileRefreshDisabled + + refresh, interval = getProfileRefreshSettings(cfg) + + assert.Equal(t, false, refresh) + assert.Equal(t, time.Duration(0), interval) + + cfg.RefreshInterval = schema.ProfileRefreshAlways + + refresh, interval = getProfileRefreshSettings(cfg) + + assert.Equal(t, true, refresh) + assert.Equal(t, time.Duration(0), interval) +} diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 8f5b86c9..abe9408e 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -60,14 +60,12 @@ func SliceString(s string, d int) (array []string) { // IsStringSlicesDifferent checks two slices of strings and on the first occurrence of a string item not existing in the // other slice returns true, otherwise returns false. func IsStringSlicesDifferent(a, b []string) (different bool) { - for _, s := range a { - if !IsStringInSlice(s, b) { - return true - } + if len(a) != len(b) { + return true } - for _, s := range b { - if !IsStringInSlice(s, a) { + for _, s := range a { + if !IsStringInSlice(s, b) { return true } } diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index 8ec896a1..ad73893d 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -69,6 +69,13 @@ func TestShouldNotFindSliceDifferences(t *testing.T) { assert.False(t, diff) } +func TestShouldFindSliceDifferenceWhenDifferentLength(t *testing.T) { + a := []string{"abc", "onetwothree"} + b := []string{"abc", "onetwothree", "more"} + diff := IsStringSlicesDifferent(a, b) + assert.True(t, diff) +} + func TestShouldFindStringInSliceContains(t *testing.T) { a := "abc" b := []string{"abc", "onetwothree"}