mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
feat(oidc): add automatic allow all cors to discovery (#2953)
This adds a Cross Origin Resource Sharing policy that automatically allows any cross-origin request to the OpenID Connect discovery documents.
This commit is contained in:
parent
a5c400cb1d
commit
a8f5a70b03
|
@ -10,8 +10,8 @@ import (
|
|||
// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout.
|
||||
func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) {
|
||||
// TODO: Add OPTIONS handler.
|
||||
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(wellKnownOpenIDConnectConfigurationGET))
|
||||
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(wellKnownOAuthAuthorizationServerGET))
|
||||
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET)))
|
||||
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET)))
|
||||
|
||||
router.GET(pathOpenIDConnectConsent, middleware(oidcConsent))
|
||||
|
||||
|
|
|
@ -15,6 +15,24 @@ var (
|
|||
headerXOriginalURL = []byte("X-Original-URL")
|
||||
headerXForwardedMethod = []byte("X-Forwarded-Method")
|
||||
|
||||
headerVary = []byte(fasthttp.HeaderVary)
|
||||
headerOrigin = []byte(fasthttp.HeaderOrigin)
|
||||
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
|
||||
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
|
||||
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
|
||||
headerAccessControlAllowOrigin = []byte(fasthttp.HeaderAccessControlAllowOrigin)
|
||||
headerAccessControlMaxAge = []byte(fasthttp.HeaderAccessControlMaxAge)
|
||||
headerAccessControlRequestHeaders = []byte(fasthttp.HeaderAccessControlRequestHeaders)
|
||||
headerAccessControlRequestMethod = []byte(fasthttp.HeaderAccessControlRequestMethod)
|
||||
)
|
||||
|
||||
var (
|
||||
headerValueFalse = []byte("false")
|
||||
headerValueMaxAge = []byte("100")
|
||||
headerValueVary = []byte("Accept-Encoding, Origin")
|
||||
)
|
||||
|
||||
var (
|
||||
protoHTTPS = []byte("https")
|
||||
protoHTTP = []byte("http")
|
||||
|
||||
|
|
53
internal/middlewares/cors.go
Normal file
53
internal/middlewares/cors.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well
|
||||
// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied
|
||||
// to both Accept-Encoding and Origin. It grants the GET Request Method only.
|
||||
func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil {
|
||||
corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin)
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) {
|
||||
originURL, err := url.Parse(string(origin))
|
||||
if err != nil || originURL.Scheme != "https" {
|
||||
return
|
||||
}
|
||||
|
||||
resp.Header.SetBytesKV(headerVary, headerValueVary)
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin)
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse)
|
||||
resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge)
|
||||
|
||||
if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
|
||||
requestedHeaders := strings.Split(string(headers), ",")
|
||||
allowHeaders := make([]string, len(requestedHeaders))
|
||||
|
||||
for i, header := range requestedHeaders {
|
||||
headerTrimmed := strings.Trim(header, " ")
|
||||
if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) {
|
||||
allowHeaders[i] = headerTrimmed
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowHeaders) != 0 {
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
|
||||
}
|
||||
}
|
65
internal/middlewares/cors_test.go
Normal file
65
internal/middlewares/cors_test.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.Response{}
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
|
||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
||||
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.Response{}
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
|
||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
||||
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
|
||||
resp := fasthttp.Response{}
|
||||
|
||||
origin := []byte("http://myapp.example.com")
|
||||
|
||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
|
||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
||||
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
Loading…
Reference in New Issue
Block a user