From 6276883f04c25af02339432bef4349f3904f19b2 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Wed, 2 Mar 2022 17:40:26 +1100 Subject: [PATCH] refactor(configuration): utilize time duration decode hook (#2938) This enhances the existing time.Duration parser to allow multiple units, and implements a decode hook which can be used by koanf to decode string/integers into time.Durations as applicable. --- docs/configuration/index.md | 55 ++-- internal/commands/helpers.go | 5 +- internal/configuration/decode_hooks.go | 67 +++++ internal/configuration/decode_hooks_test.go | 270 ++++++++++++++++++ internal/configuration/provider.go | 2 +- .../configuration/schema/configuration.go | 4 +- internal/configuration/schema/ntp.go | 16 +- internal/configuration/schema/regulation.go | 14 +- internal/configuration/schema/session.go | 27 +- internal/configuration/validator/const.go | 6 +- internal/configuration/validator/ntp.go | 14 +- internal/configuration/validator/ntp_test.go | 14 +- .../configuration/validator/regulation.go | 23 +- .../validator/regulation_test.go | 20 +- internal/configuration/validator/session.go | 12 +- .../configuration/validator/session_test.go | 27 +- internal/handlers/const.go | 4 +- internal/handlers/handler_verify_test.go | 2 +- internal/ntp/ntp.go | 5 +- internal/ntp/ntp_test.go | 12 +- internal/regulation/regulator.go | 47 +-- internal/regulation/regulator_test.go | 46 +-- internal/regulation/types.go | 11 +- internal/server/server.go | 2 +- internal/session/const.go | 6 +- internal/session/provider.go | 34 +-- internal/session/provider_config.go | 84 +++--- internal/utils/const.go | 20 +- internal/utils/time.go | 98 ++++--- internal/utils/time_test.go | 82 ++++-- 30 files changed, 686 insertions(+), 343 deletions(-) create mode 100644 internal/configuration/decode_hooks_test.go diff --git a/docs/configuration/index.md b/docs/configuration/index.md index 1fb0d713..a49bca77 100644 --- a/docs/configuration/index.md +++ b/docs/configuration/index.md @@ -100,27 +100,46 @@ $ authelia validate-config --config configuration.yml # Duration Notation Format -We have implemented a string based notation for configuration options that take a duration. This section describes its -usage. You can use this implementation in: session for expiration, inactivity, and remember_me_duration; and regulation -for ban_time, and find_time. This notation also supports just providing the number of seconds instead. +We have implemented a string/integer based notation for configuration options that take a duration of time. This section +describes the implementation of this. You can use this implementation in various areas of configuration such as: -The notation is comprised of a number which must be positive and not have leading zeros, followed by a letter -denoting the unit of time measurement. The table below describes the units of time and the associated letter. +- session: + - expiration + - inactivity + - remember_me_duration +- regulation: + - ban_time + - find_time +- ntp: + - max_desync +- webauthn: + - timeout -|Unit |Associated Letter| -|:-----:|:---------------:| -|Years |y | -|Months |M | -|Weeks |w | -|Days |d | -|Hours |h | -|Minutes|m | -|Seconds|s | +The way this format works is you can either configure an integer or a string in the specific configuration areas. If you +supply an integer, it is considered a representation of seconds. If you supply a string, it parses the string in blocks +of quantities and units (number followed by a unit letter). For example `5h` indicates a quantity of 5 units of `h`. -Examples: -* 1 hour and 30 minutes: 90m -* 1 day: 1d -* 10 hours: 10h +While you can use multiple of these blocks in combination, ee suggest keeping it simple and use a single value. + +## Duration Notation Format Unit Legend + +| Unit | Associated Letter | +|:-------:|:-----------------:| +| Years | y | +| Months | M | +| Weeks | w | +| Days | d | +| Hours | h | +| Minutes | m | +| Seconds | s | + +## Duration Notation Format Examples + +| Desired Value | Configuration Examples | +|:---------------------:|:-------------------------------------:| +| 1 hour and 30 minutes | `90m` or `1h30m` or `5400` or `5400s` | +| 1 day | `1d` or `24h` or `86400` or `86400s` | +| 10 hours | `10h` or `600m` or `9h60m` or `36000` | # TLS Configuration diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go index 4c68c2ef..24424ade 100644 --- a/internal/commands/helpers.go +++ b/internal/commands/helpers.go @@ -57,10 +57,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [ notifier = notification.NewFileNotifier(*config.Notifier.FileSystem) } - var ntpProvider *ntp.Provider - if config.NTP != nil { - ntpProvider = ntp.NewProvider(config.NTP) - } + ntpProvider := ntp.NewProvider(&config.NTP) clock := utils.RealClock{} authorizer := authorization.NewAuthorizer(config) diff --git a/internal/configuration/decode_hooks.go b/internal/configuration/decode_hooks.go index 513f7fd3..cc9aac69 100644 --- a/internal/configuration/decode_hooks.go +++ b/internal/configuration/decode_hooks.go @@ -4,8 +4,11 @@ import ( "fmt" "net/mail" "reflect" + "time" "github.com/mitchellh/mapstructure" + + "github.com/authelia/authelia/v4/internal/utils" ) // StringToMailAddressFunc decodes a string into a mail.Address. @@ -33,3 +36,67 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc { return *mailAddress, nil } } + +// ToTimeDurationFunc converts string and integer types to a time.Duration. +func ToTimeDurationFunc() mapstructure.DecodeHookFuncType { + return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) { + var ( + ptr bool + ) + + switch f.Kind() { + case reflect.String, reflect.Int, reflect.Int32, reflect.Int64: + // We only allow string and integer from kinds to match. + break + default: + return data, nil + } + + typeTimeDuration := reflect.TypeOf(time.Hour) + + if t.Kind() == reflect.Ptr { + if t.Elem() != typeTimeDuration { + return data, nil + } + + ptr = true + } else if t != typeTimeDuration { + return data, nil + } + + var duration time.Duration + + switch { + case f.Kind() == reflect.String: + break + case f.Kind() == reflect.Int: + seconds := data.(int) + + duration = time.Second * time.Duration(seconds) + case f.Kind() == reflect.Int32: + seconds := data.(int32) + + duration = time.Second * time.Duration(seconds) + case f == typeTimeDuration: + duration = data.(time.Duration) + case f.Kind() == reflect.Int64: + seconds := data.(int64) + + duration = time.Second * time.Duration(seconds) + } + + if duration == 0 { + dataStr := data.(string) + + if duration, err = utils.ParseDurationString(dataStr); err != nil { + return nil, err + } + } + + if ptr { + return &duration, nil + } + + return duration, nil + } +} diff --git a/internal/configuration/decode_hooks_test.go b/internal/configuration/decode_hooks_test.go new file mode 100644 index 00000000..2b038cdb --- /dev/null +++ b/internal/configuration/decode_hooks_test.go @@ -0,0 +1,270 @@ +package configuration + +import ( + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = "1h" + expected = time.Hour + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = "1y" + expected = time.Hour * 24 * 365 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = "1M" + expected = time.Hour * 24 * 30 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = "1w" + expected = time.Hour * 24 * 7 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = "1d" + expected = time.Hour * 24 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = "abc" + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.EqualError(t, err, "could not parse 'abc' as a duration") + assert.Nil(t, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.EqualError(t, err, "could not parse 'abc' as a duration") + assert.Nil(t, result) +} + +func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = 60 + expected = time.Second * 60 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = int32(120) + expected = time.Second * 120 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = int64(30) + expected = time.Second * 30 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = time.Second * 30 + expected = time.Second * 30 + + to time.Duration + ptrTo *time.Duration + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, &expected, result) +} + +func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = int64(30) + + to string + ptrTo *string + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, from, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, from, result) +} + +func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) { + hook := ToTimeDurationFunc() + + var ( + from = true + + to string + ptrTo *string + result interface{} + err error + ) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from) + assert.NoError(t, err) + assert.Equal(t, from, result) + + result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from) + assert.NoError(t, err) + assert.Equal(t, from, result) +} diff --git a/internal/configuration/provider.go b/internal/configuration/provider.go index ba7243ee..70f8baf3 100644 --- a/internal/configuration/provider.go +++ b/internal/configuration/provider.go @@ -43,9 +43,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte c := koanf.UnmarshalConf{ DecoderConfig: &mapstructure.DecoderConfig{ DecodeHook: mapstructure.ComposeDecodeHookFunc( - mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), StringToMailAddressFunc(), + ToTimeDurationFunc(), ), Metadata: nil, Result: o, diff --git a/internal/configuration/schema/configuration.go b/internal/configuration/schema/configuration.go index 78a1dff1..b79d2a34 100644 --- a/internal/configuration/schema/configuration.go +++ b/internal/configuration/schema/configuration.go @@ -14,8 +14,8 @@ type Configuration struct { TOTP *TOTPConfiguration `koanf:"totp"` DuoAPI *DuoAPIConfiguration `koanf:"duo_api"` AccessControl AccessControlConfiguration `koanf:"access_control"` - NTP *NTPConfiguration `koanf:"ntp"` - Regulation *RegulationConfiguration `koanf:"regulation"` + NTP NTPConfiguration `koanf:"ntp"` + Regulation RegulationConfiguration `koanf:"regulation"` Storage StorageConfiguration `koanf:"storage"` Notifier *NotifierConfiguration `koanf:"notifier"` Server ServerConfiguration `koanf:"server"` diff --git a/internal/configuration/schema/ntp.go b/internal/configuration/schema/ntp.go index 5ea9b67e..a9fdd3d2 100644 --- a/internal/configuration/schema/ntp.go +++ b/internal/configuration/schema/ntp.go @@ -1,17 +1,21 @@ package schema +import ( + "time" +) + // NTPConfiguration represents the configuration related to ntp server. type NTPConfiguration struct { - Address string `koanf:"address"` - Version int `koanf:"version"` - MaximumDesync string `koanf:"max_desync"` - DisableStartupCheck bool `koanf:"disable_startup_check"` - DisableFailure bool `koanf:"disable_failure"` + Address string `koanf:"address"` + Version int `koanf:"version"` + MaximumDesync time.Duration `koanf:"max_desync"` + DisableStartupCheck bool `koanf:"disable_startup_check"` + DisableFailure bool `koanf:"disable_failure"` } // DefaultNTPConfiguration represents default configuration parameters for the NTP server. var DefaultNTPConfiguration = NTPConfiguration{ Address: "time.cloudflare.com:123", Version: 4, - MaximumDesync: "3s", + MaximumDesync: time.Second * 3, } diff --git a/internal/configuration/schema/regulation.go b/internal/configuration/schema/regulation.go index 61002736..a5cf517e 100644 --- a/internal/configuration/schema/regulation.go +++ b/internal/configuration/schema/regulation.go @@ -1,15 +1,19 @@ package schema +import ( + "time" +) + // RegulationConfiguration represents the configuration related to regulation. type RegulationConfiguration struct { - MaxRetries int `koanf:"max_retries"` - FindTime string `koanf:"find_time,weak"` - BanTime string `koanf:"ban_time,weak"` + MaxRetries int `koanf:"max_retries"` + FindTime time.Duration `koanf:"find_time,weak"` + BanTime time.Duration `koanf:"ban_time,weak"` } // DefaultRegulationConfiguration represents default configuration parameters for the regulator. var DefaultRegulationConfiguration = RegulationConfiguration{ MaxRetries: 3, - FindTime: "2m", - BanTime: "5m", + FindTime: time.Minute * 2, + BanTime: time.Minute * 5, } diff --git a/internal/configuration/schema/session.go b/internal/configuration/schema/session.go index e82004ee..6b55e8df 100644 --- a/internal/configuration/schema/session.go +++ b/internal/configuration/schema/session.go @@ -1,5 +1,9 @@ package schema +import ( + "time" +) + // RedisNode Represents a Node. type RedisNode struct { Host string `koanf:"host"` @@ -31,21 +35,22 @@ type RedisSessionConfiguration struct { // SessionConfiguration represents the configuration related to user sessions. type SessionConfiguration struct { - Name string `koanf:"name"` - Domain string `koanf:"domain"` - SameSite string `koanf:"same_site"` - Secret string `koanf:"secret"` - Expiration string `koanf:"expiration"` - Inactivity string `koanf:"inactivity"` - RememberMeDuration string `koanf:"remember_me_duration"` - Redis *RedisSessionConfiguration `koanf:"redis"` + Name string `koanf:"name"` + Domain string `koanf:"domain"` + SameSite string `koanf:"same_site"` + Secret string `koanf:"secret"` + Expiration time.Duration `koanf:"expiration"` + Inactivity time.Duration `koanf:"inactivity"` + RememberMeDuration time.Duration `koanf:"remember_me_duration"` + + Redis *RedisSessionConfiguration `koanf:"redis"` } // DefaultSessionConfiguration is the default session configuration. var DefaultSessionConfiguration = SessionConfiguration{ Name: "authelia_session", - Expiration: "1h", - Inactivity: "5m", - RememberMeDuration: "1M", + Expiration: time.Hour, + Inactivity: time.Minute * 5, + RememberMeDuration: time.Hour * 24 * 30, SameSite: "lax", } diff --git a/internal/configuration/validator/const.go b/internal/configuration/validator/const.go index 812bbab5..2d9c9f58 100644 --- a/internal/configuration/validator/const.go +++ b/internal/configuration/validator/const.go @@ -35,7 +35,6 @@ const ( // Test constants. const ( - testBadTimer = "-1" testInvalidPolicy = "invalid" testJWTSecret = "a_secret" testLDAPBaseDN = "base_dn" @@ -185,13 +184,11 @@ const ( // NTP Error constants. const ( - errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'" - errFmtNTPMaxDesync = "ntp: option 'max_desync' can't be parsed: %w" + errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'" ) // Session error constants. const ( - errFmtSessionCouldNotParseDuration = "session: option '%s' could not be parsed: %w" errFmtSessionOptionRequired = "session: option '%s' is required" errFmtSessionDomainMustBeRoot = "session: option 'domain' must be the domain you wish to protect not a wildcard domain but it is configured as '%s'" errFmtSessionSameSite = "session: option 'same_site' must be one of '%s' but is configured as '%s'" @@ -206,7 +203,6 @@ const ( // Regulation Error Consts. const ( - errFmtRegulationParseDuration = "regulation: option '%s' could not be parsed: %w" errFmtRegulationFindTimeGreaterThanBanTime = "regulation: option 'find_time' must be less than or equal to option 'ban_time'" ) diff --git a/internal/configuration/validator/ntp.go b/internal/configuration/validator/ntp.go index a99f9e9e..acd2cd0c 100644 --- a/internal/configuration/validator/ntp.go +++ b/internal/configuration/validator/ntp.go @@ -4,17 +4,10 @@ import ( "fmt" "github.com/authelia/authelia/v4/internal/configuration/schema" - "github.com/authelia/authelia/v4/internal/utils" ) // ValidateNTP validates and update NTP configuration. func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator) { - if config.NTP == nil { - config.NTP = &schema.DefaultNTPConfiguration - - return - } - if config.NTP.Address == "" { config.NTP.Address = schema.DefaultNTPConfiguration.Address } @@ -25,12 +18,7 @@ func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator validator.Push(fmt.Errorf(errFmtNTPVersion, config.NTP.Version)) } - if config.NTP.MaximumDesync == "" { + if config.NTP.MaximumDesync == 0 { config.NTP.MaximumDesync = schema.DefaultNTPConfiguration.MaximumDesync } - - _, err := utils.ParseDurationString(config.NTP.MaximumDesync) - if err != nil { - validator.Push(fmt.Errorf(errFmtNTPMaxDesync, err)) - } } diff --git a/internal/configuration/validator/ntp_test.go b/internal/configuration/validator/ntp_test.go index 34fcae22..8855e2a0 100644 --- a/internal/configuration/validator/ntp_test.go +++ b/internal/configuration/validator/ntp_test.go @@ -11,7 +11,7 @@ import ( func newDefaultNTPConfig() schema.Configuration { return schema.Configuration{ - NTP: &schema.NTPConfiguration{}, + NTP: schema.NTPConfiguration{}, } } @@ -55,18 +55,6 @@ func TestShouldSetDefaultNtpDisableStartupCheck(t *testing.T) { assert.Equal(t, schema.DefaultNTPConfiguration.DisableStartupCheck, config.NTP.DisableStartupCheck) } -func TestShouldRaiseErrorOnMaximumDesyncString(t *testing.T) { - validator := schema.NewStructValidator() - config := newDefaultNTPConfig() - config.NTP.MaximumDesync = "a second" - - ValidateNTP(&config, validator) - - require.Len(t, validator.Errors(), 1) - - assert.EqualError(t, validator.Errors()[0], "ntp: option 'max_desync' can't be parsed: could not parse 'a second' as a duration") -} - func TestShouldRaiseErrorOnInvalidNTPVersion(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultNTPConfig() diff --git a/internal/configuration/validator/regulation.go b/internal/configuration/validator/regulation.go index 4f25a607..bde3cf27 100644 --- a/internal/configuration/validator/regulation.go +++ b/internal/configuration/validator/regulation.go @@ -4,36 +4,19 @@ import ( "fmt" "github.com/authelia/authelia/v4/internal/configuration/schema" - "github.com/authelia/authelia/v4/internal/utils" ) // ValidateRegulation validates and update regulator configuration. func ValidateRegulation(config *schema.Configuration, validator *schema.StructValidator) { - if config.Regulation == nil { - config.Regulation = &schema.DefaultRegulationConfiguration - - return - } - - if config.Regulation.FindTime == "" { + if config.Regulation.FindTime == 0 { config.Regulation.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min. } - if config.Regulation.BanTime == "" { + if config.Regulation.BanTime == 0 { config.Regulation.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min. } - findTime, err := utils.ParseDurationString(config.Regulation.FindTime) - if err != nil { - validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "find_time", err)) - } - - banTime, err := utils.ParseDurationString(config.Regulation.BanTime) - if err != nil { - validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "ban_time", err)) - } - - if findTime > banTime { + if config.Regulation.FindTime > config.Regulation.BanTime { validator.Push(fmt.Errorf(errFmtRegulationFindTimeGreaterThanBanTime)) } } diff --git a/internal/configuration/validator/regulation_test.go b/internal/configuration/validator/regulation_test.go index 81893c95..12bd6988 100644 --- a/internal/configuration/validator/regulation_test.go +++ b/internal/configuration/validator/regulation_test.go @@ -2,6 +2,7 @@ package validator import ( "testing" + "time" "github.com/stretchr/testify/assert" @@ -10,7 +11,7 @@ import ( func newDefaultRegulationConfig() schema.Configuration { config := schema.Configuration{ - Regulation: &schema.RegulationConfiguration{}, + Regulation: schema.RegulationConfiguration{}, } return config @@ -39,24 +40,11 @@ func TestShouldSetDefaultRegulationFindTime(t *testing.T) { func TestShouldRaiseErrorWhenFindTimeLessThanBanTime(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultRegulationConfig() - config.Regulation.FindTime = "1m" - config.Regulation.BanTime = "10s" + config.Regulation.FindTime = time.Minute + config.Regulation.BanTime = time.Second * 10 ValidateRegulation(&config, validator) assert.Len(t, validator.Errors(), 1) assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' must be less than or equal to option 'ban_time'") } - -func TestShouldRaiseErrorOnBadDurationStrings(t *testing.T) { - validator := schema.NewStructValidator() - config := newDefaultRegulationConfig() - config.Regulation.FindTime = "a year" - config.Regulation.BanTime = "forever" - - ValidateRegulation(&config, validator) - - assert.Len(t, validator.Errors(), 2) - assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' could not be parsed: could not parse 'a year' as a duration") - assert.EqualError(t, validator.Errors()[1], "regulation: option 'ban_time' could not be parsed: could not parse 'forever' as a duration") -} diff --git a/internal/configuration/validator/session.go b/internal/configuration/validator/session.go index 83a6c39c..7286eb6a 100644 --- a/internal/configuration/validator/session.go +++ b/internal/configuration/validator/session.go @@ -27,22 +27,16 @@ func ValidateSession(config *schema.SessionConfiguration, validator *schema.Stru } func validateSession(config *schema.SessionConfiguration, validator *schema.StructValidator) { - if config.Expiration == "" { + if config.Expiration <= 0 { config.Expiration = schema.DefaultSessionConfiguration.Expiration // 1 hour. - } else if _, err := utils.ParseDurationString(config.Expiration); err != nil { - validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "expiriation", err)) } - if config.Inactivity == "" { + if config.Inactivity <= 0 { config.Inactivity = schema.DefaultSessionConfiguration.Inactivity // 5 min. - } else if _, err := utils.ParseDurationString(config.Inactivity); err != nil { - validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "inactivity", err)) } - if config.RememberMeDuration == "" { + if config.RememberMeDuration <= 0 { config.RememberMeDuration = schema.DefaultSessionConfiguration.RememberMeDuration // 1 month. - } else if _, err := utils.ParseDurationString(config.RememberMeDuration); err != nil { - validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "remember_me_duration", err)) } if config.Domain == "" { diff --git a/internal/configuration/validator/session_test.go b/internal/configuration/validator/session_test.go index 281fbae8..2ceedd57 100644 --- a/internal/configuration/validator/session_test.go +++ b/internal/configuration/validator/session_test.go @@ -420,30 +420,21 @@ func TestShouldNotRaiseErrorWhenSameSiteSetCorrectly(t *testing.T) { } } -func TestShouldRaiseErrorWhenBadInactivityAndExpirationSet(t *testing.T) { +func TestShouldSetDefaultWhenNegativeInactivityAndExpirationSet(t *testing.T) { validator := schema.NewStructValidator() config := newDefaultSessionConfig() - config.Inactivity = testBadTimer - config.Expiration = testBadTimer + config.Inactivity = -1 + config.Expiration = -1 + config.RememberMeDuration = -1 ValidateSession(&config, validator) - assert.False(t, validator.HasWarnings()) - assert.Len(t, validator.Errors(), 2) - assert.EqualError(t, validator.Errors()[0], "session: option 'expiriation' could not be parsed: could not parse '-1' as a duration") - assert.EqualError(t, validator.Errors()[1], "session: option 'inactivity' could not be parsed: could not parse '-1' as a duration") -} + assert.Len(t, validator.Warnings(), 0) + assert.Len(t, validator.Errors(), 0) -func TestShouldRaiseErrorWhenBadRememberMeDurationSet(t *testing.T) { - validator := schema.NewStructValidator() - config := newDefaultSessionConfig() - config.RememberMeDuration = "1 year" - - ValidateSession(&config, validator) - - assert.False(t, validator.HasWarnings()) - assert.Len(t, validator.Errors(), 1) - assert.EqualError(t, validator.Errors()[0], "session: option 'remember_me_duration' could not be parsed: could not parse '1 year' as a duration") + assert.Equal(t, schema.DefaultSessionConfiguration.Inactivity, config.Inactivity) + assert.Equal(t, schema.DefaultSessionConfiguration.Expiration, config.Expiration) + assert.Equal(t, schema.DefaultSessionConfiguration.RememberMeDuration, config.RememberMeDuration) } func TestShouldSetDefaultRememberMeDuration(t *testing.T) { diff --git a/internal/handlers/const.go b/internal/handlers/const.go index ba6e3ce1..fe67231e 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -1,6 +1,8 @@ package handlers import ( + "time" + "github.com/valyala/fasthttp" ) @@ -56,7 +58,7 @@ const ( ) const ( - testInactivity = "10" + testInactivity = time.Second * 10 testRedirectionURL = "http://redirection.local" testUsername = "john" ) diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go index b8f84f21..8f10edb2 100644 --- a/internal/handlers/handler_verify_test.go +++ b/internal/handlers/handler_verify_test.go @@ -602,7 +602,7 @@ func TestShouldDestroySessionWhenInactiveForTooLongUsingDurationNotation(t *test clock := mocks.TestingClock{} clock.Set(time.Now()) - mock.Ctx.Configuration.Session.Inactivity = "10s" + mock.Ctx.Configuration.Session.Inactivity = time.Second * 10 // Reload the session provider since the configuration is indirect. mock.Ctx.Providers.SessionProvider = session.NewProvider(mock.Ctx.Configuration.Session, nil) assert.Equal(t, time.Second*10, mock.Ctx.Providers.SessionProvider.Inactivity) diff --git a/internal/ntp/ntp.go b/internal/ntp/ntp.go index 36b5370b..991eec9c 100644 --- a/internal/ntp/ntp.go +++ b/internal/ntp/ntp.go @@ -8,7 +8,6 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/utils" ) // NewProvider instantiate a ntp provider given a configuration. @@ -59,11 +58,9 @@ func (p *Provider) StartupCheck() (err error) { return nil } - maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync) - ntpTime := ntpPacketToTime(resp) - if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result { + if result := ntpIsOffsetTooLarge(p.config.MaximumDesync, now, ntpTime); result { return errors.New("the system clock is not synchronized accurately enough with the configured NTP server") } diff --git a/internal/ntp/ntp_test.go b/internal/ntp/ntp_test.go index 624d13b5..3528b84c 100644 --- a/internal/ntp/ntp_test.go +++ b/internal/ntp/ntp_test.go @@ -2,6 +2,7 @@ package ntp import ( "testing" + "time" "github.com/stretchr/testify/assert" @@ -11,18 +12,17 @@ import ( func TestShouldCheckNTP(t *testing.T) { config := &schema.Configuration{ - NTP: &schema.NTPConfiguration{ - Address: "time.cloudflare.com:123", - Version: 4, - MaximumDesync: "3s", - DisableStartupCheck: false, + NTP: schema.NTPConfiguration{ + Address: "time.cloudflare.com:123", + Version: 4, + MaximumDesync: time.Second * 3, }, } sv := schema.NewStructValidator() validator.ValidateNTP(config, sv) - ntp := NewProvider(config.NTP) + ntp := NewProvider(&config.NTP) assert.NoError(t, ntp.StartupCheck()) } diff --git a/internal/regulation/regulator.go b/internal/regulation/regulator.go index 6104c150..abd0339e 100644 --- a/internal/regulation/regulator.go +++ b/internal/regulation/regulator.go @@ -2,7 +2,6 @@ package regulation import ( "context" - "fmt" "net" "time" @@ -13,33 +12,13 @@ import ( ) // NewRegulator create a regulator instance. -func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator { - regulator := &Regulator{storageProvider: provider} - regulator.clock = clock - - if configuration != nil { - findTime, err := utils.ParseDurationString(configuration.FindTime) - if err != nil { - panic(err) - } - - banTime, err := utils.ParseDurationString(configuration.BanTime) - if err != nil { - panic(err) - } - - if findTime > banTime { - panic(fmt.Errorf("find_time cannot be greater than ban_time")) - } - - // Set regulator enabled only if MaxRetries is not 0. - regulator.enabled = configuration.MaxRetries > 0 - regulator.maxRetries = configuration.MaxRetries - regulator.findTime = findTime - regulator.banTime = banTime +func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator { + return &Regulator{ + enabled: config.MaxRetries > 0, + storageProvider: provider, + clock: clock, + config: config, } - - return regulator } // Mark an authentication attempt. @@ -65,15 +44,15 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e return time.Time{}, nil } - attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0) + attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0) if err != nil { return time.Time{}, nil } - latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.maxRetries) + latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.config.MaxRetries) for _, attempt := range attempts { - if attempt.Successful || len(latestFailedAttempts) >= r.maxRetries { + if attempt.Successful || len(latestFailedAttempts) >= r.config.MaxRetries { // We stop appending failed attempts once we find the first successful attempts or we reach // the configured number of retries, meaning the user is already banned. break @@ -84,17 +63,17 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e // If the number of failed attempts within the ban time is less than the max number of retries // then the user is not banned. - if len(latestFailedAttempts) < r.maxRetries { + if len(latestFailedAttempts) < r.config.MaxRetries { return time.Time{}, nil } // Now we compute the time between the latest attempt and the MaxRetry-th one. If it's // within the FindTime then it means that the user has been banned. durationBetweenLatestAttempts := latestFailedAttempts[0].Time.Sub( - latestFailedAttempts[r.maxRetries-1].Time) + latestFailedAttempts[r.config.MaxRetries-1].Time) - if durationBetweenLatestAttempts < r.findTime { - bannedUntil := latestFailedAttempts[0].Time.Add(r.banTime) + if durationBetweenLatestAttempts < r.config.FindTime { + bannedUntil := latestFailedAttempts[0].Time.Add(r.config.BanTime) return bannedUntil, ErrUserIsBanned } diff --git a/internal/regulation/regulator_test.go b/internal/regulation/regulator_test.go index a0dbb453..22fc9858 100644 --- a/internal/regulation/regulator_test.go +++ b/internal/regulation/regulator_test.go @@ -18,11 +18,11 @@ import ( type RegulatorSuite struct { suite.Suite - ctx context.Context - ctrl *gomock.Controller - storageMock *mocks.MockStorage - configuration schema.RegulationConfiguration - clock mocks.TestingClock + ctx context.Context + ctrl *gomock.Controller + storageMock *mocks.MockStorage + config schema.RegulationConfiguration + clock mocks.TestingClock } func (s *RegulatorSuite) SetupTest() { @@ -30,10 +30,10 @@ func (s *RegulatorSuite) SetupTest() { s.storageMock = mocks.NewMockStorage(s.ctrl) s.ctx = context.Background() - s.configuration = schema.RegulationConfiguration{ + s.config = schema.RegulationConfiguration{ MaxRetries: 3, - BanTime: "180", - FindTime: "30", + BanTime: time.Second * 180, + FindTime: time.Second * 30, } s.clock.Set(time.Now()) } @@ -55,7 +55,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() { LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) @@ -86,7 +86,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) @@ -122,7 +122,7 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() { LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) @@ -155,7 +155,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() { LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) @@ -179,7 +179,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() { LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) @@ -211,7 +211,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() { LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) @@ -247,7 +247,7 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)). Return(attemptsInDB, nil) - regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) @@ -283,24 +283,24 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() { Return(attemptsInDB, nil) // Check Disabled Functionality. - configuration := schema.RegulationConfiguration{ + config := schema.RegulationConfiguration{ MaxRetries: 0, - FindTime: "180", - BanTime: "180", + FindTime: time.Second * 180, + BanTime: time.Second * 180, } - regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock) + regulator := regulation.NewRegulator(config, s.storageMock, &s.clock) _, err := regulator.Regulate(s.ctx, "john") assert.NoError(s.T(), err) // Check Enabled Functionality. - configuration = schema.RegulationConfiguration{ + config = schema.RegulationConfiguration{ MaxRetries: 1, - FindTime: "180", - BanTime: "180", + FindTime: time.Second * 180, + BanTime: time.Second * 180, } - regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock) + regulator = regulation.NewRegulator(config, s.storageMock, &s.clock) _, err = regulator.Regulate(s.ctx, "john") assert.Equal(s.T(), regulation.ErrUserIsBanned, err) } diff --git a/internal/regulation/types.go b/internal/regulation/types.go index ebaa94ca..510ac222 100644 --- a/internal/regulation/types.go +++ b/internal/regulation/types.go @@ -1,8 +1,7 @@ package regulation import ( - "time" - + "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/utils" ) @@ -11,12 +10,8 @@ import ( type Regulator struct { // Is the regulation enabled. enabled bool - // The number of failed authentication attempt before banning the user. - maxRetries int - // If a user does the max number of retries within that duration, she will be banned. - findTime time.Duration - // If a user has been banned, this duration is the timelapse during which the user is banned. - banTime time.Duration + + config schema.RegulationConfiguration storageProvider storage.RegulatorProvider diff --git a/internal/server/server.go b/internal/server/server.go index cbc53e71..fe4574e5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -28,7 +28,7 @@ var assets embed.FS func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler { autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers) - rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != "0") + rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != -1) resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword) duoSelfEnrollment := f diff --git a/internal/session/const.go b/internal/session/const.go index a4ce709d..23f047f2 100644 --- a/internal/session/const.go +++ b/internal/session/const.go @@ -1,8 +1,12 @@ package session +import ( + "time" +) + const ( testDomain = "example.com" - testExpiration = "40" + testExpiration = time.Second * 40 testName = "my_session" testUsername = "john" ) diff --git a/internal/session/provider.go b/internal/session/provider.go index 50e5c516..88faef30 100644 --- a/internal/session/provider.go +++ b/internal/session/provider.go @@ -12,7 +12,6 @@ import ( "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/utils" ) // Provider a session provider. @@ -23,38 +22,29 @@ type Provider struct { } // NewProvider instantiate a session provider given a configuration. -func NewProvider(configuration schema.SessionConfiguration, certPool *x509.CertPool) *Provider { - providerConfig := NewProviderConfig(configuration, certPool) +func NewProvider(config schema.SessionConfiguration, certPool *x509.CertPool) *Provider { + c := NewProviderConfig(config, certPool) provider := new(Provider) - provider.sessionHolder = fasthttpsession.New(providerConfig.config) + provider.sessionHolder = fasthttpsession.New(c.config) logger := logging.Logger() - duration, err := utils.ParseDurationString(configuration.RememberMeDuration) - if err != nil { - logger.Fatal(err) - } + provider.Inactivity, provider.RememberMe = config.Inactivity, config.RememberMeDuration - provider.RememberMe = duration - - duration, err = utils.ParseDurationString(configuration.Inactivity) - if err != nil { - logger.Fatal(err) - } - - provider.Inactivity = duration - - var providerImpl fasthttpsession.Provider + var ( + providerImpl fasthttpsession.Provider + err error + ) switch { - case providerConfig.redisConfig != nil: - providerImpl, err = redis.New(*providerConfig.redisConfig) + case c.redisConfig != nil: + providerImpl, err = redis.New(*c.redisConfig) if err != nil { logger.Fatal(err) } - case providerConfig.redisSentinelConfig != nil: - providerImpl, err = redis.NewFailoverCluster(*providerConfig.redisSentinelConfig) + case c.redisSentinelConfig != nil: + providerImpl, err = redis.NewFailoverCluster(*c.redisSentinelConfig) if err != nil { logger.Fatal(err) } diff --git a/internal/session/provider_config.go b/internal/session/provider_config.go index 8a2db080..e0241000 100644 --- a/internal/session/provider_config.go +++ b/internal/session/provider_config.go @@ -17,10 +17,10 @@ import ( ) // NewProviderConfig creates a configuration for creating the session provider. -func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig { - config := session.NewDefaultConfig() +func NewProviderConfig(config schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig { + c := session.NewDefaultConfig() - config.SessionIDGeneratorFunc = func() []byte { + c.SessionIDGeneratorFunc = func() []byte { bytes := make([]byte, 32) _, _ = rand.Read(bytes) @@ -33,30 +33,30 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509 } // Override the cookie name. - config.CookieName = configuration.Name + c.CookieName = config.Name // Set the cookie to the given domain. - config.Domain = configuration.Domain + c.Domain = config.Domain // Set the cookie SameSite option. - switch configuration.SameSite { + switch config.SameSite { case "strict": - config.CookieSameSite = fasthttp.CookieSameSiteStrictMode + c.CookieSameSite = fasthttp.CookieSameSiteStrictMode case "none": - config.CookieSameSite = fasthttp.CookieSameSiteNoneMode + c.CookieSameSite = fasthttp.CookieSameSiteNoneMode case "lax": - config.CookieSameSite = fasthttp.CookieSameSiteLaxMode + c.CookieSameSite = fasthttp.CookieSameSiteLaxMode default: - config.CookieSameSite = fasthttp.CookieSameSiteLaxMode + c.CookieSameSite = fasthttp.CookieSameSiteLaxMode } // Only serve the header over HTTPS. - config.Secure = true + c.Secure = true // Ignore the error as it will be handled by validator. - config.Expiration, _ = utils.ParseDurationString(configuration.Expiration) + c.Expiration = config.Expiration - config.IsSecureFunc = func(*fasthttp.RequestCtx) bool { + c.IsSecureFunc = func(*fasthttp.RequestCtx) bool { return true } @@ -68,23 +68,23 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509 // If redis configuration is provided, then use the redis provider. switch { - case configuration.Redis != nil: - serializer := NewEncryptingSerializer(configuration.Secret) + case config.Redis != nil: + serializer := NewEncryptingSerializer(config.Secret) var tlsConfig *tls.Config - if configuration.Redis.TLS != nil { - tlsConfig = utils.NewTLSConfig(configuration.Redis.TLS, tls.VersionTLS12, certPool) + if config.Redis.TLS != nil { + tlsConfig = utils.NewTLSConfig(config.Redis.TLS, tls.VersionTLS12, certPool) } - if configuration.Redis.HighAvailability != nil && configuration.Redis.HighAvailability.SentinelName != "" { + if config.Redis.HighAvailability != nil && config.Redis.HighAvailability.SentinelName != "" { addrs := make([]string, 0) - if configuration.Redis.Host != "" { - addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(configuration.Redis.Host), configuration.Redis.Port)) + if config.Redis.Host != "" { + addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(config.Redis.Host), config.Redis.Port)) } - for _, node := range configuration.Redis.HighAvailability.Nodes { + for _, node := range config.Redis.HighAvailability.Nodes { addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port) if !utils.IsStringInSlice(addr, addrs) { addrs = append(addrs, addr) @@ -94,17 +94,17 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509 providerName = "redis-sentinel" redisSentinelConfig = &redis.FailoverConfig{ Logger: &redisLogger{logger: logging.Logger()}, - MasterName: configuration.Redis.HighAvailability.SentinelName, + MasterName: config.Redis.HighAvailability.SentinelName, SentinelAddrs: addrs, - SentinelUsername: configuration.Redis.HighAvailability.SentinelUsername, - SentinelPassword: configuration.Redis.HighAvailability.SentinelPassword, - RouteByLatency: configuration.Redis.HighAvailability.RouteByLatency, - RouteRandomly: configuration.Redis.HighAvailability.RouteRandomly, - Username: configuration.Redis.Username, - Password: configuration.Redis.Password, - DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index. - PoolSize: configuration.Redis.MaximumActiveConnections, - MinIdleConns: configuration.Redis.MinimumIdleConnections, + SentinelUsername: config.Redis.HighAvailability.SentinelUsername, + SentinelPassword: config.Redis.HighAvailability.SentinelPassword, + RouteByLatency: config.Redis.HighAvailability.RouteByLatency, + RouteRandomly: config.Redis.HighAvailability.RouteRandomly, + Username: config.Redis.Username, + Password: config.Redis.Password, + DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index. + PoolSize: config.Redis.MaximumActiveConnections, + MinIdleConns: config.Redis.MinimumIdleConnections, IdleTimeout: 300, TLSConfig: tlsConfig, KeyPrefix: "authelia-session", @@ -115,36 +115,36 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509 var addr string - if configuration.Redis.Port == 0 { + if config.Redis.Port == 0 { network = "unix" - addr = configuration.Redis.Host + addr = config.Redis.Host } else { - addr = fmt.Sprintf("%s:%d", configuration.Redis.Host, configuration.Redis.Port) + addr = fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port) } redisConfig = &redis.Config{ Logger: newRedisLogger(), Network: network, Addr: addr, - Username: configuration.Redis.Username, - Password: configuration.Redis.Password, - DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index. - PoolSize: configuration.Redis.MaximumActiveConnections, - MinIdleConns: configuration.Redis.MinimumIdleConnections, + Username: config.Redis.Username, + Password: config.Redis.Password, + DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index. + PoolSize: config.Redis.MaximumActiveConnections, + MinIdleConns: config.Redis.MinimumIdleConnections, IdleTimeout: 300, TLSConfig: tlsConfig, KeyPrefix: "authelia-session", } } - config.EncodeFunc = serializer.Encode - config.DecodeFunc = serializer.Decode + c.EncodeFunc = serializer.Encode + c.DecodeFunc = serializer.Decode default: providerName = "memory" } return ProviderConfig{ - config, + c, redisConfig, redisSentinelConfig, providerName, diff --git a/internal/utils/const.go b/internal/utils/const.go index 56f5011f..649a4c63 100644 --- a/internal/utils/const.go +++ b/internal/utils/const.go @@ -53,7 +53,25 @@ const ( ) var ( - reDuration = regexp.MustCompile(`^(?P[1-9]\d*?)(?P[smhdwMy])?$`) + standardDurationUnits = []string{"ns", "us", "µs", "μs", "ms", "s", "m", "h"} + reDurationSeconds = regexp.MustCompile(`^\d+$`) + reDurationStandard = regexp.MustCompile(`(?P[1-9]\d*?)(?P[^\d\s]+)`) +) + +// Duration unit types. +const ( + DurationUnitDays = "d" + DurationUnitWeeks = "w" + DurationUnitMonths = "M" + DurationUnitYears = "y" +) + +// Number of hours in particular measurements of time. +const ( + HoursInDay = 24 + HoursInWeek = HoursInDay * 7 + HoursInMonth = HoursInDay * 30 + HoursInYear = HoursInDay * 365 ) var ( diff --git a/internal/utils/time.go b/internal/utils/time.go index 4d1f7177..51952fcd 100644 --- a/internal/utils/time.go +++ b/internal/utils/time.go @@ -6,46 +6,64 @@ import ( "time" ) -// ParseDurationString parses a string to a duration -// Duration notations are an integer followed by a unit -// Units are s = second, m = minute, d = day, w = week, M = month, y = year -// Example 1y is the same as 1 year. -func ParseDurationString(input string) (time.Duration, error) { - var duration time.Duration - - matches := reDuration.FindStringSubmatch(input) - - switch { - case len(matches) == 3 && matches[2] != "": - d, _ := strconv.Atoi(matches[1]) - - switch matches[2] { - case "y": - duration = time.Duration(d) * Year - case "M": - duration = time.Duration(d) * Month - case "w": - duration = time.Duration(d) * Week - case "d": - duration = time.Duration(d) * Day - case "h": - duration = time.Duration(d) * Hour - case "m": - duration = time.Duration(d) * time.Minute - case "s": - duration = time.Duration(d) * time.Second - } - case input == "0" || len(matches) == 3: - seconds, err := strconv.Atoi(input) - if err != nil { - return 0, fmt.Errorf("could not parse '%s' as a duration: %w", input, err) - } - - duration = time.Duration(seconds) * time.Second - case input != "": - // Throw this error if input is anything other than a blank string, blank string will default to a duration of nothing. - return 0, fmt.Errorf("could not parse '%s' as a duration", input) +// StandardizeDurationString converts units of time that stdlib is unaware of to hours. +func StandardizeDurationString(input string) (output string, err error) { + if input == "" { + return "0s", nil } - return duration, nil + matches := reDurationStandard.FindAllStringSubmatch(input, -1) + + if len(matches) == 0 { + return "", fmt.Errorf("could not parse '%s' as a duration", input) + } + + var d int + + for _, match := range matches { + if d, err = strconv.Atoi(match[1]); err != nil { + return "", fmt.Errorf("could not parse the numeric portion of '%s' in duration string '%s': %w", match[0], input, err) + } + + unit := match[2] + + switch { + case IsStringInSlice(unit, standardDurationUnits): + output += fmt.Sprintf("%d%s", d, unit) + case unit == DurationUnitDays: + output += fmt.Sprintf("%dh", d*HoursInDay) + case unit == DurationUnitWeeks: + output += fmt.Sprintf("%dh", d*HoursInWeek) + case unit == DurationUnitMonths: + output += fmt.Sprintf("%dh", d*HoursInMonth) + case unit == DurationUnitYears: + output += fmt.Sprintf("%dh", d*HoursInYear) + default: + return "", fmt.Errorf("could not parse the units portion of '%s' in duration string '%s': the unit '%s' is not valid", match[0], input, unit) + } + } + + return output, nil +} + +// ParseDurationString standardizes a duration string with StandardizeDurationString then uses time.ParseDuration to +// convert it into a time.Duration. +func ParseDurationString(input string) (duration time.Duration, err error) { + if reDurationSeconds.MatchString(input) { + var seconds int + + if seconds, err = strconv.Atoi(input); err != nil { + return 0, nil + } + + return time.Second * time.Duration(seconds), nil + } + + var out string + + if out, err = StandardizeDurationString(input); err != nil { + return 0, err + } + + return time.ParseDuration(out) } diff --git a/internal/utils/time_test.go b/internal/utils/time_test.go index f2e208c8..f75871ec 100644 --- a/internal/utils/time_test.go +++ b/internal/utils/time_test.go @@ -7,66 +7,112 @@ import ( "github.com/stretchr/testify/assert" ) -func TestShouldParseDurationString(t *testing.T) { +func TestParseDurationString_ShouldParseDurationString(t *testing.T) { duration, err := ParseDurationString("1h") + assert.NoError(t, err) assert.Equal(t, 60*time.Minute, duration) } -func TestShouldParseDurationStringAllUnits(t *testing.T) { - duration, err := ParseDurationString("1y") +func TestParseDurationString_ShouldParseBlankString(t *testing.T) { + duration, err := ParseDurationString("") + assert.NoError(t, err) - assert.Equal(t, Year, duration) + assert.Equal(t, time.Second*0, duration) +} + +func TestParseDurationString_ShouldParseDurationStringAllUnits(t *testing.T) { + duration, err := ParseDurationString("1y") + + assert.NoError(t, err) + assert.Equal(t, time.Hour*24*365, duration) duration, err = ParseDurationString("1M") + assert.NoError(t, err) - assert.Equal(t, Month, duration) + assert.Equal(t, time.Hour*24*30, duration) duration, err = ParseDurationString("1w") + assert.NoError(t, err) - assert.Equal(t, Week, duration) + assert.Equal(t, time.Hour*24*7, duration) duration, err = ParseDurationString("1d") + assert.NoError(t, err) - assert.Equal(t, Day, duration) + assert.Equal(t, time.Hour*24, duration) duration, err = ParseDurationString("1h") + assert.NoError(t, err) - assert.Equal(t, Hour, duration) + assert.Equal(t, time.Hour, duration) duration, err = ParseDurationString("1s") + assert.NoError(t, err) assert.Equal(t, time.Second, duration) } -func TestShouldParseSecondsString(t *testing.T) { +func TestParseDurationString_ShouldParseSecondsString(t *testing.T) { duration, err := ParseDurationString("100") + assert.NoError(t, err) assert.Equal(t, 100*time.Second, duration) } -func TestShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) { +func TestParseDurationString_ShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) { duration, err := ParseDurationString("h1") + assert.EqualError(t, err, "could not parse 'h1' as a duration") assert.Equal(t, time.Duration(0), duration) } -func TestShouldNotParseBadDurationString(t *testing.T) { +func TestParseDurationString_ShouldNotParseBadDurationString(t *testing.T) { duration, err := ParseDurationString("10x") - assert.EqualError(t, err, "could not parse '10x' as a duration") + + assert.EqualError(t, err, "could not parse the units portion of '10x' in duration string '10x': the unit 'x' is not valid") assert.Equal(t, time.Duration(0), duration) } -func TestShouldNotParseDurationStringWithMultiValueUnits(t *testing.T) { +func TestParseDurationString_ShouldParseDurationStringWithMultiValueUnits(t *testing.T) { duration, err := ParseDurationString("10ms") - assert.EqualError(t, err, "could not parse '10ms' as a duration") - assert.Equal(t, time.Duration(0), duration) + + assert.NoError(t, err) + assert.Equal(t, time.Duration(10)*time.Millisecond, duration) } -func TestShouldNotParseDurationStringWithLeadingZero(t *testing.T) { +func TestParseDurationString_ShouldParseDurationStringWithLeadingZero(t *testing.T) { duration, err := ParseDurationString("005h") - assert.EqualError(t, err, "could not parse '005h' as a duration") - assert.Equal(t, time.Duration(0), duration) + + assert.NoError(t, err) + assert.Equal(t, time.Duration(5)*time.Hour, duration) +} + +func TestParseDurationString_ShouldParseMultiUnitValues(t *testing.T) { + duration, err := ParseDurationString("1d3w10ms") + + assert.NoError(t, err) + assert.Equal(t, + (time.Hour*time.Duration(24))+ + (time.Hour*time.Duration(24)*time.Duration(7)*time.Duration(3))+ + (time.Millisecond*time.Duration(10)), duration) +} + +func TestParseDurationString_ShouldParseDuplicateUnitValues(t *testing.T) { + duration, err := ParseDurationString("1d4d2d") + + assert.NoError(t, err) + assert.Equal(t, + (time.Hour*time.Duration(24))+ + (time.Hour*time.Duration(24)*time.Duration(4))+ + (time.Hour*time.Duration(24)*time.Duration(2)), duration) +} + +func TestStandardizeDurationString_ShouldParseStringWithSpaces(t *testing.T) { + result, err := StandardizeDurationString("1d 1h 20m") + + assert.NoError(t, err) + assert.Equal(t, result, "24h1h20m") } func TestShouldTimeIntervalsMakeSense(t *testing.T) {