refactor(configuration): utilize time duration decode hook (#2938)

This enhances the existing time.Duration parser to allow multiple units, and implements a decode hook which can be used by koanf to decode string/integers into time.Durations as applicable.
This commit is contained in:
James Elliott 2022-03-02 17:40:26 +11:00 committed by GitHub
parent d867fa1a63
commit 6276883f04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 686 additions and 343 deletions

View File

@ -100,27 +100,46 @@ $ authelia validate-config --config configuration.yml
# Duration Notation Format
We have implemented a string based notation for configuration options that take a duration. This section describes its
usage. You can use this implementation in: session for expiration, inactivity, and remember_me_duration; and regulation
for ban_time, and find_time. This notation also supports just providing the number of seconds instead.
We have implemented a string/integer based notation for configuration options that take a duration of time. This section
describes the implementation of this. You can use this implementation in various areas of configuration such as:
The notation is comprised of a number which must be positive and not have leading zeros, followed by a letter
denoting the unit of time measurement. The table below describes the units of time and the associated letter.
- session:
- expiration
- inactivity
- remember_me_duration
- regulation:
- ban_time
- find_time
- ntp:
- max_desync
- webauthn:
- timeout
|Unit |Associated Letter|
|:-----:|:---------------:|
|Years |y |
|Months |M |
|Weeks |w |
|Days |d |
|Hours |h |
|Minutes|m |
|Seconds|s |
The way this format works is you can either configure an integer or a string in the specific configuration areas. If you
supply an integer, it is considered a representation of seconds. If you supply a string, it parses the string in blocks
of quantities and units (number followed by a unit letter). For example `5h` indicates a quantity of 5 units of `h`.
Examples:
* 1 hour and 30 minutes: 90m
* 1 day: 1d
* 10 hours: 10h
While you can use multiple of these blocks in combination, ee suggest keeping it simple and use a single value.
## Duration Notation Format Unit Legend
| Unit | Associated Letter |
|:-------:|:-----------------:|
| Years | y |
| Months | M |
| Weeks | w |
| Days | d |
| Hours | h |
| Minutes | m |
| Seconds | s |
## Duration Notation Format Examples
| Desired Value | Configuration Examples |
|:---------------------:|:-------------------------------------:|
| 1 hour and 30 minutes | `90m` or `1h30m` or `5400` or `5400s` |
| 1 day | `1d` or `24h` or `86400` or `86400s` |
| 10 hours | `10h` or `600m` or `9h60m` or `36000` |
# TLS Configuration

View File

@ -57,10 +57,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
}
var ntpProvider *ntp.Provider
if config.NTP != nil {
ntpProvider = ntp.NewProvider(config.NTP)
}
ntpProvider := ntp.NewProvider(&config.NTP)
clock := utils.RealClock{}
authorizer := authorization.NewAuthorizer(config)

View File

@ -4,8 +4,11 @@ import (
"fmt"
"net/mail"
"reflect"
"time"
"github.com/mitchellh/mapstructure"
"github.com/authelia/authelia/v4/internal/utils"
)
// StringToMailAddressFunc decodes a string into a mail.Address.
@ -33,3 +36,67 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
return *mailAddress, nil
}
}
// ToTimeDurationFunc converts string and integer types to a time.Duration.
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType {
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
var (
ptr bool
)
switch f.Kind() {
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:
// We only allow string and integer from kinds to match.
break
default:
return data, nil
}
typeTimeDuration := reflect.TypeOf(time.Hour)
if t.Kind() == reflect.Ptr {
if t.Elem() != typeTimeDuration {
return data, nil
}
ptr = true
} else if t != typeTimeDuration {
return data, nil
}
var duration time.Duration
switch {
case f.Kind() == reflect.String:
break
case f.Kind() == reflect.Int:
seconds := data.(int)
duration = time.Second * time.Duration(seconds)
case f.Kind() == reflect.Int32:
seconds := data.(int32)
duration = time.Second * time.Duration(seconds)
case f == typeTimeDuration:
duration = data.(time.Duration)
case f.Kind() == reflect.Int64:
seconds := data.(int64)
duration = time.Second * time.Duration(seconds)
}
if duration == 0 {
dataStr := data.(string)
if duration, err = utils.ParseDurationString(dataStr); err != nil {
return nil, err
}
}
if ptr {
return &duration, nil
}
return duration, nil
}
}

View File

@ -0,0 +1,270 @@
package configuration
import (
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1h"
expected = time.Hour
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1y"
expected = time.Hour * 24 * 365
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1M"
expected = time.Hour * 24 * 30
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1w"
expected = time.Hour * 24 * 7
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "1d"
expected = time.Hour * 24
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = "abc"
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.EqualError(t, err, "could not parse 'abc' as a duration")
assert.Nil(t, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.EqualError(t, err, "could not parse 'abc' as a duration")
assert.Nil(t, result)
}
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = 60
expected = time.Second * 60
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = int32(120)
expected = time.Second * 120
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = int64(30)
expected = time.Second * 30
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = time.Second * 30
expected = time.Second * 30
to time.Duration
ptrTo *time.Duration
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, expected, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, &expected, result)
}
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = int64(30)
to string
ptrTo *string
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
}
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
hook := ToTimeDurationFunc()
var (
from = true
to string
ptrTo *string
result interface{}
err error
)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
assert.NoError(t, err)
assert.Equal(t, from, result)
}

View File

@ -43,9 +43,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
c := koanf.UnmarshalConf{
DecoderConfig: &mapstructure.DecoderConfig{
DecodeHook: mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
StringToMailAddressFunc(),
ToTimeDurationFunc(),
),
Metadata: nil,
Result: o,

View File

@ -14,8 +14,8 @@ type Configuration struct {
TOTP *TOTPConfiguration `koanf:"totp"`
DuoAPI *DuoAPIConfiguration `koanf:"duo_api"`
AccessControl AccessControlConfiguration `koanf:"access_control"`
NTP *NTPConfiguration `koanf:"ntp"`
Regulation *RegulationConfiguration `koanf:"regulation"`
NTP NTPConfiguration `koanf:"ntp"`
Regulation RegulationConfiguration `koanf:"regulation"`
Storage StorageConfiguration `koanf:"storage"`
Notifier *NotifierConfiguration `koanf:"notifier"`
Server ServerConfiguration `koanf:"server"`

View File

@ -1,10 +1,14 @@
package schema
import (
"time"
)
// NTPConfiguration represents the configuration related to ntp server.
type NTPConfiguration struct {
Address string `koanf:"address"`
Version int `koanf:"version"`
MaximumDesync string `koanf:"max_desync"`
MaximumDesync time.Duration `koanf:"max_desync"`
DisableStartupCheck bool `koanf:"disable_startup_check"`
DisableFailure bool `koanf:"disable_failure"`
}
@ -13,5 +17,5 @@ type NTPConfiguration struct {
var DefaultNTPConfiguration = NTPConfiguration{
Address: "time.cloudflare.com:123",
Version: 4,
MaximumDesync: "3s",
MaximumDesync: time.Second * 3,
}

View File

@ -1,15 +1,19 @@
package schema
import (
"time"
)
// RegulationConfiguration represents the configuration related to regulation.
type RegulationConfiguration struct {
MaxRetries int `koanf:"max_retries"`
FindTime string `koanf:"find_time,weak"`
BanTime string `koanf:"ban_time,weak"`
FindTime time.Duration `koanf:"find_time,weak"`
BanTime time.Duration `koanf:"ban_time,weak"`
}
// DefaultRegulationConfiguration represents default configuration parameters for the regulator.
var DefaultRegulationConfiguration = RegulationConfiguration{
MaxRetries: 3,
FindTime: "2m",
BanTime: "5m",
FindTime: time.Minute * 2,
BanTime: time.Minute * 5,
}

View File

@ -1,5 +1,9 @@
package schema
import (
"time"
)
// RedisNode Represents a Node.
type RedisNode struct {
Host string `koanf:"host"`
@ -35,17 +39,18 @@ type SessionConfiguration struct {
Domain string `koanf:"domain"`
SameSite string `koanf:"same_site"`
Secret string `koanf:"secret"`
Expiration string `koanf:"expiration"`
Inactivity string `koanf:"inactivity"`
RememberMeDuration string `koanf:"remember_me_duration"`
Expiration time.Duration `koanf:"expiration"`
Inactivity time.Duration `koanf:"inactivity"`
RememberMeDuration time.Duration `koanf:"remember_me_duration"`
Redis *RedisSessionConfiguration `koanf:"redis"`
}
// DefaultSessionConfiguration is the default session configuration.
var DefaultSessionConfiguration = SessionConfiguration{
Name: "authelia_session",
Expiration: "1h",
Inactivity: "5m",
RememberMeDuration: "1M",
Expiration: time.Hour,
Inactivity: time.Minute * 5,
RememberMeDuration: time.Hour * 24 * 30,
SameSite: "lax",
}

View File

@ -35,7 +35,6 @@ const (
// Test constants.
const (
testBadTimer = "-1"
testInvalidPolicy = "invalid"
testJWTSecret = "a_secret"
testLDAPBaseDN = "base_dn"
@ -186,12 +185,10 @@ const (
// NTP Error constants.
const (
errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'"
errFmtNTPMaxDesync = "ntp: option 'max_desync' can't be parsed: %w"
)
// Session error constants.
const (
errFmtSessionCouldNotParseDuration = "session: option '%s' could not be parsed: %w"
errFmtSessionOptionRequired = "session: option '%s' is required"
errFmtSessionDomainMustBeRoot = "session: option 'domain' must be the domain you wish to protect not a wildcard domain but it is configured as '%s'"
errFmtSessionSameSite = "session: option 'same_site' must be one of '%s' but is configured as '%s'"
@ -206,7 +203,6 @@ const (
// Regulation Error Consts.
const (
errFmtRegulationParseDuration = "regulation: option '%s' could not be parsed: %w"
errFmtRegulationFindTimeGreaterThanBanTime = "regulation: option 'find_time' must be less than or equal to option 'ban_time'"
)

View File

@ -4,17 +4,10 @@ import (
"fmt"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils"
)
// ValidateNTP validates and update NTP configuration.
func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator) {
if config.NTP == nil {
config.NTP = &schema.DefaultNTPConfiguration
return
}
if config.NTP.Address == "" {
config.NTP.Address = schema.DefaultNTPConfiguration.Address
}
@ -25,12 +18,7 @@ func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator
validator.Push(fmt.Errorf(errFmtNTPVersion, config.NTP.Version))
}
if config.NTP.MaximumDesync == "" {
if config.NTP.MaximumDesync == 0 {
config.NTP.MaximumDesync = schema.DefaultNTPConfiguration.MaximumDesync
}
_, err := utils.ParseDurationString(config.NTP.MaximumDesync)
if err != nil {
validator.Push(fmt.Errorf(errFmtNTPMaxDesync, err))
}
}

View File

@ -11,7 +11,7 @@ import (
func newDefaultNTPConfig() schema.Configuration {
return schema.Configuration{
NTP: &schema.NTPConfiguration{},
NTP: schema.NTPConfiguration{},
}
}
@ -55,18 +55,6 @@ func TestShouldSetDefaultNtpDisableStartupCheck(t *testing.T) {
assert.Equal(t, schema.DefaultNTPConfiguration.DisableStartupCheck, config.NTP.DisableStartupCheck)
}
func TestShouldRaiseErrorOnMaximumDesyncString(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultNTPConfig()
config.NTP.MaximumDesync = "a second"
ValidateNTP(&config, validator)
require.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "ntp: option 'max_desync' can't be parsed: could not parse 'a second' as a duration")
}
func TestShouldRaiseErrorOnInvalidNTPVersion(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultNTPConfig()

View File

@ -4,36 +4,19 @@ import (
"fmt"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/utils"
)
// ValidateRegulation validates and update regulator configuration.
func ValidateRegulation(config *schema.Configuration, validator *schema.StructValidator) {
if config.Regulation == nil {
config.Regulation = &schema.DefaultRegulationConfiguration
return
}
if config.Regulation.FindTime == "" {
if config.Regulation.FindTime == 0 {
config.Regulation.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min.
}
if config.Regulation.BanTime == "" {
if config.Regulation.BanTime == 0 {
config.Regulation.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min.
}
findTime, err := utils.ParseDurationString(config.Regulation.FindTime)
if err != nil {
validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "find_time", err))
}
banTime, err := utils.ParseDurationString(config.Regulation.BanTime)
if err != nil {
validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "ban_time", err))
}
if findTime > banTime {
if config.Regulation.FindTime > config.Regulation.BanTime {
validator.Push(fmt.Errorf(errFmtRegulationFindTimeGreaterThanBanTime))
}
}

View File

@ -2,6 +2,7 @@ package validator
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -10,7 +11,7 @@ import (
func newDefaultRegulationConfig() schema.Configuration {
config := schema.Configuration{
Regulation: &schema.RegulationConfiguration{},
Regulation: schema.RegulationConfiguration{},
}
return config
@ -39,24 +40,11 @@ func TestShouldSetDefaultRegulationFindTime(t *testing.T) {
func TestShouldRaiseErrorWhenFindTimeLessThanBanTime(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultRegulationConfig()
config.Regulation.FindTime = "1m"
config.Regulation.BanTime = "10s"
config.Regulation.FindTime = time.Minute
config.Regulation.BanTime = time.Second * 10
ValidateRegulation(&config, validator)
assert.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' must be less than or equal to option 'ban_time'")
}
func TestShouldRaiseErrorOnBadDurationStrings(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultRegulationConfig()
config.Regulation.FindTime = "a year"
config.Regulation.BanTime = "forever"
ValidateRegulation(&config, validator)
assert.Len(t, validator.Errors(), 2)
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' could not be parsed: could not parse 'a year' as a duration")
assert.EqualError(t, validator.Errors()[1], "regulation: option 'ban_time' could not be parsed: could not parse 'forever' as a duration")
}

View File

@ -27,22 +27,16 @@ func ValidateSession(config *schema.SessionConfiguration, validator *schema.Stru
}
func validateSession(config *schema.SessionConfiguration, validator *schema.StructValidator) {
if config.Expiration == "" {
if config.Expiration <= 0 {
config.Expiration = schema.DefaultSessionConfiguration.Expiration // 1 hour.
} else if _, err := utils.ParseDurationString(config.Expiration); err != nil {
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "expiriation", err))
}
if config.Inactivity == "" {
if config.Inactivity <= 0 {
config.Inactivity = schema.DefaultSessionConfiguration.Inactivity // 5 min.
} else if _, err := utils.ParseDurationString(config.Inactivity); err != nil {
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "inactivity", err))
}
if config.RememberMeDuration == "" {
if config.RememberMeDuration <= 0 {
config.RememberMeDuration = schema.DefaultSessionConfiguration.RememberMeDuration // 1 month.
} else if _, err := utils.ParseDurationString(config.RememberMeDuration); err != nil {
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "remember_me_duration", err))
}
if config.Domain == "" {

View File

@ -420,30 +420,21 @@ func TestShouldNotRaiseErrorWhenSameSiteSetCorrectly(t *testing.T) {
}
}
func TestShouldRaiseErrorWhenBadInactivityAndExpirationSet(t *testing.T) {
func TestShouldSetDefaultWhenNegativeInactivityAndExpirationSet(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultSessionConfig()
config.Inactivity = testBadTimer
config.Expiration = testBadTimer
config.Inactivity = -1
config.Expiration = -1
config.RememberMeDuration = -1
ValidateSession(&config, validator)
assert.False(t, validator.HasWarnings())
assert.Len(t, validator.Errors(), 2)
assert.EqualError(t, validator.Errors()[0], "session: option 'expiriation' could not be parsed: could not parse '-1' as a duration")
assert.EqualError(t, validator.Errors()[1], "session: option 'inactivity' could not be parsed: could not parse '-1' as a duration")
}
assert.Len(t, validator.Warnings(), 0)
assert.Len(t, validator.Errors(), 0)
func TestShouldRaiseErrorWhenBadRememberMeDurationSet(t *testing.T) {
validator := schema.NewStructValidator()
config := newDefaultSessionConfig()
config.RememberMeDuration = "1 year"
ValidateSession(&config, validator)
assert.False(t, validator.HasWarnings())
assert.Len(t, validator.Errors(), 1)
assert.EqualError(t, validator.Errors()[0], "session: option 'remember_me_duration' could not be parsed: could not parse '1 year' as a duration")
assert.Equal(t, schema.DefaultSessionConfiguration.Inactivity, config.Inactivity)
assert.Equal(t, schema.DefaultSessionConfiguration.Expiration, config.Expiration)
assert.Equal(t, schema.DefaultSessionConfiguration.RememberMeDuration, config.RememberMeDuration)
}
func TestShouldSetDefaultRememberMeDuration(t *testing.T) {

View File

@ -1,6 +1,8 @@
package handlers
import (
"time"
"github.com/valyala/fasthttp"
)
@ -56,7 +58,7 @@ const (
)
const (
testInactivity = "10"
testInactivity = time.Second * 10
testRedirectionURL = "http://redirection.local"
testUsername = "john"
)

View File

@ -602,7 +602,7 @@ func TestShouldDestroySessionWhenInactiveForTooLongUsingDurationNotation(t *test
clock := mocks.TestingClock{}
clock.Set(time.Now())
mock.Ctx.Configuration.Session.Inactivity = "10s"
mock.Ctx.Configuration.Session.Inactivity = time.Second * 10
// Reload the session provider since the configuration is indirect.
mock.Ctx.Providers.SessionProvider = session.NewProvider(mock.Ctx.Configuration.Session, nil)
assert.Equal(t, time.Second*10, mock.Ctx.Providers.SessionProvider.Inactivity)

View File

@ -8,7 +8,6 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewProvider instantiate a ntp provider given a configuration.
@ -59,11 +58,9 @@ func (p *Provider) StartupCheck() (err error) {
return nil
}
maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync)
ntpTime := ntpPacketToTime(resp)
if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result {
if result := ntpIsOffsetTooLarge(p.config.MaximumDesync, now, ntpTime); result {
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
}

View File

@ -2,6 +2,7 @@ package ntp
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -11,18 +12,17 @@ import (
func TestShouldCheckNTP(t *testing.T) {
config := &schema.Configuration{
NTP: &schema.NTPConfiguration{
NTP: schema.NTPConfiguration{
Address: "time.cloudflare.com:123",
Version: 4,
MaximumDesync: "3s",
DisableStartupCheck: false,
MaximumDesync: time.Second * 3,
},
}
sv := schema.NewStructValidator()
validator.ValidateNTP(config, sv)
ntp := NewProvider(config.NTP)
ntp := NewProvider(&config.NTP)
assert.NoError(t, ntp.StartupCheck())
}

View File

@ -2,7 +2,6 @@ package regulation
import (
"context"
"fmt"
"net"
"time"
@ -13,33 +12,13 @@ import (
)
// NewRegulator create a regulator instance.
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
regulator := &Regulator{storageProvider: provider}
regulator.clock = clock
if configuration != nil {
findTime, err := utils.ParseDurationString(configuration.FindTime)
if err != nil {
panic(err)
func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
return &Regulator{
enabled: config.MaxRetries > 0,
storageProvider: provider,
clock: clock,
config: config,
}
banTime, err := utils.ParseDurationString(configuration.BanTime)
if err != nil {
panic(err)
}
if findTime > banTime {
panic(fmt.Errorf("find_time cannot be greater than ban_time"))
}
// Set regulator enabled only if MaxRetries is not 0.
regulator.enabled = configuration.MaxRetries > 0
regulator.maxRetries = configuration.MaxRetries
regulator.findTime = findTime
regulator.banTime = banTime
}
return regulator
}
// Mark an authentication attempt.
@ -65,15 +44,15 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
return time.Time{}, nil
}
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0)
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
if err != nil {
return time.Time{}, nil
}
latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.maxRetries)
latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.config.MaxRetries)
for _, attempt := range attempts {
if attempt.Successful || len(latestFailedAttempts) >= r.maxRetries {
if attempt.Successful || len(latestFailedAttempts) >= r.config.MaxRetries {
// We stop appending failed attempts once we find the first successful attempts or we reach
// the configured number of retries, meaning the user is already banned.
break
@ -84,17 +63,17 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
// If the number of failed attempts within the ban time is less than the max number of retries
// then the user is not banned.
if len(latestFailedAttempts) < r.maxRetries {
if len(latestFailedAttempts) < r.config.MaxRetries {
return time.Time{}, nil
}
// Now we compute the time between the latest attempt and the MaxRetry-th one. If it's
// within the FindTime then it means that the user has been banned.
durationBetweenLatestAttempts := latestFailedAttempts[0].Time.Sub(
latestFailedAttempts[r.maxRetries-1].Time)
latestFailedAttempts[r.config.MaxRetries-1].Time)
if durationBetweenLatestAttempts < r.findTime {
bannedUntil := latestFailedAttempts[0].Time.Add(r.banTime)
if durationBetweenLatestAttempts < r.config.FindTime {
bannedUntil := latestFailedAttempts[0].Time.Add(r.config.BanTime)
return bannedUntil, ErrUserIsBanned
}

View File

@ -21,7 +21,7 @@ type RegulatorSuite struct {
ctx context.Context
ctrl *gomock.Controller
storageMock *mocks.MockStorage
configuration schema.RegulationConfiguration
config schema.RegulationConfiguration
clock mocks.TestingClock
}
@ -30,10 +30,10 @@ func (s *RegulatorSuite) SetupTest() {
s.storageMock = mocks.NewMockStorage(s.ctrl)
s.ctx = context.Background()
s.configuration = schema.RegulationConfiguration{
s.config = schema.RegulationConfiguration{
MaxRetries: 3,
BanTime: "180",
FindTime: "30",
BanTime: time.Second * 180,
FindTime: time.Second * 30,
}
s.clock.Set(time.Now())
}
@ -55,7 +55,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
@ -86,7 +86,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
@ -122,7 +122,7 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
@ -155,7 +155,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
@ -179,7 +179,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
@ -211,7 +211,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
@ -247,7 +247,7 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
@ -283,24 +283,24 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
Return(attemptsInDB, nil)
// Check Disabled Functionality.
configuration := schema.RegulationConfiguration{
config := schema.RegulationConfiguration{
MaxRetries: 0,
FindTime: "180",
BanTime: "180",
FindTime: time.Second * 180,
BanTime: time.Second * 180,
}
regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
regulator := regulation.NewRegulator(config, s.storageMock, &s.clock)
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
// Check Enabled Functionality.
configuration = schema.RegulationConfiguration{
config = schema.RegulationConfiguration{
MaxRetries: 1,
FindTime: "180",
BanTime: "180",
FindTime: time.Second * 180,
BanTime: time.Second * 180,
}
regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
regulator = regulation.NewRegulator(config, s.storageMock, &s.clock)
_, err = regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
}

View File

@ -1,8 +1,7 @@
package regulation
import (
"time"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils"
)
@ -11,12 +10,8 @@ import (
type Regulator struct {
// Is the regulation enabled.
enabled bool
// The number of failed authentication attempt before banning the user.
maxRetries int
// If a user does the max number of retries within that duration, she will be banned.
findTime time.Duration
// If a user has been banned, this duration is the timelapse during which the user is banned.
banTime time.Duration
config schema.RegulationConfiguration
storageProvider storage.RegulatorProvider

View File

@ -28,7 +28,7 @@ var assets embed.FS
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != "0")
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != -1)
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
duoSelfEnrollment := f

View File

@ -1,8 +1,12 @@
package session
import (
"time"
)
const (
testDomain = "example.com"
testExpiration = "40"
testExpiration = time.Second * 40
testName = "my_session"
testUsername = "john"
)

View File

@ -12,7 +12,6 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
)
// Provider a session provider.
@ -23,38 +22,29 @@ type Provider struct {
}
// NewProvider instantiate a session provider given a configuration.
func NewProvider(configuration schema.SessionConfiguration, certPool *x509.CertPool) *Provider {
providerConfig := NewProviderConfig(configuration, certPool)
func NewProvider(config schema.SessionConfiguration, certPool *x509.CertPool) *Provider {
c := NewProviderConfig(config, certPool)
provider := new(Provider)
provider.sessionHolder = fasthttpsession.New(providerConfig.config)
provider.sessionHolder = fasthttpsession.New(c.config)
logger := logging.Logger()
duration, err := utils.ParseDurationString(configuration.RememberMeDuration)
if err != nil {
logger.Fatal(err)
}
provider.Inactivity, provider.RememberMe = config.Inactivity, config.RememberMeDuration
provider.RememberMe = duration
duration, err = utils.ParseDurationString(configuration.Inactivity)
if err != nil {
logger.Fatal(err)
}
provider.Inactivity = duration
var providerImpl fasthttpsession.Provider
var (
providerImpl fasthttpsession.Provider
err error
)
switch {
case providerConfig.redisConfig != nil:
providerImpl, err = redis.New(*providerConfig.redisConfig)
case c.redisConfig != nil:
providerImpl, err = redis.New(*c.redisConfig)
if err != nil {
logger.Fatal(err)
}
case providerConfig.redisSentinelConfig != nil:
providerImpl, err = redis.NewFailoverCluster(*providerConfig.redisSentinelConfig)
case c.redisSentinelConfig != nil:
providerImpl, err = redis.NewFailoverCluster(*c.redisSentinelConfig)
if err != nil {
logger.Fatal(err)
}

View File

@ -17,10 +17,10 @@ import (
)
// NewProviderConfig creates a configuration for creating the session provider.
func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig {
config := session.NewDefaultConfig()
func NewProviderConfig(config schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig {
c := session.NewDefaultConfig()
config.SessionIDGeneratorFunc = func() []byte {
c.SessionIDGeneratorFunc = func() []byte {
bytes := make([]byte, 32)
_, _ = rand.Read(bytes)
@ -33,30 +33,30 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
}
// Override the cookie name.
config.CookieName = configuration.Name
c.CookieName = config.Name
// Set the cookie to the given domain.
config.Domain = configuration.Domain
c.Domain = config.Domain
// Set the cookie SameSite option.
switch configuration.SameSite {
switch config.SameSite {
case "strict":
config.CookieSameSite = fasthttp.CookieSameSiteStrictMode
c.CookieSameSite = fasthttp.CookieSameSiteStrictMode
case "none":
config.CookieSameSite = fasthttp.CookieSameSiteNoneMode
c.CookieSameSite = fasthttp.CookieSameSiteNoneMode
case "lax":
config.CookieSameSite = fasthttp.CookieSameSiteLaxMode
c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
default:
config.CookieSameSite = fasthttp.CookieSameSiteLaxMode
c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
}
// Only serve the header over HTTPS.
config.Secure = true
c.Secure = true
// Ignore the error as it will be handled by validator.
config.Expiration, _ = utils.ParseDurationString(configuration.Expiration)
c.Expiration = config.Expiration
config.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
c.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
return true
}
@ -68,23 +68,23 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
// If redis configuration is provided, then use the redis provider.
switch {
case configuration.Redis != nil:
serializer := NewEncryptingSerializer(configuration.Secret)
case config.Redis != nil:
serializer := NewEncryptingSerializer(config.Secret)
var tlsConfig *tls.Config
if configuration.Redis.TLS != nil {
tlsConfig = utils.NewTLSConfig(configuration.Redis.TLS, tls.VersionTLS12, certPool)
if config.Redis.TLS != nil {
tlsConfig = utils.NewTLSConfig(config.Redis.TLS, tls.VersionTLS12, certPool)
}
if configuration.Redis.HighAvailability != nil && configuration.Redis.HighAvailability.SentinelName != "" {
if config.Redis.HighAvailability != nil && config.Redis.HighAvailability.SentinelName != "" {
addrs := make([]string, 0)
if configuration.Redis.Host != "" {
addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(configuration.Redis.Host), configuration.Redis.Port))
if config.Redis.Host != "" {
addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(config.Redis.Host), config.Redis.Port))
}
for _, node := range configuration.Redis.HighAvailability.Nodes {
for _, node := range config.Redis.HighAvailability.Nodes {
addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port)
if !utils.IsStringInSlice(addr, addrs) {
addrs = append(addrs, addr)
@ -94,17 +94,17 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
providerName = "redis-sentinel"
redisSentinelConfig = &redis.FailoverConfig{
Logger: &redisLogger{logger: logging.Logger()},
MasterName: configuration.Redis.HighAvailability.SentinelName,
MasterName: config.Redis.HighAvailability.SentinelName,
SentinelAddrs: addrs,
SentinelUsername: configuration.Redis.HighAvailability.SentinelUsername,
SentinelPassword: configuration.Redis.HighAvailability.SentinelPassword,
RouteByLatency: configuration.Redis.HighAvailability.RouteByLatency,
RouteRandomly: configuration.Redis.HighAvailability.RouteRandomly,
Username: configuration.Redis.Username,
Password: configuration.Redis.Password,
DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: configuration.Redis.MaximumActiveConnections,
MinIdleConns: configuration.Redis.MinimumIdleConnections,
SentinelUsername: config.Redis.HighAvailability.SentinelUsername,
SentinelPassword: config.Redis.HighAvailability.SentinelPassword,
RouteByLatency: config.Redis.HighAvailability.RouteByLatency,
RouteRandomly: config.Redis.HighAvailability.RouteRandomly,
Username: config.Redis.Username,
Password: config.Redis.Password,
DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: config.Redis.MaximumActiveConnections,
MinIdleConns: config.Redis.MinimumIdleConnections,
IdleTimeout: 300,
TLSConfig: tlsConfig,
KeyPrefix: "authelia-session",
@ -115,36 +115,36 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
var addr string
if configuration.Redis.Port == 0 {
if config.Redis.Port == 0 {
network = "unix"
addr = configuration.Redis.Host
addr = config.Redis.Host
} else {
addr = fmt.Sprintf("%s:%d", configuration.Redis.Host, configuration.Redis.Port)
addr = fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port)
}
redisConfig = &redis.Config{
Logger: newRedisLogger(),
Network: network,
Addr: addr,
Username: configuration.Redis.Username,
Password: configuration.Redis.Password,
DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: configuration.Redis.MaximumActiveConnections,
MinIdleConns: configuration.Redis.MinimumIdleConnections,
Username: config.Redis.Username,
Password: config.Redis.Password,
DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: config.Redis.MaximumActiveConnections,
MinIdleConns: config.Redis.MinimumIdleConnections,
IdleTimeout: 300,
TLSConfig: tlsConfig,
KeyPrefix: "authelia-session",
}
}
config.EncodeFunc = serializer.Encode
config.DecodeFunc = serializer.Decode
c.EncodeFunc = serializer.Encode
c.DecodeFunc = serializer.Decode
default:
providerName = "memory"
}
return ProviderConfig{
config,
c,
redisConfig,
redisSentinelConfig,
providerName,

View File

@ -53,7 +53,25 @@ const (
)
var (
reDuration = regexp.MustCompile(`^(?P<Duration>[1-9]\d*?)(?P<Unit>[smhdwMy])?$`)
standardDurationUnits = []string{"ns", "us", "µs", "μs", "ms", "s", "m", "h"}
reDurationSeconds = regexp.MustCompile(`^\d+$`)
reDurationStandard = regexp.MustCompile(`(?P<Duration>[1-9]\d*?)(?P<Unit>[^\d\s]+)`)
)
// Duration unit types.
const (
DurationUnitDays = "d"
DurationUnitWeeks = "w"
DurationUnitMonths = "M"
DurationUnitYears = "y"
)
// Number of hours in particular measurements of time.
const (
HoursInDay = 24
HoursInWeek = HoursInDay * 7
HoursInMonth = HoursInDay * 30
HoursInYear = HoursInDay * 365
)
var (

View File

@ -6,46 +6,64 @@ import (
"time"
)
// ParseDurationString parses a string to a duration
// Duration notations are an integer followed by a unit
// Units are s = second, m = minute, d = day, w = week, M = month, y = year
// Example 1y is the same as 1 year.
func ParseDurationString(input string) (time.Duration, error) {
var duration time.Duration
// StandardizeDurationString converts units of time that stdlib is unaware of to hours.
func StandardizeDurationString(input string) (output string, err error) {
if input == "" {
return "0s", nil
}
matches := reDuration.FindStringSubmatch(input)
matches := reDurationStandard.FindAllStringSubmatch(input, -1)
if len(matches) == 0 {
return "", fmt.Errorf("could not parse '%s' as a duration", input)
}
var d int
for _, match := range matches {
if d, err = strconv.Atoi(match[1]); err != nil {
return "", fmt.Errorf("could not parse the numeric portion of '%s' in duration string '%s': %w", match[0], input, err)
}
unit := match[2]
switch {
case len(matches) == 3 && matches[2] != "":
d, _ := strconv.Atoi(matches[1])
switch matches[2] {
case "y":
duration = time.Duration(d) * Year
case "M":
duration = time.Duration(d) * Month
case "w":
duration = time.Duration(d) * Week
case "d":
duration = time.Duration(d) * Day
case "h":
duration = time.Duration(d) * Hour
case "m":
duration = time.Duration(d) * time.Minute
case "s":
duration = time.Duration(d) * time.Second
case IsStringInSlice(unit, standardDurationUnits):
output += fmt.Sprintf("%d%s", d, unit)
case unit == DurationUnitDays:
output += fmt.Sprintf("%dh", d*HoursInDay)
case unit == DurationUnitWeeks:
output += fmt.Sprintf("%dh", d*HoursInWeek)
case unit == DurationUnitMonths:
output += fmt.Sprintf("%dh", d*HoursInMonth)
case unit == DurationUnitYears:
output += fmt.Sprintf("%dh", d*HoursInYear)
default:
return "", fmt.Errorf("could not parse the units portion of '%s' in duration string '%s': the unit '%s' is not valid", match[0], input, unit)
}
case input == "0" || len(matches) == 3:
seconds, err := strconv.Atoi(input)
if err != nil {
return 0, fmt.Errorf("could not parse '%s' as a duration: %w", input, err)
}
duration = time.Duration(seconds) * time.Second
case input != "":
// Throw this error if input is anything other than a blank string, blank string will default to a duration of nothing.
return 0, fmt.Errorf("could not parse '%s' as a duration", input)
}
return duration, nil
return output, nil
}
// ParseDurationString standardizes a duration string with StandardizeDurationString then uses time.ParseDuration to
// convert it into a time.Duration.
func ParseDurationString(input string) (duration time.Duration, err error) {
if reDurationSeconds.MatchString(input) {
var seconds int
if seconds, err = strconv.Atoi(input); err != nil {
return 0, nil
}
return time.Second * time.Duration(seconds), nil
}
var out string
if out, err = StandardizeDurationString(input); err != nil {
return 0, err
}
return time.ParseDuration(out)
}

View File

@ -7,66 +7,112 @@ import (
"github.com/stretchr/testify/assert"
)
func TestShouldParseDurationString(t *testing.T) {
func TestParseDurationString_ShouldParseDurationString(t *testing.T) {
duration, err := ParseDurationString("1h")
assert.NoError(t, err)
assert.Equal(t, 60*time.Minute, duration)
}
func TestShouldParseDurationStringAllUnits(t *testing.T) {
duration, err := ParseDurationString("1y")
func TestParseDurationString_ShouldParseBlankString(t *testing.T) {
duration, err := ParseDurationString("")
assert.NoError(t, err)
assert.Equal(t, Year, duration)
assert.Equal(t, time.Second*0, duration)
}
func TestParseDurationString_ShouldParseDurationStringAllUnits(t *testing.T) {
duration, err := ParseDurationString("1y")
assert.NoError(t, err)
assert.Equal(t, time.Hour*24*365, duration)
duration, err = ParseDurationString("1M")
assert.NoError(t, err)
assert.Equal(t, Month, duration)
assert.Equal(t, time.Hour*24*30, duration)
duration, err = ParseDurationString("1w")
assert.NoError(t, err)
assert.Equal(t, Week, duration)
assert.Equal(t, time.Hour*24*7, duration)
duration, err = ParseDurationString("1d")
assert.NoError(t, err)
assert.Equal(t, Day, duration)
assert.Equal(t, time.Hour*24, duration)
duration, err = ParseDurationString("1h")
assert.NoError(t, err)
assert.Equal(t, Hour, duration)
assert.Equal(t, time.Hour, duration)
duration, err = ParseDurationString("1s")
assert.NoError(t, err)
assert.Equal(t, time.Second, duration)
}
func TestShouldParseSecondsString(t *testing.T) {
func TestParseDurationString_ShouldParseSecondsString(t *testing.T) {
duration, err := ParseDurationString("100")
assert.NoError(t, err)
assert.Equal(t, 100*time.Second, duration)
}
func TestShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) {
func TestParseDurationString_ShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) {
duration, err := ParseDurationString("h1")
assert.EqualError(t, err, "could not parse 'h1' as a duration")
assert.Equal(t, time.Duration(0), duration)
}
func TestShouldNotParseBadDurationString(t *testing.T) {
func TestParseDurationString_ShouldNotParseBadDurationString(t *testing.T) {
duration, err := ParseDurationString("10x")
assert.EqualError(t, err, "could not parse '10x' as a duration")
assert.EqualError(t, err, "could not parse the units portion of '10x' in duration string '10x': the unit 'x' is not valid")
assert.Equal(t, time.Duration(0), duration)
}
func TestShouldNotParseDurationStringWithMultiValueUnits(t *testing.T) {
func TestParseDurationString_ShouldParseDurationStringWithMultiValueUnits(t *testing.T) {
duration, err := ParseDurationString("10ms")
assert.EqualError(t, err, "could not parse '10ms' as a duration")
assert.Equal(t, time.Duration(0), duration)
assert.NoError(t, err)
assert.Equal(t, time.Duration(10)*time.Millisecond, duration)
}
func TestShouldNotParseDurationStringWithLeadingZero(t *testing.T) {
func TestParseDurationString_ShouldParseDurationStringWithLeadingZero(t *testing.T) {
duration, err := ParseDurationString("005h")
assert.EqualError(t, err, "could not parse '005h' as a duration")
assert.Equal(t, time.Duration(0), duration)
assert.NoError(t, err)
assert.Equal(t, time.Duration(5)*time.Hour, duration)
}
func TestParseDurationString_ShouldParseMultiUnitValues(t *testing.T) {
duration, err := ParseDurationString("1d3w10ms")
assert.NoError(t, err)
assert.Equal(t,
(time.Hour*time.Duration(24))+
(time.Hour*time.Duration(24)*time.Duration(7)*time.Duration(3))+
(time.Millisecond*time.Duration(10)), duration)
}
func TestParseDurationString_ShouldParseDuplicateUnitValues(t *testing.T) {
duration, err := ParseDurationString("1d4d2d")
assert.NoError(t, err)
assert.Equal(t,
(time.Hour*time.Duration(24))+
(time.Hour*time.Duration(24)*time.Duration(4))+
(time.Hour*time.Duration(24)*time.Duration(2)), duration)
}
func TestStandardizeDurationString_ShouldParseStringWithSpaces(t *testing.T) {
result, err := StandardizeDurationString("1d 1h 20m")
assert.NoError(t, err)
assert.Equal(t, result, "24h1h20m")
}
func TestShouldTimeIntervalsMakeSense(t *testing.T) {