mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
refactor: include url hook func (#3022)
This adds a hook func for url.URL and *url.URL types to the configuration.
This commit is contained in:
parent
99326c2688
commit
dbe290a1c9
|
@ -3,6 +3,7 @@ package configuration
|
|||
import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
|
@ -11,8 +12,8 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// StringToMailAddressFunc decodes a string into a mail.Address.
|
||||
func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
|
||||
// StringToMailAddressHookFunc decodes a string into a mail.Address.
|
||||
func StringToMailAddressHookFunc() mapstructure.DecodeHookFuncType {
|
||||
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||
if f.Kind() != reflect.String || t != reflect.TypeOf(mail.Address{}) {
|
||||
return data, nil
|
||||
|
@ -36,12 +37,53 @@ func StringToMailAddressFunc() mapstructure.DecodeHookFunc {
|
|||
}
|
||||
}
|
||||
|
||||
// ToTimeDurationFunc converts string and integer types to a time.Duration.
|
||||
func ToTimeDurationFunc() mapstructure.DecodeHookFuncType {
|
||||
// StringToURLHookFunc converts string types into a url.URL.
|
||||
func StringToURLHookFunc() mapstructure.DecodeHookFuncType {
|
||||
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||
var (
|
||||
ptr bool
|
||||
)
|
||||
var ptr bool
|
||||
|
||||
if f.Kind() != reflect.String {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
ptr = t.Kind() == reflect.Ptr
|
||||
|
||||
typeURL := reflect.TypeOf(url.URL{})
|
||||
|
||||
if ptr && t.Elem() != typeURL {
|
||||
return data, nil
|
||||
} else if !ptr && t != typeURL {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
dataStr := data.(string)
|
||||
|
||||
var parsedURL *url.URL
|
||||
|
||||
// Return an empty URL if there is an empty string.
|
||||
if dataStr != "" {
|
||||
if parsedURL, err = url.Parse(dataStr); err != nil {
|
||||
return nil, fmt.Errorf("could not parse '%s' as a URL: %w", dataStr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if ptr {
|
||||
return parsedURL, nil
|
||||
}
|
||||
|
||||
// Return an empty URL if there is an empty string.
|
||||
if parsedURL == nil {
|
||||
return url.URL{}, nil
|
||||
}
|
||||
|
||||
return *parsedURL, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ToTimeDurationHookFunc converts string and integer types to a time.Duration.
|
||||
func ToTimeDurationHookFunc() mapstructure.DecodeHookFuncType {
|
||||
return func(f reflect.Type, t reflect.Type, data interface{}) (value interface{}, err error) {
|
||||
var ptr bool
|
||||
|
||||
switch f.Kind() {
|
||||
case reflect.String, reflect.Int, reflect.Int32, reflect.Int64:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package configuration
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -8,8 +9,134 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestStringToURLHookFunc_ShouldNotParseStrings(t *testing.T) {
|
||||
hook := StringToURLHookFunc()
|
||||
|
||||
var (
|
||||
from = "https://google.com/abc?a=123"
|
||||
|
||||
result interface{}
|
||||
err error
|
||||
|
||||
resultTo string
|
||||
resultPtrTo *time.Time
|
||||
ok bool
|
||||
)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultTo, ok = result.(string)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, from, resultTo)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultTo, ok = result.(string)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, from, resultTo)
|
||||
}
|
||||
|
||||
func TestStringToURLHookFunc_ShouldParseEmptyString(t *testing.T) {
|
||||
hook := StringToURLHookFunc()
|
||||
|
||||
var (
|
||||
from = ""
|
||||
|
||||
result interface{}
|
||||
err error
|
||||
|
||||
resultTo url.URL
|
||||
resultPtrTo *url.URL
|
||||
|
||||
ok bool
|
||||
)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultTo, ok = result.(url.URL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "", resultTo.String())
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultPtrTo, ok = result.(*url.URL)
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, resultPtrTo)
|
||||
}
|
||||
|
||||
func TestStringToURLHookFunc_ShouldNotParseBadURLs(t *testing.T) {
|
||||
hook := StringToURLHookFunc()
|
||||
|
||||
var (
|
||||
from = "*(!&@#(!*^$%"
|
||||
|
||||
result interface{}
|
||||
err error
|
||||
|
||||
resultTo url.URL
|
||||
resultPtrTo *url.URL
|
||||
)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
|
||||
assert.Nil(t, result)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||
assert.EqualError(t, err, "could not parse '*(!&@#(!*^$%' as a URL: parse \"*(!&@#(!*^$%\": invalid URL escape \"%\"")
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestStringToURLHookFunc_ShouldParseURLs(t *testing.T) {
|
||||
hook := StringToURLHookFunc()
|
||||
|
||||
var (
|
||||
from = "https://google.com/abc?a=123"
|
||||
|
||||
result interface{}
|
||||
err error
|
||||
|
||||
resultTo url.URL
|
||||
resultPtrTo *url.URL
|
||||
|
||||
ok bool
|
||||
)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultTo), from)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultTo, ok = result.(url.URL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "https", resultTo.Scheme)
|
||||
assert.Equal(t, "google.com", resultTo.Host)
|
||||
assert.Equal(t, "/abc", resultTo.Path)
|
||||
assert.Equal(t, "a=123", resultTo.RawQuery)
|
||||
|
||||
resultPtrTo, ok = result.(*url.URL)
|
||||
assert.False(t, ok)
|
||||
assert.Nil(t, resultPtrTo)
|
||||
|
||||
result, err = hook(reflect.TypeOf(from), reflect.TypeOf(resultPtrTo), from)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resultPtrTo, ok = result.(*url.URL)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, resultPtrTo)
|
||||
|
||||
assert.Equal(t, "https", resultTo.Scheme)
|
||||
assert.Equal(t, "google.com", resultTo.Host)
|
||||
assert.Equal(t, "/abc", resultTo.Path)
|
||||
assert.Equal(t, "a=123", resultTo.RawQuery)
|
||||
|
||||
resultTo, ok = result.(url.URL)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestToTimeDurationHookFunc_ShouldParse_String(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = "1h"
|
||||
|
@ -30,8 +157,8 @@ func TestToTimeDurationFunc_ShouldParse_String(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_String_Years(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = "1y"
|
||||
|
@ -52,8 +179,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Years(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_String_Months(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = "1M"
|
||||
|
@ -74,8 +201,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Months(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_String_Weeks(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = "1w"
|
||||
|
@ -96,8 +223,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Weeks(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_String_Days(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = "1d"
|
||||
|
@ -118,8 +245,8 @@ func TestToTimeDurationFunc_ShouldParse_String_Days(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = "abc"
|
||||
|
@ -139,8 +266,8 @@ func TestToTimeDurationFunc_ShouldNotParseAndRaiseErr_InvalidString(t *testing.T
|
|||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_Int(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = 60
|
||||
|
@ -161,8 +288,8 @@ func TestToTimeDurationFunc_ShouldParse_Int(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_Int32(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = int32(120)
|
||||
|
@ -183,8 +310,8 @@ func TestToTimeDurationFunc_ShouldParse_Int32(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_Int64(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = int64(30)
|
||||
|
@ -205,8 +332,8 @@ func TestToTimeDurationFunc_ShouldParse_Int64(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_Duration(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = time.Second * 30
|
||||
|
@ -227,8 +354,8 @@ func TestToTimeDurationFunc_ShouldParse_Duration(t *testing.T) {
|
|||
assert.Equal(t, &expected, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = int64(30)
|
||||
|
@ -248,8 +375,8 @@ func TestToTimeDurationFunc_ShouldNotParse_Int64ToString(t *testing.T) {
|
|||
assert.Equal(t, from, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldNotParse_FromBool(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = true
|
||||
|
@ -269,8 +396,8 @@ func TestToTimeDurationFunc_ShouldNotParse_FromBool(t *testing.T) {
|
|||
assert.Equal(t, from, result)
|
||||
}
|
||||
|
||||
func TestToTimeDurationFunc_ShouldParse_FromZero(t *testing.T) {
|
||||
hook := ToTimeDurationFunc()
|
||||
func TestToTimeDurationHookFunc_ShouldParse_FromZero(t *testing.T) {
|
||||
hook := ToTimeDurationHookFunc()
|
||||
|
||||
var (
|
||||
from = 0
|
||||
|
|
|
@ -44,8 +44,9 @@ func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o inte
|
|||
DecoderConfig: &mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToSliceHookFunc(","),
|
||||
StringToMailAddressFunc(),
|
||||
ToTimeDurationFunc(),
|
||||
StringToMailAddressHookFunc(),
|
||||
ToTimeDurationHookFunc(),
|
||||
StringToURLHookFunc(),
|
||||
),
|
||||
Metadata: nil,
|
||||
Result: o,
|
||||
|
|
Loading…
Reference in New Issue
Block a user