package middlewares

import (
	"crypto/rand"
	"math"
	"math/big"
	"sync"
	"time"
)

// TimingAttackDelayFunc describes a function for preventing timing attacks via a delay.
type TimingAttackDelayFunc func(ctx *AutheliaCtx, requestTime time.Time, successful *bool)

// TimingAttackDelay creates a new standard timing delay func.
func TimingAttackDelay(history int, minDelayMs float64, maxRandomMs int64, initialDelay time.Duration, record bool) TimingAttackDelayFunc {
	var (
		mutex  = &sync.Mutex{}
		cursor = 0
	)

	execDurationMovingAverage := make([]time.Duration, history)

	for i := range execDurationMovingAverage {
		execDurationMovingAverage[i] = initialDelay
	}

	return func(ctx *AutheliaCtx, requestTime time.Time, successful *bool) {
		successfulValue := false
		if successful != nil {
			successfulValue = *successful
		}

		execDuration := time.Since(requestTime)

		if record && ctx.Providers.Metrics != nil {
			ctx.Providers.Metrics.RecordAuthenticationDuration(successfulValue, execDuration)
		}

		execDurationAvgMs := movingAverageIteration(execDuration, history, successfulValue, &cursor, &execDurationMovingAverage, mutex)
		actualDelayMs := calculateActualDelay(ctx, execDuration, execDurationAvgMs, minDelayMs, maxRandomMs, successfulValue)
		time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
	}
}

func movingAverageIteration(value time.Duration, history int, successful bool, cursor *int, movingAvg *[]time.Duration, mutex sync.Locker) float64 {
	mutex.Lock()

	var sum int64

	for _, v := range *movingAvg {
		sum += v.Milliseconds()
	}

	if successful {
		(*movingAvg)[*cursor] = value
		*cursor = (*cursor + 1) % history
	}

	mutex.Unlock()

	return float64(sum / int64(history))
}

func calculateActualDelay(ctx *AutheliaCtx, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
	randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs))
	if err != nil {
		return float64(maxRandomMs)
	}

	totalDelayMs := math.Max(execDurationAvgMs, minDelayMs) + float64(randomDelayMs.Int64())
	actualDelayMs = math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
	ctx.Logger.Tracef("Timing Attack Delay successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", successful, execDuration.Milliseconds(), int64(execDurationAvgMs), randomDelayMs.Int64(), int64(totalDelayMs), int64(actualDelayMs))

	return actualDelayMs
}