authelia/internal/session/provider_config.go
James Elliott 6276883f04
refactor(configuration): utilize time duration decode hook (#2938)
This enhances the existing time.Duration parser to allow multiple units, and implements a decode hook which can be used by koanf to decode string/integers into time.Durations as applicable.
2022-03-02 17:40:26 +11:00

153 lines
4.2 KiB
Go

package session
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"fmt"
"strings"
"github.com/fasthttp/session/v2"
"github.com/fasthttp/session/v2/providers/redis"
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewProviderConfig creates a configuration for creating the session provider.
func NewProviderConfig(config schema.SessionConfiguration, certPool *x509.CertPool) ProviderConfig {
c := session.NewDefaultConfig()
c.SessionIDGeneratorFunc = func() []byte {
bytes := make([]byte, 32)
_, _ = rand.Read(bytes)
for i, b := range bytes {
bytes[i] = randomSessionChars[b%byte(len(randomSessionChars))]
}
return bytes
}
// Override the cookie name.
c.CookieName = config.Name
// Set the cookie to the given domain.
c.Domain = config.Domain
// Set the cookie SameSite option.
switch config.SameSite {
case "strict":
c.CookieSameSite = fasthttp.CookieSameSiteStrictMode
case "none":
c.CookieSameSite = fasthttp.CookieSameSiteNoneMode
case "lax":
c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
default:
c.CookieSameSite = fasthttp.CookieSameSiteLaxMode
}
// Only serve the header over HTTPS.
c.Secure = true
// Ignore the error as it will be handled by validator.
c.Expiration = config.Expiration
c.IsSecureFunc = func(*fasthttp.RequestCtx) bool {
return true
}
var redisConfig *redis.Config
var redisSentinelConfig *redis.FailoverConfig
var providerName string
// If redis configuration is provided, then use the redis provider.
switch {
case config.Redis != nil:
serializer := NewEncryptingSerializer(config.Secret)
var tlsConfig *tls.Config
if config.Redis.TLS != nil {
tlsConfig = utils.NewTLSConfig(config.Redis.TLS, tls.VersionTLS12, certPool)
}
if config.Redis.HighAvailability != nil && config.Redis.HighAvailability.SentinelName != "" {
addrs := make([]string, 0)
if config.Redis.Host != "" {
addrs = append(addrs, fmt.Sprintf("%s:%d", strings.ToLower(config.Redis.Host), config.Redis.Port))
}
for _, node := range config.Redis.HighAvailability.Nodes {
addr := fmt.Sprintf("%s:%d", strings.ToLower(node.Host), node.Port)
if !utils.IsStringInSlice(addr, addrs) {
addrs = append(addrs, addr)
}
}
providerName = "redis-sentinel"
redisSentinelConfig = &redis.FailoverConfig{
Logger: &redisLogger{logger: logging.Logger()},
MasterName: config.Redis.HighAvailability.SentinelName,
SentinelAddrs: addrs,
SentinelUsername: config.Redis.HighAvailability.SentinelUsername,
SentinelPassword: config.Redis.HighAvailability.SentinelPassword,
RouteByLatency: config.Redis.HighAvailability.RouteByLatency,
RouteRandomly: config.Redis.HighAvailability.RouteRandomly,
Username: config.Redis.Username,
Password: config.Redis.Password,
DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: config.Redis.MaximumActiveConnections,
MinIdleConns: config.Redis.MinimumIdleConnections,
IdleTimeout: 300,
TLSConfig: tlsConfig,
KeyPrefix: "authelia-session",
}
} else {
providerName = "redis"
network := "tcp"
var addr string
if config.Redis.Port == 0 {
network = "unix"
addr = config.Redis.Host
} else {
addr = fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port)
}
redisConfig = &redis.Config{
Logger: newRedisLogger(),
Network: network,
Addr: addr,
Username: config.Redis.Username,
Password: config.Redis.Password,
DB: config.Redis.DatabaseIndex, // DB is the fasthttp/session property for the Redis DB Index.
PoolSize: config.Redis.MaximumActiveConnections,
MinIdleConns: config.Redis.MinimumIdleConnections,
IdleTimeout: 300,
TLSConfig: tlsConfig,
KeyPrefix: "authelia-session",
}
}
c.EncodeFunc = serializer.Encode
c.DecodeFunc = serializer.Decode
default:
providerName = "memory"
}
return ProviderConfig{
c,
redisConfig,
redisSentinelConfig,
providerName,
}
}