From 26236f491e6d2b16ae2bc8297e33a9dc883f44e5 Mon Sep 17 00:00:00 2001
From: James Elliott <james-d-elliott@users.noreply.github.com>
Date: Mon, 7 Feb 2022 00:37:28 +1100
Subject: [PATCH] fix(server): use of inconsistent methods for determining
 origin (#2848)

This unifies the methods to obtain the X-Forwarded-* header values and provides logical fallbacks. In addition, so we can ensure this functionality extends to the templated files we've converted the ServeTemplatedFile method into a function that operates as a middlewares.RequestHandler.

Fixes #2765
---
 .../handler_register_u2f_step1_test.go        |  19 ---
 .../handlers/handler_sign_u2f_step1_test.go   |   2 +-
 internal/handlers/handler_verify_test.go      |   4 +-
 internal/middlewares/authelia_context.go      | 158 ++++++++++--------
 internal/middlewares/authelia_context_test.go |  47 +++++-
 internal/middlewares/const.go                 |   6 +
 .../middlewares/identity_verification_test.go |  18 --
 internal/middlewares/strip_path.go            |   2 +-
 internal/server/server.go                     |   8 +-
 internal/server/template.go                   |  19 +--
 10 files changed, 158 insertions(+), 125 deletions(-)

diff --git a/internal/handlers/handler_register_u2f_step1_test.go b/internal/handlers/handler_register_u2f_step1_test.go
index 6223b318..3a3a160b 100644
--- a/internal/handlers/handler_register_u2f_step1_test.go
+++ b/internal/handlers/handler_register_u2f_step1_test.go
@@ -48,25 +48,6 @@ func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt
 	return ss, verification
 }
 
-func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() {
-	token, verification := createToken(s.mock, "john", ActionU2FRegistration,
-		time.Now().Add(1*time.Minute))
-	s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
-
-	s.mock.StorageMock.EXPECT().
-		FindIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
-		Return(true, nil)
-
-	s.mock.StorageMock.EXPECT().
-		ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
-		Return(nil)
-
-	SecondFactorU2FIdentityFinish(s.mock.Ctx)
-
-	assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
-	assert.Equal(s.T(), "missing header X-Forwarded-Proto", s.mock.Hook.LastEntry().Message)
-}
-
 func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
 	s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
 	token, verification := createToken(s.mock, "john", ActionU2FRegistration,
diff --git a/internal/handlers/handler_sign_u2f_step1_test.go b/internal/handlers/handler_sign_u2f_step1_test.go
index 5483c8bc..5381cbb2 100644
--- a/internal/handlers/handler_sign_u2f_step1_test.go
+++ b/internal/handlers/handler_sign_u2f_step1_test.go
@@ -27,7 +27,7 @@ func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing()
 	SecondFactorU2FSignGet(s.mock.Ctx)
 
 	assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
-	assert.Equal(s.T(), "missing header X-Forwarded-Proto", s.mock.Hook.LastEntry().Message)
+	assert.Equal(s.T(), "missing header X-Forwarded-Host", s.mock.Hook.LastEntry().Message)
 }
 
 func (s *HandlerSignU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go
index 12a37abe..b8f84f21 100644
--- a/internal/handlers/handler_verify_test.go
+++ b/internal/handlers/handler_verify_test.go
@@ -45,7 +45,7 @@ func TestShouldRaiseWhenNoHeaderProvidedToDetectTargetURL(t *testing.T) {
 	defer mock.Close()
 	_, err := mock.Ctx.GetOriginalURL()
 	assert.Error(t, err)
-	assert.Equal(t, "Missing header X-Forwarded-Proto", err.Error())
+	assert.Equal(t, "Missing header X-Forwarded-Host", err.Error())
 }
 
 func TestShouldRaiseWhenNoXForwardedHostHeaderProvidedToDetectTargetURL(t *testing.T) {
@@ -67,7 +67,7 @@ func TestShouldRaiseWhenXForwardedProtoIsNotParsable(t *testing.T) {
 
 	_, err := mock.Ctx.GetOriginalURL()
 	assert.Error(t, err)
-	assert.Equal(t, "Unable to parse URL !:;;:,://myhost.local: parse \"!:;;:,://myhost.local\": invalid URI for request", err.Error())
+	assert.Equal(t, "Unable to parse URL !:;;:,://myhost.local/: parse \"!:;;:,://myhost.local/\": invalid URI for request", err.Error())
 }
 
 func TestShouldRaiseWhenXForwardedURIIsNotParsable(t *testing.T) {
diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go
index ec4b0dce..03016a68 100644
--- a/internal/middlewares/authelia_context.go
+++ b/internal/middlewares/authelia_context.go
@@ -55,75 +55,97 @@ func AutheliaMiddleware(configuration schema.Configuration, providers Providers)
 }
 
 // Error reply with an error and display the stack trace in the logs.
-func (c *AutheliaCtx) Error(err error, message string) {
-	c.SetJSONError(message)
+func (ctx *AutheliaCtx) Error(err error, message string) {
+	ctx.SetJSONError(message)
 
-	c.Logger.Error(err)
+	ctx.Logger.Error(err)
 }
 
 // SetJSONError sets the body of the response to an JSON error KO message.
-func (c *AutheliaCtx) SetJSONError(message string) {
+func (ctx *AutheliaCtx) SetJSONError(message string) {
 	b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
 
 	if marshalErr != nil {
-		c.Logger.Error(marshalErr)
+		ctx.Logger.Error(marshalErr)
 	}
 
-	c.SetContentType(contentTypeApplicationJSON)
-	c.SetBody(b)
+	ctx.SetContentType(contentTypeApplicationJSON)
+	ctx.SetBody(b)
 }
 
 // ReplyError reply with an error but does not display any stack trace in the logs.
-func (c *AutheliaCtx) ReplyError(err error, message string) {
+func (ctx *AutheliaCtx) ReplyError(err error, message string) {
 	b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
 
 	if marshalErr != nil {
-		c.Logger.Error(marshalErr)
+		ctx.Logger.Error(marshalErr)
 	}
 
-	c.SetContentType(contentTypeApplicationJSON)
-	c.SetBody(b)
-	c.Logger.Debug(err)
+	ctx.SetContentType(contentTypeApplicationJSON)
+	ctx.SetBody(b)
+	ctx.Logger.Debug(err)
 }
 
 // ReplyUnauthorized response sent when user is unauthorized.
-func (c *AutheliaCtx) ReplyUnauthorized() {
-	c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
+func (ctx *AutheliaCtx) ReplyUnauthorized() {
+	ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
 }
 
 // ReplyForbidden response sent when access is forbidden to user.
-func (c *AutheliaCtx) ReplyForbidden() {
-	c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden)
+func (ctx *AutheliaCtx) ReplyForbidden() {
+	ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden)
 }
 
 // ReplyBadRequest response sent when bad request has been sent.
-func (c *AutheliaCtx) ReplyBadRequest() {
-	c.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest)
+func (ctx *AutheliaCtx) ReplyBadRequest() {
+	ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest)
 }
 
 // XForwardedProto return the content of the X-Forwarded-Proto header.
-func (c *AutheliaCtx) XForwardedProto() []byte {
-	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto)
+func (ctx *AutheliaCtx) XForwardedProto() (proto []byte) {
+	proto = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedProto)
+
+	if proto == nil {
+		if ctx.RequestCtx.IsTLS() {
+			return protoHTTPS
+		}
+
+		return protoHTTP
+	}
+
+	return proto
 }
 
 // XForwardedMethod return the content of the X-Forwarded-Method header.
-func (c *AutheliaCtx) XForwardedMethod() []byte {
-	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
+func (ctx *AutheliaCtx) XForwardedMethod() (method []byte) {
+	return ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedMethod)
 }
 
 // XForwardedHost return the content of the X-Forwarded-Host header.
-func (c *AutheliaCtx) XForwardedHost() []byte {
-	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost)
+func (ctx *AutheliaCtx) XForwardedHost() (host []byte) {
+	host = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedHost)
+
+	if host == nil {
+		return ctx.RequestCtx.Host()
+	}
+
+	return host
 }
 
 // XForwardedURI return the content of the X-Forwarded-URI header.
-func (c *AutheliaCtx) XForwardedURI() []byte {
-	return c.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI)
+func (ctx *AutheliaCtx) XForwardedURI() (uri []byte) {
+	uri = ctx.RequestCtx.Request.Header.PeekBytes(headerXForwardedURI)
+
+	if len(uri) == 0 {
+		return ctx.RequestCtx.RequestURI()
+	}
+
+	return uri
 }
 
 // BasePath returns the base_url as per the path visited by the client.
-func (c *AutheliaCtx) BasePath() (base string) {
-	if baseURL := c.UserValue("base_url"); baseURL != nil {
+func (ctx *AutheliaCtx) BasePath() (base string) {
+	if baseURL := ctx.UserValueBytes(UserValueKeyBaseURL); baseURL != nil {
 		return baseURL.(string)
 	}
 
@@ -131,20 +153,20 @@ func (c *AutheliaCtx) BasePath() (base string) {
 }
 
 // ExternalRootURL gets the X-Forwarded-Proto, X-Forwarded-Host headers and the BasePath and forms them into a URL.
-func (c *AutheliaCtx) ExternalRootURL() (string, error) {
-	protocol := c.XForwardedProto()
+func (ctx *AutheliaCtx) ExternalRootURL() (string, error) {
+	protocol := ctx.XForwardedProto()
 	if protocol == nil {
 		return "", errMissingXForwardedProto
 	}
 
-	host := c.XForwardedHost()
+	host := ctx.XForwardedHost()
 	if host == nil {
 		return "", errMissingXForwardedHost
 	}
 
 	externalRootURL := fmt.Sprintf("%s://%s", protocol, host)
 
-	if base := c.BasePath(); base != "" {
+	if base := ctx.BasePath(); base != "" {
 		externalBaseURL, err := url.Parse(externalRootURL)
 		if err != nil {
 			return "", err
@@ -159,15 +181,15 @@ func (c *AutheliaCtx) ExternalRootURL() (string, error) {
 }
 
 // XOriginalURL return the content of the X-Original-URL header.
-func (c *AutheliaCtx) XOriginalURL() []byte {
-	return c.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
+func (ctx *AutheliaCtx) XOriginalURL() []byte {
+	return ctx.RequestCtx.Request.Header.PeekBytes(headerXOriginalURL)
 }
 
 // GetSession return the user session. Any update will be saved in cache.
-func (c *AutheliaCtx) GetSession() session.UserSession {
-	userSession, err := c.Providers.SessionProvider.GetSession(c.RequestCtx)
+func (ctx *AutheliaCtx) GetSession() session.UserSession {
+	userSession, err := ctx.Providers.SessionProvider.GetSession(ctx.RequestCtx)
 	if err != nil {
-		c.Logger.Error("Unable to retrieve user session")
+		ctx.Logger.Error("Unable to retrieve user session")
 		return session.NewDefaultUserSession()
 	}
 
@@ -175,19 +197,19 @@ func (c *AutheliaCtx) GetSession() session.UserSession {
 }
 
 // SaveSession save the content of the session.
-func (c *AutheliaCtx) SaveSession(userSession session.UserSession) error {
-	return c.Providers.SessionProvider.SaveSession(c.RequestCtx, userSession)
+func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error {
+	return ctx.Providers.SessionProvider.SaveSession(ctx.RequestCtx, userSession)
 }
 
 // ReplyOK is a helper method to reply ok.
-func (c *AutheliaCtx) ReplyOK() {
-	c.SetContentType(contentTypeApplicationJSON)
-	c.SetBody(okMessageBytes)
+func (ctx *AutheliaCtx) ReplyOK() {
+	ctx.SetContentType(contentTypeApplicationJSON)
+	ctx.SetBody(okMessageBytes)
 }
 
 // ParseBody parse the request body into the type of value.
-func (c *AutheliaCtx) ParseBody(value interface{}) error {
-	err := json.Unmarshal(c.PostBody(), &value)
+func (ctx *AutheliaCtx) ParseBody(value interface{}) error {
+	err := json.Unmarshal(ctx.PostBody(), &value)
 
 	if err != nil {
 		return fmt.Errorf("unable to parse body: %w", err)
@@ -207,21 +229,21 @@ func (c *AutheliaCtx) ParseBody(value interface{}) error {
 }
 
 // SetJSONBody Set json body.
-func (c *AutheliaCtx) SetJSONBody(value interface{}) error {
+func (ctx *AutheliaCtx) SetJSONBody(value interface{}) error {
 	b, err := json.Marshal(OKResponse{Status: "OK", Data: value})
 	if err != nil {
 		return fmt.Errorf("unable to marshal JSON body: %w", err)
 	}
 
-	c.SetContentType(contentTypeApplicationJSON)
-	c.SetBody(b)
+	ctx.SetContentType(contentTypeApplicationJSON)
+	ctx.SetBody(b)
 
 	return nil
 }
 
 // RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
-func (c *AutheliaCtx) RemoteIP() net.IP {
-	XForwardedFor := c.Request.Header.PeekBytes(headerXForwardedFor)
+func (ctx *AutheliaCtx) RemoteIP() net.IP {
+	XForwardedFor := ctx.Request.Header.PeekBytes(headerXForwardedFor)
 	if XForwardedFor != nil {
 		ips := strings.Split(string(XForwardedFor), ",")
 
@@ -230,26 +252,24 @@ func (c *AutheliaCtx) RemoteIP() net.IP {
 		}
 	}
 
-	return c.RequestCtx.RemoteIP()
+	return ctx.RequestCtx.RemoteIP()
 }
 
-// GetOriginalURL extract the URL from the request headers (X-Original-URI or X-Forwarded-* headers).
-func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
-	originalURL := c.XOriginalURL()
+// GetOriginalURL extract the URL from the request headers (X-Original-URL or X-Forwarded-* headers).
+func (ctx *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
+	originalURL := ctx.XOriginalURL()
 	if originalURL != nil {
 		parsedURL, err := url.ParseRequestURI(string(originalURL))
 		if err != nil {
 			return nil, fmt.Errorf("Unable to parse URL extracted from X-Original-URL header: %v", err)
 		}
 
-		c.Logger.Trace("Using X-Original-URL header content as targeted site URL")
+		ctx.Logger.Trace("Using X-Original-URL header content as targeted site URL")
 
 		return parsedURL, nil
 	}
 
-	forwardedProto := c.XForwardedProto()
-	forwardedHost := c.XForwardedHost()
-	forwardedURI := c.XForwardedURI()
+	forwardedProto, forwardedHost, forwardedURI := ctx.XForwardedProto(), ctx.XForwardedHost(), ctx.XForwardedURI()
 
 	if forwardedProto == nil {
 		return nil, errMissingXForwardedProto
@@ -271,22 +291,22 @@ func (c *AutheliaCtx) GetOriginalURL() (*url.URL, error) {
 		return nil, fmt.Errorf("Unable to parse URL %s: %v", requestURI, err)
 	}
 
-	c.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " +
+	ctx.Logger.Tracef("Using X-Fowarded-Proto, X-Forwarded-Host and X-Forwarded-URI headers " +
 		"to construct targeted site URL")
 
 	return parsedURL, nil
 }
 
 // IsXHR returns true if the request is a XMLHttpRequest.
-func (c AutheliaCtx) IsXHR() (xhr bool) {
-	requestedWith := c.Request.Header.PeekBytes(headerXRequestedWith)
+func (ctx AutheliaCtx) IsXHR() (xhr bool) {
+	requestedWith := ctx.Request.Header.PeekBytes(headerXRequestedWith)
 
-	return requestedWith != nil && string(requestedWith) == headerValueXRequestedWithXHR
+	return requestedWith != nil && strings.EqualFold(string(requestedWith), headerValueXRequestedWithXHR)
 }
 
 // AcceptsMIME takes a mime type and returns true if the request accepts that type or the wildcard type.
-func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
-	accepts := strings.Split(string(c.Request.Header.PeekBytes(headerAccept)), ",")
+func (ctx AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
+	accepts := strings.Split(string(ctx.Request.Header.PeekBytes(headerAccept)), ",")
 
 	for i, accept := range accepts {
 		mimeType := strings.Trim(strings.SplitN(accept, ";", 2)[0], " ")
@@ -300,22 +320,22 @@ func (c AutheliaCtx) AcceptsMIME(mime string) (acceptsMime bool) {
 
 // SpecialRedirect performs a redirect similar to fasthttp.RequestCtx except it allows statusCode 401 and includes body
 // content in the form of a link to the location.
-func (c *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
+func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
 	if statusCode < fasthttp.StatusMovedPermanently || (statusCode > fasthttp.StatusSeeOther && statusCode != fasthttp.StatusTemporaryRedirect && statusCode != fasthttp.StatusPermanentRedirect && statusCode != fasthttp.StatusUnauthorized) {
 		statusCode = fasthttp.StatusFound
 	}
 
-	c.SetContentType(contentTypeTextHTML)
-	c.SetStatusCode(statusCode)
+	ctx.SetContentType(contentTypeTextHTML)
+	ctx.SetStatusCode(statusCode)
 
 	u := fasthttp.AcquireURI()
 
-	c.URI().CopyTo(u)
+	ctx.URI().CopyTo(u)
 	u.Update(uri)
 
-	c.Response.Header.SetBytesV("Location", u.FullURI())
+	ctx.Response.Header.SetBytesV("Location", u.FullURI())
 
-	c.SetBodyString(fmt.Sprintf("<a href=\"%s\">%s</a>", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode)))
+	ctx.SetBodyString(fmt.Sprintf("<a href=\"%s\">%s</a>", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode)))
 
 	fasthttp.ReleaseURI(u)
 }
diff --git a/internal/middlewares/authelia_context_test.go b/internal/middlewares/authelia_context_test.go
index f5ea6281..5a26f1c7 100644
--- a/internal/middlewares/authelia_context_test.go
+++ b/internal/middlewares/authelia_context_test.go
@@ -57,7 +57,7 @@ func TestShouldGetOriginalURLFromForwardedHeadersWithoutURI(t *testing.T) {
 	originalURL, err := mock.Ctx.GetOriginalURL()
 	assert.NoError(t, err)
 
-	expectedURL, err := url.ParseRequestURI("https://home.example.com")
+	expectedURL, err := url.ParseRequestURI("https://home.example.com/")
 	assert.NoError(t, err)
 	assert.Equal(t, expectedURL, originalURL)
 }
@@ -70,3 +70,48 @@ func TestShouldGetOriginalURLFromForwardedHeadersWithURI(t *testing.T) {
 	assert.Error(t, err)
 	assert.Equal(t, "Unable to parse URL extracted from X-Original-URL header: parse \"htt-ps//home?-.example.com\": invalid URI for request", err.Error())
 }
+
+func TestShouldFallbackToNonXForwardedHeaders(t *testing.T) {
+	mock := mocks.NewMockAutheliaCtx(t)
+	defer mock.Close()
+
+	mock.Ctx.RequestCtx.Request.SetRequestURI("/2fa/one-time-password")
+	mock.Ctx.RequestCtx.Request.SetHost("auth.example.com:1234")
+
+	assert.Equal(t, []byte("http"), mock.Ctx.XForwardedProto())
+	assert.Equal(t, []byte("auth.example.com:1234"), mock.Ctx.XForwardedHost())
+	assert.Equal(t, []byte("/2fa/one-time-password"), mock.Ctx.XForwardedURI())
+}
+
+func TestShouldOnlyFallbackToNonXForwardedHeadersWhenNil(t *testing.T) {
+	mock := mocks.NewMockAutheliaCtx(t)
+	defer mock.Close()
+
+	mock.Ctx.RequestCtx.Request.SetRequestURI("/2fa/one-time-password")
+	mock.Ctx.RequestCtx.Request.SetHost("localhost")
+	mock.Ctx.RequestCtx.Request.Header.Set(fasthttp.HeaderXForwardedHost, "auth.example.com:1234")
+	mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-URI", "/base/2fa/one-time-password")
+	mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-Proto", "https")
+	mock.Ctx.RequestCtx.Request.Header.Set("X-Forwarded-Method", "GET")
+
+	assert.Equal(t, []byte("https"), mock.Ctx.XForwardedProto())
+	assert.Equal(t, []byte("auth.example.com:1234"), mock.Ctx.XForwardedHost())
+	assert.Equal(t, []byte("/base/2fa/one-time-password"), mock.Ctx.XForwardedURI())
+	assert.Equal(t, []byte("GET"), mock.Ctx.XForwardedMethod())
+}
+
+func TestShouldDetectXHR(t *testing.T) {
+	mock := mocks.NewMockAutheliaCtx(t)
+	defer mock.Close()
+
+	mock.Ctx.RequestCtx.Request.Header.Set(fasthttp.HeaderXRequestedWith, "XMLHttpRequest")
+
+	assert.True(t, mock.Ctx.IsXHR())
+}
+
+func TestShouldDetectNonXHR(t *testing.T) {
+	mock := mocks.NewMockAutheliaCtx(t)
+	defer mock.Close()
+
+	assert.False(t, mock.Ctx.IsXHR())
+}
diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go
index 01fc3f1e..dbe37e6f 100644
--- a/internal/middlewares/const.go
+++ b/internal/middlewares/const.go
@@ -14,6 +14,12 @@ var (
 	headerXForwardedURI    = []byte("X-Forwarded-URI")
 	headerXOriginalURL     = []byte("X-Original-URL")
 	headerXForwardedMethod = []byte("X-Forwarded-Method")
+
+	protoHTTPS = []byte("https")
+	protoHTTP  = []byte("http")
+
+	// UserValueKeyBaseURL is the User Value key where we store the Base URL.
+	UserValueKeyBaseURL = []byte("base_url")
 )
 
 const (
diff --git a/internal/middlewares/identity_verification_test.go b/internal/middlewares/identity_verification_test.go
index 38e592c7..8bfedda6 100644
--- a/internal/middlewares/identity_verification_test.go
+++ b/internal/middlewares/identity_verification_test.go
@@ -90,24 +90,6 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
 	assert.Equal(t, "no notif", mock.Hook.LastEntry().Message)
 }
 
-func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
-	mock := mocks.NewMockAutheliaCtx(t)
-	defer mock.Close()
-
-	mock.Ctx.Configuration.JWTSecret = testJWTSecret
-	mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
-
-	mock.StorageMock.EXPECT().
-		SaveIdentityVerification(mock.Ctx, gomock.Any()).
-		Return(nil)
-
-	args := newArgs(defaultRetriever)
-	middlewares.IdentityVerificationStart(args, nil)(mock.Ctx)
-
-	assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
-	assert.Equal(t, "Missing header X-Forwarded-Proto", mock.Hook.LastEntry().Message)
-}
-
 func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
 	mock := mocks.NewMockAutheliaCtx(t)
 	defer mock.Close()
diff --git a/internal/middlewares/strip_path.go b/internal/middlewares/strip_path.go
index 2bddcb55..8079bb42 100644
--- a/internal/middlewares/strip_path.go
+++ b/internal/middlewares/strip_path.go
@@ -12,7 +12,7 @@ func StripPathMiddleware(path string, next fasthttp.RequestHandler) fasthttp.Req
 		uri := ctx.RequestURI()
 
 		if strings.HasPrefix(string(uri), path) {
-			ctx.SetUserValue("base_url", path)
+			ctx.SetUserValueBytes(UserValueKeyBaseURL, path)
 
 			newURI := strings.TrimPrefix(string(uri), path)
 			ctx.Request.SetRequestURI(newURI)
diff --git a/internal/server/server.go b/internal/server/server.go
index d75bb0d4..cbc53e71 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -46,11 +46,11 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
 	serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, configuration.Session.Name, configuration.Theme, https)
 
 	r := router.New()
-	r.GET("/", serveIndexHandler)
+	r.GET("/", autheliaMiddleware(serveIndexHandler))
 	r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
 
-	r.GET("/api/", serveSwaggerHandler)
-	r.GET("/api/"+apiFile, serveSwaggerAPIHandler)
+	r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
+	r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler))
 
 	for _, f := range rootFiles {
 		r.GET("/"+f, middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, embeddedFS))
@@ -148,7 +148,7 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
 		r.GET("/debug/vars", expvarhandler.ExpvarHandler)
 	}
 
-	r.NotFound = serveIndexHandler
+	r.NotFound = autheliaMiddleware(serveIndexHandler)
 
 	handler := middlewares.LogRequestMiddleware(r.Handler)
 	if configuration.Server.Path != "" {
diff --git a/internal/server/template.go b/internal/server/template.go
index a9592756..e3214aca 100644
--- a/internal/server/template.go
+++ b/internal/server/template.go
@@ -7,16 +7,15 @@ import (
 	"path/filepath"
 	"text/template"
 
-	"github.com/valyala/fasthttp"
-
 	"github.com/authelia/authelia/v4/internal/logging"
+	"github.com/authelia/authelia/v4/internal/middlewares"
 	"github.com/authelia/authelia/v4/internal/utils"
 )
 
 // ServeTemplatedFile serves a templated version of a specified file,
 // this is utilised to pass information between the backend and frontend
 // and generate a nonce to support a restrictive CSP while using material-ui.
-func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, session, theme string, https bool) fasthttp.RequestHandler {
+func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberMe, resetPassword, session, theme string, https bool) middlewares.RequestHandler {
 	logger := logging.Logger()
 
 	a, err := assets.Open(publicDir + file)
@@ -34,9 +33,9 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
 		logger.Fatalf("Unable to parse %s template: %s", file, err)
 	}
 
-	return func(ctx *fasthttp.RequestCtx) {
+	return func(ctx *middlewares.AutheliaCtx) {
 		base := ""
-		if baseURL := ctx.UserValue("base_url"); baseURL != nil {
+		if baseURL := ctx.UserValueBytes(middlewares.UserValueKeyBaseURL); baseURL != nil {
 			base = baseURL.(string)
 		}
 
@@ -51,16 +50,16 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
 		var scheme = "https"
 
 		if !https {
-			proto := string(ctx.Request.Header.Peek(fasthttp.HeaderXForwardedProto))
+			proto := string(ctx.XForwardedProto())
 			switch proto {
 			case "":
-				scheme = "http"
-			default:
+				break
+			case "http", "https":
 				scheme = proto
 			}
 		}
 
-		baseURL := scheme + "://" + string(ctx.Request.Host()) + base + "/"
+		baseURL := scheme + "://" + string(ctx.XForwardedHost()) + base + "/"
 		nonce := utils.RandomString(32, utils.AlphaNumericCharacters, true)
 
 		switch extension := filepath.Ext(file); extension {
@@ -81,7 +80,7 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
 
 		err := tmpl.Execute(ctx.Response.BodyWriter(), struct{ Base, BaseURL, CSPNonce, DuoSelfEnrollment, LogoOverride, RememberMe, ResetPassword, Session, Theme string }{Base: base, BaseURL: baseURL, CSPNonce: nonce, DuoSelfEnrollment: duoSelfEnrollment, LogoOverride: logoOverride, RememberMe: rememberMe, ResetPassword: resetPassword, Session: session, Theme: theme})
 		if err != nil {
-			ctx.Error("an error occurred", 503)
+			ctx.RequestCtx.Error("an error occurred", 503)
 			logger.Errorf("Unable to execute template: %v", err)
 
 			return