diff --git a/internal/handlers/const.go b/internal/handlers/const.go
index 0398cffe..3eea1161 100644
--- a/internal/handlers/const.go
+++ b/internal/handlers/const.go
@@ -29,7 +29,7 @@ var (
)
var (
- headerContentTypeValueDefault = []byte("text/plain; charset=utf-8")
+ headerContentTypeValueTextPlain = []byte("text/plain; charset=utf-8")
)
const (
diff --git a/internal/handlers/handler_oidc_wellknown.go b/internal/handlers/handler_oidc_wellknown.go
index 0efd5387..89d346e0 100644
--- a/internal/handlers/handler_oidc_wellknown.go
+++ b/internal/handlers/handler_oidc_wellknown.go
@@ -1,8 +1,6 @@
package handlers
import (
- "encoding/json"
-
"github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/middlewares"
@@ -18,19 +16,19 @@ func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
- ctx.Response.SetStatusCode(fasthttp.StatusBadRequest)
+
+ ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
return
}
wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer)
- ctx.SetContentType("application/json")
-
- if err = json.NewEncoder(ctx).Encode(wellKnown); err != nil {
+ if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
+
// TODO: Determine if this is the appropriate error code here.
- ctx.Response.SetStatusCode(fasthttp.StatusInternalServerError)
+ ctx.ReplyStatusCode(fasthttp.StatusInternalServerError)
return
}
@@ -46,19 +44,19 @@ func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
issuer, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
- ctx.Response.SetStatusCode(fasthttp.StatusBadRequest)
+
+ ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
return
}
wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer)
- ctx.SetContentType("application/json")
-
- if err = json.NewEncoder(ctx).Encode(wellKnown); err != nil {
+ if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil {
ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err)
+
// TODO: Determine if this is the appropriate error code here.
- ctx.Response.SetStatusCode(fasthttp.StatusInternalServerError)
+ ctx.ReplyStatusCode(fasthttp.StatusInternalServerError)
return
}
diff --git a/internal/handlers/handler_state_test.go b/internal/handlers/handler_state_test.go
index 9e56bbe4..2605be47 100644
--- a/internal/handlers/handler_state_test.go
+++ b/internal/handlers/handler_state_test.go
@@ -52,7 +52,7 @@ func (s *StateGetSuite) TestShouldReturnUsernameFromSession() {
err = json.Unmarshal(s.mock.Ctx.Response.Body(), &actualBody)
require.NoError(s.T(), err)
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
- assert.Equal(s.T(), []byte("application/json"), s.mock.Ctx.Response.Header.ContentType())
+ assert.Equal(s.T(), []byte("application/json; charset=utf-8"), s.mock.Ctx.Response.Header.ContentType())
assert.Equal(s.T(), expectedBody, actualBody)
}
@@ -82,7 +82,7 @@ func (s *StateGetSuite) TestShouldReturnAuthenticationLevelFromSession() {
err = json.Unmarshal(s.mock.Ctx.Response.Body(), &actualBody)
require.NoError(s.T(), err)
assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode())
- assert.Equal(s.T(), []byte("application/json"), s.mock.Ctx.Response.Header.ContentType())
+ assert.Equal(s.T(), []byte("application/json; charset=utf-8"), s.mock.Ctx.Response.Header.ContentType())
assert.Equal(s.T(), expectedBody, actualBody)
}
diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go
index 66eb23dd..0472b566 100644
--- a/internal/handlers/handler_verify_test.go
+++ b/internal/handlers/handler_verify_test.go
@@ -351,7 +351,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingNoHeader()
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
- assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
+ assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@@ -367,7 +367,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingEmptyHeader
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
- assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
+ assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@@ -387,7 +387,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingWrongPasswo
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
- assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
+ assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@@ -403,7 +403,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingWrongHeader
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode())
- assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body()))
+ assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body()))
assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate"))
assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate")))
}
@@ -721,7 +721,7 @@ func TestShouldRedirectWhenSessionInactiveForTooLongAndRDParamProvided(t *testin
mock.Ctx.Request.Header.Set("Accept", "text/html; charset=utf-8")
VerifyGET(verifyGetCfg)(mock.Ctx)
- assert.Equal(t, "Found",
+ assert.Equal(t, "302 Found",
string(mock.Ctx.Response.Body()))
assert.Equal(t, 302, mock.Ctx.Response.StatusCode())
@@ -741,7 +741,7 @@ func TestShouldRedirectWithCorrectStatusCodeBasedOnRequestMethod(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
- assert.Equal(t, "Found",
+ assert.Equal(t, "302 Found",
string(mock.Ctx.Response.Body()))
assert.Equal(t, 302, mock.Ctx.Response.StatusCode())
@@ -752,7 +752,7 @@ func TestShouldRedirectWithCorrectStatusCodeBasedOnRequestMethod(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
- assert.Equal(t, "See Other",
+ assert.Equal(t, "303 See Other",
string(mock.Ctx.Response.Body()))
assert.Equal(t, 303, mock.Ctx.Response.StatusCode())
}
@@ -809,7 +809,7 @@ func TestShouldURLEncodeRedirectionURLParameter(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
- assert.Equal(t, "Found",
+ assert.Equal(t, "302 Found",
string(mock.Ctx.Response.Body()))
}
@@ -1240,7 +1240,7 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) {
VerifyGET(verifyGetCfg)(mock.Ctx)
assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode())
- assert.Equal(t, "Unauthorized", string(mock.Ctx.Response.Body()))
+ assert.Equal(t, "401 Unauthorized", string(mock.Ctx.Response.Body()))
}
func TestGetProfileRefreshSettings(t *testing.T) {
diff --git a/internal/handlers/response.go b/internal/handlers/response.go
index 26363323..0908a5a6 100644
--- a/internal/handlers/response.go
+++ b/internal/handlers/response.go
@@ -250,7 +250,7 @@ func respondUnauthorized(ctx *middlewares.AutheliaCtx, message string) {
// *fasthttp.RequestCtx or *middlewares.AutheliaCtx.
func SetStatusCodeResponse(ctx *fasthttp.RequestCtx, statusCode int) {
ctx.Response.Reset()
- ctx.SetContentTypeBytes(headerContentTypeValueDefault)
+ ctx.SetContentTypeBytes(headerContentTypeValueTextPlain)
ctx.SetStatusCode(statusCode)
ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode)))
}
diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go
index 0c43083f..584c3df8 100644
--- a/internal/middlewares/authelia_context.go
+++ b/internal/middlewares/authelia_context.go
@@ -68,14 +68,9 @@ func (ctx *AutheliaCtx) Error(err error, message string) {
// SetJSONError sets the body of the response to an JSON error KO message.
func (ctx *AutheliaCtx) SetJSONError(message string) {
- b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message})
-
- if marshalErr != nil {
- ctx.Logger.Error(marshalErr)
+ if replyErr := ctx.ReplyJSON(ErrorResponse{Status: "KO", Message: message}, 0); replyErr != nil {
+ ctx.Logger.Error(replyErr)
}
-
- ctx.SetContentType(contentTypeApplicationJSON)
- ctx.SetBody(b)
}
// ReplyError reply with an error but does not display any stack trace in the logs.
@@ -86,24 +81,52 @@ func (ctx *AutheliaCtx) ReplyError(err error, message string) {
ctx.Logger.Error(marshalErr)
}
- ctx.SetContentType(contentTypeApplicationJSON)
+ ctx.SetContentTypeBytes(contentTypeApplicationJSON)
ctx.SetBody(b)
ctx.Logger.Debug(err)
}
+// ReplyStatusCode resets a response and replies with the given status code and relevant message.
+func (ctx *AutheliaCtx) ReplyStatusCode(statusCode int) {
+ ctx.Response.Reset()
+ ctx.SetStatusCode(statusCode)
+ ctx.SetContentTypeBytes(contentTypeTextPlain)
+ ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode)))
+}
+
+// ReplyJSON writes a JSON response.
+func (ctx *AutheliaCtx) ReplyJSON(data interface{}, statusCode int) (err error) {
+ var (
+ body []byte
+ )
+
+ if body, err = json.Marshal(data); err != nil {
+ return fmt.Errorf("unable to marshal JSON body: %w", err)
+ }
+
+ if statusCode > 0 {
+ ctx.SetStatusCode(statusCode)
+ }
+
+ ctx.SetContentTypeBytes(contentTypeApplicationJSON)
+ ctx.SetBody(body)
+
+ return nil
+}
+
// ReplyUnauthorized response sent when user is unauthorized.
func (ctx *AutheliaCtx) ReplyUnauthorized() {
- ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized)
+ ctx.ReplyStatusCode(fasthttp.StatusUnauthorized)
}
// ReplyForbidden response sent when access is forbidden to user.
func (ctx *AutheliaCtx) ReplyForbidden() {
- ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden)
+ ctx.ReplyStatusCode(fasthttp.StatusForbidden)
}
// ReplyBadRequest response sent when bad request has been sent.
func (ctx *AutheliaCtx) ReplyBadRequest() {
- ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest)
+ ctx.ReplyStatusCode(fasthttp.StatusBadRequest)
}
// XForwardedProto return the content of the X-Forwarded-Proto header.
@@ -208,7 +231,7 @@ func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error {
// ReplyOK is a helper method to reply ok.
func (ctx *AutheliaCtx) ReplyOK() {
- ctx.SetContentType(contentTypeApplicationJSON)
+ ctx.SetContentTypeBytes(contentTypeApplicationJSON)
ctx.SetBody(okMessageBytes)
}
@@ -235,15 +258,7 @@ func (ctx *AutheliaCtx) ParseBody(value interface{}) error {
// SetJSONBody Set json body.
func (ctx *AutheliaCtx) SetJSONBody(value interface{}) error {
- b, err := json.Marshal(OKResponse{Status: "OK", Data: value})
- if err != nil {
- return fmt.Errorf("unable to marshal JSON body: %w", err)
- }
-
- ctx.SetContentType(contentTypeApplicationJSON)
- ctx.SetBody(b)
-
- return nil
+ return ctx.ReplyJSON(OKResponse{Status: "OK", Data: value}, 0)
}
// RemoteIP return the remote IP taking X-Forwarded-For header into account if provided.
@@ -329,7 +344,7 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
statusCode = fasthttp.StatusFound
}
- ctx.SetContentType(contentTypeTextHTML)
+ ctx.SetContentTypeBytes(contentTypeTextHTML)
ctx.SetStatusCode(statusCode)
u := fasthttp.AcquireURI()
@@ -337,9 +352,9 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) {
ctx.URI().CopyTo(u)
u.Update(uri)
- ctx.Response.Header.SetBytesV("Location", u.FullURI())
+ ctx.Response.Header.SetBytesKV(headerLocation, u.FullURI())
- ctx.SetBodyString(fmt.Sprintf("%s", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode)))
+ ctx.SetBodyString(fmt.Sprintf("%d %s", utils.StringHTMLEscape(string(u.FullURI())), statusCode, fasthttp.StatusMessage(statusCode)))
fasthttp.ReleaseURI(u)
}
diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go
index b4ecf280..3975df0a 100644
--- a/internal/middlewares/const.go
+++ b/internal/middlewares/const.go
@@ -9,6 +9,7 @@ import (
var (
headerAccept = []byte(fasthttp.HeaderAccept)
headerContentLength = []byte(fasthttp.HeaderContentLength)
+ headerLocation = []byte(fasthttp.HeaderLocation)
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
@@ -69,12 +70,14 @@ var (
UserValueKeyBaseURL = []byte("base_url")
headerSeparator = []byte(", ")
+
+ contentTypeTextPlain = []byte("text/plain; charset=utf-8")
+ contentTypeTextHTML = []byte("text/html; charset=utf-8")
+ contentTypeApplicationJSON = []byte("application/json; charset=utf-8")
)
const (
headerValueXRequestedWithXHR = "XMLHttpRequest"
- contentTypeApplicationJSON = "application/json"
- contentTypeTextHTML = "text/html"
)
var okMessageBytes = []byte("{\"status\":\"OK\"}")
diff --git a/internal/suites/suite_standalone_test.go b/internal/suites/suite_standalone_test.go
index 6860f3a3..cdc9cc3e 100644
--- a/internal/suites/suite_standalone_test.go
+++ b/internal/suites/suite_standalone_test.go
@@ -193,7 +193,7 @@ func (s *StandaloneSuite) TestShouldRespectMethodsACL() {
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/")
- s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
+ s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
req.Header.Set("X-Forwarded-Method", "OPTIONS")
@@ -219,7 +219,7 @@ func (s *StandaloneSuite) TestShouldRespondWithCorrectStatusCode() {
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/")
- s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
+ s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
req.Header.Set("X-Forwarded-Method", "POST")
@@ -230,7 +230,7 @@ func (s *StandaloneSuite) TestShouldRespondWithCorrectStatusCode() {
s.Assert().NoError(err)
urlEncodedAdminURL = url.QueryEscape(SecureBaseURL + "/")
- s.Assert().Equal(fmt.Sprintf("See Other", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=POST", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
+ s.Assert().Equal(fmt.Sprintf("303 See Other", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=POST", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
}
// Standard case using nginx.
@@ -247,7 +247,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyUnauthorized() {
s.Assert().Equal(res.StatusCode, 401)
body, err := io.ReadAll(res.Body)
s.Assert().NoError(err)
- s.Assert().Equal("Unauthorized", string(body))
+ s.Assert().Equal("401 Unauthorized", string(body))
}
// Standard case using Kubernetes.
@@ -266,7 +266,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalURL() {
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(AdminBaseURL)
- s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
+ s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
}
func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalHostURI() {
@@ -285,7 +285,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalHostURI(
s.Assert().NoError(err)
urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/")
- s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
+ s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body))
}
func (s *StandaloneSuite) TestShouldRecordMetrics() {