mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
fix(middlewares): smart delay on reset password (#2767)
This adds a smart delay on reset password attempts to prevent username enumeration. Additionally utilizes crypto rand instead of math rand. It also moves the timing delay functionality into its own handler func.
This commit is contained in:
parent
97a862e81a
commit
9a8c6602dd
|
@ -61,12 +61,6 @@ const (
|
|||
testUsername = "john"
|
||||
)
|
||||
|
||||
const (
|
||||
loginDelayMovingAverageWindow = 10
|
||||
loginDelayMinimumDelayMilliseconds = float64(250)
|
||||
loginDelayMaximumRandomDelayMilliseconds = int64(85)
|
||||
)
|
||||
|
||||
// Duo constants.
|
||||
const (
|
||||
allow = "allow"
|
||||
|
|
|
@ -2,9 +2,6 @@ package handlers
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
|
@ -12,61 +9,16 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
func movingAverageIteration(value time.Duration, successful bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) float64 {
|
||||
mutex.Lock()
|
||||
if successful {
|
||||
(*execDurationMovingAverage)[*movingAverageCursor] = value
|
||||
*movingAverageCursor = (*movingAverageCursor + 1) % loginDelayMovingAverageWindow
|
||||
}
|
||||
|
||||
var sum int64
|
||||
|
||||
for _, v := range *execDurationMovingAverage {
|
||||
sum += v.Milliseconds()
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
return float64(sum / loginDelayMovingAverageWindow)
|
||||
}
|
||||
|
||||
func calculateActualDelay(ctx *middlewares.AutheliaCtx, execDuration time.Duration, avgExecDurationMs float64, successful *bool) float64 {
|
||||
randomDelayMs := float64(rand.Int63n(loginDelayMaximumRandomDelayMilliseconds)) //nolint:gosec // TODO: Consider use of crypto/rand, this should be benchmarked and measured first.
|
||||
totalDelayMs := math.Max(avgExecDurationMs, loginDelayMinimumDelayMilliseconds) + randomDelayMs
|
||||
actualDelayMs := math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
|
||||
ctx.Logger.Tracef("Attempt successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", *successful, execDuration.Milliseconds(), int64(avgExecDurationMs), int64(randomDelayMs), int64(totalDelayMs), int64(actualDelayMs))
|
||||
|
||||
return actualDelayMs
|
||||
}
|
||||
|
||||
func delayToPreventTimingAttacks(ctx *middlewares.AutheliaCtx, requestTime time.Time, successful *bool, movingAverageCursor *int, execDurationMovingAverage *[]time.Duration, mutex sync.Locker) {
|
||||
execDuration := time.Since(requestTime)
|
||||
avgExecDurationMs := movingAverageIteration(execDuration, *successful, movingAverageCursor, execDurationMovingAverage, mutex)
|
||||
actualDelayMs := calculateActualDelay(ctx, execDuration, avgExecDurationMs, successful)
|
||||
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
|
||||
}
|
||||
|
||||
// FirstFactorPost is the handler performing the first factory.
|
||||
//nolint:gocyclo // TODO: Consider refactoring time permitting.
|
||||
func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middlewares.RequestHandler {
|
||||
var execDurationMovingAverage = make([]time.Duration, loginDelayMovingAverageWindow)
|
||||
|
||||
var movingAverageCursor = 0
|
||||
|
||||
var mutex = &sync.Mutex{}
|
||||
|
||||
for i := range execDurationMovingAverage {
|
||||
execDurationMovingAverage[i] = msInitialDelay * time.Millisecond
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.RequestHandler {
|
||||
return func(ctx *middlewares.AutheliaCtx) {
|
||||
var successful bool
|
||||
|
||||
requestTime := time.Now()
|
||||
|
||||
if delayEnabled {
|
||||
defer delayToPreventTimingAttacks(ctx, requestTime, &successful, &movingAverageCursor, &execDurationMovingAverage, mutex)
|
||||
if delayFunc != nil {
|
||||
defer delayFunc(ctx.Logger, requestTime, &successful)
|
||||
}
|
||||
|
||||
bodyJSON := firstFactorRequestBody{}
|
||||
|
|
|
@ -2,9 +2,7 @@ package handlers
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -33,7 +31,7 @@ func (s *FirstFactorSuite) TearDownTest() {
|
|||
}
|
||||
|
||||
func (s *FirstFactorSuite) TestShouldFailIfBodyIsNil() {
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// No body
|
||||
assert.Equal(s.T(), "Failed to parse 1FA request body: unable to parse body: unexpected end of JSON input", s.mock.Hook.LastEntry().Message)
|
||||
|
@ -45,7 +43,7 @@ func (s *FirstFactorSuite) TestShouldFailIfBodyIsInBadFormat() {
|
|||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
"username": "test"
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Failed to parse 1FA request body: unable to validate body: password: non zero value required", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -73,7 +71,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Unsuccessful 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -102,7 +100,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC
|
|||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
}
|
||||
|
||||
func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCredentials() {
|
||||
|
@ -128,7 +126,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
|
|||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
}
|
||||
|
||||
func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
|
||||
|
@ -152,7 +150,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Could not obtain profile details during 1FA authentication for user 'test': failed", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -174,7 +172,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
assert.Equal(s.T(), "Unable to mark 1FA authentication attempt by user 'test': failed", s.mock.Hook.LastEntry().Message)
|
||||
s.mock.Assert401KO(s.T(), "Authentication failed. Check your credentials.")
|
||||
|
@ -205,7 +203,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() {
|
|||
"password": "hello",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||
|
@ -246,7 +244,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() {
|
|||
"requestMethod": "GET",
|
||||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||
|
@ -290,7 +288,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess
|
|||
"requestMethod": "GET",
|
||||
"keepMeLoggedIn": true
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
|
||||
|
@ -360,7 +358,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenNoTarget
|
|||
"requestMethod": "GET",
|
||||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"})
|
||||
|
@ -381,7 +379,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldRedirectToDefaultURLWhenURLIsUns
|
|||
"targetURL": "http://notsafe.local"
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), redirectResponse{Redirect: "https://default.local"})
|
||||
|
@ -404,7 +402,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenNoTargetURLProvidedA
|
|||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), nil)
|
||||
|
@ -436,7 +434,7 @@ func (s *FirstFactorRedirectionSuite) TestShouldReply200WhenUnsafeTargetURLProvi
|
|||
"keepMeLoggedIn": false
|
||||
}`)
|
||||
|
||||
FirstFactorPost(0, false)(s.mock.Ctx)
|
||||
FirstFactorPost(nil)(s.mock.Ctx)
|
||||
|
||||
// Respond with 200.
|
||||
s.mock.Assert200OK(s.T(), nil)
|
||||
|
@ -446,57 +444,3 @@ func TestFirstFactorSuite(t *testing.T) {
|
|||
suite.Run(t, new(FirstFactorSuite))
|
||||
suite.Run(t, new(FirstFactorRedirectionSuite))
|
||||
}
|
||||
|
||||
func TestFirstFactorDelayAverages(t *testing.T) {
|
||||
execDuration := time.Millisecond * 500
|
||||
oneSecond := time.Millisecond * 1000
|
||||
durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond}
|
||||
cursor := 0
|
||||
mutex := &sync.Mutex{}
|
||||
avgExecDuration := movingAverageIteration(execDuration, false, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, float64(1000))
|
||||
|
||||
execDurations := []time.Duration{
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
}
|
||||
|
||||
current := float64(1000)
|
||||
|
||||
// Execute at 500ms for 12 requests.
|
||||
for _, execDuration = range execDurations {
|
||||
// Should not dip below 500, and should decrease in value by 50 each iteration.
|
||||
if current > 500 {
|
||||
current -= 50
|
||||
}
|
||||
|
||||
avgExecDuration := movingAverageIteration(execDuration, true, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, current)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirstFactorDelayCalculations(t *testing.T) {
|
||||
mock := mocks.NewMockAutheliaCtx(t)
|
||||
successful := false
|
||||
|
||||
execDuration := 500 * time.Millisecond
|
||||
avgExecDurationMs := 1000.0
|
||||
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds))
|
||||
}
|
||||
|
||||
execDuration = 5 * time.Millisecond
|
||||
avgExecDurationMs = 5.0
|
||||
expectedMinimumDelayMs = loginDelayMinimumDelayMilliseconds - float64(execDuration.Milliseconds())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(mock.Ctx, execDuration, avgExecDurationMs, &successful)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(loginDelayMaximumRandomDelayMilliseconds))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle
|
|||
TargetEndpoint: "/one-time-password/register",
|
||||
ActionClaim: ActionTOTPRegistration,
|
||||
IdentityRetrieverFunc: identityRetrieverFromSession,
|
||||
})
|
||||
}, nil)
|
||||
|
||||
func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
var (
|
||||
|
|
|
@ -21,7 +21,7 @@ var SecondFactorU2FIdentityStart = middlewares.IdentityVerificationStart(middlew
|
|||
TargetEndpoint: "/security-key/register",
|
||||
ActionClaim: ActionU2FRegistration,
|
||||
IdentityRetrieverFunc: identityRetrieverFromSession,
|
||||
})
|
||||
}, nil)
|
||||
|
||||
func secondFactorU2FIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
if ctx.XForwardedProto() == nil {
|
||||
|
|
|
@ -3,6 +3,7 @@ package handlers
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
|
@ -40,7 +41,7 @@ var ResetPasswordIdentityStart = middlewares.IdentityVerificationStart(middlewar
|
|||
TargetEndpoint: "/reset-password/step2",
|
||||
ActionClaim: ActionResetPassword,
|
||||
IdentityRetrieverFunc: identityRetrieverFromStorage,
|
||||
})
|
||||
}, middlewares.TimingAttackDelay(10, 250, 85, time.Millisecond*500))
|
||||
|
||||
func resetPasswordIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
userSession := ctx.GetSession()
|
||||
|
|
|
@ -13,13 +13,14 @@ import (
|
|||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// NewRequestLogger create a new request logger for the given request.
|
||||
func NewRequestLogger(ctx *AutheliaCtx) *logrus.Entry {
|
||||
return logrus.WithFields(logrus.Fields{
|
||||
return logging.Logger().WithFields(logrus.Fields{
|
||||
"method": string(ctx.Method()),
|
||||
"path": string(ctx.Path()),
|
||||
"remote_ip": ctx.RemoteIP().String(),
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/google/uuid"
|
||||
|
@ -13,12 +14,19 @@ import (
|
|||
)
|
||||
|
||||
// IdentityVerificationStart the handler for initiating the identity validation process.
|
||||
func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandler {
|
||||
func IdentityVerificationStart(args IdentityVerificationStartArgs, delayFunc TimingAttackDelayFunc) RequestHandler {
|
||||
if args.IdentityRetrieverFunc == nil {
|
||||
panic(fmt.Errorf("Identity verification requires an identity retriever"))
|
||||
}
|
||||
|
||||
return func(ctx *AutheliaCtx) {
|
||||
requestTime := time.Now()
|
||||
success := false
|
||||
|
||||
if delayFunc != nil {
|
||||
defer delayFunc(ctx.Logger, requestTime, &success)
|
||||
}
|
||||
|
||||
identity, err := args.IdentityRetrieverFunc(ctx)
|
||||
if err != nil {
|
||||
// In that case we reply ok to avoid user enumeration.
|
||||
|
@ -106,6 +114,8 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
|
|||
return
|
||||
}
|
||||
|
||||
success = true
|
||||
|
||||
ctx.ReplyOK()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ func TestShouldFailStartingProcessIfUserHasNoEmailAddress(t *testing.T) {
|
|||
return nil, fmt.Errorf("User does not have any email")
|
||||
}
|
||||
|
||||
middlewares.IdentityVerificationStart(newArgs(retriever))(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(newArgs(retriever), nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "User does not have any email", mock.Hook.LastEntry().Message)
|
||||
|
@ -61,7 +61,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) {
|
|||
Return(fmt.Errorf("cannot save"))
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "cannot save", mock.Hook.LastEntry().Message)
|
||||
|
@ -84,7 +84,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
|
|||
Return(fmt.Errorf("no notif"))
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
|
||||
|
@ -102,7 +102,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
|
|||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message)
|
||||
|
@ -120,7 +120,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
|
|||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
assert.Equal(t, "Missing header X-Forwarded-Host", mock.Hook.LastEntry().Message)
|
||||
|
@ -142,7 +142,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
|
|||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
middlewares.IdentityVerificationStart(args)(mock.Ctx)
|
||||
middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
|
||||
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
|
||||
|
|
72
internal/middlewares/timing_attack_delay.go
Normal file
72
internal/middlewares/timing_attack_delay.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math"
|
||||
"math/big"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TimingAttackDelayFunc describes a function for preventing timing attacks via a delay.
|
||||
type TimingAttackDelayFunc func(logger *logrus.Entry, requestTime time.Time, successful *bool)
|
||||
|
||||
// TimingAttackDelay creates a new standard timing delay func.
|
||||
func TimingAttackDelay(history int, minDelayMs float64, maxRandomMs int64, initialDelay time.Duration) TimingAttackDelayFunc {
|
||||
var (
|
||||
mutex = &sync.Mutex{}
|
||||
cursor = 0
|
||||
)
|
||||
|
||||
execDurationMovingAverage := make([]time.Duration, history)
|
||||
|
||||
for i := range execDurationMovingAverage {
|
||||
execDurationMovingAverage[i] = initialDelay
|
||||
}
|
||||
|
||||
return func(logger *logrus.Entry, requestTime time.Time, successful *bool) {
|
||||
successfulValue := false
|
||||
if successful != nil {
|
||||
successfulValue = *successful
|
||||
}
|
||||
|
||||
execDuration := time.Since(requestTime)
|
||||
execDurationAvgMs := movingAverageIteration(execDuration, history, successfulValue, &cursor, &execDurationMovingAverage, mutex)
|
||||
actualDelayMs := calculateActualDelay(logger, execDuration, execDurationAvgMs, minDelayMs, maxRandomMs, successfulValue)
|
||||
time.Sleep(time.Duration(actualDelayMs) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func movingAverageIteration(value time.Duration, history int, successful bool, cursor *int, movingAvg *[]time.Duration, mutex sync.Locker) float64 {
|
||||
mutex.Lock()
|
||||
|
||||
var sum int64
|
||||
|
||||
for _, v := range *movingAvg {
|
||||
sum += v.Milliseconds()
|
||||
}
|
||||
|
||||
if successful {
|
||||
(*movingAvg)[*cursor] = value
|
||||
*cursor = (*cursor + 1) % history
|
||||
}
|
||||
|
||||
mutex.Unlock()
|
||||
|
||||
return float64(sum / int64(history))
|
||||
}
|
||||
|
||||
func calculateActualDelay(logger *logrus.Entry, execDuration time.Duration, execDurationAvgMs, minDelayMs float64, maxRandomMs int64, successful bool) (actualDelayMs float64) {
|
||||
randomDelayMs, err := rand.Int(rand.Reader, big.NewInt(maxRandomMs))
|
||||
if err != nil {
|
||||
return float64(maxRandomMs)
|
||||
}
|
||||
|
||||
totalDelayMs := math.Max(execDurationAvgMs, minDelayMs) + float64(randomDelayMs.Int64())
|
||||
actualDelayMs = math.Max(totalDelayMs-float64(execDuration.Milliseconds()), 1.0)
|
||||
logger.Tracef("Timing Attack Delay successful: %t, exec duration: %d, avg execution duration: %d, random delay ms: %d, total delay ms: %d, actual delay ms: %d", successful, execDuration.Milliseconds(), int64(execDurationAvgMs), randomDelayMs.Int64(), int64(totalDelayMs), int64(actualDelayMs))
|
||||
|
||||
return actualDelayMs
|
||||
}
|
65
internal/middlewares/timing_attack_delay_test.go
Normal file
65
internal/middlewares/timing_attack_delay_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
func TestTimingAttackDelayAverages(t *testing.T) {
|
||||
execDuration := time.Millisecond * 500
|
||||
oneSecond := time.Millisecond * 1000
|
||||
durations := []time.Duration{oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond, oneSecond}
|
||||
cursor := 0
|
||||
mutex := &sync.Mutex{}
|
||||
avgExecDuration := movingAverageIteration(execDuration, 10, false, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, float64(1000))
|
||||
|
||||
execDurations := []time.Duration{
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500, time.Millisecond * 500,
|
||||
}
|
||||
|
||||
current := float64(1000)
|
||||
|
||||
// Execute at 500ms for 12 requests.
|
||||
for _, execDuration = range execDurations {
|
||||
avgExecDuration = movingAverageIteration(execDuration, 10, true, &cursor, &durations, mutex)
|
||||
assert.Equal(t, avgExecDuration, current)
|
||||
|
||||
// Should not dip below 500, and should decrease in value by 50 each iteration.
|
||||
if current > 500 {
|
||||
current -= 50
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimingAttackDelayCalculations(t *testing.T) {
|
||||
execDuration := 500 * time.Millisecond
|
||||
avgExecDurationMs := 1000.0
|
||||
expectedMinimumDelayMs := avgExecDurationMs - float64(execDuration.Milliseconds())
|
||||
|
||||
logger := logging.Logger().WithFields(logrus.Fields{})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(85))
|
||||
}
|
||||
|
||||
execDuration = 5 * time.Millisecond
|
||||
avgExecDurationMs = 5.0
|
||||
expectedMinimumDelayMs = 250 - float64(execDuration.Milliseconds())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
delay := calculateActualDelay(logger, execDuration, avgExecDurationMs, 250, 85, false)
|
||||
assert.True(t, delay >= expectedMinimumDelayMs)
|
||||
assert.True(t, delay <= expectedMinimumDelayMs+float64(250))
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
duoapi "github.com/duosecurity/duo_api_golang"
|
||||
"github.com/fasthttp/router"
|
||||
|
@ -69,7 +70,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
|||
|
||||
r.POST("/api/checks/safe-redirection", autheliaMiddleware(handlers.CheckSafeRedirection))
|
||||
|
||||
r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(1000, true)))
|
||||
r.POST("/api/firstfactor", autheliaMiddleware(handlers.FirstFactorPost(middlewares.TimingAttackDelay(10, 250, 85, time.Second))))
|
||||
r.POST("/api/logout", autheliaMiddleware(handlers.LogoutPost))
|
||||
|
||||
// Only register endpoints if forgot password is not disabled.
|
||||
|
|
Loading…
Reference in New Issue
Block a user