From bf9ab360bd8dd46739d4aa0018ceb4c08e05dba8 Mon Sep 17 00:00:00 2001
From: James Elliott <james-d-elliott@users.noreply.github.com>
Date: Thu, 2 Dec 2021 13:21:46 +1100
Subject: [PATCH] refactor(handlers): utilize referer for auth logging rm/rd
 (#2655)

This utilizes the referrer query parameters instead of current request query parameters for logging the requested URI and method. Minor performance improvements to header peek/sets.
---
 internal/handlers/const.go               | 25 ++++++++++-----------
 internal/handlers/handler_verify.go      | 28 ++++++++++++------------
 internal/handlers/handler_verify_test.go | 16 +++++++-------
 internal/handlers/response.go            | 15 ++++++++++++-
 internal/middlewares/authelia_context.go | 16 +++++++-------
 internal/middlewares/const.go            | 28 ++++++++++++++----------
 6 files changed, 72 insertions(+), 56 deletions(-)

diff --git a/internal/handlers/const.go b/internal/handlers/const.go
index fe292e72..d3bc4e98 100644
--- a/internal/handlers/const.go
+++ b/internal/handlers/const.go
@@ -1,5 +1,9 @@
 package handlers
 
+import (
+	"github.com/valyala/fasthttp"
+)
+
 const (
 	// ActionTOTPRegistration is the string representation of the action for which the token has been produced.
 	ActionTOTPRegistration = "RegisterTOTPDevice"
@@ -11,20 +15,15 @@ const (
 	ActionResetPassword = "ResetPassword"
 )
 
-const (
-	// HeaderProxyAuthorization is the basic-auth HTTP header Authelia utilises.
-	HeaderProxyAuthorization = "Proxy-Authorization"
+var (
+	headerAuthorization      = []byte(fasthttp.HeaderAuthorization)
+	headerProxyAuthorization = []byte(fasthttp.HeaderProxyAuthorization)
 
-	// HeaderAuthorization is the basic-auth HTTP header Authelia utilises with "auth=basic" query param.
-	HeaderAuthorization = "Authorization"
-
-	// HeaderSessionUsername is used as additional protection to validate a user for things like pam_exec.
-	HeaderSessionUsername = "Session-Username"
-
-	headerRemoteUser   = "Remote-User"
-	headerRemoteName   = "Remote-Name"
-	headerRemoteEmail  = "Remote-Email"
-	headerRemoteGroups = "Remote-Groups"
+	headerSessionUsername = []byte("Session-Username")
+	headerRemoteUser      = []byte("Remote-User")
+	headerRemoteGroups    = []byte("Remote-Groups")
+	headerRemoteName      = []byte("Remote-Name")
+	headerRemoteEmail     = []byte("Remote-Email")
 )
 
 const (
diff --git a/internal/handlers/handler_verify.go b/internal/handlers/handler_verify.go
index 6fce4474..ba3f71f6 100644
--- a/internal/handlers/handler_verify.go
+++ b/internal/handlers/handler_verify.go
@@ -33,7 +33,7 @@ func isSchemeWSS(url *url.URL) bool {
 
 // parseBasicAuth parses an HTTP Basic Authentication string.
 // "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
-func parseBasicAuth(header, auth string) (username, password string, err error) {
+func parseBasicAuth(header []byte, auth string) (username, password string, err error) {
 	if !strings.HasPrefix(auth, authPrefix) {
 		return "", "", fmt.Errorf("%s prefix not found in %s header", strings.Trim(authPrefix, " "), header)
 	}
@@ -85,7 +85,7 @@ func isTargetURLAuthorized(authorizer *authorization.Authorizer, targetURL url.U
 
 // verifyBasicAuth verify that the provided username and password are correct and
 // that the user is authorized to target the resource.
-func verifyBasicAuth(header string, auth []byte, ctx *middlewares.AutheliaCtx) (username, name string, groups, emails []string, authLevel authentication.Level, err error) {
+func verifyBasicAuth(ctx *middlewares.AutheliaCtx, header, auth []byte) (username, name string, groups, emails []string, authLevel authentication.Level, err error) {
 	username, password, err := parseBasicAuth(header, string(auth))
 
 	if err != nil {
@@ -116,14 +116,14 @@ func verifyBasicAuth(header string, auth []byte, ctx *middlewares.AutheliaCtx) (
 // setForwardedHeaders set the forwarded User, Groups, Name and Email headers.
 func setForwardedHeaders(headers *fasthttp.ResponseHeader, username, name string, groups, emails []string) {
 	if username != "" {
-		headers.Set(headerRemoteUser, username)
-		headers.Set(headerRemoteGroups, strings.Join(groups, ","))
-		headers.Set(headerRemoteName, name)
+		headers.SetBytesK(headerRemoteUser, username)
+		headers.SetBytesK(headerRemoteGroups, strings.Join(groups, ","))
+		headers.SetBytesK(headerRemoteName, name)
 
 		if emails != nil {
-			headers.Set(headerRemoteEmail, emails[0])
+			headers.SetBytesK(headerRemoteEmail, emails[0])
 		} else {
-			headers.Set(headerRemoteEmail, "")
+			headers.SetBytesK(headerRemoteEmail, "")
 		}
 	}
 }
@@ -403,13 +403,13 @@ func getProfileRefreshSettings(cfg schema.AuthenticationBackendConfiguration) (r
 }
 
 func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile bool, refreshProfileInterval time.Duration) (isBasicAuth bool, username, name string, groups, emails []string, authLevel authentication.Level, err error) {
-	authHeader := HeaderProxyAuthorization
+	authHeader := headerProxyAuthorization
 	if bytes.Equal(ctx.QueryArgs().Peek("auth"), []byte("basic")) {
-		authHeader = HeaderAuthorization
+		authHeader = headerAuthorization
 		isBasicAuth = true
 	}
 
-	authValue := ctx.Request.Header.Peek(authHeader)
+	authValue := ctx.Request.Header.PeekBytes(authHeader)
 	if authValue != nil {
 		isBasicAuth = true
 	} else if isBasicAuth {
@@ -418,23 +418,23 @@ func verifyAuth(ctx *middlewares.AutheliaCtx, targetURL *url.URL, refreshProfile
 	}
 
 	if isBasicAuth {
-		username, name, groups, emails, authLevel, err = verifyBasicAuth(authHeader, authValue, ctx)
+		username, name, groups, emails, authLevel, err = verifyBasicAuth(ctx, authHeader, authValue)
 		return
 	}
 
 	userSession := ctx.GetSession()
 	username, name, groups, emails, authLevel, err = verifySessionCookie(ctx, targetURL, &userSession, refreshProfile, refreshProfileInterval)
 
-	sessionUsername := ctx.Request.Header.Peek(HeaderSessionUsername)
+	sessionUsername := ctx.Request.Header.PeekBytes(headerSessionUsername)
 	if sessionUsername != nil && !strings.EqualFold(string(sessionUsername), username) {
 		ctx.Logger.Warnf("Possible cookie hijack or attempt to bypass security detected destroying the session and sending 401 response")
 
 		err = ctx.Providers.SessionProvider.DestroySession(ctx.RequestCtx)
 		if err != nil {
-			ctx.Logger.Errorf("Unable to destroy user session after handler could not match them to their %s header: %s", HeaderSessionUsername, err)
+			ctx.Logger.Errorf("Unable to destroy user session after handler could not match them to their %s header: %s", headerSessionUsername, err)
 		}
 
-		err = fmt.Errorf("could not match user %s to their %s header with a value of %s when visiting %s", username, HeaderSessionUsername, sessionUsername, targetURL.String())
+		err = fmt.Errorf("could not match user %s to their %s header with a value of %s when visiting %s", username, headerSessionUsername, sessionUsername, targetURL.String())
 	}
 
 	return
diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go
index cf1faa17..bb78ff02 100644
--- a/internal/handlers/handler_verify_test.go
+++ b/internal/handlers/handler_verify_test.go
@@ -85,34 +85,34 @@ func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) {
 
 // Test parseBasicAuth.
 func TestShouldRaiseWhenHeaderDoesNotContainBasicPrefix(t *testing.T) {
-	_, _, err := parseBasicAuth(HeaderProxyAuthorization, "alzefzlfzemjfej==")
+	_, _, err := parseBasicAuth(headerProxyAuthorization, "alzefzlfzemjfej==")
 	assert.Error(t, err)
 	assert.Equal(t, "Basic prefix not found in Proxy-Authorization header", err.Error())
 }
 
 func TestShouldRaiseWhenCredentialsAreNotInBase64(t *testing.T) {
-	_, _, err := parseBasicAuth(HeaderProxyAuthorization, "Basic alzefzlfzemjfej==")
+	_, _, err := parseBasicAuth(headerProxyAuthorization, "Basic alzefzlfzemjfej==")
 	assert.Error(t, err)
 	assert.Equal(t, "illegal base64 data at input byte 16", err.Error())
 }
 
 func TestShouldRaiseWhenCredentialsAreNotInCorrectForm(t *testing.T) {
 	// The decoded format should be user:password.
-	_, _, err := parseBasicAuth(HeaderProxyAuthorization, "Basic am9obiBwYXNzd29yZA==")
+	_, _, err := parseBasicAuth(headerProxyAuthorization, "Basic am9obiBwYXNzd29yZA==")
 	assert.Error(t, err)
 	assert.Equal(t, "format of Proxy-Authorization header must be user:password", err.Error())
 }
 
 func TestShouldUseProvidedHeaderName(t *testing.T) {
 	// The decoded format should be user:password.
-	_, _, err := parseBasicAuth("HeaderName", "")
+	_, _, err := parseBasicAuth([]byte("HeaderName"), "")
 	assert.Error(t, err)
 	assert.Equal(t, "Basic prefix not found in HeaderName header", err.Error())
 }
 
 func TestShouldReturnUsernameAndPassword(t *testing.T) {
 	// the decoded format should be user:password.
-	user, password, err := parseBasicAuth(HeaderProxyAuthorization, "Basic am9objpwYXNzd29yZA==")
+	user, password, err := parseBasicAuth(headerProxyAuthorization, "Basic am9objpwYXNzd29yZA==")
 	assert.NoError(t, err)
 	assert.Equal(t, "john", user)
 	assert.Equal(t, "password", password)
@@ -176,7 +176,7 @@ func TestShouldVerifyWrongCredentials(t *testing.T) {
 		CheckUserPassword(gomock.Eq("john"), gomock.Eq("password")).
 		Return(false, nil)
 
-	_, _, _, _, _, err := verifyBasicAuth(HeaderProxyAuthorization, []byte("Basic am9objpwYXNzd29yZA=="), mock.Ctx)
+	_, _, _, _, _, err := verifyBasicAuth(mock.Ctx, headerProxyAuthorization, []byte("Basic am9objpwYXNzd29yZA=="))
 
 	assert.Error(t, err)
 }
@@ -1211,7 +1211,7 @@ func TestShouldCheckValidSessionUsernameHeaderAndReturn200(t *testing.T) {
 	require.NoError(t, err)
 
 	mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com")
-	mock.Ctx.Request.Header.Set(HeaderSessionUsername, testUsername)
+	mock.Ctx.Request.Header.SetBytesK(headerSessionUsername, testUsername)
 	VerifyGet(verifyGetCfg)(mock.Ctx)
 
 	assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode())
@@ -1235,7 +1235,7 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) {
 	require.NoError(t, err)
 
 	mock.Ctx.Request.Header.Set("X-Original-URL", "https://one-factor.example.com")
-	mock.Ctx.Request.Header.Set(HeaderSessionUsername, "root")
+	mock.Ctx.Request.Header.SetBytesK(headerSessionUsername, "root")
 	VerifyGet(verifyGetCfg)(mock.Ctx)
 
 	assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode())
diff --git a/internal/handlers/response.go b/internal/handlers/response.go
index 34a26860..1f3f5220 100644
--- a/internal/handlers/response.go
+++ b/internal/handlers/response.go
@@ -150,7 +150,20 @@ func markAuthenticationAttempt(ctx *middlewares.AutheliaCtx, successful bool, ba
 	// We only Mark if there was no underlying error.
 	ctx.Logger.Debugf("Mark %s authentication attempt made by user '%s'", authType, username)
 
-	if err = ctx.Providers.Regulator.Mark(ctx, successful, bannedUntil != nil, username, string(ctx.RequestCtx.QueryArgs().Peek("rd")), string(ctx.RequestCtx.QueryArgs().Peek("rm")), authType, ctx.RemoteIP()); err != nil {
+	var (
+		requestURI, requestMethod string
+	)
+
+	referer := ctx.Request.Header.Referer()
+	if referer != nil {
+		refererURL, err := url.Parse(string(referer))
+		if err == nil {
+			requestURI = refererURL.Query().Get("rd")
+			requestMethod = refererURL.Query().Get("rm")
+		}
+	}
+
+	if err = ctx.Providers.Regulator.Mark(ctx, successful, bannedUntil != nil, username, requestURI, requestMethod, authType, ctx.RemoteIP()); err != nil {
 		ctx.Logger.Errorf("Unable to mark %s authentication attempt by user '%s': %+v", authType, username, err)
 
 		return err
diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go
index 4a1242ca..1552cd3d 100644
--- a/internal/middlewares/authelia_context.go
+++ b/internal/middlewares/authelia_context.go
@@ -102,22 +102,22 @@ func (c *AutheliaCtx) ReplyBadRequest() {
 
 // XForwardedProto return the content of the X-Forwarded-Proto header.
 func (c *AutheliaCtx) XForwardedProto() []byte {
-	return c.RequestCtx.Request.Header.Peek(headerXForwardedProto)
+	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto)
 }
 
 // XForwardedMethod return the content of the X-Forwarded-Method header.
 func (c *AutheliaCtx) XForwardedMethod() []byte {
-	return c.RequestCtx.Request.Header.Peek(headerXForwardedMethod)
+	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
 }
 
 // XForwardedHost return the content of the X-Forwarded-Host header.
 func (c *AutheliaCtx) XForwardedHost() []byte {
-	return c.RequestCtx.Request.Header.Peek(headerXForwardedHost)
+	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost)
 }
 
 // XForwardedURI return the content of the X-Forwarded-URI header.
 func (c *AutheliaCtx) XForwardedURI() []byte {
-	return c.RequestCtx.Request.Header.Peek(headerXForwardedURI)
+	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI)
 }
 
 // BasePath returns the base_url as per the path visited by the client.
@@ -159,7 +159,7 @@ func (c *AutheliaCtx) ExternalRootURL() (string, error) {
 
 // XOriginalURL return the content of the X-Original-URL header.
 func (c *AutheliaCtx) XOriginalURL() []byte {
-	return c.RequestCtx.Request.Header.Peek(headerXOriginalURL)
+	return c.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
 }
 
 // GetSession return the user session. Any update will be saved in cache.
@@ -220,7 +220,7 @@ func (c *AutheliaCtx) SetJSONBody(value interface{}) error {
 
 // RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
 func (c *AutheliaCtx) RemoteIP() net.IP {
-	XForwardedFor := c.Request.Header.Peek("X-Forwarded-For")
+	XForwardedFor := c.Request.Header.PeekBytes(headerXForwardedFor)
 	if XForwardedFor != nil {
 		ips := strings.Split(string(XForwardedFor), ",")
 
@@ -278,14 +278,14 @@ func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
 
 // IsXHR returns true if the request is a XMLHttpRequest.
 func (c AutheliaCtx) IsXHR() (xhr bool) {
-	requestedWith := c.Request.Header.Peek(headerXRequestedWith)
+	requestedWith := c.Request.Header.PeekBytes(headerXRequestedWith)
 
 	return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR
 }
 
 // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
 func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
-	accepts := strings.Split(string(c.Request.Header.Peek("Accept")), ",")
+	accepts := strings.Split(string(c.Request.Header.PeekBytes(headerAccept)), ",")
 
 	for i, accept := range accepts {
 		mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ")
diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go
index 23de4427..01fc3f1e 100644
--- a/internal/middlewares/const.go
+++ b/internal/middlewares/const.go
@@ -1,21 +1,25 @@
 package middlewares
 
-const (
-	headerXForwardedProto  = "X-Forwarded-Proto"
-	headerXForwardedMethod = "X-Forwarded-Method"
-	headerXForwardedHost   = "X-Forwarded-Host"
-	headerXForwardedURI    = "X-Forwarded-URI"
-	headerXOriginalURL     = "X-Original-URL"
-	headerXRequestedWith   = "X-Requested-With"
+import (
+	"github.com/valyala/fasthttp"
+)
+
+var (
+	headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
+	headerXForwardedHost  = []byte(fasthttp.HeaderXForwardedHost)
+	headerXForwardedFor   = []byte(fasthttp.HeaderXForwardedFor)
+	headerXRequestedWith  = []byte(fasthttp.HeaderXRequestedWith)
+	headerAccept          = []byte(fasthttp.HeaderAccept)
+
+	headerXForwardedURI    = []byte("X-Forwarded-URI")
+	headerXOriginalURL     = []byte("X-Original-URL")
+	headerXForwardedMethod = []byte("X-Forwarded-Method")
 )
 
 const (
 	headerValueXRequestedWithXHR = "XMLHttpRequest"
-)
-
-const (
-	contentTypeApplicationJSON = "application/json"
-	contentTypeTextHTML        = "text/html"
+	contentTypeApplicationJSON   = "application/json"
+	contentTypeTextHTML          = "text/html"
 )
 
 var okMessageBytes = []byte("{\"status\":\"OK\"}")