fix(web): show appropriate default and available methods (#2999)

This ensures that; the method set when a user does not have a preference is a method that is available, that if a user has a preferred method that is not available it is changed to an enabled method with preference put on methods the user has configured, that the frontend does not show the method selection option when only one method is available.
This commit is contained in:
James Elliott 2022-03-28 12:26:30 +11:00 committed by GitHub
parent 2d8978c15a
commit 70ab8aab15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 632 additions and 66 deletions

View File

@ -317,6 +317,24 @@ paths:
description: Forbidden description: Forbidden
security: security:
- authelia_auth: [] - authelia_auth: []
post:
tags:
- User Information
summary: User Configuration
description: >
The user info endpoint provides detailed information including a users display name, preferred and registered
second factor method(s). The POST method also ensures the preferred method is configured correctly.
responses:
"200":
description: Successful Operation
content:
application/json:
schema:
$ref: '#/components/schemas/handlers.UserInfo'
"403":
description: Forbidden
security:
- authelia_auth: []
/api/user/info/totp: /api/user/info/totp:
get: get:
tags: tags:

View File

@ -16,15 +16,6 @@ const (
TwoFactor Level = iota TwoFactor Level = iota
) )
const (
// TOTP Method using Time-Based One-Time Password applications like Google Authenticator.
TOTP = "totp"
// Webauthn Method using Webauthn devices like YubiKeys.
Webauthn = "webauthn"
// Push Method using Duo application to receive push notifications.
Push = "mobile_push"
)
const ( const (
ldapSupportedExtensionAttribute = "supportedExtension" ldapSupportedExtensionAttribute = "supportedExtension"
ldapOIDPasswdModifyExtension = "1.3.6.1.4.1.4203.1.11.1" // http://oidref.com/1.3.6.1.4.1.4203.1.11.1 ldapOIDPasswdModifyExtension = "1.3.6.1.4.1.4203.1.11.1" // http://oidref.com/1.3.6.1.4.1.4203.1.11.1
@ -36,9 +27,6 @@ const (
ldapPlaceholderUsername = "{username}" ldapPlaceholderUsername = "{username}"
) )
// PossibleMethods is the set of all possible 2FA methods.
var PossibleMethods = []string{TOTP, Webauthn, Push}
// CryptAlgo the crypt representation of an algorithm used in the prefix of the hash. // CryptAlgo the crypt representation of an algorithm used in the prefix of the hash.
type CryptAlgo string type CryptAlgo string

View File

@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
) )
@ -12,17 +11,7 @@ func ConfigurationGet(ctx *middlewares.AutheliaCtx) {
} }
if ctx.Providers.Authorizer.IsSecondFactorEnabled() { if ctx.Providers.Authorizer.IsSecondFactorEnabled() {
if !ctx.Configuration.TOTP.Disable { body.AvailableMethods = ctx.AvailableSecondFactorMethods()
body.AvailableMethods = append(body.AvailableMethods, authentication.TOTP)
}
if !ctx.Configuration.Webauthn.Disable {
body.AvailableMethods = append(body.AvailableMethods, authentication.Webauthn)
}
if ctx.Configuration.DuoAPI != nil {
body.AvailableMethods = append(body.AvailableMethods, authentication.Push)
}
} }
ctx.Logger.Tracef("Available methods are %s", body.AvailableMethods) ctx.Logger.Tracef("Available methods are %s", body.AvailableMethods)

View File

@ -1,16 +1,61 @@
package handlers package handlers
import ( import (
"database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// UserInfoGet get the info related to the user identified by the session. // UserInfoPOST handles setting up info for users if necessary when they login.
func UserInfoGet(ctx *middlewares.AutheliaCtx) { func UserInfoPOST(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession()
var (
userInfo model.UserInfo
err error
)
if _, err = ctx.Providers.StorageProvider.LoadPreferred2FAMethod(ctx, userSession.Username); err != nil {
if errors.Is(err, sql.ErrNoRows) {
if err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, ""); err != nil {
ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed)
}
} else {
ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed)
}
}
if userInfo, err = ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username); err != nil {
ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed)
return
}
var (
changed bool
)
if changed = userInfo.SetDefaultPreferred2FAMethod(ctx.AvailableSecondFactorMethods()); changed {
if err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, userInfo.Method); err != nil {
ctx.Error(fmt.Errorf("unable to save user two factor method: %v", err), messageOperationFailed)
return
}
}
userInfo.DisplayName = userSession.DisplayName
err = ctx.SetJSONBody(userInfo)
if err != nil {
ctx.Logger.Errorf("Unable to set user info response in body: %s", err)
}
}
// UserInfoGET get the info related to the user identified by the session.
func UserInfoGET(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession() userSession := ctx.GetSession()
userInfo, err := ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username) userInfo, err := ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username)
@ -37,8 +82,8 @@ func MethodPreferencePost(ctx *middlewares.AutheliaCtx) {
return return
} }
if !utils.IsStringInSlice(bodyJSON.Method, authentication.PossibleMethods) { if !utils.IsStringInSlice(bodyJSON.Method, ctx.AvailableSecondFactorMethods()) {
ctx.Error(fmt.Errorf("unknown method '%s', it should be one of %s", bodyJSON.Method, strings.Join(authentication.PossibleMethods, ", ")), messageOperationFailed) ctx.Error(fmt.Errorf("unknown or unavailable method '%s', it should be one of %s", bodyJSON.Method, strings.Join(ctx.AvailableSecondFactorMethods(), ", ")), messageOperationFailed)
return return
} }

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
) )
@ -41,7 +42,17 @@ type expectedResponse struct {
err error err error
} }
func TestMethodSetToU2F(t *testing.T) { type expectedResponseAlt struct {
description string
db model.UserInfo
api *model.UserInfo
loadErr error
saveErr error
config *schema.Configuration
}
func TestUserInfoEndpoint_SetCorrectMethod(t *testing.T) {
expectedResponses := []expectedResponse{ expectedResponses := []expectedResponse{
{ {
db: model.UserInfo{ db: model.UserInfo{
@ -89,6 +100,9 @@ func TestMethodSetToU2F(t *testing.T) {
} }
mock := mocks.NewMockAutheliaCtx(t) mock := mocks.NewMockAutheliaCtx(t)
mock.Ctx.Configuration.DuoAPI = &schema.DuoAPIConfiguration{}
// Set the initial user session. // Set the initial user session.
userSession := mock.Ctx.GetSession() userSession := mock.Ctx.GetSession()
userSession.Username = testUsername userSession.Username = testUsername
@ -101,7 +115,7 @@ func TestMethodSetToU2F(t *testing.T) {
LoadUserInfo(mock.Ctx, gomock.Eq("john")). LoadUserInfo(mock.Ctx, gomock.Eq("john")).
Return(resp.db, resp.err) Return(resp.db, resp.err)
UserInfoGet(mock.Ctx) UserInfoGET(mock.Ctx)
if resp.err == nil { if resp.err == nil {
t.Run("expected status code", func(t *testing.T) { t.Run("expected status code", func(t *testing.T) {
@ -123,6 +137,207 @@ func TestMethodSetToU2F(t *testing.T) {
t.Run("registered totp", func(t *testing.T) { t.Run("registered totp", func(t *testing.T) {
assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP) assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP)
}) })
t.Run("registered duo", func(t *testing.T) {
assert.Equal(t, resp.api.HasDuo, actualPreferences.HasDuo)
})
} else {
t.Run("expected status code", func(t *testing.T) {
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
})
errResponse := mock.GetResponseError(t)
assert.Equal(t, "KO", errResponse.Status)
assert.Equal(t, "Operation failed.", errResponse.Message)
}
mock.Close()
}
}
func TestUserInfoEndpoint_SetDefaultMethod(t *testing.T) {
expectedResponses := []expectedResponseAlt{
{
description: "should set method to totp by default even when user doesn't have totp configured and no preferred method",
db: model.UserInfo{
Method: "",
HasTOTP: false,
HasWebauthn: false,
HasDuo: false,
},
api: &model.UserInfo{
Method: "totp",
HasTOTP: false,
HasWebauthn: false,
HasDuo: false,
},
config: &schema.Configuration{
DuoAPI: &schema.DuoAPIConfiguration{},
},
loadErr: nil,
saveErr: nil,
},
{
description: "should set method to duo by default when user has duo configured and no preferred method",
db: model.UserInfo{
Method: "",
HasTOTP: false,
HasWebauthn: false,
HasDuo: true,
},
api: &model.UserInfo{
Method: "mobile_push",
HasTOTP: false,
HasWebauthn: false,
HasDuo: true,
},
config: &schema.Configuration{
DuoAPI: &schema.DuoAPIConfiguration{},
},
loadErr: nil,
saveErr: nil,
},
{
description: "should set method to totp by default when user has duo configured and no preferred method but duo is not enabled",
db: model.UserInfo{
Method: "",
HasTOTP: false,
HasWebauthn: false,
HasDuo: true,
},
api: &model.UserInfo{
Method: "totp",
HasTOTP: false,
HasWebauthn: false,
HasDuo: true,
},
loadErr: nil,
saveErr: nil,
},
{
description: "should set method to duo by default when user has duo configured and no preferred method",
db: model.UserInfo{
Method: "",
HasTOTP: true,
HasWebauthn: true,
HasDuo: true,
},
api: &model.UserInfo{
Method: "webauthn",
HasTOTP: true,
HasWebauthn: true,
HasDuo: true,
},
config: &schema.Configuration{
TOTP: schema.TOTPConfiguration{
Disable: true,
},
DuoAPI: &schema.DuoAPIConfiguration{},
},
loadErr: nil,
saveErr: nil,
},
{
description: "should default new users to totp if all enabled",
db: model.UserInfo{
Method: "",
HasTOTP: false,
HasWebauthn: false,
HasDuo: false,
},
api: &model.UserInfo{
Method: "totp",
HasTOTP: true,
HasWebauthn: true,
HasDuo: true,
},
config: &schema.Configuration{
DuoAPI: &schema.DuoAPIConfiguration{},
},
loadErr: nil,
saveErr: errors.New("could not save"),
},
}
for _, resp := range expectedResponses {
if resp.api == nil {
resp.api = &resp.db
}
mock := mocks.NewMockAutheliaCtx(t)
if resp.config != nil {
mock.Ctx.Configuration = *resp.config
}
// Set the initial user session.
userSession := mock.Ctx.GetSession()
userSession.Username = testUsername
userSession.AuthenticationLevel = 1
err := mock.Ctx.SaveSession(userSession)
require.NoError(t, err)
if resp.db.Method == "" {
gomock.InOrder(
mock.StorageMock.
EXPECT().
LoadPreferred2FAMethod(mock.Ctx, gomock.Eq("john")).
Return("", sql.ErrNoRows),
mock.StorageMock.
EXPECT().
SavePreferred2FAMethod(mock.Ctx, gomock.Eq("john"), gomock.Eq("")).
Return(resp.saveErr),
mock.StorageMock.
EXPECT().
LoadUserInfo(mock.Ctx, gomock.Eq("john")).
Return(resp.db, nil),
mock.StorageMock.EXPECT().
SavePreferred2FAMethod(mock.Ctx, gomock.Eq("john"), gomock.Eq(resp.api.Method)).
Return(resp.saveErr),
)
} else {
gomock.InOrder(
mock.StorageMock.
EXPECT().
LoadPreferred2FAMethod(mock.Ctx, gomock.Eq("john")).
Return(resp.db.Method, nil),
mock.StorageMock.
EXPECT().
LoadUserInfo(mock.Ctx, gomock.Eq("john")).
Return(resp.db, nil),
mock.StorageMock.EXPECT().
SavePreferred2FAMethod(mock.Ctx, gomock.Eq("john"), gomock.Eq(resp.api.Method)).
Return(resp.saveErr),
)
}
UserInfoPOST(mock.Ctx)
if resp.loadErr == nil && resp.saveErr == nil {
t.Run(fmt.Sprintf("%s/%s", resp.description, "expected status code"), func(t *testing.T) {
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
})
actualPreferences := model.UserInfo{}
mock.GetResponseData(t, &actualPreferences)
t.Run(fmt.Sprintf("%s/%s", resp.description, "expected method"), func(t *testing.T) {
assert.Equal(t, resp.api.Method, actualPreferences.Method)
})
t.Run(fmt.Sprintf("%s/%s", resp.description, "registered webauthn"), func(t *testing.T) {
assert.Equal(t, resp.api.HasWebauthn, actualPreferences.HasWebauthn)
})
t.Run(fmt.Sprintf("%s/%s", resp.description, "registered totp"), func(t *testing.T) {
assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP)
})
t.Run(fmt.Sprintf("%s/%s", resp.description, "registered duo"), func(t *testing.T) {
assert.Equal(t, resp.api.HasDuo, actualPreferences.HasDuo)
})
} else { } else {
t.Run("expected status code", func(t *testing.T) { t.Run("expected status code", func(t *testing.T) {
assert.Equal(t, 200, mock.Ctx.Response.StatusCode()) assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
@ -143,7 +358,7 @@ func (s *FetchSuite) TestShouldReturnError500WhenStorageFailsToLoad() {
LoadUserInfo(s.mock.Ctx, gomock.Eq("john")). LoadUserInfo(s.mock.Ctx, gomock.Eq("john")).
Return(model.UserInfo{}, fmt.Errorf("failure")) Return(model.UserInfo{}, fmt.Errorf("failure"))
UserInfoGet(s.mock.Ctx) UserInfoGET(s.mock.Ctx)
s.mock.Assert200KO(s.T(), "Operation failed.") s.mock.Assert200KO(s.T(), "Operation failed.")
assert.Equal(s.T(), "unable to load user information: failure", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "unable to load user information: failure", s.mock.Hook.LastEntry().Message)
@ -205,7 +420,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenBadMethodProvided() {
MethodPreferencePost(s.mock.Ctx) MethodPreferencePost(s.mock.Ctx)
s.mock.Assert200KO(s.T(), "Operation failed.") s.mock.Assert200KO(s.T(), "Operation failed.")
assert.Equal(s.T(), "unknown method 'abc', it should be one of totp, webauthn, mobile_push", s.mock.Hook.LastEntry().Message) assert.Equal(s.T(), "unknown or unavailable method 'abc', it should be one of totp, webauthn", s.mock.Hook.LastEntry().Message)
assert.Equal(s.T(), logrus.ErrorLevel, s.mock.Hook.LastEntry().Level) assert.Equal(s.T(), logrus.ErrorLevel, s.mock.Hook.LastEntry().Level)
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -54,6 +55,25 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers)
} }
} }
// AvailableSecondFactorMethods returns the available 2FA methods.
func (ctx *AutheliaCtx) AvailableSecondFactorMethods() (methods []string) {
methods = make([]string, 0, 3)
if !ctx.Configuration.TOTP.Disable {
methods = append(methods, model.SecondFactorMethodTOTP)
}
if !ctx.Configuration.Webauthn.Disable {
methods = append(methods, model.SecondFactorMethodWebauthn)
}
if ctx.Configuration.DuoAPI != nil {
methods = append(methods, model.SecondFactorMethodDuo)
}
return methods
}
// Error reply with an error and display the stack trace in the logs. // Error reply with an error and display the stack trace in the logs.
func (ctx *AutheliaCtx) Error(err error, message string) { func (ctx *AutheliaCtx) Error(err error, message string) {
ctx.SetJSONError(message) ctx.SetJSONError(message)

View File

@ -11,6 +11,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
) )
@ -115,3 +116,26 @@ func TestShouldDetectNonXHR(t *testing.T) {
assert.False(t, mock.Ctx.IsXHR()) assert.False(t, mock.Ctx.IsXHR())
} }
func TestShouldReturnCorrectSecondFactorMethods(t *testing.T) {
mock := mocks.NewMockAutheliaCtx(t)
defer mock.Close()
assert.Equal(t, []string{model.SecondFactorMethodTOTP, model.SecondFactorMethodWebauthn}, mock.Ctx.AvailableSecondFactorMethods())
mock.Ctx.Configuration.DuoAPI = &schema.DuoAPIConfiguration{}
assert.Equal(t, []string{model.SecondFactorMethodTOTP, model.SecondFactorMethodWebauthn, model.SecondFactorMethodDuo}, mock.Ctx.AvailableSecondFactorMethods())
mock.Ctx.Configuration.TOTP.Disable = true
assert.Equal(t, []string{model.SecondFactorMethodWebauthn, model.SecondFactorMethodDuo}, mock.Ctx.AvailableSecondFactorMethods())
mock.Ctx.Configuration.Webauthn.Disable = true
assert.Equal(t, []string{model.SecondFactorMethodDuo}, mock.Ctx.AvailableSecondFactorMethods())
mock.Ctx.Configuration.DuoAPI = nil
assert.Equal(t, []string{}, mock.Ctx.AvailableSecondFactorMethods())
}

View File

@ -6,3 +6,14 @@ const (
errFmtScanInvalidType = "cannot scan model type '%T' from type '%T' with value '%v'" errFmtScanInvalidType = "cannot scan model type '%T' from type '%T' with value '%v'"
errFmtScanInvalidTypeErr = "cannot scan model type '%T' from type '%T' with value '%v': %w" errFmtScanInvalidTypeErr = "cannot scan model type '%T' from type '%T' with value '%v': %w"
) )
const (
// SecondFactorMethodTOTP method using Time-Based One-Time Password applications like Google Authenticator.
SecondFactorMethodTOTP = "totp"
// SecondFactorMethodWebauthn method using Webauthn devices like YubiKey's.
SecondFactorMethodWebauthn = "webauthn"
// SecondFactorMethodDuo method using Duo application to receive push notifications.
SecondFactorMethodDuo = "mobile_push"
)

View File

@ -1,5 +1,9 @@
package model package model
import (
"github.com/authelia/authelia/v4/internal/utils"
)
// UserInfo represents the user information required by the web UI. // UserInfo represents the user information required by the web UI.
type UserInfo struct { type UserInfo struct {
// The users display name. // The users display name.
@ -17,3 +21,38 @@ type UserInfo struct {
// True if a duo device has been configured as the preferred. // True if a duo device has been configured as the preferred.
HasDuo bool `db:"has_duo" json:"has_duo" valid:"required"` HasDuo bool `db:"has_duo" json:"has_duo" valid:"required"`
} }
// SetDefaultPreferred2FAMethod configures the default method based on what is configured as available and the users available methods.
func (i *UserInfo) SetDefaultPreferred2FAMethod(methods []string) (changed bool) {
if len(methods) == 0 {
// No point attempting to change the method if no methods are available.
return false
}
before := i.Method
totp, webauthn, duo := utils.IsStringInSlice(SecondFactorMethodTOTP, methods), utils.IsStringInSlice(SecondFactorMethodWebauthn, methods), utils.IsStringInSlice(SecondFactorMethodDuo, methods)
if i.Method != "" && !utils.IsStringInSlice(i.Method, methods) {
i.Method = ""
}
if i.Method == "" {
switch {
case i.HasTOTP && totp:
i.Method = SecondFactorMethodTOTP
case i.HasWebauthn && webauthn:
i.Method = SecondFactorMethodWebauthn
case i.HasDuo && duo:
i.Method = SecondFactorMethodDuo
case totp:
i.Method = SecondFactorMethodTOTP
case webauthn:
i.Method = SecondFactorMethodWebauthn
case duo:
i.Method = SecondFactorMethodDuo
}
}
return before != i.Method
}

View File

@ -0,0 +1,222 @@
package model
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestUserInfo_SetDefaultMethod_ShouldConfigureConfigDefault(t *testing.T) {
none := "none"
testName := func(i int, have UserInfo, availableMethods []string) string {
method := have.Method
if method == "" {
method = none
}
has := ""
if have.HasTOTP || have.HasDuo || have.HasWebauthn {
has += " has"
if have.HasTOTP {
has += " " + SecondFactorMethodTOTP
}
if have.HasDuo {
has += " " + SecondFactorMethodDuo
}
if have.HasWebauthn {
has += " " + SecondFactorMethodWebauthn
}
}
available := none
if len(availableMethods) != 0 {
available = strings.Join(availableMethods, " ")
}
return fmt.Sprintf("%d/method %s%s/available methods %s", i+1, method, has, available)
}
testCases := []struct {
have UserInfo
availableMethods []string
changed bool
want UserInfo
}{
{
have: UserInfo{
Method: SecondFactorMethodTOTP,
HasDuo: true,
HasTOTP: true,
HasWebauthn: true,
},
availableMethods: []string{SecondFactorMethodWebauthn, SecondFactorMethodDuo},
changed: true,
want: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: true,
HasTOTP: true,
HasWebauthn: true,
},
},
{
have: UserInfo{
HasDuo: true,
HasTOTP: true,
HasWebauthn: true,
},
availableMethods: []string{SecondFactorMethodTOTP, SecondFactorMethodWebauthn, SecondFactorMethodDuo},
changed: true,
want: UserInfo{
Method: SecondFactorMethodTOTP,
HasDuo: true,
HasTOTP: true,
HasWebauthn: true,
},
},
{
have: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: true,
HasTOTP: false,
HasWebauthn: false,
},
availableMethods: []string{SecondFactorMethodTOTP},
changed: true,
want: UserInfo{
Method: SecondFactorMethodTOTP,
HasDuo: true,
HasTOTP: false,
HasWebauthn: false,
},
},
{
have: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: false,
HasTOTP: false,
HasWebauthn: false,
},
availableMethods: []string{SecondFactorMethodTOTP},
changed: true,
want: UserInfo{
Method: SecondFactorMethodTOTP,
HasDuo: false,
HasTOTP: false,
HasWebauthn: false,
},
},
{
have: UserInfo{
Method: SecondFactorMethodTOTP,
HasDuo: false,
HasTOTP: false,
HasWebauthn: false,
},
availableMethods: []string{SecondFactorMethodWebauthn},
changed: true,
want: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: false,
HasTOTP: false,
HasWebauthn: false,
},
},
{
have: UserInfo{
Method: SecondFactorMethodTOTP,
HasDuo: false,
HasTOTP: false,
HasWebauthn: false,
},
availableMethods: []string{SecondFactorMethodDuo},
changed: true,
want: UserInfo{
Method: SecondFactorMethodDuo,
HasDuo: false,
HasTOTP: false,
HasWebauthn: false,
},
},
{
have: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
availableMethods: []string{SecondFactorMethodTOTP, SecondFactorMethodWebauthn, SecondFactorMethodDuo},
changed: false,
want: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
},
{
have: UserInfo{
Method: "",
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
availableMethods: []string{SecondFactorMethodWebauthn, SecondFactorMethodDuo},
changed: true,
want: UserInfo{
Method: SecondFactorMethodWebauthn,
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
},
{
have: UserInfo{
Method: "",
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
availableMethods: []string{SecondFactorMethodDuo},
changed: true,
want: UserInfo{
Method: SecondFactorMethodDuo,
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
},
{
have: UserInfo{
Method: "",
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
availableMethods: nil,
changed: false,
want: UserInfo{
Method: "",
HasDuo: false,
HasTOTP: true,
HasWebauthn: true,
},
},
}
for i, tc := range testCases {
t.Run(testName(i, tc.have, tc.availableMethods), func(t *testing.T) {
changed := tc.have.SetDefaultPreferred2FAMethod(tc.availableMethods)
assert.Equal(t, tc.changed, changed)
assert.Equal(t, tc.want, tc.have)
})
}
}

View File

@ -86,7 +86,9 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
// Information about the user. // Information about the user.
r.GET("/api/user/info", autheliaMiddleware( r.GET("/api/user/info", autheliaMiddleware(
middlewares.RequireFirstFactor(handlers.UserInfoGet))) middlewares.RequireFirstFactor(handlers.UserInfoGET)))
r.POST("/api/user/info", autheliaMiddleware(
middlewares.RequireFirstFactor(handlers.UserInfoPOST)))
r.POST("/api/user/info/2fa_method", autheliaMiddleware( r.POST("/api/user/info/2fa_method", autheliaMiddleware(
middlewares.RequireFirstFactor(handlers.MethodPreferencePost))) middlewares.RequireFirstFactor(handlers.MethodPreferencePost)))

View File

@ -11,7 +11,6 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
@ -205,7 +204,7 @@ func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username strin
case err == nil: case err == nil:
return method, nil return method, nil
case errors.Is(err, sql.ErrNoRows): case errors.Is(err, sql.ErrNoRows):
return "", nil return "", sql.ErrNoRows
default: default:
return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err) return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err)
} }
@ -216,17 +215,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m
err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username) err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username)
switch { switch {
case err == nil: case err == nil, errors.Is(err, sql.ErrNoRows):
return info, nil
case errors.Is(err, sql.ErrNoRows):
if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0]); err != nil {
return model.UserInfo{}, fmt.Errorf("error upserting preferred two factor method while selecting user info for user '%s': %w", username, err)
}
if err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username); err != nil {
return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)
}
return info, nil return info, nil
default: default:
return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err) return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)

View File

@ -1,6 +1,6 @@
import { useRemoteCall } from "@hooks/RemoteCall"; import { useRemoteCall } from "@hooks/RemoteCall";
import { getUserInfo } from "@services/UserInfo"; import { postUserInfo } from "@services/UserInfo";
export function useUserInfo() { export function useUserInfoPOST() {
return useRemoteCall(getUserInfo, []); return useRemoteCall(postUserInfo, []);
} }

View File

@ -1,7 +1,7 @@
import { SecondFactorMethod } from "@models/Methods"; import { SecondFactorMethod } from "@models/Methods";
import { UserInfo } from "@models/UserInfo"; import { UserInfo } from "@models/UserInfo";
import { UserInfo2FAMethodPath, UserInfoPath } from "@services/Api"; import { UserInfo2FAMethodPath, UserInfoPath } from "@services/Api";
import { Get, PostWithOptionalResponse } from "@services/Client"; import { Post, PostWithOptionalResponse } from "@services/Client";
export type Method2FA = "webauthn" | "totp" | "mobile_push"; export type Method2FA = "webauthn" | "totp" | "mobile_push";
@ -39,8 +39,8 @@ export function toString(method: SecondFactorMethod): Method2FA {
} }
} }
export async function getUserInfo(): Promise<UserInfo> { export async function postUserInfo(): Promise<UserInfo> {
const res = await Get<UserInfoPayload>(UserInfoPath); const res = await Post<UserInfoPayload>(UserInfoPath);
return { ...res, method: toEnum(res.method) }; return { ...res, method: toEnum(res.method) };
} }

View File

@ -16,7 +16,7 @@ import { useRedirectionURL } from "@hooks/RedirectionURL";
import { useRedirector } from "@hooks/Redirector"; import { useRedirector } from "@hooks/Redirector";
import { useRequestMethod } from "@hooks/RequestMethod"; import { useRequestMethod } from "@hooks/RequestMethod";
import { useAutheliaState } from "@hooks/State"; import { useAutheliaState } from "@hooks/State";
import { useUserInfo } from "@hooks/UserInfo"; import { useUserInfoPOST } from "@hooks/UserInfo";
import { SecondFactorMethod } from "@models/Methods"; import { SecondFactorMethod } from "@models/Methods";
import { checkSafeRedirection } from "@services/SafeRedirection"; import { checkSafeRedirection } from "@services/SafeRedirection";
import { AuthenticationLevel } from "@services/State"; import { AuthenticationLevel } from "@services/State";
@ -44,7 +44,7 @@ const LoginPortal = function (props: Props) {
const redirector = useRedirector(); const redirector = useRedirector();
const [state, fetchState, , fetchStateError] = useAutheliaState(); const [state, fetchState, , fetchStateError] = useAutheliaState();
const [userInfo, fetchUserInfo, , fetchUserInfoError] = useUserInfo(); const [userInfo, fetchUserInfo, , fetchUserInfoError] = useUserInfoPOST();
const [configuration, fetchConfiguration, , fetchConfigurationError] = useConfiguration(); const [configuration, fetchConfiguration, , fetchConfigurationError] = useConfiguration();
const redirect = useCallback((url: string) => navigate(url), [navigate]); const redirect = useCallback((url: string) => navigate(url), [navigate]);

View File

@ -85,6 +85,7 @@ const SecondFactorForm = function (props: Props) {
return ( return (
<LoginLayout id="second-factor-stage" title={`${translate("Hi")} ${props.userInfo.display_name}`} showBrand> <LoginLayout id="second-factor-stage" title={`${translate("Hi")} ${props.userInfo.display_name}`} showBrand>
{props.configuration.available_methods.size > 1 ? (
<MethodSelectionDialog <MethodSelectionDialog
open={methodSelectionOpen} open={methodSelectionOpen}
methods={props.configuration.available_methods} methods={props.configuration.available_methods}
@ -92,15 +93,18 @@ const SecondFactorForm = function (props: Props) {
onClose={() => setMethodSelectionOpen(false)} onClose={() => setMethodSelectionOpen(false)}
onClick={handleMethodSelected} onClick={handleMethodSelected}
/> />
) : null}
<Grid container> <Grid container>
<Grid item xs={12}> <Grid item xs={12}>
<Button color="secondary" onClick={handleLogoutClick} id="logout-button"> <Button color="secondary" onClick={handleLogoutClick} id="logout-button">
{translate("Logout")} {translate("Logout")}
</Button> </Button>
{" | "} {props.configuration.available_methods.size > 1 ? " | " : null}
{props.configuration.available_methods.size > 1 ? (
<Button color="secondary" onClick={handleMethodSelectionClick} id="methods-button"> <Button color="secondary" onClick={handleMethodSelectionClick} id="methods-button">
{translate("Methods")} {translate("Methods")}
</Button> </Button>
) : null}
</Grid> </Grid>
<Grid item xs={12} className={style.methodContainer}> <Grid item xs={12} className={style.methodContainer}>
<Routes> <Routes>