diff --git a/internal/handlers/handler_firstfactor.go b/internal/handlers/handler_firstfactor.go index 2d42ac7a..b637b3ba 100644 --- a/internal/handlers/handler_firstfactor.go +++ b/internal/handlers/handler_firstfactor.go @@ -102,6 +102,7 @@ func FirstFactorPost(ctx *middlewares.AutheliaCtx) { userSession.Emails = userDetails.Emails userSession.AuthenticationLevel = authentication.OneFactor userSession.LastActivity = time.Now().Unix() + userSession.KeepMeLoggedIn = *bodyJSON.KeepMeLoggedIn err = ctx.SaveSession(userSession) if err != nil { diff --git a/internal/handlers/handler_firstfactor_test.go b/internal/handlers/handler_firstfactor_test.go index 48738199..6129592a 100644 --- a/internal/handlers/handler_firstfactor_test.go +++ b/internal/handlers/handler_firstfactor_test.go @@ -151,7 +151,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() { s.mock.Assert200KO(s.T(), "Authentication failed. Check your credentials.") } -func (s *FirstFactorSuite) TestShouldAuthenticateUser() { +func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() { s.mock.UserProviderMock. EXPECT(). CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). @@ -171,10 +171,10 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUser() { Return(nil) s.mock.Ctx.Request.SetBodyString(`{ - "username": "test", - "password": "hello", - "keepMeLoggedIn": true - }`) + "username": "test", + "password": "hello", + "keepMeLoggedIn": true + }`) FirstFactorPost(s.mock.Ctx) // Respond with 200. @@ -184,10 +184,49 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUser() { // And store authentication in session. session := s.mock.Ctx.GetSession() assert.Equal(s.T(), "test", session.Username) + assert.Equal(s.T(), true, session.KeepMeLoggedIn) assert.Equal(s.T(), authentication.OneFactor, session.AuthenticationLevel) assert.Equal(s.T(), []string{"test@example.com"}, session.Emails) assert.Equal(s.T(), []string{"dev", "admins"}, session.Groups) +} +func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() { + s.mock.UserProviderMock. + EXPECT(). + CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). + Return(true, nil) + + s.mock.UserProviderMock. + EXPECT(). + GetDetails(gomock.Eq("test")). + Return(&authentication.UserDetails{ + Emails: []string{"test@example.com"}, + Groups: []string{"dev", "admins"}, + }, nil) + + s.mock.StorageProviderMock. + EXPECT(). + AppendAuthenticationLog(gomock.Any()). + Return(nil) + + s.mock.Ctx.Request.SetBodyString(`{ + "username": "test", + "password": "hello", + "keepMeLoggedIn": false + }`) + FirstFactorPost(s.mock.Ctx) + + // Respond with 200. + assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) + assert.Equal(s.T(), []byte("{\"status\":\"OK\"}"), s.mock.Ctx.Response.Body()) + + // And store authentication in session. + session := s.mock.Ctx.GetSession() + assert.Equal(s.T(), "test", session.Username) + assert.Equal(s.T(), false, session.KeepMeLoggedIn) + assert.Equal(s.T(), authentication.OneFactor, session.AuthenticationLevel) + assert.Equal(s.T(), []string{"test@example.com"}, session.Emails) + assert.Equal(s.T(), []string{"dev", "admins"}, session.Groups) } func TestFirstFactorSuite(t *testing.T) { diff --git a/internal/handlers/handler_verify.go b/internal/handlers/handler_verify.go index ba865049..41eb49de 100644 --- a/internal/handlers/handler_verify.go +++ b/internal/handlers/handler_verify.go @@ -139,24 +139,13 @@ func setForwardedHeaders(headers *fasthttp.ResponseHeader, username string, grou // hasUserBeenInactiveLongEnough check whether the user has been inactive for too long. func hasUserBeenInactiveLongEnough(ctx *middlewares.AutheliaCtx) (bool, error) { - expiration, err := ctx.Providers.SessionProvider.GetExpiration(ctx.RequestCtx) - - if err != nil { - return false, err - } - - // If the cookie has no expiration. - if expiration == 0 { - return false, nil - } - maxInactivityPeriod := ctx.Configuration.Session.Inactivity if maxInactivityPeriod == 0 { return false, nil } lastActivity := ctx.GetSession().LastActivity - inactivityPeriod := time.Now().Unix() - lastActivity + inactivityPeriod := ctx.Clock.Now().Unix() - lastActivity ctx.Logger.Tracef("Inactivity report: Inactivity=%d, MaxInactivity=%d", inactivityPeriod, maxInactivityPeriod) @@ -178,7 +167,7 @@ func verifyFromSessionCookie(targetURL url.URL, ctx *middlewares.AutheliaCtx) (u return "", nil, authentication.NotAuthenticated, fmt.Errorf("An anonymous user cannot be authenticated. That might be the sign of a compromise") } - if !isUserAnonymous { + if !userSession.KeepMeLoggedIn && !isUserAnonymous { inactiveLongEnough, err := hasUserBeenInactiveLongEnough(ctx) if err != nil { return "", nil, authentication.NotAuthenticated, fmt.Errorf("Unable to check if user has been inactive for a long time: %s", err) diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go index eaa39483..b1e3f507 100644 --- a/internal/handlers/handler_verify_test.go +++ b/internal/handlers/handler_verify_test.go @@ -5,6 +5,7 @@ import ( "net" "net/url" "testing" + "time" "github.com/authelia/authelia/internal/authentication" "github.com/authelia/authelia/internal/authorization" @@ -426,3 +427,79 @@ func TestShouldVerifyAuthorizationsUsingSessionCookie(t *testing.T) { }) } } + +func TestShouldDestroySessionWhenInactiveForTooLong(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + clock := mocks.TestingClock{} + clock.Set(time.Now()) + + mock.Ctx.Configuration.Session.Inactivity = 10 + + userSession := mock.Ctx.GetSession() + userSession.Username = "john" + userSession.AuthenticationLevel = authentication.TwoFactor + userSession.LastActivity = clock.Now().Add(-1 * time.Hour).Unix() + mock.Ctx.SaveSession(userSession) + + mock.Ctx.Request.Header.Set("X-Original-URL", "https://two-factor.example.com") + + VerifyGet(mock.Ctx) + + // The session has been destroyed + newUserSession := mock.Ctx.GetSession() + assert.Equal(t, "", newUserSession.Username) + assert.Equal(t, authentication.NotAuthenticated, newUserSession.AuthenticationLevel) +} + +func TestShouldKeepSessionWhenUserCheckedRememberMeAndIsInactiveForTooLong(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + clock := mocks.TestingClock{} + clock.Set(time.Now()) + + mock.Ctx.Configuration.Session.Inactivity = 10 + + userSession := mock.Ctx.GetSession() + userSession.Username = "john" + userSession.AuthenticationLevel = authentication.TwoFactor + userSession.LastActivity = clock.Now().Add(-1 * time.Hour).Unix() + userSession.KeepMeLoggedIn = true + mock.Ctx.SaveSession(userSession) + + mock.Ctx.Request.Header.Set("X-Original-URL", "https://two-factor.example.com") + + VerifyGet(mock.Ctx) + + // The session has been destroyed + newUserSession := mock.Ctx.GetSession() + assert.Equal(t, "john", newUserSession.Username) + assert.Equal(t, authentication.TwoFactor, newUserSession.AuthenticationLevel) +} + +func TestShouldKeepSessionWhenInactivityTimeoutHasNotBeenExceeded(t *testing.T) { + mock := mocks.NewMockAutheliaCtx(t) + defer mock.Close() + + clock := mocks.TestingClock{} + clock.Set(time.Now()) + + mock.Ctx.Configuration.Session.Inactivity = 10 + + userSession := mock.Ctx.GetSession() + userSession.Username = "john" + userSession.AuthenticationLevel = authentication.TwoFactor + userSession.LastActivity = clock.Now().Add(-1 * time.Second).Unix() + mock.Ctx.SaveSession(userSession) + + mock.Ctx.Request.Header.Set("X-Original-URL", "https://two-factor.example.com") + + VerifyGet(mock.Ctx) + + // The session has been destroyed + newUserSession := mock.Ctx.GetSession() + assert.Equal(t, "john", newUserSession.Username) + assert.Equal(t, authentication.TwoFactor, newUserSession.AuthenticationLevel) +} diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 6d3d7aba..8d97e2c5 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -9,6 +9,7 @@ import ( "github.com/asaskevich/govalidator" "github.com/authelia/authelia/internal/configuration/schema" "github.com/authelia/authelia/internal/session" + "github.com/authelia/authelia/internal/utils" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" ) @@ -29,13 +30,7 @@ func NewAutheliaCtx(ctx *fasthttp.RequestCtx, configuration schema.Configuration autheliaCtx.Providers = providers autheliaCtx.Configuration = configuration autheliaCtx.Logger = NewRequestLogger(autheliaCtx) - - userSession, err := providers.SessionProvider.GetSession(ctx) - if err != nil { - return autheliaCtx, fmt.Errorf("Unable to retrieve user session: %s", err.Error()) - } - - autheliaCtx.userSession = userSession + autheliaCtx.Clock = utils.RealClock{} return autheliaCtx, nil } @@ -112,12 +107,16 @@ func (c *AutheliaCtx) XOriginalURL() []byte { // GetSession return the user session. Any update will be saved in cache. func (c *AutheliaCtx) GetSession() session.UserSession { - return c.userSession + userSession, err := c.Providers.SessionProvider.GetSession(c.RequestCtx) + if err != nil { + c.Logger.Error("Unable to retrieve user session") + return session.NewDefaultUserSession() + } + return userSession } // SaveSession save the content of the session. func (c *AutheliaCtx) SaveSession(userSession session.UserSession) error { - c.userSession = userSession return c.Providers.SessionProvider.SaveSession(c.RequestCtx, userSession) } diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 6d98319d..9e8b9558 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -8,6 +8,7 @@ import ( "github.com/authelia/authelia/internal/regulation" "github.com/authelia/authelia/internal/session" "github.com/authelia/authelia/internal/storage" + "github.com/authelia/authelia/internal/utils" jwt "github.com/dgrijalva/jwt-go" "github.com/sirupsen/logrus" "github.com/valyala/fasthttp" @@ -20,7 +21,8 @@ type AutheliaCtx struct { Logger *logrus.Entry Providers Providers Configuration schema.Configuration - userSession session.UserSession + + Clock utils.Clock } // Providers contain all provider provided to Authelia. diff --git a/internal/session/provider_test.go b/internal/session/provider_test.go index 183d7dd2..f7f1b683 100644 --- a/internal/session/provider_test.go +++ b/internal/session/provider_test.go @@ -6,6 +6,7 @@ import ( "github.com/authelia/authelia/internal/authentication" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" @@ -20,7 +21,8 @@ func TestShouldInitializerSession(t *testing.T) { configuration.Expiration = 40 provider := NewProvider(configuration) - session, _ := provider.GetSession(ctx) + session, err := provider.GetSession(ctx) + require.NoError(t, err) assert.Equal(t, NewDefaultUserSession(), session) } @@ -38,12 +40,45 @@ func TestShouldUpdateSession(t *testing.T) { session.Username = "john" session.AuthenticationLevel = authentication.TwoFactor - _ = provider.SaveSession(ctx, session) + err := provider.SaveSession(ctx, session) + require.NoError(t, err) - session, _ = provider.GetSession(ctx) + session, err = provider.GetSession(ctx) + require.NoError(t, err) assert.Equal(t, UserSession{ Username: "john", AuthenticationLevel: authentication.TwoFactor, }, session) } + +func TestShouldDestroySessionAndWipeSessionData(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + configuration := schema.SessionConfiguration{} + configuration.Domain = "example.com" + configuration.Name = "my_session" + configuration.Expiration = 40 + + provider := NewProvider(configuration) + session, err := provider.GetSession(ctx) + require.NoError(t, err) + + session.Username = "john" + session.AuthenticationLevel = authentication.TwoFactor + + err = provider.SaveSession(ctx, session) + require.NoError(t, err) + + newUserSession, err := provider.GetSession(ctx) + require.NoError(t, err) + assert.Equal(t, "john", newUserSession.Username) + assert.Equal(t, authentication.TwoFactor, newUserSession.AuthenticationLevel) + + err = provider.DestroySession(ctx) + require.NoError(t, err) + + newUserSession, err = provider.GetSession(ctx) + require.NoError(t, err) + assert.Equal(t, "", newUserSession.Username) + assert.Equal(t, authentication.NotAuthenticated, newUserSession.AuthenticationLevel) +} diff --git a/internal/suites/scenario_inactivity_test.go b/internal/suites/scenario_inactivity_test.go index 87c779c5..bfca5dd5 100644 --- a/internal/suites/scenario_inactivity_test.go +++ b/internal/suites/scenario_inactivity_test.go @@ -106,7 +106,7 @@ func (s *InactivityScenario) TestShouldDisableCookieExpirationAndInactivity() { s.doVisit(s.T(), HomeBaseURL) s.verifyIsHome(ctx, s.T()) - time.Sleep(9 * time.Second) + time.Sleep(10 * time.Second) s.doVisit(s.T(), targetURL) s.verifySecretAuthorized(ctx, s.T())