mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
refactor(configuration): utilize time duration decode hook (#2938)
This enhances the existing time.Duration parser to allow multiple units, and implements a decode hook which can be used by koanf to decode string/integers into time.Durations as applicable.
This commit is contained in:
parent
d867fa1a63
commit
6276883f04
|
@ -100,27 +100,46 @@ $ authelia validate-config --config configuration.yml
|
||||||
|
|
||||||
# Duration Notation Format
|
# Duration Notation Format
|
||||||
|
|
||||||
We have implemented a string based notation for configuration options that take a duration. This section describes its
|
We have implemented a string/integer based notation for configuration options that take a duration of time. This section
|
||||||
usage. You can use this implementation in: session for expiration, inactivity, and remember_me_duration; and regulation
|
describes the implementation of this. You can use this implementation in various areas of configuration such as:
|
||||||
for ban_time, and find_time. This notation also supports just providing the number of seconds instead.
|
|
||||||
|
|
||||||
The notation is comprised of a number which must be positive and not have leading zeros, followed by a letter
|
- session:
|
||||||
denoting the unit of time measurement. The table below describes the units of time and the associated letter.
|
- expiration
|
||||||
|
- inactivity
|
||||||
|
- remember_me_duration
|
||||||
|
- regulation:
|
||||||
|
- ban_time
|
||||||
|
- find_time
|
||||||
|
- ntp:
|
||||||
|
- max_desync
|
||||||
|
- webauthn:
|
||||||
|
- timeout
|
||||||
|
|
||||||
|Unit |Associated Letter|
|
The way this format works is you can either configure an integer or a string in the specific configuration areas. If you
|
||||||
|:-----:|:---------------:|
|
supply an integer, it is considered a representation of seconds. If you supply a string, it parses the string in blocks
|
||||||
|Years |y |
|
of quantities and units (number followed by a unit letter). For example `5h` indicates a quantity of 5 units of `h`.
|
||||||
|Months |M |
|
|
||||||
|Weeks |w |
|
|
||||||
|Days |d |
|
|
||||||
|Hours |h |
|
|
||||||
|Minutes|m |
|
|
||||||
|Seconds|s |
|
|
||||||
|
|
||||||
Examples:
|
While you can use multiple of these blocks in combination, ee suggest keeping it simple and use a single value.
|
||||||
* 1 hour and 30 minutes: 90m
|
|
||||||
* 1 day: 1d
|
## Duration Notation Format Unit Legend
|
||||||
* 10 hours: 10h
|
|
||||||
|
| Unit | Associated Letter |
|
||||||
|
|:-------:|:-----------------:|
|
||||||
|
| Years | y |
|
||||||
|
| Months | M |
|
||||||
|
| Weeks | w |
|
||||||
|
| Days | d |
|
||||||
|
| Hours | h |
|
||||||
|
| Minutes | m |
|
||||||
|
| Seconds | s |
|
||||||
|
|
||||||
|
## Duration Notation Format Examples
|
||||||
|
|
||||||
|
| Desired Value | Configuration Examples |
|
||||||
|
|:---------------------:|:-------------------------------------:|
|
||||||
|
| 1 hour and 30 minutes | `90m` or `1h30m` or `5400` or `5400s` |
|
||||||
|
| 1 day | `1d` or `24h` or `86400` or `86400s` |
|
||||||
|
| 10 hours | `10h` or `600m` or `9h60m` or `36000` |
|
||||||
|
|
||||||
# TLS Configuration
|
# TLS Configuration
|
||||||
|
|
||||||
|
|
|
@ -57,10 +57,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [
|
||||||
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
|
notifier = notification.NewFileNotifier(*config.Notifier.FileSystem)
|
||||||
}
|
}
|
||||||
|
|
||||||
var ntpProvider *ntp.Provider
|
ntpProvider := ntp.NewProvider(&config.NTP)
|
||||||
if config.NTP != nil {
|
|
||||||
ntpProvider = ntp.NewProvider(config.NTP)
|
|
||||||
}
|
|
||||||
|
|
||||||
clock := utils.RealClock{}
|
clock := utils.RealClock{}
|
||||||
authorizer := authorization.NewAuthorizer(config)
|
authorizer := authorization.NewAuthorizer(config)
|
||||||
|
|
|
@ -4,8 +4,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
|
|
||||||
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StringToMailAddressFunc decodes a string into a mail.Address.
|
// StringToMailAddressFunc decodes a string into a mail.Address.
|
||||||
|
@ -33,3 +36,67 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
|
||||||
return *mailAddress, nil
|
return *mailAddress, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToTimeDurationFunc converts string and integer types to a time.Duration.
|
||||||
|
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType {
|
||||||
|
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||||
|
var (
|
||||||
|
ptr bool
|
||||||
|
)
|
||||||
|
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
|
// We only allow string and integer from kinds to match.
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
typeTimeDuration := reflect.TypeOf(time.Hour)
|
||||||
|
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
if t.Elem() != typeTimeDuration {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr = true
|
||||||
|
} else if t != typeTimeDuration {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var duration time.Duration
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case f.Kind() == reflect.String:
|
||||||
|
break
|
||||||
|
case f.Kind() == reflect.Int:
|
||||||
|
seconds := data.(int)
|
||||||
|
|
||||||
|
duration = time.Second * time.Duration(seconds)
|
||||||
|
case f.Kind() == reflect.Int32:
|
||||||
|
seconds := data.(int32)
|
||||||
|
|
||||||
|
duration = time.Second * time.Duration(seconds)
|
||||||
|
case f == typeTimeDuration:
|
||||||
|
duration = data.(time.Duration)
|
||||||
|
case f.Kind() == reflect.Int64:
|
||||||
|
seconds := data.(int64)
|
||||||
|
|
||||||
|
duration = time.Second * time.Duration(seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
if duration == 0 {
|
||||||
|
dataStr := data.(string)
|
||||||
|
|
||||||
|
if duration, err = utils.ParseDurationString(dataStr); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ptr {
|
||||||
|
return &duration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return duration, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
270
internal/configuration/decode_hooks_test.go
Normal file
270
internal/configuration/decode_hooks_test.go
Normal file
|
@ -0,0 +1,270 @@
|
||||||
|
package configuration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "1h"
|
||||||
|
expected = time.Hour
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "1y"
|
||||||
|
expected = time.Hour * 24 * 365
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "1M"
|
||||||
|
expected = time.Hour * 24 * 30
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "1w"
|
||||||
|
expected = time.Hour * 24 * 7
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "1d"
|
||||||
|
expected = time.Hour * 24
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = "abc"
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.EqualError(t, err, "could not parse 'abc' as a duration")
|
||||||
|
assert.Nil(t, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.EqualError(t, err, "could not parse 'abc' as a duration")
|
||||||
|
assert.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = 60
|
||||||
|
expected = time.Second * 60
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = int32(120)
|
||||||
|
expected = time.Second * 120
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = int64(30)
|
||||||
|
expected = time.Second * 30
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = time.Second * 30
|
||||||
|
expected = time.Second * 30
|
||||||
|
|
||||||
|
to time.Duration
|
||||||
|
ptrTo *time.Duration
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = int64(30)
|
||||||
|
|
||||||
|
to string
|
||||||
|
ptrTo *string
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, from, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, from, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
|
||||||
|
hook := ToTimeDurationFunc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
from = true
|
||||||
|
|
||||||
|
to string
|
||||||
|
ptrTo *string
|
||||||
|
result interface{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(to), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, from, result)
|
||||||
|
|
||||||
|
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(ptrTo), from)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, from, result)
|
||||||
|
}
|
|
@ -43,9 +43,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
|
||||||
c := koanf.UnmarshalConf{
|
c := koanf.UnmarshalConf{
|
||||||
DecoderConfig: &mapstructure.DecoderConfig{
|
DecoderConfig: &mapstructure.DecoderConfig{
|
||||||
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
||||||
mapstructure.StringToTimeDurationHookFunc(),
|
|
||||||
mapstructure.StringToSliceHookFunc(","),
|
mapstructure.StringToSliceHookFunc(","),
|
||||||
StringToMailAddressFunc(),
|
StringToMailAddressFunc(),
|
||||||
|
ToTimeDurationFunc(),
|
||||||
),
|
),
|
||||||
Metadata: nil,
|
Metadata: nil,
|
||||||
Result: o,
|
Result: o,
|
||||||
|
|
|
@ -14,8 +14,8 @@ type Configuration struct {
|
||||||
TOTP *TOTPConfiguration `koanf:"totp"`
|
TOTP *TOTPConfiguration `koanf:"totp"`
|
||||||
DuoAPI *DuoAPIConfiguration `koanf:"duo_api"`
|
DuoAPI *DuoAPIConfiguration `koanf:"duo_api"`
|
||||||
AccessControl AccessControlConfiguration `koanf:"access_control"`
|
AccessControl AccessControlConfiguration `koanf:"access_control"`
|
||||||
NTP *NTPConfiguration `koanf:"ntp"`
|
NTP NTPConfiguration `koanf:"ntp"`
|
||||||
Regulation *RegulationConfiguration `koanf:"regulation"`
|
Regulation RegulationConfiguration `koanf:"regulation"`
|
||||||
Storage StorageConfiguration `koanf:"storage"`
|
Storage StorageConfiguration `koanf:"storage"`
|
||||||
Notifier *NotifierConfiguration `koanf:"notifier"`
|
Notifier *NotifierConfiguration `koanf:"notifier"`
|
||||||
Server ServerConfiguration `koanf:"server"`
|
Server ServerConfiguration `koanf:"server"`
|
||||||
|
|
|
@ -1,17 +1,21 @@
|
||||||
package schema
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// NTPConfiguration represents the configuration related to ntp server.
|
// NTPConfiguration represents the configuration related to ntp server.
|
||||||
type NTPConfiguration struct {
|
type NTPConfiguration struct {
|
||||||
Address string `koanf:"address"`
|
Address string `koanf:"address"`
|
||||||
Version int `koanf:"version"`
|
Version int `koanf:"version"`
|
||||||
MaximumDesync string `koanf:"max_desync"`
|
MaximumDesync time.Duration `koanf:"max_desync"`
|
||||||
DisableStartupCheck bool `koanf:"disable_startup_check"`
|
DisableStartupCheck bool `koanf:"disable_startup_check"`
|
||||||
DisableFailure bool `koanf:"disable_failure"`
|
DisableFailure bool `koanf:"disable_failure"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultNTPConfiguration represents default configuration parameters for the NTP server.
|
// DefaultNTPConfiguration represents default configuration parameters for the NTP server.
|
||||||
var DefaultNTPConfiguration = NTPConfiguration{
|
var DefaultNTPConfiguration = NTPConfiguration{
|
||||||
Address: "time.cloudflare.com:123",
|
Address: "time.cloudflare.com:123",
|
||||||
Version: 4,
|
Version: 4,
|
||||||
MaximumDesync: "3s",
|
MaximumDesync: time.Second * 3,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,19 @@
|
||||||
package schema
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// RegulationConfiguration represents the configuration related to regulation.
|
// RegulationConfiguration represents the configuration related to regulation.
|
||||||
type RegulationConfiguration struct {
|
type RegulationConfiguration struct {
|
||||||
MaxRetries int `koanf:"max_retries"`
|
MaxRetries int `koanf:"max_retries"`
|
||||||
FindTime string `koanf:"find_time,weak"`
|
FindTime time.Duration `koanf:"find_time,weak"`
|
||||||
BanTime string `koanf:"ban_time,weak"`
|
BanTime time.Duration `koanf:"ban_time,weak"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultRegulationConfiguration represents default configuration parameters for the regulator.
|
// DefaultRegulationConfiguration represents default configuration parameters for the regulator.
|
||||||
var DefaultRegulationConfiguration = RegulationConfiguration{
|
var DefaultRegulationConfiguration = RegulationConfiguration{
|
||||||
MaxRetries: 3,
|
MaxRetries: 3,
|
||||||
FindTime: "2m",
|
FindTime: time.Minute * 2,
|
||||||
BanTime: "5m",
|
BanTime: time.Minute * 5,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package schema
|
package schema
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// RedisNode Represents a Node.
|
// RedisNode Represents a Node.
|
||||||
type RedisNode struct {
|
type RedisNode struct {
|
||||||
Host string `koanf:"host"`
|
Host string `koanf:"host"`
|
||||||
|
@ -31,21 +35,22 @@ type RedisSessionConfiguration struct {
|
||||||
|
|
||||||
// SessionConfiguration represents the configuration related to user sessions.
|
// SessionConfiguration represents the configuration related to user sessions.
|
||||||
type SessionConfiguration struct {
|
type SessionConfiguration struct {
|
||||||
Name string `koanf:"name"`
|
Name string `koanf:"name"`
|
||||||
Domain string `koanf:"domain"`
|
Domain string `koanf:"domain"`
|
||||||
SameSite string `koanf:"same_site"`
|
SameSite string `koanf:"same_site"`
|
||||||
Secret string `koanf:"secret"`
|
Secret string `koanf:"secret"`
|
||||||
Expiration string `koanf:"expiration"`
|
Expiration time.Duration `koanf:"expiration"`
|
||||||
Inactivity string `koanf:"inactivity"`
|
Inactivity time.Duration `koanf:"inactivity"`
|
||||||
RememberMeDuration string `koanf:"remember_me_duration"`
|
RememberMeDuration time.Duration `koanf:"remember_me_duration"`
|
||||||
Redis *RedisSessionConfiguration `koanf:"redis"`
|
|
||||||
|
Redis *RedisSessionConfiguration `koanf:"redis"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultSessionConfiguration is the default session configuration.
|
// DefaultSessionConfiguration is the default session configuration.
|
||||||
var DefaultSessionConfiguration = SessionConfiguration{
|
var DefaultSessionConfiguration = SessionConfiguration{
|
||||||
Name: "authelia_session",
|
Name: "authelia_session",
|
||||||
Expiration: "1h",
|
Expiration: time.Hour,
|
||||||
Inactivity: "5m",
|
Inactivity: time.Minute * 5,
|
||||||
RememberMeDuration: "1M",
|
RememberMeDuration: time.Hour * 24 * 30,
|
||||||
SameSite: "lax",
|
SameSite: "lax",
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,6 @@ const (
|
||||||
|
|
||||||
// Test constants.
|
// Test constants.
|
||||||
const (
|
const (
|
||||||
testBadTimer = "-1"
|
|
||||||
testInvalidPolicy = "invalid"
|
testInvalidPolicy = "invalid"
|
||||||
testJWTSecret = "a_secret"
|
testJWTSecret = "a_secret"
|
||||||
testLDAPBaseDN = "base_dn"
|
testLDAPBaseDN = "base_dn"
|
||||||
|
@ -185,13 +184,11 @@ const (
|
||||||
|
|
||||||
// NTP Error constants.
|
// NTP Error constants.
|
||||||
const (
|
const (
|
||||||
errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'"
|
errFmtNTPVersion = "ntp: option 'version' must be either 3 or 4 but it is configured as '%d'"
|
||||||
errFmtNTPMaxDesync = "ntp: option 'max_desync' can't be parsed: %w"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Session error constants.
|
// Session error constants.
|
||||||
const (
|
const (
|
||||||
errFmtSessionCouldNotParseDuration = "session: option '%s' could not be parsed: %w"
|
|
||||||
errFmtSessionOptionRequired = "session: option '%s' is required"
|
errFmtSessionOptionRequired = "session: option '%s' is required"
|
||||||
errFmtSessionDomainMustBeRoot = "session: option 'domain' must be the domain you wish to protect not a wildcard domain but it is configured as '%s'"
|
errFmtSessionDomainMustBeRoot = "session: option 'domain' must be the domain you wish to protect not a wildcard domain but it is configured as '%s'"
|
||||||
errFmtSessionSameSite = "session: option 'same_site' must be one of '%s' but is configured as '%s'"
|
errFmtSessionSameSite = "session: option 'same_site' must be one of '%s' but is configured as '%s'"
|
||||||
|
@ -206,7 +203,6 @@ const (
|
||||||
|
|
||||||
// Regulation Error Consts.
|
// Regulation Error Consts.
|
||||||
const (
|
const (
|
||||||
errFmtRegulationParseDuration = "regulation: option '%s' could not be parsed: %w"
|
|
||||||
errFmtRegulationFindTimeGreaterThanBanTime = "regulation: option 'find_time' must be less than or equal to option 'ban_time'"
|
errFmtRegulationFindTimeGreaterThanBanTime = "regulation: option 'find_time' must be less than or equal to option 'ban_time'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,17 +4,10 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateNTP validates and update NTP configuration.
|
// ValidateNTP validates and update NTP configuration.
|
||||||
func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator) {
|
func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator) {
|
||||||
if config.NTP == nil {
|
|
||||||
config.NTP = &schema.DefaultNTPConfiguration
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.NTP.Address == "" {
|
if config.NTP.Address == "" {
|
||||||
config.NTP.Address = schema.DefaultNTPConfiguration.Address
|
config.NTP.Address = schema.DefaultNTPConfiguration.Address
|
||||||
}
|
}
|
||||||
|
@ -25,12 +18,7 @@ func ValidateNTP(config *schema.Configuration, validator *schema.StructValidator
|
||||||
validator.Push(fmt.Errorf(errFmtNTPVersion, config.NTP.Version))
|
validator.Push(fmt.Errorf(errFmtNTPVersion, config.NTP.Version))
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.NTP.MaximumDesync == "" {
|
if config.NTP.MaximumDesync == 0 {
|
||||||
config.NTP.MaximumDesync = schema.DefaultNTPConfiguration.MaximumDesync
|
config.NTP.MaximumDesync = schema.DefaultNTPConfiguration.MaximumDesync
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := utils.ParseDurationString(config.NTP.MaximumDesync)
|
|
||||||
if err != nil {
|
|
||||||
validator.Push(fmt.Errorf(errFmtNTPMaxDesync, err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
|
|
||||||
func newDefaultNTPConfig() schema.Configuration {
|
func newDefaultNTPConfig() schema.Configuration {
|
||||||
return schema.Configuration{
|
return schema.Configuration{
|
||||||
NTP: &schema.NTPConfiguration{},
|
NTP: schema.NTPConfiguration{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,18 +55,6 @@ func TestShouldSetDefaultNtpDisableStartupCheck(t *testing.T) {
|
||||||
assert.Equal(t, schema.DefaultNTPConfiguration.DisableStartupCheck, config.NTP.DisableStartupCheck)
|
assert.Equal(t, schema.DefaultNTPConfiguration.DisableStartupCheck, config.NTP.DisableStartupCheck)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldRaiseErrorOnMaximumDesyncString(t *testing.T) {
|
|
||||||
validator := schema.NewStructValidator()
|
|
||||||
config := newDefaultNTPConfig()
|
|
||||||
config.NTP.MaximumDesync = "a second"
|
|
||||||
|
|
||||||
ValidateNTP(&config, validator)
|
|
||||||
|
|
||||||
require.Len(t, validator.Errors(), 1)
|
|
||||||
|
|
||||||
assert.EqualError(t, validator.Errors()[0], "ntp: option 'max_desync' can't be parsed: could not parse 'a second' as a duration")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldRaiseErrorOnInvalidNTPVersion(t *testing.T) {
|
func TestShouldRaiseErrorOnInvalidNTPVersion(t *testing.T) {
|
||||||
validator := schema.NewStructValidator()
|
validator := schema.NewStructValidator()
|
||||||
config := newDefaultNTPConfig()
|
config := newDefaultNTPConfig()
|
||||||
|
|
|
@ -4,36 +4,19 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateRegulation validates and update regulator configuration.
|
// ValidateRegulation validates and update regulator configuration.
|
||||||
func ValidateRegulation(config *schema.Configuration, validator *schema.StructValidator) {
|
func ValidateRegulation(config *schema.Configuration, validator *schema.StructValidator) {
|
||||||
if config.Regulation == nil {
|
if config.Regulation.FindTime == 0 {
|
||||||
config.Regulation = &schema.DefaultRegulationConfiguration
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Regulation.FindTime == "" {
|
|
||||||
config.Regulation.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min.
|
config.Regulation.FindTime = schema.DefaultRegulationConfiguration.FindTime // 2 min.
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Regulation.BanTime == "" {
|
if config.Regulation.BanTime == 0 {
|
||||||
config.Regulation.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min.
|
config.Regulation.BanTime = schema.DefaultRegulationConfiguration.BanTime // 5 min.
|
||||||
}
|
}
|
||||||
|
|
||||||
findTime, err := utils.ParseDurationString(config.Regulation.FindTime)
|
if config.Regulation.FindTime > config.Regulation.BanTime {
|
||||||
if err != nil {
|
|
||||||
validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "find_time", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
banTime, err := utils.ParseDurationString(config.Regulation.BanTime)
|
|
||||||
if err != nil {
|
|
||||||
validator.Push(fmt.Errorf(errFmtRegulationParseDuration, "ban_time", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if findTime > banTime {
|
|
||||||
validator.Push(fmt.Errorf(errFmtRegulationFindTimeGreaterThanBanTime))
|
validator.Push(fmt.Errorf(errFmtRegulationFindTimeGreaterThanBanTime))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package validator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
@ -10,7 +11,7 @@ import (
|
||||||
|
|
||||||
func newDefaultRegulationConfig() schema.Configuration {
|
func newDefaultRegulationConfig() schema.Configuration {
|
||||||
config := schema.Configuration{
|
config := schema.Configuration{
|
||||||
Regulation: &schema.RegulationConfiguration{},
|
Regulation: schema.RegulationConfiguration{},
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
@ -39,24 +40,11 @@ func TestShouldSetDefaultRegulationFindTime(t *testing.T) {
|
||||||
func TestShouldRaiseErrorWhenFindTimeLessThanBanTime(t *testing.T) {
|
func TestShouldRaiseErrorWhenFindTimeLessThanBanTime(t *testing.T) {
|
||||||
validator := schema.NewStructValidator()
|
validator := schema.NewStructValidator()
|
||||||
config := newDefaultRegulationConfig()
|
config := newDefaultRegulationConfig()
|
||||||
config.Regulation.FindTime = "1m"
|
config.Regulation.FindTime = time.Minute
|
||||||
config.Regulation.BanTime = "10s"
|
config.Regulation.BanTime = time.Second * 10
|
||||||
|
|
||||||
ValidateRegulation(&config, validator)
|
ValidateRegulation(&config, validator)
|
||||||
|
|
||||||
assert.Len(t, validator.Errors(), 1)
|
assert.Len(t, validator.Errors(), 1)
|
||||||
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' must be less than or equal to option 'ban_time'")
|
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' must be less than or equal to option 'ban_time'")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldRaiseErrorOnBadDurationStrings(t *testing.T) {
|
|
||||||
validator := schema.NewStructValidator()
|
|
||||||
config := newDefaultRegulationConfig()
|
|
||||||
config.Regulation.FindTime = "a year"
|
|
||||||
config.Regulation.BanTime = "forever"
|
|
||||||
|
|
||||||
ValidateRegulation(&config, validator)
|
|
||||||
|
|
||||||
assert.Len(t, validator.Errors(), 2)
|
|
||||||
assert.EqualError(t, validator.Errors()[0], "regulation: option 'find_time' could not be parsed: could not parse 'a year' as a duration")
|
|
||||||
assert.EqualError(t, validator.Errors()[1], "regulation: option 'ban_time' could not be parsed: could not parse 'forever' as a duration")
|
|
||||||
}
|
|
||||||
|
|
|
@ -27,22 +27,16 @@ func ValidateSession(config *schema.SessionConfiguration, validator *schema.Stru
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSession(config *schema.SessionConfiguration, validator *schema.StructValidator) {
|
func validateSession(config *schema.SessionConfiguration, validator *schema.StructValidator) {
|
||||||
if config.Expiration == "" {
|
if config.Expiration <= 0 {
|
||||||
config.Expiration = schema.DefaultSessionConfiguration.Expiration // 1 hour.
|
config.Expiration = schema.DefaultSessionConfiguration.Expiration // 1 hour.
|
||||||
} else if _, err := utils.ParseDurationString(config.Expiration); err != nil {
|
|
||||||
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "expiriation", err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Inactivity == "" {
|
if config.Inactivity <= 0 {
|
||||||
config.Inactivity = schema.DefaultSessionConfiguration.Inactivity // 5 min.
|
config.Inactivity = schema.DefaultSessionConfiguration.Inactivity // 5 min.
|
||||||
} else if _, err := utils.ParseDurationString(config.Inactivity); err != nil {
|
|
||||||
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "inactivity", err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.RememberMeDuration == "" {
|
if config.RememberMeDuration <= 0 {
|
||||||
config.RememberMeDuration = schema.DefaultSessionConfiguration.RememberMeDuration // 1 month.
|
config.RememberMeDuration = schema.DefaultSessionConfiguration.RememberMeDuration // 1 month.
|
||||||
} else if _, err := utils.ParseDurationString(config.RememberMeDuration); err != nil {
|
|
||||||
validator.Push(fmt.Errorf(errFmtSessionCouldNotParseDuration, "remember_me_duration", err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Domain == "" {
|
if config.Domain == "" {
|
||||||
|
|
|
@ -420,30 +420,21 @@ func TestShouldNotRaiseErrorWhenSameSiteSetCorrectly(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldRaiseErrorWhenBadInactivityAndExpirationSet(t *testing.T) {
|
func TestShouldSetDefaultWhenNegativeInactivityAndExpirationSet(t *testing.T) {
|
||||||
validator := schema.NewStructValidator()
|
validator := schema.NewStructValidator()
|
||||||
config := newDefaultSessionConfig()
|
config := newDefaultSessionConfig()
|
||||||
config.Inactivity = testBadTimer
|
config.Inactivity = -1
|
||||||
config.Expiration = testBadTimer
|
config.Expiration = -1
|
||||||
|
config.RememberMeDuration = -1
|
||||||
|
|
||||||
ValidateSession(&config, validator)
|
ValidateSession(&config, validator)
|
||||||
|
|
||||||
assert.False(t, validator.HasWarnings())
|
assert.Len(t, validator.Warnings(), 0)
|
||||||
assert.Len(t, validator.Errors(), 2)
|
assert.Len(t, validator.Errors(), 0)
|
||||||
assert.EqualError(t, validator.Errors()[0], "session: option 'expiriation' could not be parsed: could not parse '-1' as a duration")
|
|
||||||
assert.EqualError(t, validator.Errors()[1], "session: option 'inactivity' could not be parsed: could not parse '-1' as a duration")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestShouldRaiseErrorWhenBadRememberMeDurationSet(t *testing.T) {
|
assert.Equal(t, schema.DefaultSessionConfiguration.Inactivity, config.Inactivity)
|
||||||
validator := schema.NewStructValidator()
|
assert.Equal(t, schema.DefaultSessionConfiguration.Expiration, config.Expiration)
|
||||||
config := newDefaultSessionConfig()
|
assert.Equal(t, schema.DefaultSessionConfiguration.RememberMeDuration, config.RememberMeDuration)
|
||||||
config.RememberMeDuration = "1 year"
|
|
||||||
|
|
||||||
ValidateSession(&config, validator)
|
|
||||||
|
|
||||||
assert.False(t, validator.HasWarnings())
|
|
||||||
assert.Len(t, validator.Errors(), 1)
|
|
||||||
assert.EqualError(t, validator.Errors()[0], "session: option 'remember_me_duration' could not be parsed: could not parse '1 year' as a duration")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldSetDefaultRememberMeDuration(t *testing.T) {
|
func TestShouldSetDefaultRememberMeDuration(t *testing.T) {
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -56,7 +58,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testInactivity = "10"
|
testInactivity = time.Second * 10
|
||||||
testRedirectionURL = "http://redirection.local"
|
testRedirectionURL = "http://redirection.local"
|
||||||
testUsername = "john"
|
testUsername = "john"
|
||||||
)
|
)
|
||||||
|
|
|
@ -602,7 +602,7 @@ func TestShouldDestroySessionWhenInactiveForTooLongUsingDurationNotation(t *test
|
||||||
clock := mocks.TestingClock{}
|
clock := mocks.TestingClock{}
|
||||||
clock.Set(time.Now())
|
clock.Set(time.Now())
|
||||||
|
|
||||||
mock.Ctx.Configuration.Session.Inactivity = "10s"
|
mock.Ctx.Configuration.Session.Inactivity = time.Second * 10
|
||||||
// Reload the session provider since the configuration is indirect.
|
// Reload the session provider since the configuration is indirect.
|
||||||
mock.Ctx.Providers.SessionProvider = session.NewProvider(mock.Ctx.Configuration.Session, nil)
|
mock.Ctx.Providers.SessionProvider = session.NewProvider(mock.Ctx.Configuration.Session, nil)
|
||||||
assert.Equal(t, time.Second*10, mock.Ctx.Providers.SessionProvider.Inactivity)
|
assert.Equal(t, time.Second*10, mock.Ctx.Providers.SessionProvider.Inactivity)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
"github.com/authelia/authelia/v4/internal/logging"
|
"github.com/authelia/authelia/v4/internal/logging"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewProvider instantiate a ntp provider given a configuration.
|
// NewProvider instantiate a ntp provider given a configuration.
|
||||||
|
@ -59,11 +58,9 @@ func (p *Provider) StartupCheck() (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
maxOffset, _ := utils.ParseDurationString(p.config.MaximumDesync)
|
|
||||||
|
|
||||||
ntpTime := ntpPacketToTime(resp)
|
ntpTime := ntpPacketToTime(resp)
|
||||||
|
|
||||||
if result := ntpIsOffsetTooLarge(maxOffset, now, ntpTime); result {
|
if result := ntpIsOffsetTooLarge(p.config.MaximumDesync, now, ntpTime); result {
|
||||||
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
|
return errors.New("the system clock is not synchronized accurately enough with the configured NTP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package ntp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
@ -11,18 +12,17 @@ import (
|
||||||
|
|
||||||
func TestShouldCheckNTP(t *testing.T) {
|
func TestShouldCheckNTP(t *testing.T) {
|
||||||
config := &schema.Configuration{
|
config := &schema.Configuration{
|
||||||
NTP: &schema.NTPConfiguration{
|
NTP: schema.NTPConfiguration{
|
||||||
Address: "time.cloudflare.com:123",
|
Address: "time.cloudflare.com:123",
|
||||||
Version: 4,
|
Version: 4,
|
||||||
MaximumDesync: "3s",
|
MaximumDesync: time.Second * 3,
|
||||||
DisableStartupCheck: false,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
sv := schema.NewStructValidator()
|
sv := schema.NewStructValidator()
|
||||||
validator.ValidateNTP(config, sv)
|
validator.ValidateNTP(config, sv)
|
||||||
|
|
||||||
ntp := NewProvider(config.NTP)
|
ntp := NewProvider(&config.NTP)
|
||||||
|
|
||||||
assert.NoError(t, ntp.StartupCheck())
|
assert.NoError(t, ntp.StartupCheck())
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package regulation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -13,33 +12,13 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewRegulator create a regulator instance.
|
// NewRegulator create a regulator instance.
|
||||||
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
func NewRegulator(config schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
||||||
regulator := &Regulator{storageProvider: provider}
|
return &Regulator{
|
||||||
regulator.clock = clock
|
enabled: config.MaxRetries > 0,
|
||||||
|
storageProvider: provider,
|
||||||
if configuration != nil {
|
clock: clock,
|
||||||
findTime, err := utils.ParseDurationString(configuration.FindTime)
|
config: config,
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
banTime, err := utils.ParseDurationString(configuration.BanTime)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if findTime > banTime {
|
|
||||||
panic(fmt.Errorf("find_time cannot be greater than ban_time"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set regulator enabled only if MaxRetries is not 0.
|
|
||||||
regulator.enabled = configuration.MaxRetries > 0
|
|
||||||
regulator.maxRetries = configuration.MaxRetries
|
|
||||||
regulator.findTime = findTime
|
|
||||||
regulator.banTime = banTime
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return regulator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark an authentication attempt.
|
// Mark an authentication attempt.
|
||||||
|
@ -65,15 +44,15 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
|
||||||
return time.Time{}, nil
|
return time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0)
|
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.config.BanTime), 10, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Time{}, nil
|
return time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.maxRetries)
|
latestFailedAttempts := make([]models.AuthenticationAttempt, 0, r.config.MaxRetries)
|
||||||
|
|
||||||
for _, attempt := range attempts {
|
for _, attempt := range attempts {
|
||||||
if attempt.Successful || len(latestFailedAttempts) >= r.maxRetries {
|
if attempt.Successful || len(latestFailedAttempts) >= r.config.MaxRetries {
|
||||||
// We stop appending failed attempts once we find the first successful attempts or we reach
|
// We stop appending failed attempts once we find the first successful attempts or we reach
|
||||||
// the configured number of retries, meaning the user is already banned.
|
// the configured number of retries, meaning the user is already banned.
|
||||||
break
|
break
|
||||||
|
@ -84,17 +63,17 @@ func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, e
|
||||||
|
|
||||||
// If the number of failed attempts within the ban time is less than the max number of retries
|
// If the number of failed attempts within the ban time is less than the max number of retries
|
||||||
// then the user is not banned.
|
// then the user is not banned.
|
||||||
if len(latestFailedAttempts) < r.maxRetries {
|
if len(latestFailedAttempts) < r.config.MaxRetries {
|
||||||
return time.Time{}, nil
|
return time.Time{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we compute the time between the latest attempt and the MaxRetry-th one. If it's
|
// Now we compute the time between the latest attempt and the MaxRetry-th one. If it's
|
||||||
// within the FindTime then it means that the user has been banned.
|
// within the FindTime then it means that the user has been banned.
|
||||||
durationBetweenLatestAttempts := latestFailedAttempts[0].Time.Sub(
|
durationBetweenLatestAttempts := latestFailedAttempts[0].Time.Sub(
|
||||||
latestFailedAttempts[r.maxRetries-1].Time)
|
latestFailedAttempts[r.config.MaxRetries-1].Time)
|
||||||
|
|
||||||
if durationBetweenLatestAttempts < r.findTime {
|
if durationBetweenLatestAttempts < r.config.FindTime {
|
||||||
bannedUntil := latestFailedAttempts[0].Time.Add(r.banTime)
|
bannedUntil := latestFailedAttempts[0].Time.Add(r.config.BanTime)
|
||||||
return bannedUntil, ErrUserIsBanned
|
return bannedUntil, ErrUserIsBanned
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,11 +18,11 @@ import (
|
||||||
type RegulatorSuite struct {
|
type RegulatorSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
storageMock *mocks.MockStorage
|
storageMock *mocks.MockStorage
|
||||||
configuration schema.RegulationConfiguration
|
config schema.RegulationConfiguration
|
||||||
clock mocks.TestingClock
|
clock mocks.TestingClock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *RegulatorSuite) SetupTest() {
|
func (s *RegulatorSuite) SetupTest() {
|
||||||
|
@ -30,10 +30,10 @@ func (s *RegulatorSuite) SetupTest() {
|
||||||
s.storageMock = mocks.NewMockStorage(s.ctrl)
|
s.storageMock = mocks.NewMockStorage(s.ctrl)
|
||||||
s.ctx = context.Background()
|
s.ctx = context.Background()
|
||||||
|
|
||||||
s.configuration = schema.RegulationConfiguration{
|
s.config = schema.RegulationConfiguration{
|
||||||
MaxRetries: 3,
|
MaxRetries: 3,
|
||||||
BanTime: "180",
|
BanTime: time.Second * 180,
|
||||||
FindTime: "30",
|
FindTime: time.Second * 30,
|
||||||
}
|
}
|
||||||
s.clock.Set(time.Now())
|
s.clock.Set(time.Now())
|
||||||
}
|
}
|
||||||
|
@ -55,7 +55,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
@ -86,7 +86,7 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
@ -122,7 +122,7 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||||
|
@ -155,7 +155,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||||
|
@ -179,7 +179,7 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
@ -211,7 +211,7 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
@ -247,7 +247,7 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
|
||||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(s.config, s.storageMock, &s.clock)
|
||||||
|
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
@ -283,24 +283,24 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
||||||
Return(attemptsInDB, nil)
|
Return(attemptsInDB, nil)
|
||||||
|
|
||||||
// Check Disabled Functionality.
|
// Check Disabled Functionality.
|
||||||
configuration := schema.RegulationConfiguration{
|
config := schema.RegulationConfiguration{
|
||||||
MaxRetries: 0,
|
MaxRetries: 0,
|
||||||
FindTime: "180",
|
FindTime: time.Second * 180,
|
||||||
BanTime: "180",
|
BanTime: time.Second * 180,
|
||||||
}
|
}
|
||||||
|
|
||||||
regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
|
regulator := regulation.NewRegulator(config, s.storageMock, &s.clock)
|
||||||
_, err := regulator.Regulate(s.ctx, "john")
|
_, err := regulator.Regulate(s.ctx, "john")
|
||||||
assert.NoError(s.T(), err)
|
assert.NoError(s.T(), err)
|
||||||
|
|
||||||
// Check Enabled Functionality.
|
// Check Enabled Functionality.
|
||||||
configuration = schema.RegulationConfiguration{
|
config = schema.RegulationConfiguration{
|
||||||
MaxRetries: 1,
|
MaxRetries: 1,
|
||||||
FindTime: "180",
|
FindTime: time.Second * 180,
|
||||||
BanTime: "180",
|
BanTime: time.Second * 180,
|
||||||
}
|
}
|
||||||
|
|
||||||
regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
|
regulator = regulation.NewRegulator(config, s.storageMock, &s.clock)
|
||||||
_, err = regulator.Regulate(s.ctx, "john")
|
_, err = regulator.Regulate(s.ctx, "john")
|
||||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
package regulation
|
package regulation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/storage"
|
"github.com/authelia/authelia/v4/internal/storage"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
@ -11,12 +10,8 @@ import (
|
||||||
type Regulator struct {
|
type Regulator struct {
|
||||||
// Is the regulation enabled.
|
// Is the regulation enabled.
|
||||||
enabled bool
|
enabled bool
|
||||||
// The number of failed authentication attempt before banning the user.
|
|
||||||
maxRetries int
|
config schema.RegulationConfiguration
|
||||||
// If a user does the max number of retries within that duration, she will be banned.
|
|
||||||
findTime time.Duration
|
|
||||||
// If a user has been banned, this duration is the timelapse during which the user is banned.
|
|
||||||
banTime time.Duration
|
|
||||||
|
|
||||||
storageProvider storage.RegulatorProvider
|
storageProvider storage.RegulatorProvider
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ var assets embed.FS
|
||||||
|
|
||||||
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||||
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
|
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
|
||||||
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != "0")
|
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != -1)
|
||||||
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
|
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
|
||||||
|
|
||||||
duoSelfEnrollment := f
|
duoSelfEnrollment := f
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testDomain = "example.com"
|
testDomain = "example.com"
|
||||||
testExpiration = "40"
|
testExpiration = time.Second * 40
|
||||||
testName = "my_session"
|
testName = "my_session"
|
||||||
testUsername = "john"
|
testUsername = "john"
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||||
"github.com/authelia/authelia/v4/internal/logging"
|
"github.com/authelia/authelia/v4/internal/logging"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider a session provider.
|
// Provider a session provider.
|
||||||
|
@ -23,38 +22,29 @@ type Provider struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewProvider instantiate a session provider given a configuration.
|
// NewProvider instantiate a session provider given a configuration.
|
||||||
func NewProvider(configuration schema.SessionConfiguration, certPool *x509.CertPool) *Provider {
|
func NewProvider(config schema.SessionConfiguration, certPool *x509.CertPool) *Provider {
|
||||||
providerConfig := NewProviderConfig(configuration, certPool)
|
c := NewProviderConfig(config, certPool)
|
||||||
|
|
||||||
provider := new(Provider)
|
provider := new(Provider)
|
||||||
provider.sessionHolder = fasthttpsession.New(providerConfig.config)
|
provider.sessionHolder = fasthttpsession.New(c.config)
|
||||||
|
|
||||||
logger := logging.Logger()
|
logger := logging.Logger()
|
||||||
|
|
||||||
duration, err := utils.ParseDurationString(configuration.RememberMeDuration)
|
provider.Inactivity, provider.RememberMe = config.Inactivity, config.RememberMeDuration
|
||||||
if err != nil {
|
|
||||||
logger.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
provider.RememberMe = duration
|
var (
|
||||||
|
providerImpl fasthttpsession.Provider
|
||||||
duration, err = utils.ParseDurationString(configuration.Inactivity)
|
err error
|
||||||
if err != nil {
|
)
|
||||||
logger.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
provider.Inactivity = duration
|
|
||||||
|
|
||||||
var providerImpl fasthttpsession.Provider
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case providerConfig.redisConfig != nil:
|
case c.redisConfig != nil:
|
||||||
providerImpl, err = redis.New(*providerConfig.redisConfig)
|
providerImpl, err = redis.New(*c.redisConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
}
|
}
|
||||||
case providerConfig.redisSentinelConfig != nil:
|
case c.redisSentinelConfig != nil:
|
||||||
providerImpl, err = redis.NewFailoverCluster(*providerConfig.redisSentinelConfig)
|
providerImpl, err = redis.NewFailoverCluster(*c.redisSentinelConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(err)
|
logger.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,10 +17,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewProviderConfig creates a configuration for creating the session provider.
|
// NewProviderConfig creates a configuration for creating the session provider.
|
||||||
func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig {
|
func NewProviderConfig(config schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig {
|
||||||
config := session.NewDefaultConfig()
|
c := session.NewDefaultConfig()
|
||||||
|
|
||||||
config.SessionIDGeneratorFunc = func() []byte {
|
c.SessionIDGeneratorFunc = func() []byte {
|
||||||
bytes := make([]byte, 32)
|
bytes := make([]byte, 32)
|
||||||
|
|
||||||
_, _ = rand.Read(bytes)
|
_, _ = rand.Read(bytes)
|
||||||
|
@ -33,30 +33,30 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override the cookie name.
|
// Override the cookie name.
|
||||||
config.CookieName = configuration.Name
|
c.CookieName = config.Name
|
||||||
|
|
||||||
// Set the cookie to the given domain.
|
// Set the cookie to the given domain.
|
||||||
config.Domain = configuration.Domain
|
c.Domain = config.Domain
|
||||||
|
|
||||||
// Set the cookie SameSite option.
|
// Set the cookie SameSite option.
|
||||||
switch configuration.SameSite {
|
switch config.SameSite {
|
||||||
case "strict":
|
case "strict":
|
||||||
config.CookieSameSite = fasthttp.CookieSameSiteStrictMode
|
c.CookieSameSite = fasthttp.CookieSameSiteStrictMode
|
||||||
case "none":
|
case "none":
|
||||||
config.CookieSameSite = fasthttp.CookieSameSiteNoneMode
|
c.CookieSameSite = fasthttp.CookieSameSiteNoneMode
|
||||||
case "lax":
|
case "lax":
|
||||||
config.CookieSameSite = fasthttp.CookieSameSiteLaxMode
|
c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
|
||||||
default:
|
default:
|
||||||
config.CookieSameSite = fasthttp.CookieSameSiteLaxMode
|
c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only serve the header over HTTPS.
|
// Only serve the header over HTTPS.
|
||||||
config.Secure = true
|
c.Secure = true
|
||||||
|
|
||||||
// Ignore the error as it will be handled by validator.
|
// Ignore the error as it will be handled by validator.
|
||||||
config.Expiration, _ = utils.ParseDurationString(configuration.Expiration)
|
c.Expiration = config.Expiration
|
||||||
|
|
||||||
config.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
|
c.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,23 +68,23 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
|
||||||
|
|
||||||
// If redis configuration is provided, then use the redis provider.
|
// If redis configuration is provided, then use the redis provider.
|
||||||
switch {
|
switch {
|
||||||
case configuration.Redis != nil:
|
case config.Redis != nil:
|
||||||
serializer := NewEncryptingSerializer(configuration.Secret)
|
serializer := NewEncryptingSerializer(config.Secret)
|
||||||
|
|
||||||
var tlsConfig *tls.Config
|
var tlsConfig *tls.Config
|
||||||
|
|
||||||
if configuration.Redis.TLS != nil {
|
if config.Redis.TLS != nil {
|
||||||
tlsConfig = utils.NewTLSConfig(configuration.Redis.TLS, tls.VersionTLS12, certPool)
|
tlsConfig = utils.NewTLSConfig(config.Redis.TLS, tls.VersionTLS12, certPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
if configuration.Redis.HighAvailability != nil && configuration.Redis.HighAvailability.SentinelName != "" {
|
if config.Redis.HighAvailability != nil && config.Redis.HighAvailability.SentinelName != "" {
|
||||||
addrs := make([]string, 0)
|
addrs := make([]string, 0)
|
||||||
|
|
||||||
if configuration.Redis.Host != "" {
|
if config.Redis.Host != "" {
|
||||||
addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(configuration.Redis.Host), configuration.Redis.Port))
|
addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(config.Redis.Host), config.Redis.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, node := range configuration.Redis.HighAvailability.Nodes {
|
for _, node := range config.Redis.HighAvailability.Nodes {
|
||||||
addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port)
|
addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port)
|
||||||
if !utils.IsStringInSlice(addr, addrs) {
|
if !utils.IsStringInSlice(addr, addrs) {
|
||||||
addrs = append(addrs, addr)
|
addrs = append(addrs, addr)
|
||||||
|
@ -94,17 +94,17 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
|
||||||
providerName = "redis-sentinel"
|
providerName = "redis-sentinel"
|
||||||
redisSentinelConfig = &redis.FailoverConfig{
|
redisSentinelConfig = &redis.FailoverConfig{
|
||||||
Logger: &redisLogger{logger: logging.Logger()},
|
Logger: &redisLogger{logger: logging.Logger()},
|
||||||
MasterName: configuration.Redis.HighAvailability.SentinelName,
|
MasterName: config.Redis.HighAvailability.SentinelName,
|
||||||
SentinelAddrs: addrs,
|
SentinelAddrs: addrs,
|
||||||
SentinelUsername: configuration.Redis.HighAvailability.SentinelUsername,
|
SentinelUsername: config.Redis.HighAvailability.SentinelUsername,
|
||||||
SentinelPassword: configuration.Redis.HighAvailability.SentinelPassword,
|
SentinelPassword: config.Redis.HighAvailability.SentinelPassword,
|
||||||
RouteByLatency: configuration.Redis.HighAvailability.RouteByLatency,
|
RouteByLatency: config.Redis.HighAvailability.RouteByLatency,
|
||||||
RouteRandomly: configuration.Redis.HighAvailability.RouteRandomly,
|
RouteRandomly: config.Redis.HighAvailability.RouteRandomly,
|
||||||
Username: configuration.Redis.Username,
|
Username: config.Redis.Username,
|
||||||
Password: configuration.Redis.Password,
|
Password: config.Redis.Password,
|
||||||
DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
|
DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
|
||||||
PoolSize: configuration.Redis.MaximumActiveConnections,
|
PoolSize: config.Redis.MaximumActiveConnections,
|
||||||
MinIdleConns: configuration.Redis.MinimumIdleConnections,
|
MinIdleConns: config.Redis.MinimumIdleConnections,
|
||||||
IdleTimeout: 300,
|
IdleTimeout: 300,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
KeyPrefix: "authelia-session",
|
KeyPrefix: "authelia-session",
|
||||||
|
@ -115,36 +115,36 @@ func NewProviderConfig(configuration schema.SessionConfiguration, certPool *x509
|
||||||
|
|
||||||
var addr string
|
var addr string
|
||||||
|
|
||||||
if configuration.Redis.Port == 0 {
|
if config.Redis.Port == 0 {
|
||||||
network = "unix"
|
network = "unix"
|
||||||
addr = configuration.Redis.Host
|
addr = config.Redis.Host
|
||||||
} else {
|
} else {
|
||||||
addr = fmt.Sprintf("%s:%d", configuration.Redis.Host, configuration.Redis.Port)
|
addr = fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
redisConfig = &redis.Config{
|
redisConfig = &redis.Config{
|
||||||
Logger: newRedisLogger(),
|
Logger: newRedisLogger(),
|
||||||
Network: network,
|
Network: network,
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Username: configuration.Redis.Username,
|
Username: config.Redis.Username,
|
||||||
Password: configuration.Redis.Password,
|
Password: config.Redis.Password,
|
||||||
DB: configuration.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
|
DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
|
||||||
PoolSize: configuration.Redis.MaximumActiveConnections,
|
PoolSize: config.Redis.MaximumActiveConnections,
|
||||||
MinIdleConns: configuration.Redis.MinimumIdleConnections,
|
MinIdleConns: config.Redis.MinimumIdleConnections,
|
||||||
IdleTimeout: 300,
|
IdleTimeout: 300,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
KeyPrefix: "authelia-session",
|
KeyPrefix: "authelia-session",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config.EncodeFunc = serializer.Encode
|
c.EncodeFunc = serializer.Encode
|
||||||
config.DecodeFunc = serializer.Decode
|
c.DecodeFunc = serializer.Decode
|
||||||
default:
|
default:
|
||||||
providerName = "memory"
|
providerName = "memory"
|
||||||
}
|
}
|
||||||
|
|
||||||
return ProviderConfig{
|
return ProviderConfig{
|
||||||
config,
|
c,
|
||||||
redisConfig,
|
redisConfig,
|
||||||
redisSentinelConfig,
|
redisSentinelConfig,
|
||||||
providerName,
|
providerName,
|
||||||
|
|
|
@ -53,7 +53,25 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
reDuration = regexp.MustCompile(`^(?P<Duration>[1-9]\d*?)(?P<Unit>[smhdwMy])?$`)
|
standardDurationUnits = []string{"ns", "us", "µs", "μs", "ms", "s", "m", "h"}
|
||||||
|
reDurationSeconds = regexp.MustCompile(`^\d+$`)
|
||||||
|
reDurationStandard = regexp.MustCompile(`(?P<Duration>[1-9]\d*?)(?P<Unit>[^\d\s]+)`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Duration unit types.
|
||||||
|
const (
|
||||||
|
DurationUnitDays = "d"
|
||||||
|
DurationUnitWeeks = "w"
|
||||||
|
DurationUnitMonths = "M"
|
||||||
|
DurationUnitYears = "y"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Number of hours in particular measurements of time.
|
||||||
|
const (
|
||||||
|
HoursInDay = 24
|
||||||
|
HoursInWeek = HoursInDay * 7
|
||||||
|
HoursInMonth = HoursInDay * 30
|
||||||
|
HoursInYear = HoursInDay * 365
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -6,46 +6,64 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseDurationString parses a string to a duration
|
// StandardizeDurationString converts units of time that stdlib is unaware of to hours.
|
||||||
// Duration notations are an integer followed by a unit
|
func StandardizeDurationString(input string) (output string, err error) {
|
||||||
// Units are s = second, m = minute, d = day, w = week, M = month, y = year
|
if input == "" {
|
||||||
// Example 1y is the same as 1 year.
|
return "0s", nil
|
||||||
func ParseDurationString(input string) (time.Duration, error) {
|
|
||||||
var duration time.Duration
|
|
||||||
|
|
||||||
matches := reDuration.FindStringSubmatch(input)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case len(matches) == 3 && matches[2] != "":
|
|
||||||
d, _ := strconv.Atoi(matches[1])
|
|
||||||
|
|
||||||
switch matches[2] {
|
|
||||||
case "y":
|
|
||||||
duration = time.Duration(d) * Year
|
|
||||||
case "M":
|
|
||||||
duration = time.Duration(d) * Month
|
|
||||||
case "w":
|
|
||||||
duration = time.Duration(d) * Week
|
|
||||||
case "d":
|
|
||||||
duration = time.Duration(d) * Day
|
|
||||||
case "h":
|
|
||||||
duration = time.Duration(d) * Hour
|
|
||||||
case "m":
|
|
||||||
duration = time.Duration(d) * time.Minute
|
|
||||||
case "s":
|
|
||||||
duration = time.Duration(d) * time.Second
|
|
||||||
}
|
|
||||||
case input == "0" || len(matches) == 3:
|
|
||||||
seconds, err := strconv.Atoi(input)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("could not parse '%s' as a duration: %w", input, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
duration = time.Duration(seconds) * time.Second
|
|
||||||
case input != "":
|
|
||||||
// Throw this error if input is anything other than a blank string, blank string will default to a duration of nothing.
|
|
||||||
return 0, fmt.Errorf("could not parse '%s' as a duration", input)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return duration, nil
|
matches := reDurationStandard.FindAllStringSubmatch(input, -1)
|
||||||
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return "", fmt.Errorf("could not parse '%s' as a duration", input)
|
||||||
|
}
|
||||||
|
|
||||||
|
var d int
|
||||||
|
|
||||||
|
for _, match := range matches {
|
||||||
|
if d, err = strconv.Atoi(match[1]); err != nil {
|
||||||
|
return "", fmt.Errorf("could not parse the numeric portion of '%s' in duration string '%s': %w", match[0], input, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
unit := match[2]
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case IsStringInSlice(unit, standardDurationUnits):
|
||||||
|
output += fmt.Sprintf("%d%s", d, unit)
|
||||||
|
case unit == DurationUnitDays:
|
||||||
|
output += fmt.Sprintf("%dh", d*HoursInDay)
|
||||||
|
case unit == DurationUnitWeeks:
|
||||||
|
output += fmt.Sprintf("%dh", d*HoursInWeek)
|
||||||
|
case unit == DurationUnitMonths:
|
||||||
|
output += fmt.Sprintf("%dh", d*HoursInMonth)
|
||||||
|
case unit == DurationUnitYears:
|
||||||
|
output += fmt.Sprintf("%dh", d*HoursInYear)
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("could not parse the units portion of '%s' in duration string '%s': the unit '%s' is not valid", match[0], input, unit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDurationString standardizes a duration string with StandardizeDurationString then uses time.ParseDuration to
|
||||||
|
// convert it into a time.Duration.
|
||||||
|
func ParseDurationString(input string) (duration time.Duration, err error) {
|
||||||
|
if reDurationSeconds.MatchString(input) {
|
||||||
|
var seconds int
|
||||||
|
|
||||||
|
if seconds, err = strconv.Atoi(input); err != nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Second * time.Duration(seconds), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var out string
|
||||||
|
|
||||||
|
if out, err = StandardizeDurationString(input); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.ParseDuration(out)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,66 +7,112 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldParseDurationString(t *testing.T) {
|
func TestParseDurationString_ShouldParseDurationString(t *testing.T) {
|
||||||
duration, err := ParseDurationString("1h")
|
duration, err := ParseDurationString("1h")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 60*time.Minute, duration)
|
assert.Equal(t, 60*time.Minute, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldParseDurationStringAllUnits(t *testing.T) {
|
func TestParseDurationString_ShouldParseBlankString(t *testing.T) {
|
||||||
duration, err := ParseDurationString("1y")
|
duration, err := ParseDurationString("")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, Year, duration)
|
assert.Equal(t, time.Second*0, duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDurationString_ShouldParseDurationStringAllUnits(t *testing.T) {
|
||||||
|
duration, err := ParseDurationString("1y")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, time.Hour*24*365, duration)
|
||||||
|
|
||||||
duration, err = ParseDurationString("1M")
|
duration, err = ParseDurationString("1M")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, Month, duration)
|
assert.Equal(t, time.Hour*24*30, duration)
|
||||||
|
|
||||||
duration, err = ParseDurationString("1w")
|
duration, err = ParseDurationString("1w")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, Week, duration)
|
assert.Equal(t, time.Hour*24*7, duration)
|
||||||
|
|
||||||
duration, err = ParseDurationString("1d")
|
duration, err = ParseDurationString("1d")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, Day, duration)
|
assert.Equal(t, time.Hour*24, duration)
|
||||||
|
|
||||||
duration, err = ParseDurationString("1h")
|
duration, err = ParseDurationString("1h")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, Hour, duration)
|
assert.Equal(t, time.Hour, duration)
|
||||||
|
|
||||||
duration, err = ParseDurationString("1s")
|
duration, err = ParseDurationString("1s")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, time.Second, duration)
|
assert.Equal(t, time.Second, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldParseSecondsString(t *testing.T) {
|
func TestParseDurationString_ShouldParseSecondsString(t *testing.T) {
|
||||||
duration, err := ParseDurationString("100")
|
duration, err := ParseDurationString("100")
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 100*time.Second, duration)
|
assert.Equal(t, 100*time.Second, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) {
|
func TestParseDurationString_ShouldNotParseDurationStringWithOutOfOrderQuantitiesAndUnits(t *testing.T) {
|
||||||
duration, err := ParseDurationString("h1")
|
duration, err := ParseDurationString("h1")
|
||||||
|
|
||||||
assert.EqualError(t, err, "could not parse 'h1' as a duration")
|
assert.EqualError(t, err, "could not parse 'h1' as a duration")
|
||||||
assert.Equal(t, time.Duration(0), duration)
|
assert.Equal(t, time.Duration(0), duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldNotParseBadDurationString(t *testing.T) {
|
func TestParseDurationString_ShouldNotParseBadDurationString(t *testing.T) {
|
||||||
duration, err := ParseDurationString("10x")
|
duration, err := ParseDurationString("10x")
|
||||||
assert.EqualError(t, err, "could not parse '10x' as a duration")
|
|
||||||
|
assert.EqualError(t, err, "could not parse the units portion of '10x' in duration string '10x': the unit 'x' is not valid")
|
||||||
assert.Equal(t, time.Duration(0), duration)
|
assert.Equal(t, time.Duration(0), duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldNotParseDurationStringWithMultiValueUnits(t *testing.T) {
|
func TestParseDurationString_ShouldParseDurationStringWithMultiValueUnits(t *testing.T) {
|
||||||
duration, err := ParseDurationString("10ms")
|
duration, err := ParseDurationString("10ms")
|
||||||
assert.EqualError(t, err, "could not parse '10ms' as a duration")
|
|
||||||
assert.Equal(t, time.Duration(0), duration)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, time.Duration(10)*time.Millisecond, duration)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldNotParseDurationStringWithLeadingZero(t *testing.T) {
|
func TestParseDurationString_ShouldParseDurationStringWithLeadingZero(t *testing.T) {
|
||||||
duration, err := ParseDurationString("005h")
|
duration, err := ParseDurationString("005h")
|
||||||
assert.EqualError(t, err, "could not parse '005h' as a duration")
|
|
||||||
assert.Equal(t, time.Duration(0), duration)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, time.Duration(5)*time.Hour, duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDurationString_ShouldParseMultiUnitValues(t *testing.T) {
|
||||||
|
duration, err := ParseDurationString("1d3w10ms")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t,
|
||||||
|
(time.Hour*time.Duration(24))+
|
||||||
|
(time.Hour*time.Duration(24)*time.Duration(7)*time.Duration(3))+
|
||||||
|
(time.Millisecond*time.Duration(10)), duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDurationString_ShouldParseDuplicateUnitValues(t *testing.T) {
|
||||||
|
duration, err := ParseDurationString("1d4d2d")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t,
|
||||||
|
(time.Hour*time.Duration(24))+
|
||||||
|
(time.Hour*time.Duration(24)*time.Duration(4))+
|
||||||
|
(time.Hour*time.Duration(24)*time.Duration(2)), duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStandardizeDurationString_ShouldParseStringWithSpaces(t *testing.T) {
|
||||||
|
result, err := StandardizeDurationString("1d 1h 20m")
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, result, "24h1h20m")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestShouldTimeIntervalsMakeSense(t *testing.T) {
|
func TestShouldTimeIntervalsMakeSense(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user