From ce779b2533d82c4cd95b52f188c5f7edd797e4d7 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Fri, 8 Jul 2022 22:18:52 +1000 Subject: [PATCH] refactor(middlewares): factorize responses (#3628) --- internal/handlers/const.go | 2 +- internal/handlers/handler_oidc_wellknown.go | 22 ++++--- internal/handlers/handler_state_test.go | 4 +- internal/handlers/handler_verify_test.go | 18 +++--- internal/handlers/response.go | 2 +- internal/middlewares/authelia_context.go | 63 +++++++++++++-------- internal/middlewares/const.go | 7 ++- internal/suites/suite_standalone_test.go | 12 ++-- 8 files changed, 73 insertions(+), 57 deletions(-) diff --git a/internal/handlers/const.go b/internal/handlers/const.go index 0398cffe..3eea1161 100644 --- a/internal/handlers/const.go +++ b/internal/handlers/const.go @@ -29,7 +29,7 @@ var ( ) var ( - headerContentTypeValueDefault = []byte("text/plain; charset=utf-8") + headerContentTypeValueTextPlain = []byte("text/plain; charset=utf-8") ) const ( diff --git a/internal/handlers/handler_oidc_wellknown.go b/internal/handlers/handler_oidc_wellknown.go index 0efd5387..89d346e0 100644 --- a/internal/handlers/handler_oidc_wellknown.go +++ b/internal/handlers/handler_oidc_wellknown.go @@ -1,8 +1,6 @@ package handlers import ( - "encoding/json" - "github.com/valyala/fasthttp" "github.com/authelia/authelia/v4/internal/middlewares" @@ -18,19 +16,19 @@ func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) { issuer, err := ctx.ExternalRootURL() if err != nil { ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err) - ctx.Response.SetStatusCode(fasthttp.StatusBadRequest) + + ctx.ReplyStatusCode(fasthttp.StatusBadRequest) return } wellKnown := ctx.Providers.OpenIDConnect.GetOpenIDConnectWellKnownConfiguration(issuer) - ctx.SetContentType("application/json") - - if err = json.NewEncoder(ctx).Encode(wellKnown); err != nil { + if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil { ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err) + // TODO: Determine if this is the appropriate error code here. - ctx.Response.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.ReplyStatusCode(fasthttp.StatusInternalServerError) return } @@ -46,19 +44,19 @@ func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) { issuer, err := ctx.ExternalRootURL() if err != nil { ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err) - ctx.Response.SetStatusCode(fasthttp.StatusBadRequest) + + ctx.ReplyStatusCode(fasthttp.StatusBadRequest) return } wellKnown := ctx.Providers.OpenIDConnect.GetOAuth2WellKnownConfiguration(issuer) - ctx.SetContentType("application/json") - - if err = json.NewEncoder(ctx).Encode(wellKnown); err != nil { + if err = ctx.ReplyJSON(wellKnown, fasthttp.StatusOK); err != nil { ctx.Logger.Errorf("Error occurred in JSON encode: %+v", err) + // TODO: Determine if this is the appropriate error code here. - ctx.Response.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.ReplyStatusCode(fasthttp.StatusInternalServerError) return } diff --git a/internal/handlers/handler_state_test.go b/internal/handlers/handler_state_test.go index 9e56bbe4..2605be47 100644 --- a/internal/handlers/handler_state_test.go +++ b/internal/handlers/handler_state_test.go @@ -52,7 +52,7 @@ func (s *StateGetSuite) TestShouldReturnUsernameFromSession() { err = json.Unmarshal(s.mock.Ctx.Response.Body(), &actualBody) require.NoError(s.T(), err) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), []byte("application/json"), s.mock.Ctx.Response.Header.ContentType()) + assert.Equal(s.T(), []byte("application/json; charset=utf-8"), s.mock.Ctx.Response.Header.ContentType()) assert.Equal(s.T(), expectedBody, actualBody) } @@ -82,7 +82,7 @@ func (s *StateGetSuite) TestShouldReturnAuthenticationLevelFromSession() { err = json.Unmarshal(s.mock.Ctx.Response.Body(), &actualBody) require.NoError(s.T(), err) assert.Equal(s.T(), 200, s.mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), []byte("application/json"), s.mock.Ctx.Response.Header.ContentType()) + assert.Equal(s.T(), []byte("application/json; charset=utf-8"), s.mock.Ctx.Response.Header.ContentType()) assert.Equal(s.T(), expectedBody, actualBody) } diff --git a/internal/handlers/handler_verify_test.go b/internal/handlers/handler_verify_test.go index 66eb23dd..0472b566 100644 --- a/internal/handlers/handler_verify_test.go +++ b/internal/handlers/handler_verify_test.go @@ -351,7 +351,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingNoHeader() VerifyGET(verifyGetCfg)(mock.Ctx) assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body())) + assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body())) assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate")) assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate"))) } @@ -367,7 +367,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingEmptyHeader VerifyGET(verifyGetCfg)(mock.Ctx) assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body())) + assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body())) assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate")) assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate"))) } @@ -387,7 +387,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingWrongPasswo VerifyGET(verifyGetCfg)(mock.Ctx) assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body())) + assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body())) assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate")) assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate"))) } @@ -403,7 +403,7 @@ func (s *BasicAuthorizationSuite) TestShouldVerifyAuthBasicArgFailingWrongHeader VerifyGET(verifyGetCfg)(mock.Ctx) assert.Equal(s.T(), 401, mock.Ctx.Response.StatusCode()) - assert.Equal(s.T(), "Unauthorized", string(mock.Ctx.Response.Body())) + assert.Equal(s.T(), "401 Unauthorized", string(mock.Ctx.Response.Body())) assert.NotEmpty(s.T(), mock.Ctx.Response.Header.Peek("WWW-Authenticate")) assert.Regexp(s.T(), regexp.MustCompile("^Basic realm="), string(mock.Ctx.Response.Header.Peek("WWW-Authenticate"))) } @@ -721,7 +721,7 @@ func TestShouldRedirectWhenSessionInactiveForTooLongAndRDParamProvided(t *testin mock.Ctx.Request.Header.Set("Accept", "text/html; charset=utf-8") VerifyGET(verifyGetCfg)(mock.Ctx) - assert.Equal(t, "Found", + assert.Equal(t, "302 Found", string(mock.Ctx.Response.Body())) assert.Equal(t, 302, mock.Ctx.Response.StatusCode()) @@ -741,7 +741,7 @@ func TestShouldRedirectWithCorrectStatusCodeBasedOnRequestMethod(t *testing.T) { VerifyGET(verifyGetCfg)(mock.Ctx) - assert.Equal(t, "Found", + assert.Equal(t, "302 Found", string(mock.Ctx.Response.Body())) assert.Equal(t, 302, mock.Ctx.Response.StatusCode()) @@ -752,7 +752,7 @@ func TestShouldRedirectWithCorrectStatusCodeBasedOnRequestMethod(t *testing.T) { VerifyGET(verifyGetCfg)(mock.Ctx) - assert.Equal(t, "See Other", + assert.Equal(t, "303 See Other", string(mock.Ctx.Response.Body())) assert.Equal(t, 303, mock.Ctx.Response.StatusCode()) } @@ -809,7 +809,7 @@ func TestShouldURLEncodeRedirectionURLParameter(t *testing.T) { VerifyGET(verifyGetCfg)(mock.Ctx) - assert.Equal(t, "Found", + assert.Equal(t, "302 Found", string(mock.Ctx.Response.Body())) } @@ -1240,7 +1240,7 @@ func TestShouldCheckInvalidSessionUsernameHeaderAndReturn401(t *testing.T) { VerifyGET(verifyGetCfg)(mock.Ctx) assert.Equal(t, expectedStatusCode, mock.Ctx.Response.StatusCode()) - assert.Equal(t, "Unauthorized", string(mock.Ctx.Response.Body())) + assert.Equal(t, "401 Unauthorized", string(mock.Ctx.Response.Body())) } func TestGetProfileRefreshSettings(t *testing.T) { diff --git a/internal/handlers/response.go b/internal/handlers/response.go index 26363323..0908a5a6 100644 --- a/internal/handlers/response.go +++ b/internal/handlers/response.go @@ -250,7 +250,7 @@ func respondUnauthorized(ctx *middlewares.AutheliaCtx, message string) { // *fasthttp.RequestCtx or *middlewares.AutheliaCtx. func SetStatusCodeResponse(ctx *fasthttp.RequestCtx, statusCode int) { ctx.Response.Reset() - ctx.SetContentTypeBytes(headerContentTypeValueDefault) + ctx.SetContentTypeBytes(headerContentTypeValueTextPlain) ctx.SetStatusCode(statusCode) ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode))) } diff --git a/internal/middlewares/authelia_context.go b/internal/middlewares/authelia_context.go index 0c43083f..584c3df8 100644 --- a/internal/middlewares/authelia_context.go +++ b/internal/middlewares/authelia_context.go @@ -68,14 +68,9 @@ func (ctx *AutheliaCtx) Error(err error, message string) { // SetJSONError sets the body of the response to an JSON error KO message. func (ctx *AutheliaCtx) SetJSONError(message string) { - b, marshalErr := json.Marshal(ErrorResponse{Status: "KO", Message: message}) - - if marshalErr != nil { - ctx.Logger.Error(marshalErr) + if replyErr := ctx.ReplyJSON(ErrorResponse{Status: "KO", Message: message}, 0); replyErr != nil { + ctx.Logger.Error(replyErr) } - - ctx.SetContentType(contentTypeApplicationJSON) - ctx.SetBody(b) } // ReplyError reply with an error but does not display any stack trace in the logs. @@ -86,24 +81,52 @@ func (ctx *AutheliaCtx) ReplyError(err error, message string) { ctx.Logger.Error(marshalErr) } - ctx.SetContentType(contentTypeApplicationJSON) + ctx.SetContentTypeBytes(contentTypeApplicationJSON) ctx.SetBody(b) ctx.Logger.Debug(err) } +// ReplyStatusCode resets a response and replies with the given status code and relevant message. +func (ctx *AutheliaCtx) ReplyStatusCode(statusCode int) { + ctx.Response.Reset() + ctx.SetStatusCode(statusCode) + ctx.SetContentTypeBytes(contentTypeTextPlain) + ctx.SetBodyString(fmt.Sprintf("%d %s", statusCode, fasthttp.StatusMessage(statusCode))) +} + +// ReplyJSON writes a JSON response. +func (ctx *AutheliaCtx) ReplyJSON(data interface{}, statusCode int) (err error) { + var ( + body []byte + ) + + if body, err = json.Marshal(data); err != nil { + return fmt.Errorf("unable to marshal JSON body: %w", err) + } + + if statusCode > 0 { + ctx.SetStatusCode(statusCode) + } + + ctx.SetContentTypeBytes(contentTypeApplicationJSON) + ctx.SetBody(body) + + return nil +} + // ReplyUnauthorized response sent when user is unauthorized. func (ctx *AutheliaCtx) ReplyUnauthorized() { - ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusUnauthorized), fasthttp.StatusUnauthorized) + ctx.ReplyStatusCode(fasthttp.StatusUnauthorized) } // ReplyForbidden response sent when access is forbidden to user. func (ctx *AutheliaCtx) ReplyForbidden() { - ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusForbidden), fasthttp.StatusForbidden) + ctx.ReplyStatusCode(fasthttp.StatusForbidden) } // ReplyBadRequest response sent when bad request has been sent. func (ctx *AutheliaCtx) ReplyBadRequest() { - ctx.RequestCtx.Error(fasthttp.StatusMessage(fasthttp.StatusBadRequest), fasthttp.StatusBadRequest) + ctx.ReplyStatusCode(fasthttp.StatusBadRequest) } // XForwardedProto return the content of the X-Forwarded-Proto header. @@ -208,7 +231,7 @@ func (ctx *AutheliaCtx) SaveSession(userSession session.UserSession) error { // ReplyOK is a helper method to reply ok. func (ctx *AutheliaCtx) ReplyOK() { - ctx.SetContentType(contentTypeApplicationJSON) + ctx.SetContentTypeBytes(contentTypeApplicationJSON) ctx.SetBody(okMessageBytes) } @@ -235,15 +258,7 @@ func (ctx *AutheliaCtx) ParseBody(value interface{}) error { // SetJSONBody Set json body. func (ctx *AutheliaCtx) SetJSONBody(value interface{}) error { - b, err := json.Marshal(OKResponse{Status: "OK", Data: value}) - if err != nil { - return fmt.Errorf("unable to marshal JSON body: %w", err) - } - - ctx.SetContentType(contentTypeApplicationJSON) - ctx.SetBody(b) - - return nil + return ctx.ReplyJSON(OKResponse{Status: "OK", Data: value}, 0) } // RemoteIP return the remote IP taking X-Forwarded-For header into account if provided. @@ -329,7 +344,7 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) { statusCode = fasthttp.StatusFound } - ctx.SetContentType(contentTypeTextHTML) + ctx.SetContentTypeBytes(contentTypeTextHTML) ctx.SetStatusCode(statusCode) u := fasthttp.AcquireURI() @@ -337,9 +352,9 @@ func (ctx *AutheliaCtx) SpecialRedirect(uri string, statusCode int) { ctx.URI().CopyTo(u) u.Update(uri) - ctx.Response.Header.SetBytesV("Location", u.FullURI()) + ctx.Response.Header.SetBytesKV(headerLocation, u.FullURI()) - ctx.SetBodyString(fmt.Sprintf("%s", utils.StringHTMLEscape(string(u.FullURI())), fasthttp.StatusMessage(statusCode))) + ctx.SetBodyString(fmt.Sprintf("%d %s", utils.StringHTMLEscape(string(u.FullURI())), statusCode, fasthttp.StatusMessage(statusCode))) fasthttp.ReleaseURI(u) } diff --git a/internal/middlewares/const.go b/internal/middlewares/const.go index b4ecf280..3975df0a 100644 --- a/internal/middlewares/const.go +++ b/internal/middlewares/const.go @@ -9,6 +9,7 @@ import ( var ( headerAccept = []byte(fasthttp.HeaderAccept) headerContentLength = []byte(fasthttp.HeaderContentLength) + headerLocation = []byte(fasthttp.HeaderLocation) headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto) headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost) @@ -69,12 +70,14 @@ var ( UserValueKeyBaseURL = []byte("base_url") headerSeparator = []byte(", ") + + contentTypeTextPlain = []byte("text/plain; charset=utf-8") + contentTypeTextHTML = []byte("text/html; charset=utf-8") + contentTypeApplicationJSON = []byte("application/json; charset=utf-8") ) const ( headerValueXRequestedWithXHR = "XMLHttpRequest" - contentTypeApplicationJSON = "application/json" - contentTypeTextHTML = "text/html" ) var okMessageBytes = []byte("{\"status\":\"OK\"}") diff --git a/internal/suites/suite_standalone_test.go b/internal/suites/suite_standalone_test.go index 6860f3a3..cdc9cc3e 100644 --- a/internal/suites/suite_standalone_test.go +++ b/internal/suites/suite_standalone_test.go @@ -193,7 +193,7 @@ func (s *StandaloneSuite) TestShouldRespectMethodsACL() { s.Assert().NoError(err) urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/") - s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) + s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) req.Header.Set("X-Forwarded-Method", "OPTIONS") @@ -219,7 +219,7 @@ func (s *StandaloneSuite) TestShouldRespondWithCorrectStatusCode() { s.Assert().NoError(err) urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/") - s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) + s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=GET", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) req.Header.Set("X-Forwarded-Method", "POST") @@ -230,7 +230,7 @@ func (s *StandaloneSuite) TestShouldRespondWithCorrectStatusCode() { s.Assert().NoError(err) urlEncodedAdminURL = url.QueryEscape(SecureBaseURL + "/") - s.Assert().Equal(fmt.Sprintf("See Other", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=POST", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) + s.Assert().Equal(fmt.Sprintf("303 See Other", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s&rm=POST", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) } // Standard case using nginx. @@ -247,7 +247,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyUnauthorized() { s.Assert().Equal(res.StatusCode, 401) body, err := io.ReadAll(res.Body) s.Assert().NoError(err) - s.Assert().Equal("Unauthorized", string(body)) + s.Assert().Equal("401 Unauthorized", string(body)) } // Standard case using Kubernetes. @@ -266,7 +266,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalURL() { s.Assert().NoError(err) urlEncodedAdminURL := url.QueryEscape(AdminBaseURL) - s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) + s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) } func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalHostURI() { @@ -285,7 +285,7 @@ func (s *StandaloneSuite) TestShouldVerifyAPIVerifyRedirectFromXOriginalHostURI( s.Assert().NoError(err) urlEncodedAdminURL := url.QueryEscape(SecureBaseURL + "/") - s.Assert().Equal(fmt.Sprintf("Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) + s.Assert().Equal(fmt.Sprintf("302 Found", utils.StringHTMLEscape(fmt.Sprintf("%s/?rd=%s", GetLoginBaseURL(), urlEncodedAdminURL))), string(body)) } func (s *StandaloneSuite) TestShouldRecordMetrics() {