authelia/internal/middlewares/timing_attack_delay.go
James Elliott 001589cd6d
feat(metrics): implement prometheus metrics (#3234)
Adds ability to record metrics and gather them for Prometheus.
2022-06-14 17:20:13 +10:00

76 lines
2.5 KiB
Go

package middlewares
import (
"crypto/rand"
"math"
"math/big"
"sync"
"time"
)
// TimingAttackDelayFunc describes a function for preventing timing attacks via a delay.
type TimingAttackDelayFunc func(ctx *AutheliaCtx, requestTime time.Time, successful *bool)
// TimingAttackDelay creates a new standard timing delay func.
func TimingAttackDelay(history int, minDelayMs float64, maxRandomMs int64, initialDelay time.Duration, record bool) TimingAttackDelayFunc {
var (
mutex = &sync.Mutex{}
cursor = 0
)
execDurationMovingAverage := make([]time.Duration, history)
for i := range execDurationMovingAverage {
execDurationMovingAverage[i] = initialDelay
}
return func(ctx *AutheliaCtx, requestTime time.Time, successful *bool) {
successfulValue := false
if successful != nil {
successfulValue = *successful
}
execDuration := time.Since(requestTime)
if record && ctx.Providers.Metrics != nil {
ctx.Providers.Metrics.RecordAuthenticationDuration(successfulValue, execDuration)
}
execDurationAvgMs := movingAverageIteration(execDuration, history, successfulValue, &cursor, &execDurationMovingAverage, mutex)
actualDelayMs := calculateActualDelay(ctx, execDuration, execDurationAvgMs, minDelayMs, maxRandomMs, successfulValue)
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
}
}
func movingAverageIteration(value time.Duration, history int, successful bool, cursor *int, movingAvg *[]time.Duration, mutex sync.Locker) float64 {
mutex.Lock()
var sum int64
for _, v := range *movingAvg {
sum += v.Milliseconds()
}
if successful {
(*movingAvg)[*cursor] = value
*cursor = (*cursor + 1) % history
}
mutex.Unlock()
return float64(sum / int64(history))
}
func calculateActualDelay(ctx *AutheliaCtx, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs))
if err != nil {
return float64(maxRandomMs)
}
totalDelayMs := math.Max(execDurationAvgMs, minDelayMs) + float64(randomDelayMs.Int64())
actualDelayMs = math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
ctx.Logger.Tracef("Timing Attack Delay successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", successful, execDuration.Milliseconds(), int64(execDurationAvgMs), randomDelayMs.Int64(), int64(totalDelayMs), int64(actualDelayMs))
return actualDelayMs
}