diff --git a/api/openapi.yml b/api/openapi.yml index 1a525d38..f0ac2014 100644 --- a/api/openapi.yml +++ b/api/openapi.yml @@ -317,6 +317,25 @@ paths: description: Forbidden security: - authelia_auth: [] + /api/user/info/totp: + get: + tags: + - User TOTP Information + summary: User TOTP Configuration + description: > + The user TOTP info endpoint provides information necessary to display the TOTP component to validate their + TOTP input such as the period/frequency and number of digits. + responses: + "200": + description: Successful Operation + content: + application/json: + schema: + $ref: '#/components/schemas/handlers.UserInfoTOTP' + "403": + description: Forbidden + security: + - authelia_auth: [] /api/user/info/2fa_method: post: tags: @@ -640,9 +659,6 @@ components: second_factor_enabled: type: boolean description: If second factor is enabled. - totp_period: - type: integer - example: 30 handlers.DuoDeviceBody: required: - device @@ -841,6 +857,25 @@ components: has_duo: type: boolean example: true + handlers.UserInfoTOTP: + type: object + properties: + status: + type: string + example: OK + data: + type: object + properties: + period: + default: 30 + description: The period defined in the users TOTP configuration + type: integer + example: 30 + digits: + default: 6 + description: The number of digits defined in the users TOTP configuration + type: integer + example: 6 handlers.UserInfo.MethodBody: required: - method diff --git a/cmd/authelia-scripts/cmd_build.go b/cmd/authelia-scripts/cmd_build.go index dff0d31a..e6cbb3b3 100644 --- a/cmd/authelia-scripts/cmd_build.go +++ b/cmd/authelia-scripts/cmd_build.go @@ -86,7 +86,7 @@ func buildFrontend(branch string) { } func buildSwagger() { - swaggerVer := "4.1.0" + swaggerVer := "4.1.2" cmd := utils.CommandWithStdout("bash", "-c", "wget -q https://github.com/swagger-api/swagger-ui/archive/v"+swaggerVer+".tar.gz -O ./v"+swaggerVer+".tar.gz") err := cmd.Run() diff --git a/config.template.yml b/config.template.yml index ac468213..e0dc2e25 100644 --- a/config.template.yml +++ b/config.template.yml @@ -90,16 +90,28 @@ log: ## ## Parameters used for TOTP generation. totp: - ## The issuer name displayed in the Authenticator application of your choice - ## See: https://github.com/google/google-authenticator/wiki/Key-Uri-Format for more info on issuer names + ## The issuer name displayed in the Authenticator application of your choice. issuer: authelia.com - ## The period in seconds a one-time password is current for. Changing this will require all users to register - ## their TOTP applications again. Warning: before changing period read the docs link below. + + ## The TOTP algorithm to use. + ## It is CRITICAL you read the documentation before changing this option: + ## https://www.authelia.com/docs/configuration/one-time-password.html#algorithm + algorithm: sha1 + + ## The number of digits a user has to input. Must either be 6 or 8. + ## Changing this option only affects newly generated TOTP configurations. + ## It is CRITICAL you read the documentation before changing this option: + ## https://www.authelia.com/docs/configuration/one-time-password.html#digits + digits: 6 + + ## The period in seconds a one-time password is valid for. + ## Changing this option only affects newly generated TOTP configurations. period: 30 + ## The skew controls number of one-time passwords either side of the current one that are valid. ## Warning: before changing skew read the docs link below. skew: 1 - ## See: https://www.authelia.com/docs/configuration/one-time-password.html#period-and-skew to read the documentation. + ## See: https://www.authelia.com/docs/configuration/one-time-password.html#input-validation to read the documentation. ## ## Duo Push API Configuration diff --git a/docs/configuration/one-time-password.md b/docs/configuration/one-time-password.md index e6ff8c40..56b264ba 100644 --- a/docs/configuration/one-time-password.md +++ b/docs/configuration/one-time-password.md @@ -7,7 +7,7 @@ nav_order: 16 # Time-based One-Time Password -Authelia uses time based one-time passwords as the OTP method. You have +Authelia uses time-based one-time passwords as the OTP method. You have the option to tune the settings of the TOTP generation, and you can see a full example of TOTP configuration below, as well as sections describing them. @@ -15,6 +15,8 @@ full example of TOTP configuration below, as well as sections describing them. ```yaml totp: issuer: authelia.com + algorithm: sha1 + digits: 6 period: 30 skew: 1 ``` @@ -37,17 +39,56 @@ differentiate applications registered by the user. Authelia allows customisation of the issuer to differentiate the entry created by Authelia from others. -## Period and Skew +### algorithm +
+type: string +{: .label .label-config .label-purple } +default: sha1 +{: .label .label-config .label-blue } +required: no +{: .label .label-config .label-green } +
-The period and skew configuration parameters affect each other. The default values are -a period of 30 and a skew of 1. It is highly recommended you do not change these unless -you wish to set skew to 0. +_**Important Note:** Many TOTP applications do not support this option. It is strongly advised you find out which +applications your users use and test them before changing this option. It is insufficient to test that the application +can add the key, it must also authenticate with Authelia as some applications silently ignore these options. Bitwarden +is the only one that has been tested at this time. If you'd like to contribute to documenting support for this option +please see [Issue 2650](https://github.com/authelia/authelia/issues/2650)._ -The way you configure these affects security by changing the length of time a one-time -password is valid for. The formula to calculate the effective validity period is -`period + (period * skew * 2)`. For example period 30 and skew 1 would result in 90 -seconds of validity, and period 30 and skew 2 would result in 150 seconds of validity. +The algorithm used for the TOTP key. +Possible Values (case-insensitive): +- `sha1` +- `sha256` +- `sha512` + +Changing this value only affects newly registered TOTP keys. See the [Registration](#registration) section for more +information. + +### digits +
+type: integer +{: .label .label-config .label-purple } +default: 6 +{: .label .label-config .label-blue } +required: no +{: .label .label-config .label-green } +
+ +_**Important Note:** Some TOTP applications do not support this option. It is strongly advised you find out which +applications your users use and test them before changing this option. It is insufficient to test that the application +can add the key, it must also authenticate with Authelia as some applications silently ignore these options. Bitwarden +is the only one that has been tested at this time. If you'd like to contribute to documenting support for this option +please see [Issue 2650](https://github.com/authelia/authelia/issues/2650)._ + +The number of digits a user needs to input to perform authentication. It's generally not recommended for this to be +altered as many TOTP applications do not support anything other than 6. What's worse is some TOTP applications allow +you to add the key, but do not use the correct number of digits specified by the key. + +The valid values are `6` or `8`. + +Changing this value only affects newly registered TOTP keys. See the [Registration](#registration) section for more +information. ### period
@@ -59,10 +100,13 @@ required: no {: .label .label-config .label-green }
-Configures the period of time in seconds a one-time password is current for. It is important -to note that changing this value will require your users to register their application again. +The period of time in seconds between key rotations or the time element of TOTP. Please see the +[input validation](#input-validation) section for how this option and the [skew](#skew) option interact with each other. -It is recommended to keep this value set to 30, the minimum is 1. +It is recommended to keep this value set to 30, the minimum is 15. + +Changing this value only affects newly registered TOTP keys. See the [Registration](#registration) section for more +information. ### skew
@@ -74,19 +118,64 @@ required: no {: .label .label-config .label-green }
-Configures the number of one-time passwords either side of the current one that are -considered valid, each time you increase this it makes two more one-time passwords valid. -For example the default of 1 has a total of 3 keys valid. A value of 2 has 5 one-time passwords -valid. +The number of one time passwords either side of the current valid one time password that should also be considered valid. +The default of 1 results in 3 one time passwords valid. A setting of 2 would result in 5. With the default period of 30 +this would result in 90 and 150 seconds of valid one time passwords respectively. Please see the +[input validation](#input-validation) section for how this option and the [period](#period) option interact with each +other. -It is recommended to keep this value set to 0 or 1, the minimum is 0. +Changing this value affects all TOTP validations, not just newly registered ones. + +## Registration +When users register their TOTP device for the first time, the current [issuer](#issuer), [algorithm](#algorithm), and +[period](#period) are used to generate the TOTP link and QR code. These values are saved to the database for future +validations. + +This means if the configuration options are changed, users will not need to regenerate their keys. This functionality +takes effect from 4.33.0 onwards, previously the effect was the keys would just fail to validate. If you'd like to force +users to register a new device, you can delete the old device for a particular user by using the +`authelia storage totp delete ` command regardless of if you change the settings or not. + +## Input Validation +The period and skew configuration parameters affect each other. The default values are a period of 30 and a skew of 1. +It is highly recommended you do not change these unless you wish to set skew to 0. + +The way you configure these affects security by changing the length of time a one-time +password is valid for. The formula to calculate the effective validity period is +`period + (period * skew * 2)`. For example period 30 and skew 1 would result in 90 +seconds of validity, and period 30 and skew 2 would result in 150 seconds of validity. ## System time accuracy - It's important to note that if the system time is not accurate enough then clients will seemingly not generate valid passwords for TOTP. Conversely this is the same when the client time is not accurate enough. This is due to the Time-based One Time Passwords being time-based. Authelia by default checks the system time against an [NTP server](./ntp.md#address) on startup. This helps to prevent a time synchronization issue on the server being an issue. There is however no effective and reliable way to check the -clients. \ No newline at end of file +clients. + +## Encryption +The TOTP secret is [encrypted](storage/index.md#encryption_key) in the database in version 4.33.0 and above. This is so +a user having access to only the database cannot easily compromise your two-factor authentication method. + +This may be inconvenient for some users who wish to export TOTP keys from Authelia to other services. As such there is +a command specifically for exporting TOTP configurations from the database. These commands require the configuration or +at least a minimal configuration that has the storage backend connection details and the encryption key. + +Export in [Key URI Format](https://github.com/google/google-authenticator/wiki/Key-Uri-Format): + +```shell +$ authelia storage totp export --format uri +``` + +Export as CSV: + +```shell +$ authelia storage totp export --format csv +``` + +Help: + +```shell +$ authelia storage totp export --help +``` diff --git a/docs/contributing/commitmsg-guidelines.md b/docs/contributing/commitmsg-guidelines.md index b9d4fd7d..3aa78a4c 100644 --- a/docs/contributing/commitmsg-guidelines.md +++ b/docs/contributing/commitmsg-guidelines.md @@ -86,6 +86,7 @@ The scope should be the name of the package affected * storage * suites * templates +* totp * utils There are currently a few exceptions to the "use package name" rule: diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go index 70ed2058..4c68c2ef 100644 --- a/internal/commands/helpers.go +++ b/internal/commands/helpers.go @@ -10,17 +10,18 @@ import ( "github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/totp" "github.com/authelia/authelia/v4/internal/utils" ) func getStorageProvider() (provider storage.Provider) { switch { case config.Storage.PostgreSQL != nil: - return storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL, config.Storage.EncryptionKey) + return storage.NewPostgreSQLProvider(config) case config.Storage.MySQL != nil: - return storage.NewMySQLProvider(*config.Storage.MySQL, config.Storage.EncryptionKey) + return storage.NewMySQLProvider(config) case config.Storage.Local != nil: - return storage.NewSQLiteProvider(config.Storage.Local.Path, config.Storage.EncryptionKey) + return storage.NewSQLiteProvider(config) default: return nil } @@ -71,6 +72,8 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [ errors = append(errors, err) } + totpProvider := totp.NewTimeBasedProvider(config.TOTP) + return middlewares.Providers{ Authorizer: authorizer, UserProvider: userProvider, @@ -80,5 +83,6 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [ NTP: ntpProvider, Notifier: notifier, SessionProvider: sessionProvider, + TOTP: totpProvider, }, warnings, errors } diff --git a/internal/commands/storage.go b/internal/commands/storage.go index 8d3cc2e1..df125007 100644 --- a/internal/commands/storage.go +++ b/internal/commands/storage.go @@ -35,7 +35,7 @@ func NewStorageCmd() (cmd *cobra.Command) { newStorageMigrateCmd(), newStorageSchemaInfoCmd(), newStorageEncryptionCmd(), - newStorageExportCmd(), + newStorageTOTPCmd(), ) return cmd @@ -79,25 +79,57 @@ func newStorageEncryptionChangeKeyCmd() (cmd *cobra.Command) { return cmd } -func newStorageExportCmd() (cmd *cobra.Command) { +func newStorageTOTPCmd() (cmd *cobra.Command) { cmd = &cobra.Command{ - Use: "export", - Short: "Performs exports", + Use: "totp", + Short: "Manage TOTP configurations", } - cmd.AddCommand(newStorageExportTOTPConfigurationsCmd()) + cmd.AddCommand( + newStorageTOTPGenerateCmd(), + newStorageTOTPDeleteCmd(), + newStorageTOTPExportCmd(), + ) return cmd } -func newStorageExportTOTPConfigurationsCmd() (cmd *cobra.Command) { +func newStorageTOTPGenerateCmd() (cmd *cobra.Command) { cmd = &cobra.Command{ - Use: "totp-configurations", - Short: "Performs exports of the totp configurations", - RunE: storageExportTOTPConfigurationsRunE, + Use: "generate username", + Short: "Generate a TOTP configuration for a user", + RunE: storageTOTPGenerateRunE, + Args: cobra.ExactArgs(1), } - cmd.Flags().String("format", storageExportFormatCSV, "changes the format of the export, options are csv and uri") + cmd.Flags().Uint("period", 30, "set the TOTP period") + cmd.Flags().Uint("digits", 6, "set the TOTP digits") + cmd.Flags().String("algorithm", "SHA1", "set the TOTP algorithm") + cmd.Flags().String("issuer", "Authelia", "set the TOTP issuer") + cmd.Flags().BoolP("force", "f", false, "forces the TOTP configuration to be generated regardless if it exists or not") + + return cmd +} + +func newStorageTOTPDeleteCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "delete username", + Short: "Delete a TOTP configuration for a user", + RunE: storageTOTPDeleteRunE, + Args: cobra.ExactArgs(1), + } + + return cmd +} + +func newStorageTOTPExportCmd() (cmd *cobra.Command) { + cmd = &cobra.Command{ + Use: "export", + Short: "Performs exports of the TOTP configurations", + RunE: storageTOTPExportRunE, + } + + cmd.Flags().String("format", storageExportFormatURI, "sets the output format") return cmd } diff --git a/internal/commands/storage_run.go b/internal/commands/storage_run.go index be7915e2..c24a7f67 100644 --- a/internal/commands/storage_run.go +++ b/internal/commands/storage_run.go @@ -14,6 +14,7 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/validator" "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/totp" ) func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) { @@ -52,6 +53,10 @@ func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) { "postgres.username": "storage.postgres.username", "postgres.password": "storage.postgres.password", "postgres.schema": "storage.postgres.schema", + "period": "totp.period", + "digits": "totp.digits", + "algorithm": "totp.algorithm", + "issuer": "totp.issuer", } sources = append(sources, configuration.NewEnvironmentSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter)) @@ -62,7 +67,7 @@ func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) { config = &schema.Configuration{} - _, err = configuration.LoadAdvanced(val, "storage", &config.Storage, sources...) + _, err = configuration.LoadAdvanced(val, "", &config, sources...) if err != nil { return err } @@ -84,6 +89,8 @@ func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) { validator.ValidateStorage(config.Storage, val) + validator.ValidateTOTP(config, val) + if val.HasErrors() { var finalErr error @@ -109,9 +116,6 @@ func storageSchemaEncryptionCheckRunE(cmd *cobra.Command, args []string) (err er ) provider = getStorageProvider() - if provider == nil { - return errNoStorageProvider - } defer func() { _ = provider.Close() @@ -145,9 +149,6 @@ func storageSchemaEncryptionChangeKeyRunE(cmd *cobra.Command, args []string) (er ) provider = getStorageProvider() - if provider == nil { - return errNoStorageProvider - } defer func() { _ = provider.Close() @@ -188,16 +189,83 @@ func storageSchemaEncryptionChangeKeyRunE(cmd *cobra.Command, args []string) (er return nil } -func storageExportTOTPConfigurationsRunE(cmd *cobra.Command, args []string) (err error) { +func storageTOTPGenerateRunE(cmd *cobra.Command, args []string) (err error) { + var ( + provider storage.Provider + ctx = context.Background() + c *models.TOTPConfiguration + force bool + ) + + provider = getStorageProvider() + + defer func() { + _ = provider.Close() + }() + + force, err = cmd.Flags().GetBool("force") + + _, err = provider.LoadTOTPConfiguration(ctx, args[0]) + if err == nil && !force { + return fmt.Errorf("%s already has a TOTP configuration, use --force to overwrite", args[0]) + } + + if err != nil && !errors.Is(err, storage.ErrNoTOTPConfiguration) { + return err + } + + totpProvider := totp.NewTimeBasedProvider(config.TOTP) + + if c, err = totpProvider.Generate(args[0]); err != nil { + return err + } + + err = provider.SaveTOTPConfiguration(ctx, *c) + if err != nil { + return err + } + + fmt.Printf("Generated TOTP configuration for user '%s': %s", args[0], c.URI()) + + return nil +} + +func storageTOTPDeleteRunE(cmd *cobra.Command, args []string) (err error) { + var ( + provider storage.Provider + ctx = context.Background() + ) + + user := args[0] + + provider = getStorageProvider() + + defer func() { + _ = provider.Close() + }() + + _, err = provider.LoadTOTPConfiguration(ctx, user) + if err != nil { + return fmt.Errorf("can't delete configuration for user '%s': %+v", user, err) + } + + err = provider.DeleteTOTPConfiguration(ctx, user) + if err != nil { + return fmt.Errorf("can't delete configuration for user '%s': %+v", user, err) + } + + fmt.Printf("Deleted TOTP configuration for user '%s'.", user) + + return nil +} + +func storageTOTPExportRunE(cmd *cobra.Command, args []string) (err error) { var ( provider storage.Provider ctx = context.Background() ) provider = getStorageProvider() - if provider == nil { - return errNoStorageProvider - } defer func() { _ = provider.Close() @@ -236,9 +304,9 @@ func storageExportTOTPConfigurationsRunE(cmd *cobra.Command, args []string) (err for _, c := range configurations { switch format { case storageExportFormatCSV: - fmt.Printf("%s,%s,%s,%d,%d,%s\n", "Authelia", c.Username, c.Algorithm, c.Digits, c.Period, string(c.Secret)) + fmt.Printf("%s,%s,%s,%d,%d,%s\n", c.Issuer, c.Username, c.Algorithm, c.Digits, c.Period, string(c.Secret)) case storageExportFormatURI: - fmt.Printf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=%s&digits=%d&period=%d\n", "Authelia", c.Username, string(c.Secret), "Authelia", c.Algorithm, c.Digits, c.Period) + fmt.Println(c.URI()) } } @@ -298,14 +366,11 @@ func newStorageMigrateListRunE(up bool) func(cmd *cobra.Command, args []string) var ( provider storage.Provider ctx = context.Background() - migrations []storage.SchemaMigration + migrations []models.SchemaMigration directionStr string ) provider = getStorageProvider() - if provider == nil { - return errNoStorageProvider - } defer func() { _ = provider.Close() @@ -345,9 +410,6 @@ func newStorageMigrationRunE(up bool) func(cmd *cobra.Command, args []string) (e ) provider = getStorageProvider() - if provider == nil { - return errNoStorageProvider - } defer func() { _ = provider.Close() @@ -420,9 +482,6 @@ func storageSchemaInfoRunE(_ *cobra.Command, _ []string) (err error) { ) provider = getStorageProvider() - if provider == nil { - return errNoStorageProvider - } defer func() { _ = provider.Close() diff --git a/internal/configuration/config.template.yml b/internal/configuration/config.template.yml index ac468213..e0dc2e25 100644 --- a/internal/configuration/config.template.yml +++ b/internal/configuration/config.template.yml @@ -90,16 +90,28 @@ log: ## ## Parameters used for TOTP generation. totp: - ## The issuer name displayed in the Authenticator application of your choice - ## See: https://github.com/google/google-authenticator/wiki/Key-Uri-Format for more info on issuer names + ## The issuer name displayed in the Authenticator application of your choice. issuer: authelia.com - ## The period in seconds a one-time password is current for. Changing this will require all users to register - ## their TOTP applications again. Warning: before changing period read the docs link below. + + ## The TOTP algorithm to use. + ## It is CRITICAL you read the documentation before changing this option: + ## https://www.authelia.com/docs/configuration/one-time-password.html#algorithm + algorithm: sha1 + + ## The number of digits a user has to input. Must either be 6 or 8. + ## Changing this option only affects newly generated TOTP configurations. + ## It is CRITICAL you read the documentation before changing this option: + ## https://www.authelia.com/docs/configuration/one-time-password.html#digits + digits: 6 + + ## The period in seconds a one-time password is valid for. + ## Changing this option only affects newly generated TOTP configurations. period: 30 + ## The skew controls number of one-time passwords either side of the current one that are valid. ## Warning: before changing skew read the docs link below. skew: 1 - ## See: https://www.authelia.com/docs/configuration/one-time-password.html#period-and-skew to read the documentation. + ## See: https://www.authelia.com/docs/configuration/one-time-password.html#input-validation to read the documentation. ## ## Duo Push API Configuration diff --git a/internal/configuration/schema/const.go b/internal/configuration/schema/const.go index ea9b8822..1c483b61 100644 --- a/internal/configuration/schema/const.go +++ b/internal/configuration/schema/const.go @@ -23,3 +23,15 @@ const LDAPImplementationCustom = "custom" // LDAPImplementationActiveDirectory is the string for the Active Directory LDAP implementation. const LDAPImplementationActiveDirectory = "activedirectory" + +// TOTP Algorithm. +const ( + TOTPAlgorithmSHA1 = "SHA1" + TOTPAlgorithmSHA256 = "SHA256" + TOTPAlgorithmSHA512 = "SHA512" +) + +var ( + // TOTPPossibleAlgorithms is a list of valid TOTP Algorithms. + TOTPPossibleAlgorithms = []string{TOTPAlgorithmSHA1, TOTPAlgorithmSHA256, TOTPAlgorithmSHA512} +) diff --git a/internal/configuration/schema/totp.go b/internal/configuration/schema/totp.go index 061f4736..e743153c 100644 --- a/internal/configuration/schema/totp.go +++ b/internal/configuration/schema/totp.go @@ -2,16 +2,20 @@ package schema // TOTPConfiguration represents the configuration related to TOTP options. type TOTPConfiguration struct { - Issuer string `koanf:"issuer"` - Period int `koanf:"period"` - Skew *int `koanf:"skew"` + Issuer string `koanf:"issuer"` + Algorithm string `koanf:"algorithm"` + Digits uint `koanf:"digits"` + Period uint `koanf:"period"` + Skew *uint `koanf:"skew"` } -var defaultOtpSkew = 1 +var defaultOtpSkew = uint(1) // DefaultTOTPConfiguration represents default configuration parameters for TOTP generation. var DefaultTOTPConfiguration = TOTPConfiguration{ - Issuer: "Authelia", - Period: 30, - Skew: &defaultOtpSkew, + Issuer: "Authelia", + Algorithm: TOTPAlgorithmSHA1, + Digits: 6, + Period: 30, + Skew: &defaultOtpSkew, } diff --git a/internal/configuration/validator/configuration.go b/internal/configuration/validator/configuration.go index 0f5ec55b..12bba0e6 100644 --- a/internal/configuration/validator/configuration.go +++ b/internal/configuration/validator/configuration.go @@ -32,13 +32,9 @@ func ValidateConfiguration(configuration *schema.Configuration, validator *schem ValidateTheme(configuration, validator) - if configuration.TOTP == nil { - configuration.TOTP = &schema.DefaultTOTPConfiguration - } - ValidateLogging(configuration, validator) - ValidateTOTP(configuration.TOTP, validator) + ValidateTOTP(configuration, validator) ValidateAuthenticationBackend(&configuration.AuthenticationBackend, validator) diff --git a/internal/configuration/validator/const.go b/internal/configuration/validator/const.go index 2d3b88d5..cc057754 100644 --- a/internal/configuration/validator/const.go +++ b/internal/configuration/validator/const.go @@ -54,6 +54,13 @@ const ( errFmtNotifierSMTPNotConfigured = "smtp notifier: the '%s' must be configured" ) +// TOTP Error constants. +const ( + errFmtTOTPInvalidAlgorithm = "totp: algorithm '%s' is invalid: must be one of %s" + errFmtTOTPInvalidPeriod = "totp: period '%d' is invalid: must be 15 or more" + errFmtTOTPInvalidDigits = "totp: digits '%d' is invalid: must be 6 or 8" +) + // OpenID Error constants. const ( errFmtOIDCClientsDuplicateID = "openid connect provider: one or more clients have the same ID" @@ -157,6 +164,8 @@ var ValidKeys = []string{ // TOTP Keys. "totp.issuer", + "totp.algorithm", + "totp.digits", "totp.period", "totp.skew", diff --git a/internal/configuration/validator/totp.go b/internal/configuration/validator/totp.go index 5a394b0c..1310b671 100644 --- a/internal/configuration/validator/totp.go +++ b/internal/configuration/validator/totp.go @@ -2,25 +2,47 @@ package validator import ( "fmt" + "strings" "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/utils" ) // ValidateTOTP validates and update TOTP configuration. -func ValidateTOTP(configuration *schema.TOTPConfiguration, validator *schema.StructValidator) { - if configuration.Issuer == "" { - configuration.Issuer = schema.DefaultTOTPConfiguration.Issuer +func ValidateTOTP(configuration *schema.Configuration, validator *schema.StructValidator) { + if configuration.TOTP == nil { + configuration.TOTP = &schema.DefaultTOTPConfiguration + + return } - if configuration.Period == 0 { - configuration.Period = schema.DefaultTOTPConfiguration.Period - } else if configuration.Period < 0 { - validator.Push(fmt.Errorf("TOTP Period must be 1 or more")) + if configuration.TOTP.Issuer == "" { + configuration.TOTP.Issuer = schema.DefaultTOTPConfiguration.Issuer } - if configuration.Skew == nil { - configuration.Skew = schema.DefaultTOTPConfiguration.Skew - } else if *configuration.Skew < 0 { - validator.Push(fmt.Errorf("TOTP Skew must be 0 or more")) + if configuration.TOTP.Algorithm == "" { + configuration.TOTP.Algorithm = schema.DefaultTOTPConfiguration.Algorithm + } else { + configuration.TOTP.Algorithm = strings.ToUpper(configuration.TOTP.Algorithm) + + if !utils.IsStringInSlice(configuration.TOTP.Algorithm, schema.TOTPPossibleAlgorithms) { + validator.Push(fmt.Errorf(errFmtTOTPInvalidAlgorithm, configuration.TOTP.Algorithm, strings.Join(schema.TOTPPossibleAlgorithms, ", "))) + } + } + + if configuration.TOTP.Period == 0 { + configuration.TOTP.Period = schema.DefaultTOTPConfiguration.Period + } else if configuration.TOTP.Period < 15 { + validator.Push(fmt.Errorf(errFmtTOTPInvalidPeriod, configuration.TOTP.Period)) + } + + if configuration.TOTP.Digits == 0 { + configuration.TOTP.Digits = schema.DefaultTOTPConfiguration.Digits + } else if configuration.TOTP.Digits != 6 && configuration.TOTP.Digits != 8 { + validator.Push(fmt.Errorf(errFmtTOTPInvalidDigits, configuration.TOTP.Digits)) + } + + if configuration.TOTP.Skew == nil { + configuration.TOTP.Skew = schema.DefaultTOTPConfiguration.Skew } } diff --git a/internal/configuration/validator/totp_test.go b/internal/configuration/validator/totp_test.go index 925217bc..497284f3 100644 --- a/internal/configuration/validator/totp_test.go +++ b/internal/configuration/validator/totp_test.go @@ -1,6 +1,8 @@ package validator import ( + "fmt" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -11,26 +13,61 @@ import ( func TestShouldSetDefaultTOTPValues(t *testing.T) { validator := schema.NewStructValidator() - config := schema.TOTPConfiguration{} + config := &schema.Configuration{ + TOTP: &schema.TOTPConfiguration{}, + } - ValidateTOTP(&config, validator) + ValidateTOTP(config, validator) require.Len(t, validator.Errors(), 0) - assert.Equal(t, "Authelia", config.Issuer) - assert.Equal(t, *schema.DefaultTOTPConfiguration.Skew, *config.Skew) - assert.Equal(t, schema.DefaultTOTPConfiguration.Period, config.Period) + assert.Equal(t, "Authelia", config.TOTP.Issuer) + assert.Equal(t, schema.DefaultTOTPConfiguration.Algorithm, config.TOTP.Algorithm) + assert.Equal(t, schema.DefaultTOTPConfiguration.Skew, config.TOTP.Skew) + assert.Equal(t, schema.DefaultTOTPConfiguration.Period, config.TOTP.Period) } -func TestShouldRaiseErrorWhenInvalidTOTPMinimumValues(t *testing.T) { - var badSkew = -1 - +func TestShouldNormalizeTOTPAlgorithm(t *testing.T) { validator := schema.NewStructValidator() - config := schema.TOTPConfiguration{ - Period: -5, - Skew: &badSkew, + + config := &schema.Configuration{ + TOTP: &schema.TOTPConfiguration{ + Algorithm: "sha1", + }, } - ValidateTOTP(&config, validator) - assert.Len(t, validator.Errors(), 2) - assert.EqualError(t, validator.Errors()[0], "TOTP Period must be 1 or more") - assert.EqualError(t, validator.Errors()[1], "TOTP Skew must be 0 or more") + + ValidateTOTP(config, validator) + + assert.Len(t, validator.Errors(), 0) + assert.Equal(t, "SHA1", config.TOTP.Algorithm) +} + +func TestShouldRaiseErrorWhenInvalidTOTPAlgorithm(t *testing.T) { + validator := schema.NewStructValidator() + + config := &schema.Configuration{ + TOTP: &schema.TOTPConfiguration{ + Algorithm: "sha3", + }, + } + + ValidateTOTP(config, validator) + + require.Len(t, validator.Errors(), 1) + assert.EqualError(t, validator.Errors()[0], fmt.Sprintf(errFmtTOTPInvalidAlgorithm, "SHA3", strings.Join(schema.TOTPPossibleAlgorithms, ", "))) +} + +func TestShouldRaiseErrorWhenInvalidTOTPValues(t *testing.T) { + validator := schema.NewStructValidator() + config := &schema.Configuration{ + TOTP: &schema.TOTPConfiguration{ + Period: 5, + Digits: 20, + }, + } + + ValidateTOTP(config, validator) + + require.Len(t, validator.Errors(), 2) + assert.EqualError(t, validator.Errors()[0], fmt.Sprintf(errFmtTOTPInvalidPeriod, 5)) + assert.EqualError(t, validator.Errors()[1], fmt.Sprintf(errFmtTOTPInvalidDigits, 20)) } diff --git a/internal/handlers/const.go b/internal/handlers/const.go index cec52f69..fe292e72 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -91,12 +91,6 @@ const ( pathOpenIDConnectConsent = "/api/oidc/consent" ) -const ( - totpAlgoSHA1 = "SHA1" - totpAlgoSHA256 = "SHA256" - totpAlgoSHA512 = "SHA512" -) - const ( accept = "accept" reject = "reject" diff --git a/internal/handlers/handler_configuration.go b/internal/handlers/handler_configuration.go index 8ad2dd3f..6f70ba7b 100644 --- a/internal/handlers/handler_configuration.go +++ b/internal/handlers/handler_configuration.go @@ -5,18 +5,10 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" ) -// ConfigurationBody the content returned by the configuration endpoint. -type ConfigurationBody struct { - AvailableMethods MethodList `json:"available_methods"` - SecondFactorEnabled bool `json:"second_factor_enabled"` // whether second factor is enabled or not. - TOTPPeriod int `json:"totp_period"` -} - // ConfigurationGet get the configuration accessible to authenticated users. func ConfigurationGet(ctx *middlewares.AutheliaCtx) { - body := ConfigurationBody{} + body := configurationBody{} body.AvailableMethods = MethodList{authentication.TOTP, authentication.U2F} - body.TOTPPeriod = ctx.Configuration.TOTP.Period if ctx.Configuration.DuoAPI != nil { body.AvailableMethods = append(body.AvailableMethods, authentication.Push) diff --git a/internal/handlers/handler_configuration_test.go b/internal/handlers/handler_configuration_test.go index 24978f23..ab69e87f 100644 --- a/internal/handlers/handler_configuration_test.go +++ b/internal/handlers/handler_configuration_test.go @@ -29,15 +29,9 @@ func (s *SecondFactorAvailableMethodsFixture) TearDownTest() { } func (s *SecondFactorAvailableMethodsFixture) TestShouldServeDefaultMethods() { - s.mock.Ctx.Configuration = schema.Configuration{ - TOTP: &schema.TOTPConfiguration{ - Period: schema.DefaultTOTPConfiguration.Period, - }, - } - expectedBody := ConfigurationBody{ + expectedBody := configurationBody{ AvailableMethods: []string{"totp", "u2f"}, SecondFactorEnabled: false, - TOTPPeriod: schema.DefaultTOTPConfiguration.Period, } ConfigurationGet(s.mock.Ctx) @@ -47,14 +41,10 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldServeDefaultMethods() { func (s *SecondFactorAvailableMethodsFixture) TestShouldServeDefaultMethodsAndMobilePush() { s.mock.Ctx.Configuration = schema.Configuration{ DuoAPI: &schema.DuoAPIConfiguration{}, - TOTP: &schema.TOTPConfiguration{ - Period: schema.DefaultTOTPConfiguration.Period, - }, } - expectedBody := ConfigurationBody{ + expectedBody := configurationBody{ AvailableMethods: []string{"totp", "u2f", "mobile_push"}, SecondFactorEnabled: false, - TOTPPeriod: schema.DefaultTOTPConfiguration.Period, } ConfigurationGet(s.mock.Ctx) @@ -62,11 +52,6 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldServeDefaultMethodsAndMo } func (s *SecondFactorAvailableMethodsFixture) TestShouldCheckSecondFactorIsDisabledWhenNoRuleIsSetToTwoFactor() { - s.mock.Ctx.Configuration = schema.Configuration{ - TOTP: &schema.TOTPConfiguration{ - Period: schema.DefaultTOTPConfiguration.Period, - }, - } s.mock.Ctx.Providers.Authorizer = authorization.NewAuthorizer( &schema.Configuration{ AccessControl: schema.AccessControlConfiguration{ @@ -87,19 +72,13 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldCheckSecondFactorIsDisab }, }}) ConfigurationGet(s.mock.Ctx) - s.mock.Assert200OK(s.T(), ConfigurationBody{ + s.mock.Assert200OK(s.T(), configurationBody{ AvailableMethods: []string{"totp", "u2f"}, SecondFactorEnabled: false, - TOTPPeriod: schema.DefaultTOTPConfiguration.Period, }) } func (s *SecondFactorAvailableMethodsFixture) TestShouldCheckSecondFactorIsEnabledWhenDefaultPolicySetToTwoFactor() { - s.mock.Ctx.Configuration = schema.Configuration{ - TOTP: &schema.TOTPConfiguration{ - Period: schema.DefaultTOTPConfiguration.Period, - }, - } s.mock.Ctx.Providers.Authorizer = authorization.NewAuthorizer(&schema.Configuration{ AccessControl: schema.AccessControlConfiguration{ DefaultPolicy: "two_factor", @@ -119,19 +98,13 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldCheckSecondFactorIsEnabl }, }}) ConfigurationGet(s.mock.Ctx) - s.mock.Assert200OK(s.T(), ConfigurationBody{ + s.mock.Assert200OK(s.T(), configurationBody{ AvailableMethods: []string{"totp", "u2f"}, SecondFactorEnabled: true, - TOTPPeriod: schema.DefaultTOTPConfiguration.Period, }) } func (s *SecondFactorAvailableMethodsFixture) TestShouldCheckSecondFactorIsEnabledWhenSomePolicySetToTwoFactor() { - s.mock.Ctx.Configuration = schema.Configuration{ - TOTP: &schema.TOTPConfiguration{ - Period: schema.DefaultTOTPConfiguration.Period, - }, - } s.mock.Ctx.Providers.Authorizer = authorization.NewAuthorizer( &schema.Configuration{ AccessControl: schema.AccessControlConfiguration{ @@ -152,10 +125,9 @@ func (s *SecondFactorAvailableMethodsFixture) TestShouldCheckSecondFactorIsEnabl }, }}) ConfigurationGet(s.mock.Ctx) - s.mock.Assert200OK(s.T(), ConfigurationBody{ + s.mock.Assert200OK(s.T(), configurationBody{ AvailableMethods: []string{"totp", "u2f"}, SecondFactorEnabled: true, - TOTPPeriod: schema.DefaultTOTPConfiguration.Period, }) } diff --git a/internal/handlers/handler_firstfactor_test.go b/internal/handlers/handler_firstfactor_test.go index 616eb159..bf332ba3 100644 --- a/internal/handlers/handler_firstfactor_test.go +++ b/internal/handlers/handler_firstfactor_test.go @@ -57,7 +57,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() { CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). Return(false, fmt.Errorf("failed")) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "test", @@ -85,7 +85,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). Return(false, fmt.Errorf("invalid credentials")) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "test", @@ -111,7 +111,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). Return(false, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "test", @@ -137,7 +137,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() { CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). Return(true, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) @@ -164,7 +164,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() { CheckUserPassword(gomock.Eq("test"), gomock.Eq("hello")). Return(true, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(fmt.Errorf("failed")) @@ -195,7 +195,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() { Groups: []string{"dev", "admins"}, }, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) @@ -235,7 +235,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() { Groups: []string{"dev", "admins"}, }, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) @@ -279,7 +279,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess Groups: []string{"dev", "admins"}, }, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) @@ -337,7 +337,7 @@ func (s *FirstFactorRedirectionSuite) SetupTest() { Groups: []string{"dev", "admins"}, }, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Any()). Return(nil) diff --git a/internal/handlers/handler_register_duo_device_test.go b/internal/handlers/handler_register_duo_device_test.go index 2d224eea..07792998 100644 --- a/internal/handlers/handler_register_duo_device_test.go +++ b/internal/handlers/handler_register_duo_device_test.go @@ -128,7 +128,7 @@ func (s *RegisterDuoDeviceSuite) TestShouldRespondWithDeny() { func (s *RegisterDuoDeviceSuite) TestShouldRespondOK() { s.mock.Ctx.Request.SetBodyString("{\"device\":\"1234567890123456\", \"method\":\"push\"}") - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). SavePreferredDuoDevice(gomock.Eq(s.mock.Ctx), gomock.Eq(models.DuoDevice{Username: "john", Device: "1234567890123456", Method: "push"})). Return(nil) diff --git a/internal/handlers/handler_register_totp.go b/internal/handlers/handler_register_totp.go index 81353851..4542d264 100644 --- a/internal/handlers/handler_register_totp.go +++ b/internal/handlers/handler_register_totp.go @@ -3,9 +3,6 @@ package handlers import ( "fmt" - "github.com/pquerna/otp" - "github.com/pquerna/otp/totp" - "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/session" @@ -39,39 +36,24 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle }) func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) { - algorithm := otp.AlgorithmSHA1 + var ( + config *models.TOTPConfiguration + err error + ) - key, err := totp.Generate(totp.GenerateOpts{ - Issuer: ctx.Configuration.TOTP.Issuer, - AccountName: username, - Period: uint(ctx.Configuration.TOTP.Period), - SecretSize: 32, - Digits: otp.Digits(6), - Algorithm: algorithm, - }) - - if err != nil { + if config, err = ctx.Providers.TOTP.Generate(username); err != nil { ctx.Error(fmt.Errorf("unable to generate TOTP key: %s", err), messageUnableToRegisterOneTimePassword) - return } - config := models.TOTPConfiguration{ - Username: username, - Algorithm: otpAlgoToString(algorithm), - Digits: 6, - Secret: []byte(key.Secret()), - Period: key.Period(), - } - - err = ctx.Providers.StorageProvider.SaveTOTPConfiguration(ctx, config) + err = ctx.Providers.StorageProvider.SaveTOTPConfiguration(ctx, *config) if err != nil { ctx.Error(fmt.Errorf("unable to save TOTP secret in DB: %s", err), messageUnableToRegisterOneTimePassword) return } response := TOTPKeyResponse{ - OTPAuthURL: key.URL(), - Base32Secret: key.Secret(), + OTPAuthURL: config.URI(), + Base32Secret: string(config.Secret), } err = ctx.SetJSONBody(response) diff --git a/internal/handlers/handler_register_u2f_step1_test.go b/internal/handlers/handler_register_u2f_step1_test.go index b9c7c81b..5e6707e1 100644 --- a/internal/handlers/handler_register_u2f_step1_test.go +++ b/internal/handlers/handler_register_u2f_step1_test.go @@ -48,16 +48,16 @@ func createToken(secret, username, action string, expiresAt time.Time) (data str } func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() { - token, v := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration, + token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration, time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerification(s.mock.Ctx, gomock.Eq(v.JTI.String())). + s.mock.StorageMock.EXPECT(). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) - s.mock.StorageProviderMock.EXPECT(). - RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(v.JTI.String())). + s.mock.StorageMock.EXPECT(). + RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(nil) SecondFactorU2FIdentityFinish(s.mock.Ctx) @@ -68,16 +68,16 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() { s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") - token, v := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration, + token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration, time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). - FindIdentityVerification(s.mock.Ctx, gomock.Eq(v.JTI.String())). + s.mock.StorageMock.EXPECT(). + FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) - s.mock.StorageProviderMock.EXPECT(). - RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(v.JTI.String())). + s.mock.StorageMock.EXPECT(). + RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(nil) SecondFactorU2FIdentityFinish(s.mock.Ctx) diff --git a/internal/handlers/handler_sign_duo_test.go b/internal/handlers/handler_sign_duo_test.go index 26f067c6..3ab9d599 100644 --- a/internal/handlers/handler_sign_duo_test.go +++ b/internal/handlers/handler_sign_duo_test.go @@ -39,7 +39,7 @@ func (s *SecondFactorDuoPostSuite) TearDownTest() { func (s *SecondFactorDuoPostSuite) TestShouldEnroll() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(nil, errors.New("no Duo device and method saved")) @@ -69,7 +69,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldEnroll() { func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT().LoadPreferredDuoDevice(s.mock.Ctx, "john").Return(nil, errors.New("no Duo device and method saved")) + s.mock.StorageMock.EXPECT().LoadPreferredDuoDevice(s.mock.Ctx, "john").Return(nil, errors.New("no Duo device and method saved")) var duoDevices = []duo.Device{ {Capabilities: []string{"auto", "push", "sms", "mobile_otp"}, Number: " ", Device: "12345ABCDEFGHIJ67890", DisplayName: "Test Device 1"}, @@ -85,11 +85,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() { duoMock.EXPECT().PreAuthCall(s.mock.Ctx, gomock.Eq(values)).Return(&preAuthResponse, nil) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). SavePreferredDuoDevice(s.mock.Ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}). Return(nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -124,7 +124,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() { func (s *SecondFactorDuoPostSuite) TestShouldDenyAutoSelect() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(nil, errors.New("no Duo device and method saved")) @@ -156,7 +156,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldDenyAutoSelect() { func (s *SecondFactorDuoPostSuite) TestShouldFailAutoSelect() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(nil, errors.New("no Duo device and method saved")) @@ -174,7 +174,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldFailAutoSelect() { func (s *SecondFactorDuoPostSuite) TestShouldDeleteOldDeviceAndEnroll() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "NOTEXISTENT", Method: "push"}, nil) @@ -189,7 +189,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldDeleteOldDeviceAndEnroll() { duoMock.EXPECT().PreAuthCall(s.mock.Ctx, gomock.Eq(values)).Return(&preAuthResponse, nil) - s.mock.StorageProviderMock.EXPECT().DeletePreferredDuoDevice(s.mock.Ctx, "john").Return(nil) + s.mock.StorageMock.EXPECT().DeletePreferredDuoDevice(s.mock.Ctx, "john").Return(nil) bodyBytes, err := json.Marshal(signDuoRequestBody{}) s.Require().NoError(err) @@ -206,7 +206,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldDeleteOldDeviceAndEnroll() { func (s *SecondFactorDuoPostSuite) TestShouldDeleteOldDeviceAndCallPreauthAPIWithInvalidDevicesAndEnroll() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "NOTEXISTENT", Method: "push"}, nil) @@ -223,7 +223,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldDeleteOldDeviceAndCallPreauthAPIWit duoMock.EXPECT().PreAuthCall(s.mock.Ctx, gomock.Eq(values)).Return(&preAuthResponse, nil) - s.mock.StorageProviderMock.EXPECT().DeletePreferredDuoDevice(s.mock.Ctx, "john").Return(nil) + s.mock.StorageMock.EXPECT().DeletePreferredDuoDevice(s.mock.Ctx, "john").Return(nil) bodyBytes, err := json.Marshal(signDuoRequestBody{}) s.Require().NoError(err) @@ -239,7 +239,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldDeleteOldDeviceAndCallPreauthAPIWit func (s *SecondFactorDuoPostSuite) TestShouldUseOldDeviceAndSelect() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "NOTEXISTENT", Method: "push"}, nil) @@ -274,11 +274,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseOldDeviceAndSelect() { func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "invalidmethod"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -303,7 +303,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() { duoMock.EXPECT().PreAuthCall(s.mock.Ctx, gomock.Eq(values)).Return(&preAuthResponse, nil) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). SavePreferredDuoDevice(s.mock.Ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}). Return(nil) @@ -330,7 +330,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() { func (s *SecondFactorDuoPostSuite) TestShouldCallDuoPreauthAPIAndAllowAccess() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) @@ -354,7 +354,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoPreauthAPIAndAllowAccess() { func (s *SecondFactorDuoPostSuite) TestShouldCallDuoPreauthAPIAndDenyAccess() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) @@ -384,7 +384,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoPreauthAPIAndDenyAccess() { func (s *SecondFactorDuoPostSuite) TestShouldCallDuoPreauthAPIAndFail() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) @@ -402,11 +402,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoPreauthAPIAndFail() { func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndDenyAccess() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -454,7 +454,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndDenyAccess() { func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndFail() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) @@ -485,11 +485,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndFail() { func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToDefaultURL() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -534,11 +534,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToDefaultURL() { func (s *SecondFactorDuoPostSuite) TestShouldNotReturnRedirectURL() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -579,11 +579,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotReturnRedirectURL() { func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToSafeTargetURL() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -628,11 +628,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToSafeTargetURL() { func (s *SecondFactorDuoPostSuite) TestShouldNotRedirectToUnsafeURL() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -675,11 +675,11 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotRedirectToUnsafeURL() { func (s *SecondFactorDuoPostSuite) TestShouldRegenerateSessionForPreventingSessionFixation() { duoMock := mocks.NewMockAPI(s.mock.Ctrl) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadPreferredDuoDevice(s.mock.Ctx, "john"). Return(&models.DuoDevice{ID: 1, Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", diff --git a/internal/handlers/handler_sign_totp.go b/internal/handlers/handler_sign_totp.go index 31ce766e..7657cc2d 100644 --- a/internal/handlers/handler_sign_totp.go +++ b/internal/handlers/handler_sign_totp.go @@ -6,73 +6,71 @@ import ( ) // SecondFactorTOTPPost validate the TOTP passcode provided by the user. -func SecondFactorTOTPPost(totpVerifier TOTPVerifier) middlewares.RequestHandler { - return func(ctx *middlewares.AutheliaCtx) { - requestBody := signTOTPRequestBody{} +func SecondFactorTOTPPost(ctx *middlewares.AutheliaCtx) { + requestBody := signTOTPRequestBody{} - if err := ctx.ParseBody(&requestBody); err != nil { - ctx.Logger.Errorf(logFmtErrParseRequestBody, regulation.AuthTypeTOTP, err) + if err := ctx.ParseBody(&requestBody); err != nil { + ctx.Logger.Errorf(logFmtErrParseRequestBody, regulation.AuthTypeTOTP, err) - respondUnauthorized(ctx, messageMFAValidationFailed) + respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + return + } - userSession := ctx.GetSession() + userSession := ctx.GetSession() - config, err := ctx.Providers.StorageProvider.LoadTOTPConfiguration(ctx, userSession.Username) - if err != nil { - ctx.Logger.Errorf("Failed to load TOTP configuration: %+v", err) + config, err := ctx.Providers.StorageProvider.LoadTOTPConfiguration(ctx, userSession.Username) + if err != nil { + ctx.Logger.Errorf("Failed to load TOTP configuration: %+v", err) - respondUnauthorized(ctx, messageMFAValidationFailed) + respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + return + } - isValid, err := totpVerifier.Verify(config, requestBody.Token) - if err != nil { - ctx.Logger.Errorf("Failed to perform TOTP verification: %+v", err) + isValid, err := ctx.Providers.TOTP.Validate(requestBody.Token, config) + if err != nil { + ctx.Logger.Errorf("Failed to perform TOTP verification: %+v", err) - respondUnauthorized(ctx, messageMFAValidationFailed) + respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + return + } - if !isValid { - _ = markAuthenticationAttempt(ctx, false, nil, userSession.Username, regulation.AuthTypeTOTP, nil) + if !isValid { + _ = markAuthenticationAttempt(ctx, false, nil, userSession.Username, regulation.AuthTypeTOTP, nil) - respondUnauthorized(ctx, messageMFAValidationFailed) + respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + return + } - if err = markAuthenticationAttempt(ctx, true, nil, userSession.Username, regulation.AuthTypeTOTP, nil); err != nil { - respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + if err = markAuthenticationAttempt(ctx, true, nil, userSession.Username, regulation.AuthTypeTOTP, nil); err != nil { + respondUnauthorized(ctx, messageMFAValidationFailed) + return + } - if err = ctx.Providers.SessionProvider.RegenerateSession(ctx.RequestCtx); err != nil { - ctx.Logger.Errorf(logFmtErrSessionRegenerate, regulation.AuthTypeTOTP, userSession.Username, err) + if err = ctx.Providers.SessionProvider.RegenerateSession(ctx.RequestCtx); err != nil { + ctx.Logger.Errorf(logFmtErrSessionRegenerate, regulation.AuthTypeTOTP, userSession.Username, err) - respondUnauthorized(ctx, messageMFAValidationFailed) + respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + return + } - userSession.SetTwoFactor(ctx.Clock.Now()) + userSession.SetTwoFactor(ctx.Clock.Now()) - if err = ctx.SaveSession(userSession); err != nil { - ctx.Logger.Errorf(logFmtErrSessionSave, "authentication time", regulation.AuthTypeTOTP, userSession.Username, err) + if err = ctx.SaveSession(userSession); err != nil { + ctx.Logger.Errorf(logFmtErrSessionSave, "authentication time", regulation.AuthTypeTOTP, userSession.Username, err) - respondUnauthorized(ctx, messageMFAValidationFailed) + respondUnauthorized(ctx, messageMFAValidationFailed) - return - } + return + } - if userSession.OIDCWorkflowSession != nil { - handleOIDCWorkflowResponse(ctx) - } else { - Handle2FAResponse(ctx, requestBody.TargetURL) - } + if userSession.OIDCWorkflowSession != nil { + handleOIDCWorkflowResponse(ctx) + } else { + Handle2FAResponse(ctx, requestBody.TargetURL) } } diff --git a/internal/handlers/handler_sign_totp_test.go b/internal/handlers/handler_sign_totp_test.go index af5b4b47..b0163a0e 100644 --- a/internal/handlers/handler_sign_totp_test.go +++ b/internal/handlers/handler_sign_totp_test.go @@ -37,15 +37,13 @@ func (s *HandlerSignTOTPSuite) TearDownTest() { } func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() { - verifier := NewMockTOTPVerifier(s.mock.Ctrl) - config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: []byte("secret"), Period: 30, Algorithm: "SHA1"} - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). Return(&config, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -56,9 +54,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() { RemoteIP: models.NewIPAddressFromString("0.0.0.0"), })) - verifier.EXPECT(). - Verify(gomock.Eq(&config), gomock.Eq("abc")). - Return(true, nil) + s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil) s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL @@ -68,22 +64,20 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() { s.Require().NoError(err) s.mock.Ctx.Request.SetBody(bodyBytes) - SecondFactorTOTPPost(verifier)(s.mock.Ctx) + SecondFactorTOTPPost(s.mock.Ctx) s.mock.Assert200OK(s.T(), redirectResponse{ Redirect: testRedirectionURL, }) } func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() { - verifier := NewMockTOTPVerifier(s.mock.Ctrl) - config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: []byte("secret"), Period: 30, Algorithm: "SHA1"} - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). Return(&config, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -94,9 +88,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() { RemoteIP: models.NewIPAddressFromString("0.0.0.0"), })) - verifier.EXPECT(). - Verify(gomock.Eq(&config), gomock.Eq("abc")). - Return(true, nil) + s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ Token: "abc", @@ -104,20 +96,18 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() { s.Require().NoError(err) s.mock.Ctx.Request.SetBody(bodyBytes) - SecondFactorTOTPPost(verifier)(s.mock.Ctx) + SecondFactorTOTPPost(s.mock.Ctx) s.mock.Assert200OK(s.T(), nil) } func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() { - verifier := NewMockTOTPVerifier(s.mock.Ctrl) - config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: []byte("secret"), Period: 30, Algorithm: "SHA1"} - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). Return(&config, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -128,9 +118,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() { RemoteIP: models.NewIPAddressFromString("0.0.0.0"), })) - verifier.EXPECT(). - Verify(gomock.Eq(&config), gomock.Eq("abc")). - Return(true, nil) + s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ Token: "abc", @@ -139,20 +127,18 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() { s.Require().NoError(err) s.mock.Ctx.Request.SetBody(bodyBytes) - SecondFactorTOTPPost(verifier)(s.mock.Ctx) + SecondFactorTOTPPost(s.mock.Ctx) s.mock.Assert200OK(s.T(), redirectResponse{ Redirect: "https://mydomain.local", }) } func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() { - verifier := NewMockTOTPVerifier(s.mock.Ctrl) - - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). Return(&models.TOTPConfiguration{Secret: []byte("secret")}, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -163,31 +149,30 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() { RemoteIP: models.NewIPAddressFromString("0.0.0.0"), })) - verifier.EXPECT(). - Verify(gomock.Eq(&models.TOTPConfiguration{Secret: []byte("secret")}), gomock.Eq("abc")). + s.mock.TOTPMock.EXPECT(). + Validate(gomock.Eq("abc"), gomock.Eq(&models.TOTPConfiguration{Secret: []byte("secret")})). Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ Token: "abc", TargetURL: "http://mydomain.local", }) + s.Require().NoError(err) s.mock.Ctx.Request.SetBody(bodyBytes) - SecondFactorTOTPPost(verifier)(s.mock.Ctx) + SecondFactorTOTPPost(s.mock.Ctx) s.mock.Assert200OK(s.T(), nil) } func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFixation() { - verifier := NewMockTOTPVerifier(s.mock.Ctrl) - config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: []byte("secret"), Period: 30, Algorithm: "SHA1"} - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()). Return(&config, nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -198,8 +183,8 @@ func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFi RemoteIP: models.NewIPAddressFromString("0.0.0.0"), })) - verifier.EXPECT(). - Verify(gomock.Eq(&config), gomock.Eq("abc")). + s.mock.TOTPMock.EXPECT(). + Validate(gomock.Eq("abc"), gomock.Eq(&config)). Return(true, nil) bodyBytes, err := json.Marshal(signTOTPRequestBody{ @@ -211,7 +196,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFi r := regexp.MustCompile("^authelia_session=(.*); path=") res := r.FindAllStringSubmatch(string(s.mock.Ctx.Response.Header.PeekCookie("authelia_session")), -1) - SecondFactorTOTPPost(verifier)(s.mock.Ctx) + SecondFactorTOTPPost(s.mock.Ctx) s.mock.Assert200OK(s.T(), nil) s.Assert().NotEqual( diff --git a/internal/handlers/handler_sign_u2f_step2_test.go b/internal/handlers/handler_sign_u2f_step2_test.go index e921c113..8cb3c289 100644 --- a/internal/handlers/handler_sign_u2f_step2_test.go +++ b/internal/handlers/handler_sign_u2f_step2_test.go @@ -37,13 +37,13 @@ func (s *HandlerSignU2FStep2Suite) TearDownTest() { } func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToDefaultURL() { - u2fVerifier := NewMockU2FVerifier(s.mock.Ctrl) + u2fVerifier := mocks.NewMockU2FVerifier(s.mock.Ctrl) u2fVerifier.EXPECT(). Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -69,13 +69,13 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToDefaultURL() { } func (s *HandlerSignU2FStep2Suite) TestShouldNotReturnRedirectURL() { - u2fVerifier := NewMockU2FVerifier(s.mock.Ctrl) + u2fVerifier := mocks.NewMockU2FVerifier(s.mock.Ctrl) u2fVerifier.EXPECT(). Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -97,13 +97,13 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotReturnRedirectURL() { } func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToSafeTargetURL() { - u2fVerifier := NewMockU2FVerifier(s.mock.Ctrl) + u2fVerifier := mocks.NewMockU2FVerifier(s.mock.Ctrl) u2fVerifier.EXPECT(). Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -128,13 +128,13 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToSafeTargetURL() { } func (s *HandlerSignU2FStep2Suite) TestShouldNotRedirectToUnsafeURL() { - u2fVerifier := NewMockU2FVerifier(s.mock.Ctrl) + u2fVerifier := mocks.NewMockU2FVerifier(s.mock.Ctrl) u2fVerifier.EXPECT(). Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", @@ -157,13 +157,13 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotRedirectToUnsafeURL() { } func (s *HandlerSignU2FStep2Suite) TestShouldRegenerateSessionForPreventingSessionFixation() { - u2fVerifier := NewMockU2FVerifier(s.mock.Ctrl) + u2fVerifier := mocks.NewMockU2FVerifier(s.mock.Ctrl) u2fVerifier.EXPECT(). Verify(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) - s.mock.StorageProviderMock. + s.mock.StorageMock. EXPECT(). AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{ Username: "john", diff --git a/internal/handlers/handler_user_info.go b/internal/handlers/handler_user_info.go index f184a965..d3c78caf 100644 --- a/internal/handlers/handler_user_info.go +++ b/internal/handlers/handler_user_info.go @@ -27,14 +27,9 @@ func UserInfoGet(ctx *middlewares.AutheliaCtx) { } } -// MethodBody the selected 2FA method. -type MethodBody struct { - Method string `json:"method" valid:"required"` -} - // MethodPreferencePost update the user preferences regarding 2FA method. func MethodPreferencePost(ctx *middlewares.AutheliaCtx) { - bodyJSON := MethodBody{} + bodyJSON := preferred2FAMethodBody{} err := ctx.ParseBody(&bodyJSON) if err != nil { diff --git a/internal/handlers/handler_user_info_test.go b/internal/handlers/handler_user_info_test.go index c062f3f3..11e253a3 100644 --- a/internal/handlers/handler_user_info_test.go +++ b/internal/handlers/handler_user_info_test.go @@ -96,7 +96,7 @@ func TestMethodSetToU2F(t *testing.T) { err := mock.Ctx.SaveSession(userSession) require.NoError(t, err) - mock.StorageProviderMock. + mock.StorageMock. EXPECT(). LoadUserInfo(mock.Ctx, gomock.Eq("john")). Return(resp.db, resp.err) @@ -139,7 +139,7 @@ func TestMethodSetToU2F(t *testing.T) { } func (s *FetchSuite) TestShouldReturnError500WhenStorageFailsToLoad() { - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). LoadUserInfo(s.mock.Ctx, gomock.Eq("john")). Return(models.UserInfo{}, fmt.Errorf("failure")) @@ -211,7 +211,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenBadMethodProvided() { func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() { s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}")) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")). Return(fmt.Errorf("Failure")) @@ -224,7 +224,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() { func (s *SaveSuite) TestShouldReturn200WhenMethodIsSuccessfullySaved() { s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}")) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")). Return(nil) diff --git a/internal/handlers/handler_user_totp.go b/internal/handlers/handler_user_totp.go new file mode 100644 index 00000000..2f7344bc --- /dev/null +++ b/internal/handlers/handler_user_totp.go @@ -0,0 +1,34 @@ +package handlers + +import ( + "errors" + + "github.com/valyala/fasthttp" + + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/storage" +) + +// UserTOTPGet returns the users TOTP configuration. +func UserTOTPGet(ctx *middlewares.AutheliaCtx) { + userSession := ctx.GetSession() + + config, err := ctx.Providers.StorageProvider.LoadTOTPConfiguration(ctx, userSession.Username) + if err != nil { + if errors.Is(err, storage.ErrNoTOTPConfiguration) { + ctx.SetStatusCode(fasthttp.StatusNotFound) + ctx.Error(err, "No TOTP Configuration.") + } else { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.Error(err, "Unknown Error.") + } + + return + } + + if err = ctx.SetJSONBody(config); err != nil { + ctx.Logger.Errorf("Unable to perform TOTP configuration response: %s", err) + } + + ctx.SetStatusCode(fasthttp.StatusOK) +} diff --git a/internal/handlers/totp.go b/internal/handlers/totp.go deleted file mode 100644 index 817b7ab8..00000000 --- a/internal/handlers/totp.go +++ /dev/null @@ -1,64 +0,0 @@ -package handlers - -import ( - "errors" - "time" - - "github.com/pquerna/otp" - "github.com/pquerna/otp/totp" - - "github.com/authelia/authelia/v4/internal/models" -) - -// TOTPVerifier is the interface for verifying TOTPs. -type TOTPVerifier interface { - Verify(config *models.TOTPConfiguration, token string) (bool, error) -} - -// TOTPVerifierImpl the production implementation for TOTP verification. -type TOTPVerifierImpl struct { - Period uint - Skew uint -} - -// Verify verifies TOTPs. -func (tv *TOTPVerifierImpl) Verify(config *models.TOTPConfiguration, token string) (bool, error) { - if config == nil { - return false, errors.New("config not provided") - } - - opts := totp.ValidateOpts{ - Period: uint(config.Period), - Skew: tv.Skew, - Digits: otp.Digits(config.Digits), - Algorithm: otpStringToAlgo(config.Algorithm), - } - - return totp.ValidateCustom(token, string(config.Secret), time.Now().UTC(), opts) -} - -func otpAlgoToString(algorithm otp.Algorithm) (out string) { - switch algorithm { - case otp.AlgorithmSHA1: - return totpAlgoSHA1 - case otp.AlgorithmSHA256: - return totpAlgoSHA256 - case otp.AlgorithmSHA512: - return totpAlgoSHA512 - default: - return "" - } -} - -func otpStringToAlgo(in string) (algorithm otp.Algorithm) { - switch in { - case totpAlgoSHA1: - return otp.AlgorithmSHA1 - case totpAlgoSHA256: - return otp.AlgorithmSHA256 - case totpAlgoSHA512: - return otp.AlgorithmSHA512 - default: - return otp.AlgorithmSHA1 - } -} diff --git a/internal/handlers/totp_mock.go b/internal/handlers/totp_mock.go deleted file mode 100644 index 0da2a4e8..00000000 --- a/internal/handlers/totp_mock.go +++ /dev/null @@ -1,51 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: internal/handlers/totp.go - -// Package handlers is a generated GoMock package. -package handlers - -import ( - "reflect" - - "github.com/golang/mock/gomock" - - "github.com/authelia/authelia/v4/internal/models" -) - -// MockTOTPVerifier is a mock of TOTPVerifier interface. -type MockTOTPVerifier struct { - ctrl *gomock.Controller - recorder *MockTOTPVerifierMockRecorder -} - -// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier. -type MockTOTPVerifierMockRecorder struct { - mock *MockTOTPVerifier -} - -// NewMockTOTPVerifier creates a new mock instance. -func NewMockTOTPVerifier(ctrl *gomock.Controller) *MockTOTPVerifier { - mock := &MockTOTPVerifier{ctrl: ctrl} - mock.recorder = &MockTOTPVerifierMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTOTPVerifier) EXPECT() *MockTOTPVerifierMockRecorder { - return m.recorder -} - -// Verify mocks base method. -func (m *MockTOTPVerifier) Verify(arg0 *models.TOTPConfiguration, arg1 string) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Verify", arg0, arg1) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Verify indicates an expected call of Verify. -func (mr *MockTOTPVerifierMockRecorder) Verify(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), arg0, arg1) -} diff --git a/internal/handlers/types.go b/internal/handlers/types.go index 2082b62c..e26d9840 100644 --- a/internal/handlers/types.go +++ b/internal/handlers/types.go @@ -11,22 +11,10 @@ type MethodList = []string type authorizationMatching int -// UserInfo is the model of user info and second factor preferences. -type UserInfo struct { - // The users display name. - DisplayName string `json:"display_name"` - - // The preferred 2FA method. - Method string `json:"method" valid:"required"` - - // True if a security key has been registered. - HasU2F bool `json:"has_u2f" valid:"required"` - - // True if a TOTP device has been registered. - HasTOTP bool `json:"has_totp" valid:"required"` - - // True if a Duo device and method has been enrolled. - HasDuo bool `json:"has_duo" valid:"required"` +// configurationBody the content returned by the configuration endpoint. +type configurationBody struct { + AvailableMethods MethodList `json:"available_methods"` + SecondFactorEnabled bool `json:"second_factor_enabled"` // whether second factor is enabled or not. } // signTOTPRequestBody model of the request body received by TOTP authentication endpoint. @@ -46,6 +34,11 @@ type signDuoRequestBody struct { Passcode string `json:"passcode"` } +// preferred2FAMethodBody the selected 2FA method. +type preferred2FAMethodBody struct { + Method string `json:"method" valid:"required"` +} + // firstFactorRequestBody represents the JSON body received by the endpoint. type firstFactorRequestBody struct { Username string `json:"username" valid:"required"` diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go index 82af45d2..a557aeee 100644 --- a/internal/middlewares/identity_verification_test.go +++ b/internal/middlewares/identity_verification_test.go @@ -55,7 +55,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) { mock.Ctx.Configuration.JWTSecret = testJWTSecret - mock.StorageProviderMock.EXPECT(). + mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(fmt.Errorf("cannot save")) @@ -74,7 +74,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) { mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") - mock.StorageProviderMock.EXPECT(). + mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) @@ -96,7 +96,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) { mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") - mock.StorageProviderMock.EXPECT(). + mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) @@ -114,7 +114,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) { mock.Ctx.Configuration.JWTSecret = testJWTSecret mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") - mock.StorageProviderMock.EXPECT(). + mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) @@ -132,7 +132,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) { mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http") mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host") - mock.StorageProviderMock.EXPECT(). + mock.StorageMock.EXPECT(). SaveIdentityVerification(mock.Ctx, gomock.Any()). Return(nil) @@ -208,7 +208,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB( s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(false, nil) @@ -244,7 +244,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() { time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) @@ -259,7 +259,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() { time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) @@ -276,11 +276,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(fmt.Errorf("cannot remove")) @@ -295,11 +295,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete( time.Now().Add(1*time.Minute)) s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token)) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(true, nil) - s.mock.StorageProviderMock.EXPECT(). + s.mock.StorageMock.EXPECT(). RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())). Return(nil) diff --git a/internal/middlewares/types.go b/internal/middlewares/types.go index 498e62fd..585f7cdf 100644 --- a/internal/middlewares/types.go +++ b/internal/middlewares/types.go @@ -13,6 +13,7 @@ import ( "github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/totp" "github.com/authelia/authelia/v4/internal/utils" ) @@ -37,6 +38,7 @@ type Providers struct { UserProvider authentication.UserProvider StorageProvider storage.Provider Notifier notification.Notifier + TOTP totp.Provider } // RequestHandler represents an Authelia request handler. diff --git a/internal/mocks/mock_authelia_ctx.go b/internal/mocks/authelia_ctx.go similarity index 94% rename from internal/mocks/mock_authelia_ctx.go rename to internal/mocks/authelia_ctx.go index e5626515..d0f694f7 100644 --- a/internal/mocks/mock_authelia_ctx.go +++ b/internal/mocks/authelia_ctx.go @@ -18,7 +18,6 @@ import ( "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/regulation" "github.com/authelia/authelia/v4/internal/session" - "github.com/authelia/authelia/v4/internal/storage" ) // MockAutheliaCtx a mock of AutheliaCtx. @@ -29,9 +28,10 @@ type MockAutheliaCtx struct { Ctrl *gomock.Controller // Providers. - UserProviderMock *MockUserProvider - StorageProviderMock *storage.MockProvider - NotifierMock *MockNotifier + UserProviderMock *MockUserProvider + StorageMock *MockStorage + NotifierMock *MockNotifier + TOTPMock *MockTOTP UserSession *session.UserSession @@ -98,8 +98,8 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx { mockAuthelia.UserProviderMock = NewMockUserProvider(mockAuthelia.Ctrl) providers.UserProvider = mockAuthelia.UserProviderMock - mockAuthelia.StorageProviderMock = storage.NewMockProvider(mockAuthelia.Ctrl) - providers.StorageProvider = mockAuthelia.StorageProviderMock + mockAuthelia.StorageMock = NewMockStorage(mockAuthelia.Ctrl) + providers.StorageProvider = mockAuthelia.StorageMock mockAuthelia.NotifierMock = NewMockNotifier(mockAuthelia.Ctrl) providers.Notifier = mockAuthelia.NotifierMock @@ -112,6 +112,9 @@ func NewMockAutheliaCtx(t *testing.T) *MockAutheliaCtx { providers.Regulator = regulation.NewRegulator(configuration.Regulation, providers.StorageProvider, &mockAuthelia.Clock) + mockAuthelia.TOTPMock = NewMockTOTP(mockAuthelia.Ctrl) + providers.TOTP = mockAuthelia.TOTPMock + request := &fasthttp.RequestCtx{} // Set a cookie to identify this client throughout the test. // request.Request.Header.SetCookie("authelia_session", "client_cookie") diff --git a/internal/mocks/mock_duo_api.go b/internal/mocks/duo_api.go similarity index 100% rename from internal/mocks/mock_duo_api.go rename to internal/mocks/duo_api.go diff --git a/internal/mocks/generate.go b/internal/mocks/generate.go new file mode 100644 index 00000000..884f0c06 --- /dev/null +++ b/internal/mocks/generate.go @@ -0,0 +1,11 @@ +package mocks + +// This file is used to generate mocks. You can generate all mocks using the +// command `go generate github.com/authelia/authelia/v4/internal/mocks`. + +//go:generate mockgen -package mocks -destination user_provider.go -mock_names UserProvider=MockUserProvider github.com/authelia/authelia/v4/internal/authentication UserProvider +//go:generate mockgen -package mocks -destination notifier.go -mock_names Notifier=MockNotifier github.com/authelia/authelia/v4/internal/notification Notifier +//go:generate mockgen -package mocks -destination totp.go -mock_names Provider=MockTOTP github.com/authelia/authelia/v4/internal/totp Provider +//go:generate mockgen -package mocks -destination u2f_verifier.go -mock_names U2FVerifier=MockU2FVerifier github.com/authelia/authelia/v4/internal/handlers U2FVerifier +//go:generate mockgen -package mocks -destination storage.go -mock_names Provider=MockStorage github.com/authelia/authelia/v4/internal/storage Provider +//go:generate mockgen -package mocks -destination duo_api.go -mock_names API=MockAPI github.com/authelia/authelia/v4/internal/duo API diff --git a/internal/mocks/mock_notifier.go b/internal/mocks/notifier.go similarity index 96% rename from internal/mocks/mock_notifier.go rename to internal/mocks/notifier.go index 05ca829f..65ace6ad 100644 --- a/internal/mocks/mock_notifier.go +++ b/internal/mocks/notifier.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/authelia/authelia/v4/internal/notification (interfaces: Notifier) -// Package mock_notification is a generated GoMock package. +// Package mocks is a generated GoMock package. package mocks import ( diff --git a/internal/storage/provider_mock.go b/internal/mocks/storage.go similarity index 57% rename from internal/storage/provider_mock.go rename to internal/mocks/storage.go index a3aca1e2..6212c17f 100644 --- a/internal/storage/provider_mock.go +++ b/internal/mocks/storage.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/authelia/authelia/v4/internal/storage (interfaces: Provider) -// Package storage is a generated GoMock package. -package storage +// Package mocks is a generated GoMock package. +package mocks import ( context "context" @@ -14,31 +14,31 @@ import ( models "github.com/authelia/authelia/v4/internal/models" ) -// MockProvider is a mock of Provider interface. -type MockProvider struct { +// MockStorage is a mock of Provider interface. +type MockStorage struct { ctrl *gomock.Controller - recorder *MockProviderMockRecorder + recorder *MockStorageMockRecorder } -// MockProviderMockRecorder is the mock recorder for MockProvider. -type MockProviderMockRecorder struct { - mock *MockProvider +// MockStorageMockRecorder is the mock recorder for MockStorage. +type MockStorageMockRecorder struct { + mock *MockStorage } -// NewMockProvider creates a new mock instance. -func NewMockProvider(ctrl *gomock.Controller) *MockProvider { - mock := &MockProvider{ctrl: ctrl} - mock.recorder = &MockProviderMockRecorder{mock} +// NewMockStorage creates a new mock instance. +func NewMockStorage(ctrl *gomock.Controller) *MockStorage { + mock := &MockStorage{ctrl: ctrl} + mock.recorder = &MockStorageMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockProvider) EXPECT() *MockProviderMockRecorder { +func (m *MockStorage) EXPECT() *MockStorageMockRecorder { return m.recorder } // AppendAuthenticationLog mocks base method. -func (m *MockProvider) AppendAuthenticationLog(arg0 context.Context, arg1 models.AuthenticationAttempt) error { +func (m *MockStorage) AppendAuthenticationLog(arg0 context.Context, arg1 models.AuthenticationAttempt) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AppendAuthenticationLog", arg0, arg1) ret0, _ := ret[0].(error) @@ -46,13 +46,13 @@ func (m *MockProvider) AppendAuthenticationLog(arg0 context.Context, arg1 models } // AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog. -func (mr *MockProviderMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorage)(nil).AppendAuthenticationLog), arg0, arg1) } // Close mocks base method. -func (m *MockProvider) Close() error { +func (m *MockStorage) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") ret0, _ := ret[0].(error) @@ -60,13 +60,13 @@ func (m *MockProvider) Close() error { } // Close indicates an expected call of Close. -func (mr *MockProviderMockRecorder) Close() *gomock.Call { +func (mr *MockStorageMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockProvider)(nil).Close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close)) } // DeletePreferredDuoDevice mocks base method. -func (m *MockProvider) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error { +func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeletePreferredDuoDevice", arg0, arg1) ret0, _ := ret[0].(error) @@ -74,13 +74,13 @@ func (m *MockProvider) DeletePreferredDuoDevice(arg0 context.Context, arg1 strin } // DeletePreferredDuoDevice indicates an expected call of DeletePreferredDuoDevice. -func (mr *MockProviderMockRecorder) DeletePreferredDuoDevice(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) DeletePreferredDuoDevice(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePreferredDuoDevice", reflect.TypeOf((*MockProvider)(nil).DeletePreferredDuoDevice), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePreferredDuoDevice", reflect.TypeOf((*MockStorage)(nil).DeletePreferredDuoDevice), arg0, arg1) } // DeleteTOTPConfiguration mocks base method. -func (m *MockProvider) DeleteTOTPConfiguration(arg0 context.Context, arg1 string) error { +func (m *MockStorage) DeleteTOTPConfiguration(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "DeleteTOTPConfiguration", arg0, arg1) ret0, _ := ret[0].(error) @@ -88,13 +88,13 @@ func (m *MockProvider) DeleteTOTPConfiguration(arg0 context.Context, arg1 string } // DeleteTOTPConfiguration indicates an expected call of DeleteTOTPConfiguration. -func (mr *MockProviderMockRecorder) DeleteTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) DeleteTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPConfiguration), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).DeleteTOTPConfiguration), arg0, arg1) } // FindIdentityVerification mocks base method. -func (m *MockProvider) FindIdentityVerification(arg0 context.Context, arg1 string) (bool, error) { +func (m *MockStorage) FindIdentityVerification(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FindIdentityVerification", arg0, arg1) ret0, _ := ret[0].(bool) @@ -103,13 +103,13 @@ func (m *MockProvider) FindIdentityVerification(arg0 context.Context, arg1 strin } // FindIdentityVerification indicates an expected call of FindIdentityVerification. -func (mr *MockProviderMockRecorder) FindIdentityVerification(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) FindIdentityVerification(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerification", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerification), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerification", reflect.TypeOf((*MockStorage)(nil).FindIdentityVerification), arg0, arg1) } // LoadAuthenticationLogs mocks base method. -func (m *MockProvider) LoadAuthenticationLogs(arg0 context.Context, arg1 string, arg2 time.Time, arg3, arg4 int) ([]models.AuthenticationAttempt, error) { +func (m *MockStorage) LoadAuthenticationLogs(arg0 context.Context, arg1 string, arg2 time.Time, arg3, arg4 int) ([]models.AuthenticationAttempt, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadAuthenticationLogs", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].([]models.AuthenticationAttempt) @@ -118,13 +118,13 @@ func (m *MockProvider) LoadAuthenticationLogs(arg0 context.Context, arg1 string, } // LoadAuthenticationLogs indicates an expected call of LoadAuthenticationLogs. -func (mr *MockProviderMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4) } // LoadPreferred2FAMethod mocks base method. -func (m *MockProvider) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { +func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", arg0, arg1) ret0, _ := ret[0].(string) @@ -133,13 +133,13 @@ func (m *MockProvider) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) } // LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod. -func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadPreferred2FAMethod(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockStorage)(nil).LoadPreferred2FAMethod), arg0, arg1) } // LoadPreferredDuoDevice mocks base method. -func (m *MockProvider) LoadPreferredDuoDevice(arg0 context.Context, arg1 string) (*models.DuoDevice, error) { +func (m *MockStorage) LoadPreferredDuoDevice(arg0 context.Context, arg1 string) (*models.DuoDevice, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadPreferredDuoDevice", arg0, arg1) ret0, _ := ret[0].(*models.DuoDevice) @@ -148,13 +148,13 @@ func (m *MockProvider) LoadPreferredDuoDevice(arg0 context.Context, arg1 string) } // LoadPreferredDuoDevice indicates an expected call of LoadPreferredDuoDevice. -func (mr *MockProviderMockRecorder) LoadPreferredDuoDevice(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadPreferredDuoDevice(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferredDuoDevice", reflect.TypeOf((*MockProvider)(nil).LoadPreferredDuoDevice), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferredDuoDevice", reflect.TypeOf((*MockStorage)(nil).LoadPreferredDuoDevice), arg0, arg1) } // LoadTOTPConfiguration mocks base method. -func (m *MockProvider) LoadTOTPConfiguration(arg0 context.Context, arg1 string) (*models.TOTPConfiguration, error) { +func (m *MockStorage) LoadTOTPConfiguration(arg0 context.Context, arg1 string) (*models.TOTPConfiguration, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadTOTPConfiguration", arg0, arg1) ret0, _ := ret[0].(*models.TOTPConfiguration) @@ -163,13 +163,13 @@ func (m *MockProvider) LoadTOTPConfiguration(arg0 context.Context, arg1 string) } // LoadTOTPConfiguration indicates an expected call of LoadTOTPConfiguration. -func (mr *MockProviderMockRecorder) LoadTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).LoadTOTPConfiguration), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).LoadTOTPConfiguration), arg0, arg1) } // LoadTOTPConfigurations mocks base method. -func (m *MockProvider) LoadTOTPConfigurations(arg0 context.Context, arg1, arg2 int) ([]models.TOTPConfiguration, error) { +func (m *MockStorage) LoadTOTPConfigurations(arg0 context.Context, arg1, arg2 int) ([]models.TOTPConfiguration, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadTOTPConfigurations", arg0, arg1, arg2) ret0, _ := ret[0].([]models.TOTPConfiguration) @@ -178,13 +178,13 @@ func (m *MockProvider) LoadTOTPConfigurations(arg0 context.Context, arg1, arg2 i } // LoadTOTPConfigurations indicates an expected call of LoadTOTPConfigurations. -func (mr *MockProviderMockRecorder) LoadTOTPConfigurations(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadTOTPConfigurations(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfigurations", reflect.TypeOf((*MockProvider)(nil).LoadTOTPConfigurations), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfigurations", reflect.TypeOf((*MockStorage)(nil).LoadTOTPConfigurations), arg0, arg1, arg2) } // LoadU2FDevice mocks base method. -func (m *MockProvider) LoadU2FDevice(arg0 context.Context, arg1 string) (*models.U2FDevice, error) { +func (m *MockStorage) LoadU2FDevice(arg0 context.Context, arg1 string) (*models.U2FDevice, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadU2FDevice", arg0, arg1) ret0, _ := ret[0].(*models.U2FDevice) @@ -193,13 +193,13 @@ func (m *MockProvider) LoadU2FDevice(arg0 context.Context, arg1 string) (*models } // LoadU2FDevice indicates an expected call of LoadU2FDevice. -func (mr *MockProviderMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockProvider)(nil).LoadU2FDevice), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevice), arg0, arg1) } // LoadUserInfo mocks base method. -func (m *MockProvider) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) { +func (m *MockStorage) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadUserInfo", arg0, arg1) ret0, _ := ret[0].(models.UserInfo) @@ -208,13 +208,13 @@ func (m *MockProvider) LoadUserInfo(arg0 context.Context, arg1 string) (models.U } // LoadUserInfo indicates an expected call of LoadUserInfo. -func (mr *MockProviderMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockProvider)(nil).LoadUserInfo), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1) } // RemoveIdentityVerification mocks base method. -func (m *MockProvider) RemoveIdentityVerification(arg0 context.Context, arg1 string) error { +func (m *MockStorage) RemoveIdentityVerification(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1) ret0, _ := ret[0].(error) @@ -222,13 +222,13 @@ func (m *MockProvider) RemoveIdentityVerification(arg0 context.Context, arg1 str } // RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification. -func (mr *MockProviderMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerification), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).RemoveIdentityVerification), arg0, arg1) } // SaveIdentityVerification mocks base method. -func (m *MockProvider) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error { +func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveIdentityVerification", arg0, arg1) ret0, _ := ret[0].(error) @@ -236,13 +236,13 @@ func (m *MockProvider) SaveIdentityVerification(arg0 context.Context, arg1 model } // SaveIdentityVerification indicates an expected call of SaveIdentityVerification. -func (mr *MockProviderMockRecorder) SaveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SaveIdentityVerification(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerification), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).SaveIdentityVerification), arg0, arg1) } // SavePreferred2FAMethod mocks base method. -func (m *MockProvider) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { +func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SavePreferred2FAMethod", arg0, arg1, arg2) ret0, _ := ret[0].(error) @@ -250,13 +250,13 @@ func (m *MockProvider) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 s } // SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod. -func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SavePreferred2FAMethod(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockStorage)(nil).SavePreferred2FAMethod), arg0, arg1, arg2) } // SavePreferredDuoDevice mocks base method. -func (m *MockProvider) SavePreferredDuoDevice(arg0 context.Context, arg1 models.DuoDevice) error { +func (m *MockStorage) SavePreferredDuoDevice(arg0 context.Context, arg1 models.DuoDevice) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SavePreferredDuoDevice", arg0, arg1) ret0, _ := ret[0].(error) @@ -264,13 +264,13 @@ func (m *MockProvider) SavePreferredDuoDevice(arg0 context.Context, arg1 models. } // SavePreferredDuoDevice indicates an expected call of SavePreferredDuoDevice. -func (mr *MockProviderMockRecorder) SavePreferredDuoDevice(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SavePreferredDuoDevice(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferredDuoDevice", reflect.TypeOf((*MockProvider)(nil).SavePreferredDuoDevice), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferredDuoDevice", reflect.TypeOf((*MockStorage)(nil).SavePreferredDuoDevice), arg0, arg1) } // SaveTOTPConfiguration mocks base method. -func (m *MockProvider) SaveTOTPConfiguration(arg0 context.Context, arg1 models.TOTPConfiguration) error { +func (m *MockStorage) SaveTOTPConfiguration(arg0 context.Context, arg1 models.TOTPConfiguration) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveTOTPConfiguration", arg0, arg1) ret0, _ := ret[0].(error) @@ -278,13 +278,13 @@ func (m *MockProvider) SaveTOTPConfiguration(arg0 context.Context, arg1 models.T } // SaveTOTPConfiguration indicates an expected call of SaveTOTPConfiguration. -func (mr *MockProviderMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).SaveTOTPConfiguration), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).SaveTOTPConfiguration), arg0, arg1) } // SaveU2FDevice mocks base method. -func (m *MockProvider) SaveU2FDevice(arg0 context.Context, arg1 models.U2FDevice) error { +func (m *MockStorage) SaveU2FDevice(arg0 context.Context, arg1 models.U2FDevice) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveU2FDevice", arg0, arg1) ret0, _ := ret[0].(error) @@ -292,13 +292,13 @@ func (m *MockProvider) SaveU2FDevice(arg0 context.Context, arg1 models.U2FDevice } // SaveU2FDevice indicates an expected call of SaveU2FDevice. -func (mr *MockProviderMockRecorder) SaveU2FDevice(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SaveU2FDevice(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDevice", reflect.TypeOf((*MockProvider)(nil).SaveU2FDevice), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDevice", reflect.TypeOf((*MockStorage)(nil).SaveU2FDevice), arg0, arg1) } // SchemaEncryptionChangeKey mocks base method. -func (m *MockProvider) SchemaEncryptionChangeKey(arg0 context.Context, arg1 string) error { +func (m *MockStorage) SchemaEncryptionChangeKey(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaEncryptionChangeKey", arg0, arg1) ret0, _ := ret[0].(error) @@ -306,13 +306,13 @@ func (m *MockProvider) SchemaEncryptionChangeKey(arg0 context.Context, arg1 stri } // SchemaEncryptionChangeKey indicates an expected call of SchemaEncryptionChangeKey. -func (mr *MockProviderMockRecorder) SchemaEncryptionChangeKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaEncryptionChangeKey(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaEncryptionChangeKey", reflect.TypeOf((*MockProvider)(nil).SchemaEncryptionChangeKey), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaEncryptionChangeKey", reflect.TypeOf((*MockStorage)(nil).SchemaEncryptionChangeKey), arg0, arg1) } // SchemaEncryptionCheckKey mocks base method. -func (m *MockProvider) SchemaEncryptionCheckKey(arg0 context.Context, arg1 bool) error { +func (m *MockStorage) SchemaEncryptionCheckKey(arg0 context.Context, arg1 bool) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaEncryptionCheckKey", arg0, arg1) ret0, _ := ret[0].(error) @@ -320,13 +320,13 @@ func (m *MockProvider) SchemaEncryptionCheckKey(arg0 context.Context, arg1 bool) } // SchemaEncryptionCheckKey indicates an expected call of SchemaEncryptionCheckKey. -func (mr *MockProviderMockRecorder) SchemaEncryptionCheckKey(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaEncryptionCheckKey(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaEncryptionCheckKey", reflect.TypeOf((*MockProvider)(nil).SchemaEncryptionCheckKey), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaEncryptionCheckKey", reflect.TypeOf((*MockStorage)(nil).SchemaEncryptionCheckKey), arg0, arg1) } // SchemaLatestVersion mocks base method. -func (m *MockProvider) SchemaLatestVersion() (int, error) { +func (m *MockStorage) SchemaLatestVersion() (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaLatestVersion") ret0, _ := ret[0].(int) @@ -335,13 +335,13 @@ func (m *MockProvider) SchemaLatestVersion() (int, error) { } // SchemaLatestVersion indicates an expected call of SchemaLatestVersion. -func (mr *MockProviderMockRecorder) SchemaLatestVersion() *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaLatestVersion() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaLatestVersion", reflect.TypeOf((*MockProvider)(nil).SchemaLatestVersion)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaLatestVersion", reflect.TypeOf((*MockStorage)(nil).SchemaLatestVersion)) } // SchemaMigrate mocks base method. -func (m *MockProvider) SchemaMigrate(arg0 context.Context, arg1 bool, arg2 int) error { +func (m *MockStorage) SchemaMigrate(arg0 context.Context, arg1 bool, arg2 int) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaMigrate", arg0, arg1, arg2) ret0, _ := ret[0].(error) @@ -349,13 +349,13 @@ func (m *MockProvider) SchemaMigrate(arg0 context.Context, arg1 bool, arg2 int) } // SchemaMigrate indicates an expected call of SchemaMigrate. -func (mr *MockProviderMockRecorder) SchemaMigrate(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaMigrate(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrate", reflect.TypeOf((*MockProvider)(nil).SchemaMigrate), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrate", reflect.TypeOf((*MockStorage)(nil).SchemaMigrate), arg0, arg1, arg2) } // SchemaMigrationHistory mocks base method. -func (m *MockProvider) SchemaMigrationHistory(arg0 context.Context) ([]models.Migration, error) { +func (m *MockStorage) SchemaMigrationHistory(arg0 context.Context) ([]models.Migration, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaMigrationHistory", arg0) ret0, _ := ret[0].([]models.Migration) @@ -364,43 +364,43 @@ func (m *MockProvider) SchemaMigrationHistory(arg0 context.Context) ([]models.Mi } // SchemaMigrationHistory indicates an expected call of SchemaMigrationHistory. -func (mr *MockProviderMockRecorder) SchemaMigrationHistory(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaMigrationHistory(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationHistory", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationHistory), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationHistory", reflect.TypeOf((*MockStorage)(nil).SchemaMigrationHistory), arg0) } // SchemaMigrationsDown mocks base method. -func (m *MockProvider) SchemaMigrationsDown(arg0 context.Context, arg1 int) ([]SchemaMigration, error) { +func (m *MockStorage) SchemaMigrationsDown(arg0 context.Context, arg1 int) ([]models.SchemaMigration, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaMigrationsDown", arg0, arg1) - ret0, _ := ret[0].([]SchemaMigration) + ret0, _ := ret[0].([]models.SchemaMigration) ret1, _ := ret[1].(error) return ret0, ret1 } // SchemaMigrationsDown indicates an expected call of SchemaMigrationsDown. -func (mr *MockProviderMockRecorder) SchemaMigrationsDown(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaMigrationsDown(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsDown", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsDown), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsDown", reflect.TypeOf((*MockStorage)(nil).SchemaMigrationsDown), arg0, arg1) } // SchemaMigrationsUp mocks base method. -func (m *MockProvider) SchemaMigrationsUp(arg0 context.Context, arg1 int) ([]SchemaMigration, error) { +func (m *MockStorage) SchemaMigrationsUp(arg0 context.Context, arg1 int) ([]models.SchemaMigration, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaMigrationsUp", arg0, arg1) - ret0, _ := ret[0].([]SchemaMigration) + ret0, _ := ret[0].([]models.SchemaMigration) ret1, _ := ret[1].(error) return ret0, ret1 } // SchemaMigrationsUp indicates an expected call of SchemaMigrationsUp. -func (mr *MockProviderMockRecorder) SchemaMigrationsUp(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaMigrationsUp(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsUp", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsUp), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsUp", reflect.TypeOf((*MockStorage)(nil).SchemaMigrationsUp), arg0, arg1) } // SchemaTables mocks base method. -func (m *MockProvider) SchemaTables(arg0 context.Context) ([]string, error) { +func (m *MockStorage) SchemaTables(arg0 context.Context) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaTables", arg0) ret0, _ := ret[0].([]string) @@ -409,13 +409,13 @@ func (m *MockProvider) SchemaTables(arg0 context.Context) ([]string, error) { } // SchemaTables indicates an expected call of SchemaTables. -func (mr *MockProviderMockRecorder) SchemaTables(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaTables(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaTables", reflect.TypeOf((*MockProvider)(nil).SchemaTables), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaTables", reflect.TypeOf((*MockStorage)(nil).SchemaTables), arg0) } // SchemaVersion mocks base method. -func (m *MockProvider) SchemaVersion(arg0 context.Context) (int, error) { +func (m *MockStorage) SchemaVersion(arg0 context.Context) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SchemaVersion", arg0) ret0, _ := ret[0].(int) @@ -424,13 +424,13 @@ func (m *MockProvider) SchemaVersion(arg0 context.Context) (int, error) { } // SchemaVersion indicates an expected call of SchemaVersion. -func (mr *MockProviderMockRecorder) SchemaVersion(arg0 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) SchemaVersion(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaVersion", reflect.TypeOf((*MockProvider)(nil).SchemaVersion), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaVersion", reflect.TypeOf((*MockStorage)(nil).SchemaVersion), arg0) } // StartupCheck mocks base method. -func (m *MockProvider) StartupCheck() error { +func (m *MockStorage) StartupCheck() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "StartupCheck") ret0, _ := ret[0].(error) @@ -438,13 +438,13 @@ func (m *MockProvider) StartupCheck() error { } // StartupCheck indicates an expected call of StartupCheck. -func (mr *MockProviderMockRecorder) StartupCheck() *gomock.Call { +func (mr *MockStorageMockRecorder) StartupCheck() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockProvider)(nil).StartupCheck)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockStorage)(nil).StartupCheck)) } // UpdateTOTPConfigurationSecret mocks base method. -func (m *MockProvider) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 models.TOTPConfiguration) error { +func (m *MockStorage) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 models.TOTPConfiguration) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateTOTPConfigurationSecret", arg0, arg1) ret0, _ := ret[0].(error) @@ -452,7 +452,7 @@ func (m *MockProvider) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 } // UpdateTOTPConfigurationSecret indicates an expected call of UpdateTOTPConfigurationSecret. -func (mr *MockProviderMockRecorder) UpdateTOTPConfigurationSecret(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockStorageMockRecorder) UpdateTOTPConfigurationSecret(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTOTPConfigurationSecret", reflect.TypeOf((*MockProvider)(nil).UpdateTOTPConfigurationSecret), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTOTPConfigurationSecret", reflect.TypeOf((*MockStorage)(nil).UpdateTOTPConfigurationSecret), arg0, arg1) } diff --git a/internal/mocks/totp.go b/internal/mocks/totp.go new file mode 100644 index 00000000..3e2a46c5 --- /dev/null +++ b/internal/mocks/totp.go @@ -0,0 +1,81 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/authelia/authelia/v4/internal/totp (interfaces: Provider) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + + models "github.com/authelia/authelia/v4/internal/models" +) + +// MockTOTP is a mock of Provider interface. +type MockTOTP struct { + ctrl *gomock.Controller + recorder *MockTOTPMockRecorder +} + +// MockTOTPMockRecorder is the mock recorder for MockTOTP. +type MockTOTPMockRecorder struct { + mock *MockTOTP +} + +// NewMockTOTP creates a new mock instance. +func NewMockTOTP(ctrl *gomock.Controller) *MockTOTP { + mock := &MockTOTP{ctrl: ctrl} + mock.recorder = &MockTOTPMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTOTP) EXPECT() *MockTOTPMockRecorder { + return m.recorder +} + +// Generate mocks base method. +func (m *MockTOTP) Generate(arg0 string) (*models.TOTPConfiguration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Generate", arg0) + ret0, _ := ret[0].(*models.TOTPConfiguration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Generate indicates an expected call of Generate. +func (mr *MockTOTPMockRecorder) Generate(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Generate", reflect.TypeOf((*MockTOTP)(nil).Generate), arg0) +} + +// GenerateCustom mocks base method. +func (m *MockTOTP) GenerateCustom(arg0, arg1 string, arg2, arg3, arg4 uint) (*models.TOTPConfiguration, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GenerateCustom", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(*models.TOTPConfiguration) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GenerateCustom indicates an expected call of GenerateCustom. +func (mr *MockTOTPMockRecorder) GenerateCustom(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateCustom", reflect.TypeOf((*MockTOTP)(nil).GenerateCustom), arg0, arg1, arg2, arg3, arg4) +} + +// Validate mocks base method. +func (m *MockTOTP) Validate(arg0 string, arg1 *models.TOTPConfiguration) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Validate", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Validate indicates an expected call of Validate. +func (mr *MockTOTPMockRecorder) Validate(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockTOTP)(nil).Validate), arg0, arg1) +} diff --git a/internal/handlers/u2f_mock.go b/internal/mocks/u2f_verifier.go similarity index 54% rename from internal/handlers/u2f_mock.go rename to internal/mocks/u2f_verifier.go index 151adff6..85715b55 100644 --- a/internal/handlers/u2f_mock.go +++ b/internal/mocks/u2f_verifier.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: internal/handlers/u2f.go +// Source: github.com/authelia/authelia/v4/internal/handlers (interfaces: U2FVerifier) -// Package handlers is a generated GoMock package. -package handlers +// Package mocks is a generated GoMock package. +package mocks import ( reflect "reflect" @@ -11,39 +11,39 @@ import ( u2f "github.com/tstranex/u2f" ) -// MockU2FVerifier is a mock of U2FVerifier interface +// MockU2FVerifier is a mock of U2FVerifier interface. type MockU2FVerifier struct { ctrl *gomock.Controller recorder *MockU2FVerifierMockRecorder } -// MockU2FVerifierMockRecorder is the mock recorder for MockU2FVerifier +// MockU2FVerifierMockRecorder is the mock recorder for MockU2FVerifier. type MockU2FVerifierMockRecorder struct { mock *MockU2FVerifier } -// NewMockU2FVerifier creates a new mock instance +// NewMockU2FVerifier creates a new mock instance. func NewMockU2FVerifier(ctrl *gomock.Controller) *MockU2FVerifier { mock := &MockU2FVerifier{ctrl: ctrl} mock.recorder = &MockU2FVerifierMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockU2FVerifier) EXPECT() *MockU2FVerifierMockRecorder { return m.recorder } -// Verify mocks base method -func (m *MockU2FVerifier) Verify(keyHandle, publicKey []byte, signResponse u2f.SignResponse, challenge u2f.Challenge) error { +// Verify mocks base method. +func (m *MockU2FVerifier) Verify(arg0, arg1 []byte, arg2 u2f.SignResponse, arg3 u2f.Challenge) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Verify", keyHandle, publicKey, signResponse, challenge) + ret := m.ctrl.Call(m, "Verify", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } -// Verify indicates an expected call of Verify -func (mr *MockU2FVerifierMockRecorder) Verify(keyHandle, publicKey, signResponse, challenge interface{}) *gomock.Call { +// Verify indicates an expected call of Verify. +func (mr *MockU2FVerifierMockRecorder) Verify(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockU2FVerifier)(nil).Verify), keyHandle, publicKey, signResponse, challenge) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockU2FVerifier)(nil).Verify), arg0, arg1, arg2, arg3) } diff --git a/internal/mocks/mock_user_provider.go b/internal/mocks/user_provider.go similarity index 100% rename from internal/mocks/mock_user_provider.go rename to internal/mocks/user_provider.go diff --git a/internal/models/model_authentication_attempt.go b/internal/models/authentication_attempt.go similarity index 100% rename from internal/models/model_authentication_attempt.go rename to internal/models/authentication_attempt.go diff --git a/internal/models/model_identity_verification.go b/internal/models/identity_verification.go similarity index 100% rename from internal/models/model_identity_verification.go rename to internal/models/identity_verification.go diff --git a/internal/models/model_migration.go b/internal/models/migration.go similarity index 100% rename from internal/models/model_migration.go rename to internal/models/migration.go diff --git a/internal/models/model_totp_configuration.go b/internal/models/model_totp_configuration.go deleted file mode 100644 index decb361d..00000000 --- a/internal/models/model_totp_configuration.go +++ /dev/null @@ -1,11 +0,0 @@ -package models - -// TOTPConfiguration represents a users TOTP configuration row in the database. -type TOTPConfiguration struct { - ID int `db:"id"` - Username string `db:"username"` - Algorithm string `db:"algorithm"` - Digits int `db:"digits"` - Period uint64 `db:"totp_period"` - Secret []byte `db:"secret"` -} diff --git a/internal/storage/types.go b/internal/models/schema_migration.go similarity index 97% rename from internal/storage/types.go rename to internal/models/schema_migration.go index 89a37ac0..011d65cc 100644 --- a/internal/storage/types.go +++ b/internal/models/schema_migration.go @@ -1,4 +1,4 @@ -package storage +package models // SchemaMigration represents an intended migration. type SchemaMigration struct { diff --git a/internal/models/totp_configuration.go b/internal/models/totp_configuration.go new file mode 100644 index 00000000..674bc43e --- /dev/null +++ b/internal/models/totp_configuration.go @@ -0,0 +1,36 @@ +package models + +import ( + "net/url" + "strconv" +) + +// TOTPConfiguration represents a users TOTP configuration row in the database. +type TOTPConfiguration struct { + ID int `db:"id" json:"-"` + Username string `db:"username" json:"-"` + Issuer string `db:"issuer" json:"-"` + Algorithm string `db:"algorithm" json:"-"` + Digits uint `db:"digits" json:"digits"` + Period uint `db:"totp_period" json:"period"` + Secret []byte `db:"secret" json:"-"` +} + +// URI shows the configuration in the URI representation. +func (c TOTPConfiguration) URI() (uri string) { + v := url.Values{} + v.Set("secret", string(c.Secret)) + v.Set("issuer", c.Issuer) + v.Set("period", strconv.FormatUint(uint64(c.Period), 10)) + v.Set("algorithm", c.Algorithm) + v.Set("digits", strconv.Itoa(int(c.Digits))) + + u := url.URL{ + Scheme: "otpauth", + Host: "totp", + Path: "/" + c.Issuer + ":" + c.Username, + RawQuery: v.Encode(), + } + + return u.String() +} diff --git a/internal/models/totp_configuration_test.go b/internal/models/totp_configuration_test.go new file mode 100644 index 00000000..c2390199 --- /dev/null +++ b/internal/models/totp_configuration_test.go @@ -0,0 +1,40 @@ +package models + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +/* + TestShouldOnlyMarshalPeriodAndDigitsAndAbsolutelyNeverSecret. + This test is vital to ensuring the TOTP configuration is marshalled correctly. If encoding/json suddenly changes + upstream and the json tag value of '-' doesn't exclude the field from marshalling then this test will pickup this + issue prior to code being shipped. + + For this reason it's essential that the marshalled object contains all values populated, especially the secret. +*/ +func TestShouldOnlyMarshalPeriodAndDigitsAndAbsolutelyNeverSecret(t *testing.T) { + object := TOTPConfiguration{ + ID: 1, + Username: "john", + Issuer: "Authelia", + Algorithm: "SHA1", + Digits: 6, + Period: 30, + + // DO NOT CHANGE THIS VALUE UNLESS YOU FULLY UNDERSTAND THE COMMENT AT THE TOP OF THIS TEST. + Secret: []byte("ABC123"), + } + + data, err := json.Marshal(object) + assert.NoError(t, err) + + assert.Equal(t, "{\"digits\":6,\"period\":30}", string(data)) + + // DO NOT REMOVE OR CHANGE THESE TESTS UNLESS YOU FULLY UNDERSTAND THE COMMENT AT THE TOP OF THIS TEST. + require.NotContains(t, string(data), "secret") + require.NotContains(t, string(data), "ABC123") +} diff --git a/internal/models/type_startup_check.go b/internal/models/type_startup_check.go deleted file mode 100644 index 76ca09ff..00000000 --- a/internal/models/type_startup_check.go +++ /dev/null @@ -1,6 +0,0 @@ -package models - -// StartupCheck represents a provider that has a startup check. -type StartupCheck interface { - StartupCheck() (err error) -} diff --git a/internal/models/type_ipaddress.go b/internal/models/types.go similarity index 88% rename from internal/models/type_ipaddress.go rename to internal/models/types.go index 09c529f2..f1cba448 100644 --- a/internal/models/type_ipaddress.go +++ b/internal/models/types.go @@ -46,3 +46,8 @@ func (ip *IPAddress) Scan(src interface{}) (err error) { return nil } + +// StartupCheck represents a provider that has a startup check. +type StartupCheck interface { + StartupCheck() (err error) +} diff --git a/internal/models/model_u2f_device.go b/internal/models/u2f_device.go similarity index 100% rename from internal/models/model_u2f_device.go rename to internal/models/u2f_device.go diff --git a/internal/models/model_userinfo.go b/internal/models/user_info.go similarity index 100% rename from internal/models/model_userinfo.go rename to internal/models/user_info.go diff --git a/internal/regulation/regulator_test.go b/internal/regulation/regulator_test.go index 5201ee38..4c763819 100644 --- a/internal/regulation/regulator_test.go +++ b/internal/regulation/regulator_test.go @@ -13,7 +13,6 @@ import ( "github.com/authelia/authelia/v4/internal/mocks" "github.com/authelia/authelia/v4/internal/models" "github.com/authelia/authelia/v4/internal/regulation" - "github.com/authelia/authelia/v4/internal/storage" ) type RegulatorSuite struct { @@ -21,14 +20,14 @@ type RegulatorSuite struct { ctx context.Context ctrl *gomock.Controller - storageMock *storage.MockProvider + storageMock *mocks.MockStorage configuration schema.RegulationConfiguration clock mocks.TestingClock } func (s *RegulatorSuite) SetupTest() { s.ctrl = gomock.NewController(s.T()) - s.storageMock = storage.NewMockProvider(s.ctrl) + s.storageMock = mocks.NewMockStorage(s.ctrl) s.ctx = context.Background() s.configuration = schema.RegulationConfiguration{ diff --git a/internal/server/server.go b/internal/server/server.go index 2e826aeb..52181c9c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -88,6 +88,8 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr middlewares.RequireFirstFactor(handlers.UserInfoGet))) r.POST("/api/user/info/2fa_method", autheliaMiddleware( middlewares.RequireFirstFactor(handlers.MethodPreferencePost))) + r.GET("/api/user/info/totp", autheliaMiddleware( + middlewares.RequireFirstFactor(handlers.UserTOTPGet))) // TOTP related endpoints. r.POST("/api/secondfactor/totp/identity/start", autheliaMiddleware( @@ -95,10 +97,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr r.POST("/api/secondfactor/totp/identity/finish", autheliaMiddleware( middlewares.RequireFirstFactor(handlers.SecondFactorTOTPIdentityFinish))) r.POST("/api/secondfactor/totp", autheliaMiddleware( - middlewares.RequireFirstFactor(handlers.SecondFactorTOTPPost(&handlers.TOTPVerifierImpl{ - Period: uint(configuration.TOTP.Period), - Skew: uint(*configuration.TOTP.Skew), - })))) + middlewares.RequireFirstFactor(handlers.SecondFactorTOTPPost))) // U2F related endpoints. r.POST("/api/secondfactor/u2f/identity/start", autheliaMiddleware( diff --git a/internal/storage/errors.go b/internal/storage/errors.go index 4bae8cd8..75044613 100644 --- a/internal/storage/errors.go +++ b/internal/storage/errors.go @@ -8,8 +8,8 @@ var ( // ErrNoAuthenticationLogs error thrown when no matching authentication logs hve been found in DB. ErrNoAuthenticationLogs = errors.New("no matching authentication logs found") - // ErrNoTOTPSecret error thrown when no TOTP secret has been found in DB. - ErrNoTOTPSecret = errors.New("no TOTP secret registered") + // ErrNoTOTPConfiguration error thrown when no TOTP configuration has been found in DB. + ErrNoTOTPConfiguration = errors.New("no TOTP configuration for user") // ErrNoU2FDeviceHandle error thrown when no U2F device handle has been found in DB. ErrNoU2FDeviceHandle = errors.New("no U2F device handle found") diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go index f7222543..e32307af 100644 --- a/internal/storage/migrations.go +++ b/internal/storage/migrations.go @@ -7,6 +7,8 @@ import ( "sort" "strconv" "strings" + + "github.com/authelia/authelia/v4/internal/models" ) //go:embed migrations/* @@ -44,7 +46,7 @@ func latestMigrationVersion(providerName string) (version int, err error) { return version, nil } -func loadMigration(providerName string, version int, up bool) (migration *SchemaMigration, err error) { +func loadMigration(providerName string, version int, up bool) (migration *models.SchemaMigration, err error) { entries, err := migrationsFS.ReadDir("migrations") if err != nil { return nil, err @@ -83,7 +85,7 @@ func loadMigration(providerName string, version int, up bool) (migration *Schema // loadMigrations scans the migrations fs and loads the appropriate migrations for a given providerName, prior and // target versions. If the target version is -1 this indicates the latest version. If the target version is 0 // this indicates the database zero state. -func loadMigrations(providerName string, prior, target int) (migrations []SchemaMigration, err error) { +func loadMigrations(providerName string, prior, target int) (migrations []models.SchemaMigration, err error) { if prior == target && (prior != -1 || target != -1) { return nil, ErrMigrateCurrentVersionSameAsTarget } @@ -125,7 +127,7 @@ func loadMigrations(providerName string, prior, target int) (migrations []Schema return migrations, nil } -func skipMigration(providerName string, up bool, target, prior int, migration *SchemaMigration) (skip bool) { +func skipMigration(providerName string, up bool, target, prior int, migration *models.SchemaMigration) (skip bool) { if migration.Provider != providerAll && migration.Provider != providerName { // Skip if migration.Provider is not a match. return true @@ -163,21 +165,21 @@ func skipMigration(providerName string, up bool, target, prior int, migration *S return false } -func scanMigration(m string) (migration SchemaMigration, err error) { +func scanMigration(m string) (migration models.SchemaMigration, err error) { result := reMigration.FindStringSubmatch(m) if result == nil || len(result) != 5 { - return SchemaMigration{}, errors.New("invalid migration: could not parse the format") + return models.SchemaMigration{}, errors.New("invalid migration: could not parse the format") } - migration = SchemaMigration{ + migration = models.SchemaMigration{ Name: strings.ReplaceAll(result[2], "_", " "), Provider: result[3], } data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m)) if err != nil { - return SchemaMigration{}, err + return models.SchemaMigration{}, err } migration.Query = string(data) @@ -188,7 +190,7 @@ func scanMigration(m string) (migration SchemaMigration, err error) { case "down": migration.Up = false default: - return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4]) + return models.SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4]) } migration.Version, _ = strconv.Atoi(result[1]) @@ -197,7 +199,7 @@ func scanMigration(m string) (migration SchemaMigration, err error) { case providerAll, providerSQLite, providerMySQL, providerPostgres: break default: - return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3]) + return models.SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3]) } return migration, nil diff --git a/internal/storage/provider.go b/internal/storage/provider.go index b046dcbb..bd86ac95 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -40,8 +40,8 @@ type Provider interface { SchemaMigrate(ctx context.Context, up bool, version int) (err error) SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error) - SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error) - SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error) + SchemaMigrationsUp(ctx context.Context, version int) (migrations []models.SchemaMigration, err error) + SchemaMigrationsDown(ctx context.Context, version int) (migrations []models.SchemaMigration, err error) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 9f448165..6a983bd7 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -12,19 +12,21 @@ import ( "github.com/sirupsen/logrus" "github.com/authelia/authelia/v4/internal/authentication" + "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/models" ) // NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's. -func NewSQLProvider(name, driverName, dataSourceName, encryptionKey string) (provider SQLProvider) { +func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceName string) (provider SQLProvider) { db, err := sqlx.Open(driverName, dataSourceName) provider = SQLProvider{ db: db, - key: sha256.Sum256([]byte(encryptionKey)), + key: sha256.Sum256([]byte(config.Storage.EncryptionKey)), name: name, driverName: driverName, + config: config, errOpen: err, log: logging.Logger(), @@ -64,10 +66,6 @@ func NewSQLProvider(name, driverName, dataSourceName, encryptionKey string) (pro sqlFmtRenameTable: queryFmtRenameTable, } - key := sha256.Sum256([]byte(encryptionKey)) - - provider.key = key - return provider } @@ -77,6 +75,7 @@ type SQLProvider struct { key [32]byte name string driverName string + config *schema.Configuration errOpen error log *logrus.Logger @@ -251,7 +250,7 @@ func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.T } if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig, - config.Username, config.Algorithm, config.Digits, config.Period, config.Secret); err != nil { + config.Username, config.Issuer, config.Algorithm, config.Digits, config.Period, config.Secret); err != nil { return fmt.Errorf("error upserting TOTP configuration: %w", err) } @@ -273,7 +272,7 @@ func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNoTOTPSecret + return nil, ErrNoTOTPConfiguration } return nil, fmt.Errorf("error selecting TOTP configuration: %w", err) diff --git a/internal/storage/sql_provider_backend_mysql.go b/internal/storage/sql_provider_backend_mysql.go index 8805aea2..dfdff770 100644 --- a/internal/storage/sql_provider_backend_mysql.go +++ b/internal/storage/sql_provider_backend_mysql.go @@ -15,9 +15,9 @@ type MySQLProvider struct { } // NewMySQLProvider a MySQL provider. -func NewMySQLProvider(config schema.MySQLStorageConfiguration, encryptionKey string) (provider *MySQLProvider) { +func NewMySQLProvider(config *schema.Configuration) (provider *MySQLProvider) { provider = &MySQLProvider{ - SQLProvider: NewSQLProvider(providerMySQL, providerMySQL, dataSourceNameMySQL(config), encryptionKey), + SQLProvider: NewSQLProvider(config, providerMySQL, providerMySQL, dataSourceNameMySQL(*config.Storage.MySQL)), } // All providers have differing SELECT existing table statements. diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 8b59367a..b5e47568 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -16,9 +16,9 @@ type PostgreSQLProvider struct { } // NewPostgreSQLProvider a PostgreSQL provider. -func NewPostgreSQLProvider(config schema.PostgreSQLStorageConfiguration, encryptionKey string) (provider *PostgreSQLProvider) { +func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLProvider) { provider = &PostgreSQLProvider{ - SQLProvider: NewSQLProvider(providerPostgres, "pgx", dataSourceNamePostgreSQL(config), encryptionKey), + SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dataSourceNamePostgreSQL(*config.Storage.PostgreSQL)), } // All providers have differing SELECT existing table statements. diff --git a/internal/storage/sql_provider_backend_sqlite.go b/internal/storage/sql_provider_backend_sqlite.go index f54309b0..95b39897 100644 --- a/internal/storage/sql_provider_backend_sqlite.go +++ b/internal/storage/sql_provider_backend_sqlite.go @@ -2,6 +2,8 @@ package storage import ( _ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string. + + "github.com/authelia/authelia/v4/internal/configuration/schema" ) // SQLiteProvider is a SQLite3 provider. @@ -10,9 +12,9 @@ type SQLiteProvider struct { } // NewSQLiteProvider constructs a SQLite provider. -func NewSQLiteProvider(path, encryptionKey string) (provider *SQLiteProvider) { +func NewSQLiteProvider(config *schema.Configuration) (provider *SQLiteProvider) { provider = &SQLiteProvider{ - SQLProvider: NewSQLProvider(providerSQLite, "sqlite3", path, encryptionKey), + SQLProvider: NewSQLProvider(config, providerSQLite, "sqlite3", config.Storage.Local.Path), } // All providers have differing SELECT existing table statements. diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index 2ce60164..fbe7a0d4 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -75,12 +75,12 @@ const ( const ( queryFmtSelectTOTPConfiguration = ` - SELECT id, username, algorithm, digits, totp_period, secret + SELECT id, username, issuer, algorithm, digits, totp_period, secret FROM %s WHERE username = ?;` queryFmtSelectTOTPConfigurations = ` - SELECT id, username, algorithm, digits, totp_period, secret + SELECT id, username, issuer, algorithm, digits, totp_period, secret FROM %s LIMIT ? OFFSET ?;` @@ -98,14 +98,14 @@ const ( WHERE username = ?;` queryFmtUpsertTOTPConfiguration = ` - REPLACE INTO %s (username, algorithm, digits, totp_period, secret) - VALUES (?, ?, ?, ?, ?);` + REPLACE INTO %s (username, issuer, algorithm, digits, totp_period, secret) + VALUES (?, ?, ?, ?, ?, ?);` queryFmtPostgresUpsertTOTPConfiguration = ` - INSERT INTO %s (username, algorithm, digits, totp_period, secret) - VALUES ($1, $2, $3, $4, $5) + INSERT INTO %s (username, issuer, algorithm, digits, totp_period, secret) + VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (username) - DO UPDATE SET algorithm = $2, digits = $3, totp_period = $4, secret = $5;` + DO UPDATE SET issuer = $2, algorithm = $3, digits = $4, totp_period = $5, secret = $6;` queryFmtDeleteTOTPConfiguration = ` DELETE FROM %s diff --git a/internal/storage/sql_provider_queries_special.go b/internal/storage/sql_provider_queries_special.go index 370e44b7..a7887c21 100644 --- a/internal/storage/sql_provider_queries_special.go +++ b/internal/storage/sql_provider_queries_special.go @@ -35,7 +35,11 @@ const ( FROM %s ORDER BY username ASC;` - queryFmtPre1InsertTOTPConfiguration = ` + queryFmtPre1To1InsertTOTPConfiguration = ` + INSERT INTO %s (username, issuer, totp_period, secret) + VALUES (?, ?, ?, ?);` + + queryFmt1ToPre1InsertTOTPConfiguration = ` INSERT INTO %s (username, secret) VALUES (?, ?);` diff --git a/internal/storage/sql_provider_schema.go b/internal/storage/sql_provider_schema.go index e0005666..dbe2bc1a 100644 --- a/internal/storage/sql_provider_schema.go +++ b/internal/storage/sql_provider_schema.go @@ -206,7 +206,7 @@ func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after in return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr) } -func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration SchemaMigration) (err error) { +func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration models.SchemaMigration) (err error) { _, err = p.db.ExecContext(ctx, migration.Query) if err != nil { return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err) @@ -231,7 +231,7 @@ func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration SchemaMi return p.schemaMigrateFinalize(ctx, migration) } -func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration SchemaMigration) (err error) { +func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration models.SchemaMigration) (err error) { return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After()) } @@ -247,7 +247,7 @@ func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, } // SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version. -func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error) { +func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []models.SchemaMigration, err error) { current, err := p.SchemaVersion(ctx) if err != nil { return migrations, err @@ -265,7 +265,7 @@ func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migr } // SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version. -func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error) { +func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []models.SchemaMigration, err error) { current, err := p.SchemaVersion(ctx) if err != nil { return migrations, err diff --git a/internal/storage/sql_provider_schema_pre1.go b/internal/storage/sql_provider_schema_pre1.go index 7cddddde..e577d1f7 100644 --- a/internal/storage/sql_provider_schema_pre1.go +++ b/internal/storage/sql_provider_schema_pre1.go @@ -226,7 +226,7 @@ func (p *SQLProvider) schemaMigratePre1To1TOTP(ctx context.Context) (err error) } for _, config := range totpConfigs { - _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tableTOTPConfigurations), config.Username, config.Secret) + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertTOTPConfiguration), tableTOTPConfigurations), config.Username, p.config.TOTP.Issuer, p.config.TOTP.Period, config.Secret) if err != nil { return err } @@ -414,7 +414,7 @@ func (p *SQLProvider) schemaMigrate1ToPre1TOTP(ctx context.Context) (err error) } for _, config := range totpConfigs { - _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tablePre1TOTPSecrets), config.Username, config.Secret) + _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertTOTPConfiguration), tablePre1TOTPSecrets), config.Username, config.Secret) if err != nil { return err } diff --git a/internal/suites/CLI/storage.yml b/internal/suites/CLI/storage.yml index ca6381fa..9b8111b2 100644 --- a/internal/suites/CLI/storage.yml +++ b/internal/suites/CLI/storage.yml @@ -1,6 +1,6 @@ --- storage: - encryption_key: a_cli_encryption_key_which_isnt_secure + encryption_key: a_not_so_secure_encryption_key local: path: /tmp/db.sqlite3 ... diff --git a/internal/suites/Mariadb/configuration.yml b/internal/suites/MariaDB/configuration.yml similarity index 100% rename from internal/suites/Mariadb/configuration.yml rename to internal/suites/MariaDB/configuration.yml diff --git a/internal/suites/MariaDB/docker-compose.yml b/internal/suites/MariaDB/docker-compose.yml new file mode 100644 index 00000000..dbe20d3b --- /dev/null +++ b/internal/suites/MariaDB/docker-compose.yml @@ -0,0 +1,9 @@ +--- +version: '3' +services: + authelia-backend: + volumes: + - './MariaDB/configuration.yml:/config/configuration.yml:ro' + - './MariaDB/users.yml:/config/users.yml' + - './common/ssl:/config/ssl:ro' +... diff --git a/internal/suites/Mariadb/users.yml b/internal/suites/MariaDB/users.yml similarity index 100% rename from internal/suites/Mariadb/users.yml rename to internal/suites/MariaDB/users.yml diff --git a/internal/suites/Mariadb/docker-compose.yml b/internal/suites/Mariadb/docker-compose.yml deleted file mode 100644 index 049b28c3..00000000 --- a/internal/suites/Mariadb/docker-compose.yml +++ /dev/null @@ -1,9 +0,0 @@ ---- -version: '3' -services: - authelia-backend: - volumes: - - './Mariadb/configuration.yml:/config/configuration.yml:ro' - - './Mariadb/users.yml:/config/users.yml' - - './common/ssl:/config/ssl:ro' -... diff --git a/internal/suites/action_totp.go b/internal/suites/action_totp.go index 95e1a72e..07d00736 100644 --- a/internal/suites/action_totp.go +++ b/internal/suites/action_totp.go @@ -30,7 +30,7 @@ func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string { func (rs *RodSession) doEnterOTP(t *testing.T, page *rod.Page, code string) { inputs := rs.WaitElementsLocatedByCSSSelector(t, page, "otp-input input") - for i := 0; i < 6; i++ { + for i := 0; i < len(code); i++ { _ = inputs[i].Input(string(code[i])) } } diff --git a/internal/suites/const.go b/internal/suites/const.go index 94016815..d4c3ce1f 100644 --- a/internal/suites/const.go +++ b/internal/suites/const.go @@ -3,6 +3,8 @@ package suites import ( "fmt" "os" + + "github.com/authelia/authelia/v4/internal/configuration/schema" ) // BaseDomain the base domain. @@ -55,3 +57,18 @@ const ( testUsername = "john" testPassword = "password" ) + +var ( + storageLocalTmpConfig = schema.Configuration{ + TOTP: &schema.TOTPConfiguration{ + Issuer: "Authelia", + Period: 6, + }, + Storage: schema.StorageConfiguration{ + EncryptionKey: "a_not_so_secure_encryption_key", + Local: &schema.LocalStorageConfiguration{ + Path: "/tmp/db.sqlite3", + }, + }, + } +) diff --git a/internal/suites/suite_cli_test.go b/internal/suites/suite_cli_test.go index 1b7a2a52..f287646e 100644 --- a/internal/suites/suite_cli_test.go +++ b/internal/suites/suite_cli_test.go @@ -5,10 +5,9 @@ import ( "fmt" "os" "regexp" + "strconv" "testing" - "github.com/pquerna/otp" - "github.com/pquerna/otp/totp" "github.com/stretchr/testify/suite" "github.com/authelia/authelia/v4/internal/models" @@ -178,6 +177,8 @@ func (s *CLISuite) TestStorageShouldShowErrWithoutConfig() { } func (s *CLISuite) TestStorage00ShouldShowCorrectPreInitInformation() { + _ = os.Remove("/tmp/db.sqlite3") + output, err := s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "schema-info", "--config", "/config/configuration.storage.yml"}) s.Assert().NoError(err) @@ -187,7 +188,7 @@ func (s *CLISuite) TestStorage00ShouldShowCorrectPreInitInformation() { patternOutdated := regexp.MustCompile(`Error: schema is version \d+ which is outdated please migrate to version \d+ in order to use this command or use an older binary`) - output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "export", "totp-configurations", "--config", "/config/configuration.storage.yml"}) + output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "totp", "export", "--config", "/config/configuration.storage.yml"}) s.Assert().EqualError(err, "exit status 1") s.Assert().Regexp(patternOutdated, output) @@ -267,16 +268,14 @@ func (s *CLISuite) TestStorage02ShouldShowSchemaInfo() { } func (s *CLISuite) TestStorage03ShouldExportTOTP() { - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_cli_encryption_key_which_isnt_secure") + storageProvider := storage.NewSQLiteProvider(&storageLocalTmpConfig) - s.Require().NoError(provider.StartupCheck()) + s.Require().NoError(storageProvider.StartupCheck()) ctx := context.Background() var ( - err error - key *otp.Key - config models.TOTPConfiguration + err error ) var ( @@ -287,39 +286,53 @@ func (s *CLISuite) TestStorage03ShouldExportTOTP() { expectedLinesCSV = append(expectedLinesCSV, "issuer,username,algorithm,digits,period,secret") - for _, name := range []string{"john", "mary", "fred"} { - key, err = totp.Generate(totp.GenerateOpts{ - Issuer: "Authelia", - AccountName: name, - Period: uint(30), - SecretSize: 32, - Digits: otp.Digits(6), - Algorithm: otp.AlgorithmSHA1, - }) - s.Require().NoError(err) - - config = models.TOTPConfiguration{ - Username: name, - Algorithm: "SHA1", + configs := []*models.TOTPConfiguration{ + { + Username: "john", + Period: 30, Digits: 6, - Secret: []byte(key.Secret()), - Period: key.Period(), - } - - expectedLinesCSV = append(expectedLinesCSV, fmt.Sprintf("%s,%s,%s,%d,%d,%s", "Authelia", config.Username, config.Algorithm, config.Digits, config.Period, string(config.Secret))) - expectedLines = append(expectedLines, fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=%s&digits=%d&period=%d", "Authelia", config.Username, string(config.Secret), "Authelia", config.Algorithm, config.Digits, config.Period)) - - s.Require().NoError(provider.SaveTOTPConfiguration(ctx, config)) + Algorithm: "SHA1", + }, + { + Username: "mary", + Period: 45, + Digits: 6, + Algorithm: "SHA1", + }, + { + Username: "fred", + Period: 30, + Digits: 8, + Algorithm: "SHA1", + }, + { + Username: "jone", + Period: 30, + Digits: 6, + Algorithm: "SHA512", + }, } - output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "export", "totp-configurations", "--format", "uri", "--config", "/config/configuration.storage.yml"}) + for _, config := range configs { + output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "totp", "generate", config.Username, "--period", strconv.Itoa(int(config.Period)), "--algorithm", config.Algorithm, "--digits", strconv.Itoa(int(config.Digits)), "--config", "/config/configuration.storage.yml"}) + s.Assert().NoError(err) + + config, err = storageProvider.LoadTOTPConfiguration(ctx, config.Username) + s.Assert().NoError(err) + s.Assert().Contains(output, config.URI()) + + expectedLinesCSV = append(expectedLinesCSV, fmt.Sprintf("%s,%s,%s,%d,%d,%s", "Authelia", config.Username, config.Algorithm, config.Digits, config.Period, string(config.Secret))) + expectedLines = append(expectedLines, config.URI()) + } + + output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "totp", "export", "--format", "uri", "--config", "/config/configuration.storage.yml"}) s.Assert().NoError(err) for _, expectedLine := range expectedLines { s.Assert().Contains(output, expectedLine) } - output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "export", "totp-configurations", "--format", "csv", "--config", "/config/configuration.storage.yml"}) + output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "totp", "export", "--format", "csv", "--config", "/config/configuration.storage.yml"}) s.Assert().NoError(err) for _, expectedLine := range expectedLinesCSV { @@ -347,7 +360,7 @@ func (s *CLISuite) TestStorage04ShouldChangeEncryptionKey() { output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "encryption", "check", "--verbose", "--config", "/config/configuration.storage.yml"}) s.Assert().NoError(err) - s.Assert().Contains(output, "Encryption key validation: failed.\n\nError: the encryption key is not valid against the schema check value, 3 of 3 total TOTP secrets were invalid.\n") + s.Assert().Contains(output, "Encryption key validation: failed.\n\nError: the encryption key is not valid against the schema check value, 4 of 4 total TOTP secrets were invalid.\n") output, err = s.Exec("authelia-backend", []string{"authelia", s.testArg, s.coverageArg, "storage", "encryption", "check", "--encryption-key", "apple-apple-apple-apple", "--config", "/config/configuration.storage.yml"}) s.Assert().NoError(err) diff --git a/internal/suites/suite_duo_push_test.go b/internal/suites/suite_duo_push_test.go index 208296f0..961afc00 100644 --- a/internal/suites/suite_duo_push_test.go +++ b/internal/suites/suite_duo_push_test.go @@ -57,7 +57,7 @@ func (s *DuoPushWebDriverSuite) TearDownTest() { }() // Set default 2FA preference and clean up any Duo device already in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferred2FAMethod(ctx, "john", "totp")) require.NoError(s.T(), provider.DeletePreferredDuoDevice(ctx, "john")) } @@ -152,7 +152,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() { defer cancel() // Set default 2FA preference to enable Select Device link in frontend. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "ABCDEFGHIJ1234567890", Method: "push"})) var PreAuthAPIResponse = duo.PreAuthResponse{ @@ -231,7 +231,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectNewDeviceAfterSavedDeviceMethodI } // Setup unsupported Duo device in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "ABCDEFGHIJ1234567890", Method: "sms"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Allow) @@ -257,7 +257,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAutoSelectNewDeviceAfterSavedDeviceIsN } // Setup unsupported Duo device in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "ABCDEFGHIJ1234567890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Allow) @@ -276,7 +276,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailSelectionBecauseOfSelectionBypasse StatusMessage: "Allowing unknown user", } - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Deny) @@ -296,7 +296,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailSelectionBecauseOfSelectionDenied( StatusMessage: "We're sorry, access is not allowed.", } - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Deny) @@ -317,7 +317,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailAuthenticationBecausePreauthDenied StatusMessage: "We're sorry, access is not allowed.", } - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) @@ -344,7 +344,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSucceedAuthentication() { } // Setup Duo device in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Allow) @@ -371,7 +371,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailAuthentication() { } // Setup Duo device in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Deny) @@ -430,7 +430,7 @@ func (s *DuoPushDefaultRedirectionSuite) TestUserIsRedirectedToDefaultURL() { } // Setup Duo device in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Allow) @@ -475,7 +475,7 @@ func (s *DuoPushSuite) TestUserPreferencesScenario() { ctx := context.Background() // Setup Duo device in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.SavePreferredDuoDevice(ctx, models.DuoDevice{Username: "john", Device: "12345ABCDEFGHIJ67890", Method: "push"})) ConfigureDuoPreAuth(s.T(), PreAuthAPIResponse) ConfigureDuo(s.T(), Allow) diff --git a/internal/suites/suite_mariadb.go b/internal/suites/suite_mariadb.go index f445826f..61aa6374 100644 --- a/internal/suites/suite_mariadb.go +++ b/internal/suites/suite_mariadb.go @@ -5,12 +5,12 @@ import ( "time" ) -var mariadbSuiteName = "Mariadb" +var mariadbSuiteName = "MariaDB" func init() { dockerEnvironment := NewDockerEnvironment([]string{ "internal/suites/docker-compose.yml", - "internal/suites/Mariadb/docker-compose.yml", + "internal/suites/MariaDB/docker-compose.yml", "internal/suites/example/compose/authelia/docker-compose.backend.{}.yml", "internal/suites/example/compose/authelia/docker-compose.frontend.{}.yml", "internal/suites/example/compose/nginx/backend/docker-compose.yml", diff --git a/internal/suites/suite_mariadb_test.go b/internal/suites/suite_mariadb_test.go index eb201aa6..1c587cad 100644 --- a/internal/suites/suite_mariadb_test.go +++ b/internal/suites/suite_mariadb_test.go @@ -6,26 +6,26 @@ import ( "github.com/stretchr/testify/suite" ) -type MariadbSuite struct { +type MariaDBSuite struct { *RodSuite } -func NewMariadbSuite() *MariadbSuite { - return &MariadbSuite{RodSuite: new(RodSuite)} +func NewMariaDBSuite() *MariaDBSuite { + return &MariaDBSuite{RodSuite: new(RodSuite)} } -func (s *MariadbSuite) TestOneFactorScenario() { +func (s *MariaDBSuite) TestOneFactorScenario() { suite.Run(s.T(), NewOneFactorScenario()) } -func (s *MariadbSuite) TestTwoFactorScenario() { +func (s *MariaDBSuite) TestTwoFactorScenario() { suite.Run(s.T(), NewTwoFactorScenario()) } -func TestMariadbSuite(t *testing.T) { +func TestMariaDBSuite(t *testing.T) { if testing.Short() { t.Skip("skipping suite test in short mode") } - suite.Run(t, NewMariadbSuite()) + suite.Run(t, NewMariaDBSuite()) } diff --git a/internal/suites/suite_standalone_test.go b/internal/suites/suite_standalone_test.go index 75d3ab35..dc76412d 100644 --- a/internal/suites/suite_standalone_test.go +++ b/internal/suites/suite_standalone_test.go @@ -121,7 +121,7 @@ func (s *StandaloneWebDriverSuite) TestShouldCheckUserIsAskedToRegisterDevice() password := "password" // Clean up any TOTP secret already in DB. - provider := storage.NewSQLiteProvider("/tmp/db.sqlite3", "a_not_so_secure_encryption_key") + provider := storage.NewSQLiteProvider(&storageLocalTmpConfig) require.NoError(s.T(), provider.StartupCheck()) require.NoError(s.T(), provider.DeleteTOTPConfiguration(ctx, username)) diff --git a/internal/totp/helpers.go b/internal/totp/helpers.go new file mode 100644 index 00000000..e464c1d7 --- /dev/null +++ b/internal/totp/helpers.go @@ -0,0 +1,20 @@ +package totp + +import ( + "github.com/pquerna/otp" + + "github.com/authelia/authelia/v4/internal/configuration/schema" +) + +func otpStringToAlgo(in string) (algorithm otp.Algorithm) { + switch in { + case schema.TOTPAlgorithmSHA1: + return otp.AlgorithmSHA1 + case schema.TOTPAlgorithmSHA256: + return otp.AlgorithmSHA256 + case schema.TOTPAlgorithmSHA512: + return otp.AlgorithmSHA512 + default: + return otp.AlgorithmSHA1 + } +} diff --git a/internal/totp/provider.go b/internal/totp/provider.go new file mode 100644 index 00000000..76f4d2f6 --- /dev/null +++ b/internal/totp/provider.go @@ -0,0 +1,12 @@ +package totp + +import ( + "github.com/authelia/authelia/v4/internal/models" +) + +// Provider for TOTP functionality. +type Provider interface { + Generate(username string) (config *models.TOTPConfiguration, err error) + GenerateCustom(username, algorithm string, digits, period, secretSize uint) (config *models.TOTPConfiguration, err error) + Validate(token string, config *models.TOTPConfiguration) (valid bool, err error) +} diff --git a/internal/totp/totp.go b/internal/totp/totp.go new file mode 100644 index 00000000..6d12e80b --- /dev/null +++ b/internal/totp/totp.go @@ -0,0 +1,78 @@ +package totp + +import ( + "time" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + + "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/models" +) + +// NewTimeBasedProvider creates a new totp.TimeBased which implements the totp.Provider. +func NewTimeBasedProvider(config *schema.TOTPConfiguration) (provider *TimeBased) { + provider = &TimeBased{ + config: config, + } + + if config.Skew != nil { + provider.skew = *config.Skew + } else { + provider.skew = 1 + } + + return provider +} + +// TimeBased totp.Provider for production use. +type TimeBased struct { + config *schema.TOTPConfiguration + skew uint +} + +// GenerateCustom generates a TOTP with custom options. +func (p TimeBased) GenerateCustom(username, algorithm string, digits, period, secretSize uint) (config *models.TOTPConfiguration, err error) { + var key *otp.Key + + opts := totp.GenerateOpts{ + Issuer: p.config.Issuer, + AccountName: username, + Period: period, + SecretSize: secretSize, + Digits: otp.Digits(digits), + Algorithm: otpStringToAlgo(algorithm), + } + + if key, err = totp.Generate(opts); err != nil { + return nil, err + } + + config = &models.TOTPConfiguration{ + Username: username, + Issuer: p.config.Issuer, + Algorithm: algorithm, + Digits: digits, + Secret: []byte(key.Secret()), + Period: period, + } + + return config, nil +} + +// Generate generates a TOTP with default options. +func (p TimeBased) Generate(username string) (config *models.TOTPConfiguration, err error) { + return p.GenerateCustom(username, p.config.Algorithm, p.config.Digits, p.config.Period, 32) +} + +// Validate the token against the given configuration. +func (p TimeBased) Validate(token string, config *models.TOTPConfiguration) (valid bool, err error) { + opts := totp.ValidateOpts{ + Period: config.Period, + Skew: p.skew, + Digits: otp.Digits(config.Digits), + Algorithm: otpStringToAlgo(config.Algorithm), + } + + return totp.ValidateCustom(token, string(config.Secret), time.Now().UTC(), opts) +} diff --git a/web/.commitlintrc.js b/web/.commitlintrc.js index a2897b48..16eedc31 100644 --- a/web/.commitlintrc.js +++ b/web/.commitlintrc.js @@ -45,6 +45,7 @@ module.exports = { "storage", "suites", "templates", + "totp", "utils", "web", ], diff --git a/web/src/hooks/UserInfo.ts b/web/src/hooks/UserInfo.ts index b59ede4d..88eb258a 100644 --- a/web/src/hooks/UserInfo.ts +++ b/web/src/hooks/UserInfo.ts @@ -1,6 +1,6 @@ import { useRemoteCall } from "@hooks/RemoteCall"; -import { getUserPreferences } from "@services/UserPreferences"; +import { getUserInfo } from "@services/UserInfo"; -export function useUserPreferences() { - return useRemoteCall(getUserPreferences, []); +export function useUserInfo() { + return useRemoteCall(getUserInfo, []); } diff --git a/web/src/hooks/UserInfoTOTPConfiguration.ts b/web/src/hooks/UserInfoTOTPConfiguration.ts new file mode 100644 index 00000000..dba45ab0 --- /dev/null +++ b/web/src/hooks/UserInfoTOTPConfiguration.ts @@ -0,0 +1,6 @@ +import { useRemoteCall } from "@hooks/RemoteCall"; +import { getUserInfoTOTPConfiguration } from "@services/UserInfoTOTPConfiguration"; + +export function useUserInfoTOTPConfiguration() { + return useRemoteCall(getUserInfoTOTPConfiguration, []); +} diff --git a/web/src/models/Configuration.ts b/web/src/models/Configuration.ts index fa0c2751..da9f4cf8 100644 --- a/web/src/models/Configuration.ts +++ b/web/src/models/Configuration.ts @@ -3,5 +3,4 @@ import { SecondFactorMethod } from "@models/Methods"; export interface Configuration { available_methods: Set; second_factor_enabled: boolean; - totp_period: number; } diff --git a/web/src/models/UserInfoTOTPConfiguration.ts b/web/src/models/UserInfoTOTPConfiguration.ts new file mode 100644 index 00000000..adbcea5e --- /dev/null +++ b/web/src/models/UserInfoTOTPConfiguration.ts @@ -0,0 +1,4 @@ +export interface UserInfoTOTPConfiguration { + period: number; + digits: number; +} diff --git a/web/src/services/Api.ts b/web/src/services/Api.ts index e315eb91..53e1c091 100644 --- a/web/src/services/Api.ts +++ b/web/src/services/Api.ts @@ -34,6 +34,7 @@ export const LogoutPath = basePath + "/api/logout"; export const StatePath = basePath + "/api/state"; export const UserInfoPath = basePath + "/api/user/info"; export const UserInfo2FAMethodPath = basePath + "/api/user/info/2fa_method"; +export const UserInfoTOTPConfigurationPath = basePath + "/api/user/info/totp"; export const ConfigurationPath = basePath + "/api/configuration"; diff --git a/web/src/services/Configuration.ts b/web/src/services/Configuration.ts index 0bbb4349..7e0c2468 100644 --- a/web/src/services/Configuration.ts +++ b/web/src/services/Configuration.ts @@ -1,12 +1,11 @@ import { Configuration } from "@models/Configuration"; import { ConfigurationPath } from "@services/Api"; import { Get } from "@services/Client"; -import { toEnum, Method2FA } from "@services/UserPreferences"; +import { toEnum, Method2FA } from "@services/UserInfo"; interface ConfigurationPayload { available_methods: Method2FA[]; second_factor_enabled: boolean; - totp_period: number; } export async function getConfiguration(): Promise { diff --git a/web/src/services/UserPreferences.ts b/web/src/services/UserInfo.ts similarity index 95% rename from web/src/services/UserPreferences.ts rename to web/src/services/UserInfo.ts index 0fe7a49c..bccc6a86 100644 --- a/web/src/services/UserPreferences.ts +++ b/web/src/services/UserInfo.ts @@ -39,7 +39,7 @@ export function toString(method: SecondFactorMethod): Method2FA { } } -export async function getUserPreferences(): Promise { +export async function getUserInfo(): Promise { const res = await Get(UserInfoPath); return { ...res, method: toEnum(res.method) }; } diff --git a/web/src/services/UserInfoTOTPConfiguration.ts b/web/src/services/UserInfoTOTPConfiguration.ts new file mode 100644 index 00000000..f32a431d --- /dev/null +++ b/web/src/services/UserInfoTOTPConfiguration.ts @@ -0,0 +1,15 @@ +import { UserInfoTOTPConfiguration } from "@models/UserInfoTOTPConfiguration"; +import { UserInfoTOTPConfigurationPath } from "@services/Api"; +import { Get } from "@services/Client"; + +export type TOTPDigits = 6 | 8; + +export interface UserInfoTOTPConfigurationPayload { + period: number; + digits: TOTPDigits; +} + +export async function getUserInfoTOTPConfiguration(): Promise { + const res = await Get(UserInfoTOTPConfigurationPath); + return { ...res }; +} diff --git a/web/src/views/LoginPortal/LoginPortal.tsx b/web/src/views/LoginPortal/LoginPortal.tsx index 29676965..a0c46e20 100644 --- a/web/src/views/LoginPortal/LoginPortal.tsx +++ b/web/src/views/LoginPortal/LoginPortal.tsx @@ -16,7 +16,7 @@ import { useRedirectionURL } from "@hooks/RedirectionURL"; import { useRedirector } from "@hooks/Redirector"; import { useRequestMethod } from "@hooks/RequestMethod"; import { useAutheliaState } from "@hooks/State"; -import { useUserPreferences as userUserInfo } from "@hooks/UserInfo"; +import { useUserInfo } from "@hooks/UserInfo"; import { SecondFactorMethod } from "@models/Methods"; import { checkSafeRedirection } from "@services/SafeRedirection"; import { AuthenticationLevel } from "@services/State"; @@ -44,7 +44,7 @@ const LoginPortal = function (props: Props) { const redirector = useRedirector(); const [state, fetchState, , fetchStateError] = useAutheliaState(); - const [userInfo, fetchUserInfo, , fetchUserInfoError] = userUserInfo(); + const [userInfo, fetchUserInfo, , fetchUserInfoError] = useUserInfo(); const [configuration, fetchConfiguration, , fetchConfigurationError] = useConfiguration(); const redirect = useCallback((url: string) => navigate(url), [navigate]); diff --git a/web/src/views/LoginPortal/SecondFactor/IconWithContext.tsx b/web/src/views/LoginPortal/SecondFactor/IconWithContext.tsx index 8ada92ed..c7d0d6c3 100644 --- a/web/src/views/LoginPortal/SecondFactor/IconWithContext.tsx +++ b/web/src/views/LoginPortal/SecondFactor/IconWithContext.tsx @@ -5,7 +5,7 @@ import classnames from "classnames"; interface IconWithContextProps { icon: ReactNode; - context: ReactNode; + children: ReactNode; className?: string; } @@ -33,7 +33,7 @@ const IconWithContext = function (props: IconWithContextProps) {
{props.icon}
-
{props.context}
+
{props.children}
); }; diff --git a/web/src/views/LoginPortal/SecondFactor/OTPDial.tsx b/web/src/views/LoginPortal/SecondFactor/OTPDial.tsx index 73c9d0e5..5c6486d8 100644 --- a/web/src/views/LoginPortal/SecondFactor/OTPDial.tsx +++ b/web/src/views/LoginPortal/SecondFactor/OTPDial.tsx @@ -12,6 +12,8 @@ import { State } from "@views/LoginPortal/SecondFactor/OneTimePasswordMethod"; export interface Props { passcode: string; state: State; + + digits: number; period: number; onChange: (passcode: string) => void; @@ -19,22 +21,23 @@ export interface Props { const OTPDial = function (props: Props) { const style = useStyles(); - const dial = ( - - - - ); - return } context={dial} />; + return ( + }> + + + + + ); }; export default OTPDial; diff --git a/web/src/views/LoginPortal/SecondFactor/OneTimePasswordMethod.tsx b/web/src/views/LoginPortal/SecondFactor/OneTimePasswordMethod.tsx index 8640d693..f49b213e 100644 --- a/web/src/views/LoginPortal/SecondFactor/OneTimePasswordMethod.tsx +++ b/web/src/views/LoginPortal/SecondFactor/OneTimePasswordMethod.tsx @@ -1,8 +1,10 @@ import React, { useCallback, useEffect, useRef, useState } from "react"; import { useRedirectionURL } from "@hooks/RedirectionURL"; +import { useUserInfoTOTPConfiguration } from "@hooks/UserInfoTOTPConfiguration"; import { completeTOTPSignIn } from "@services/OneTimePassword"; import { AuthenticationLevel } from "@services/State"; +import LoadingPage from "@views/LoadingPage/LoadingPage"; import MethodContainer, { State as MethodContainerState } from "@views/LoginPortal/SecondFactor/MethodContainer"; import OTPDial from "@views/LoginPortal/SecondFactor/OTPDial"; @@ -17,7 +19,6 @@ export interface Props { id: string; authenticationLevel: AuthenticationLevel; registered: boolean; - totp_period: number; onRegisterClick: () => void; onSignInError: (err: Error) => void; @@ -35,6 +36,20 @@ const OneTimePasswordMethod = function (props: Props) { const onSignInErrorCallback = useRef(onSignInError).current; const onSignInSuccessCallback = useRef(onSignInSuccess).current; + const [resp, fetch, , err] = useUserInfoTOTPConfiguration(); + + useEffect(() => { + if (err) { + console.error(err); + onSignInErrorCallback(new Error("Could not obtain user settings")); + setState(State.Failure); + } + }, [onSignInErrorCallback, err]); + + useEffect(() => { + fetch(); + }, [fetch]); + const signInFunc = useCallback(async () => { if (!props.registered || props.authenticationLevel === AuthenticationLevel.TwoFactor) { return; @@ -42,7 +57,7 @@ const OneTimePasswordMethod = function (props: Props) { const passcodeStr = `${passcode}`; - if (!passcode || passcodeStr.length !== 6) { + if (!passcode || passcodeStr.length !== (resp?.digits || 6)) { return; } @@ -62,6 +77,7 @@ const OneTimePasswordMethod = function (props: Props) { onSignInSuccessCallback, passcode, redirectionURL, + resp, props.authenticationLevel, props.registered, ]); @@ -94,7 +110,19 @@ const OneTimePasswordMethod = function (props: Props) { state={methodState} onRegisterClick={props.onRegisterClick} > - +
+ {resp !== undefined || err !== undefined ? ( + + ) : ( + + )} +
); }; diff --git a/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx b/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx index 5d9d7e00..97c266da 100644 --- a/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx +++ b/web/src/views/LoginPortal/SecondFactor/SecondFactorForm.tsx @@ -17,7 +17,7 @@ import { SecondFactorMethod } from "@models/Methods"; import { UserInfo } from "@models/UserInfo"; import { initiateTOTPRegistrationProcess, initiateU2FRegistrationProcess } from "@services/RegisterDevice"; import { AuthenticationLevel } from "@services/State"; -import { setPreferred2FAMethod } from "@services/UserPreferences"; +import { setPreferred2FAMethod } from "@services/UserInfo"; import MethodSelectionDialog from "@views/LoginPortal/SecondFactor/MethodSelectionDialog"; import OneTimePasswordMethod from "@views/LoginPortal/SecondFactor/OneTimePasswordMethod"; import PushNotificationMethod from "@views/LoginPortal/SecondFactor/PushNotificationMethod"; @@ -116,7 +116,6 @@ const SecondFactorForm = function (props: Props) { authenticationLevel={props.authenticationLevel} // Whether the user has a TOTP secret registered already registered={props.userInfo.has_totp} - totp_period={props.configuration.totp_period} onRegisterClick={initiateRegistration(initiateTOTPRegistrationProcess)} onSignInError={(err) => createErrorNotification(err.message)} onSignInSuccess={props.onAuthenticationSuccess} diff --git a/web/src/views/LoginPortal/SecondFactor/SecurityKeyMethod.tsx b/web/src/views/LoginPortal/SecondFactor/SecurityKeyMethod.tsx index 44710206..ab61a304 100644 --- a/web/src/views/LoginPortal/SecondFactor/SecurityKeyMethod.tsx +++ b/web/src/views/LoginPortal/SecondFactor/SecurityKeyMethod.tsx @@ -143,21 +143,18 @@ function Icon(props: IconProps) { const touch = ( } - context={} className={state === State.WaitTouch ? undefined : "hidden"} - /> + > + + ); const failure = ( - } - context={ - - } - className={state === State.Failure ? undefined : "hidden"} - /> + } className={state === State.Failure ? undefined : "hidden"}> + + ); return (