mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
feat(oidc): provide cors config including options handlers (#3005)
This adjusts the CORS headers appropriately for OpenID Connect. This includes responding to OPTIONS requests appropriately. Currently this is only configured to operate when the Origin scheme is HTTPS; but can easily be expanded in the future to include additional Origins.
This commit is contained in:
parent
a694cf851f
commit
4ebd8fdf4e
|
@ -767,6 +767,26 @@ notifier:
|
|||
## for security reasons.
|
||||
# enforce_pkce: public_clients_only
|
||||
|
||||
## Cross-Origin Resource Sharing (CORS) settings.
|
||||
# cors:
|
||||
## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
|
||||
# endpoints:
|
||||
# - authorization
|
||||
# - token
|
||||
# - revocation
|
||||
# - introspection
|
||||
# - userinfo
|
||||
|
||||
## List of allowed origins.
|
||||
## Any origin with https is permitted unless this option is configured or the
|
||||
## allowed_origins_from_client_redirect_uris option is enabled.
|
||||
# allowed_origins:
|
||||
# - https://example.com
|
||||
|
||||
## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
|
||||
## provided they have the scheme http or https and do not have the hostname of localhost.
|
||||
# allowed_origins_from_client_redirect_uris: false
|
||||
|
||||
## Clients is a list of known clients and their configuration.
|
||||
# clients:
|
||||
# -
|
||||
|
|
|
@ -35,6 +35,15 @@ identity_providers:
|
|||
refresh_token_lifespan: 90m
|
||||
enable_client_debug_messages: false
|
||||
enforce_pkce: public_clients_only
|
||||
cors:
|
||||
endpoints:
|
||||
- authorization
|
||||
- token
|
||||
- revocation
|
||||
- introspection
|
||||
allowed_origins:
|
||||
- https://example.com
|
||||
allowed_origins_from_client_redirect_uris: false
|
||||
clients:
|
||||
- id: myapp
|
||||
description: My Application
|
||||
|
@ -218,6 +227,79 @@ Allows PKCE `plain` challenges when set to `true`.
|
|||
|
||||
***Security Notice:*** Changing this value is generally discouraged. Applications should use the `S256` PKCE challenge method instead.
|
||||
|
||||
### cors
|
||||
|
||||
Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
|
||||
you to configure the optional parts. We reply with CORS headers when the request includes the Origin header.
|
||||
|
||||
##### endpoints
|
||||
<div markdown="1">
|
||||
type: list(string)
|
||||
{: .label .label-config .label-purple }
|
||||
default: empty
|
||||
{: .label .label-config .label-blue }
|
||||
required: no
|
||||
{: .label .label-config .label-green }
|
||||
</div>
|
||||
|
||||
A list of endpoints to configure with cross-origin resource sharing headers. It is recommended that the `userinfo`
|
||||
option is at least in this list. The potential endpoints which this can be enabled on are as follows:
|
||||
|
||||
* authorization
|
||||
* token
|
||||
* revocation
|
||||
* introspection
|
||||
* userinfo
|
||||
|
||||
#### allowed_origins
|
||||
<div markdown="1">
|
||||
type: list(string)
|
||||
{: .label .label-config .label-purple }
|
||||
default: empty
|
||||
{: .label .label-config .label-blue }
|
||||
required: no
|
||||
{: .label .label-config .label-green }
|
||||
</div>
|
||||
|
||||
A list of permitted origins.
|
||||
|
||||
Any origin with https is permitted unless this option is configured or the allowed_origins_from_client_redirect_uris
|
||||
option is enabled. This means you must configure this option manually if you want http endpoints to be permitted to
|
||||
make cross-origin requests to the OpenID Connect endpoints, however this is not recommended.
|
||||
|
||||
Origins must only have the scheme, hostname and port, they may not have a trailing slash or path.
|
||||
|
||||
In addition to an Origin URI, you may specify the wildcard origin in the allowed_origins. It MUST be specified by itself
|
||||
and the allowed_origins_from_client_redirect_uris MUST NOT be enabled. The wildcard origin is denoted as `*`. Examples:
|
||||
|
||||
```yaml
|
||||
identity_providers:
|
||||
oidc:
|
||||
cors:
|
||||
allowed_origins: "*"
|
||||
```
|
||||
|
||||
```yaml
|
||||
identity_providers:
|
||||
oidc:
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "*"
|
||||
```
|
||||
|
||||
#### allowed_origins_from_client_redirect_uris
|
||||
<div markdown="1">
|
||||
type: boolean
|
||||
{: .label .label-config .label-purple }
|
||||
default: false
|
||||
{: .label .label-config .label-blue }
|
||||
required: no
|
||||
{: .label .label-config .label-green }
|
||||
</div>
|
||||
|
||||
Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins, provided they
|
||||
have the scheme http or https and do not have the hostname of localhost.
|
||||
|
||||
### clients
|
||||
|
||||
A list of clients to configure. The options for each client are described below.
|
||||
|
@ -487,22 +569,46 @@ Below is a list of the potential values we place in the claim and their meaning:
|
|||
|
||||
## Endpoint Implementations
|
||||
|
||||
This is a table of the endpoints we currently support and their paths. This can be requrired information for some RP's,
|
||||
particularly those that don't use [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html). The paths are
|
||||
appended to the end of the primary URL used to access Authelia. For example in the Discovery example provided you access
|
||||
Authelia via https://auth.example.com, the discovery URL is https://auth.example.com/.well-known/openid-configuration.
|
||||
The following section documents the endpoints we implement and their respective paths. This information can traditionally
|
||||
be discovered by relying parties that utilize [discovery](https://openid.net/specs/openid-connect-discovery-1_0.html),
|
||||
however this information may be useful for clients which do not implement this.
|
||||
|
||||
| Endpoint | Path |
|
||||
|:-------------:|:---------------------------------------------:|
|
||||
| Discovery | [root]/.well-known/openid-configuration |
|
||||
| Metadata | [root]/.well-known/oauth-authorization-server |
|
||||
| JWKS | [root]/api/oidc/jwks |
|
||||
| Authorization | [root]/api/oidc/authorization |
|
||||
| Token | [root]/api/oidc/token |
|
||||
| Introspection | [root]/api/oidc/introspection |
|
||||
| Revocation | [root]/api/oidc/revocation |
|
||||
| Userinfo | [root]/api/oidc/userinfo |
|
||||
The endpoints can be discovered easily by visiting the Discovery and Metadata endpoints. It is recommended regardless
|
||||
of your version of Authelia that you utilize this version as it will always produce the correct endpoint URLs. The paths
|
||||
for the Discovery/Metadata endpoints are part of IANA's well known registration but are also documented in a table below.
|
||||
|
||||
These tables document the endpoints we currently support and their paths in the most recent version of Authelia. The paths
|
||||
are appended to the end of the primary URL used to access Authelia. The tables use the url https://auth.example.com as
|
||||
an example of the Authelia root URL which is also the OpenID Connect issuer.
|
||||
|
||||
### Well Known Discovery Endpoints
|
||||
|
||||
These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP.
|
||||
|
||||
| Endpoint | Path |
|
||||
|:-------------:|:---------------------------------------------------------------:|
|
||||
| Discovery | https://auth.example.com/.well-known/openid-configuration |
|
||||
| Metadata | https://auth.example.com/.well-known/oauth-authorization-server |
|
||||
|
||||
|
||||
### Discoverable Endpoints
|
||||
|
||||
These endpoints implement OpenID Connect elements.
|
||||
|
||||
| Endpoint | Path | Discovery Attribute |
|
||||
|:---------------:|:-----------------------------------------------:|:----------------------:|
|
||||
| JWKS | https://auth.example.com/jwks.json | jwks_uri |
|
||||
| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
|
||||
| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
|
||||
| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
|
||||
| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
|
||||
| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
|
||||
|
||||
[OpenID Connect]: https://openid.net/connect/
|
||||
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
|
||||
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
|
||||
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
|
||||
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
|
||||
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
|
||||
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
|
||||
|
|
|
@ -767,6 +767,26 @@ notifier:
|
|||
## for security reasons.
|
||||
# enforce_pkce: public_clients_only
|
||||
|
||||
## Cross-Origin Resource Sharing (CORS) settings.
|
||||
# cors:
|
||||
## List of endpoints in addition to the metadata endpoints to permit cross-origin requests on.
|
||||
# endpoints:
|
||||
# - authorization
|
||||
# - token
|
||||
# - revocation
|
||||
# - introspection
|
||||
# - userinfo
|
||||
|
||||
## List of allowed origins.
|
||||
## Any origin with https is permitted unless this option is configured or the
|
||||
## allowed_origins_from_client_redirect_uris option is enabled.
|
||||
# allowed_origins:
|
||||
# - https://example.com
|
||||
|
||||
## Automatically adds the origin portion of all redirect URI's on all clients to the list of allowed_origins,
|
||||
## provided they have the scheme http or https and do not have the hostname of localhost.
|
||||
# allowed_origins_from_client_redirect_uris: false
|
||||
|
||||
## Clients is a list of known clients and their configuration.
|
||||
# clients:
|
||||
# -
|
||||
|
|
|
@ -201,6 +201,24 @@ func TestShouldValidateConfigurationWithEnvSecrets(t *testing.T) {
|
|||
assert.Equal(t, "example_secret value", config.Storage.EncryptionKey)
|
||||
}
|
||||
|
||||
func TestShouldLoadURLList(t *testing.T) {
|
||||
testReset()
|
||||
|
||||
val := schema.NewStructValidator()
|
||||
keys, config, err := Load(val, NewDefaultSources([]string{"./test_resources/config_oidc.yml"}, DefaultEnvPrefix, DefaultEnvDelimiter)...)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
validator.ValidateKeys(keys, DefaultEnvPrefix, val)
|
||||
|
||||
assert.Len(t, val.Errors(), 0)
|
||||
assert.Len(t, val.Warnings(), 0)
|
||||
|
||||
require.Len(t, config.IdentityProviders.OIDC.CORS.AllowedOrigins, 2)
|
||||
assert.Equal(t, "https://google.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[0].String())
|
||||
assert.Equal(t, "https://example.com", config.IdentityProviders.OIDC.CORS.AllowedOrigins[1].String())
|
||||
}
|
||||
|
||||
func TestShouldValidateAndRaiseErrorsOnBadConfiguration(t *testing.T) {
|
||||
testReset()
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package schema
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdentityProvidersConfiguration represents the IdentityProviders 2.0 configuration for Authelia.
|
||||
type IdentityProvidersConfiguration struct {
|
||||
|
@ -24,9 +27,19 @@ type OpenIDConnectConfiguration struct {
|
|||
EnforcePKCE string `koanf:"enforce_pkce"`
|
||||
EnablePKCEPlainChallenge bool `koanf:"enable_pkce_plain_challenge"`
|
||||
|
||||
CORS OpenIDConnectCORSConfiguration `koanf:"cors"`
|
||||
|
||||
Clients []OpenIDConnectClientConfiguration `koanf:"clients"`
|
||||
}
|
||||
|
||||
// OpenIDConnectCORSConfiguration represents an OpenID Connect CORS config.
|
||||
type OpenIDConnectCORSConfiguration struct {
|
||||
Endpoints []string `koanf:"endpoints"`
|
||||
AllowedOrigins []url.URL `koanf:"allowed_origins"`
|
||||
|
||||
AllowedOriginsFromClientRedirectURIs bool `koanf:"allowed_origins_from_client_redirect_uris"`
|
||||
}
|
||||
|
||||
// OpenIDConnectClientConfiguration configuration for an OpenID Connect client.
|
||||
type OpenIDConnectClientConfiguration struct {
|
||||
ID string `koanf:"id"`
|
||||
|
|
133
internal/configuration/test_resources/config_oidc.yml
Normal file
133
internal/configuration/test_resources/config_oidc.yml
Normal file
|
@ -0,0 +1,133 @@
|
|||
---
|
||||
default_redirection_url: https://home.example.com:8080/
|
||||
|
||||
server:
|
||||
host: 127.0.0.1
|
||||
port: 9091
|
||||
|
||||
log:
|
||||
level: debug
|
||||
|
||||
totp:
|
||||
issuer: authelia.com
|
||||
|
||||
duo_api:
|
||||
hostname: api-123456789.example.com
|
||||
integration_key: ABCDEF
|
||||
|
||||
authentication_backend:
|
||||
ldap:
|
||||
url: ldap://127.0.0.1
|
||||
base_dn: dc=example,dc=com
|
||||
username_attribute: uid
|
||||
additional_users_dn: ou=users
|
||||
users_filter: (&({username_attribute}={input})(objectCategory=person)(objectClass=user))
|
||||
additional_groups_dn: ou=groups
|
||||
groups_filter: (&(member={dn})(objectClass=groupOfNames))
|
||||
group_name_attribute: cn
|
||||
mail_attribute: mail
|
||||
user: cn=admin,dc=example,dc=com
|
||||
|
||||
access_control:
|
||||
default_policy: deny
|
||||
|
||||
rules:
|
||||
# Rules applied to everyone
|
||||
- domain: public.example.com
|
||||
policy: bypass
|
||||
|
||||
- domain: secure.example.com
|
||||
policy: one_factor
|
||||
# Network based rule, if not provided any network matches.
|
||||
networks:
|
||||
- 192.168.1.0/24
|
||||
- domain: secure.example.com
|
||||
policy: two_factor
|
||||
|
||||
- domain: [singlefactor.example.com, onefactor.example.com]
|
||||
policy: one_factor
|
||||
|
||||
# Rules applied to 'admins' group
|
||||
- domain: "mx2.mail.example.com"
|
||||
subject: "group:admins"
|
||||
policy: deny
|
||||
- domain: "*.example.com"
|
||||
subject: "group:admins"
|
||||
policy: two_factor
|
||||
|
||||
# Rules applied to 'dev' group
|
||||
- domain: dev.example.com
|
||||
resources:
|
||||
- "^/groups/dev/.*$"
|
||||
subject: "group:dev"
|
||||
policy: two_factor
|
||||
|
||||
# Rules applied to user 'john'
|
||||
- domain: dev.example.com
|
||||
resources:
|
||||
- "^/users/john/.*$"
|
||||
subject: "user:john"
|
||||
policy: two_factor
|
||||
|
||||
# Rules applied to 'dev' group and user 'john'
|
||||
- domain: dev.example.com
|
||||
resources:
|
||||
- "^/deny-all.*$"
|
||||
subject: ["group:dev", "user:john"]
|
||||
policy: deny
|
||||
|
||||
# Rules applied to user 'harry'
|
||||
- domain: dev.example.com
|
||||
resources:
|
||||
- "^/users/harry/.*$"
|
||||
subject: "user:harry"
|
||||
policy: two_factor
|
||||
|
||||
# Rules applied to user 'bob'
|
||||
- domain: "*.mail.example.com"
|
||||
subject: "user:bob"
|
||||
policy: two_factor
|
||||
- domain: "dev.example.com"
|
||||
resources:
|
||||
- "^/users/bob/.*$"
|
||||
subject: "user:bob"
|
||||
policy: two_factor
|
||||
|
||||
session:
|
||||
name: authelia_session
|
||||
expiration: 3600000 # 1 hour
|
||||
inactivity: 300000 # 5 minutes
|
||||
domain: example.com
|
||||
redis:
|
||||
host: 127.0.0.1
|
||||
port: 6379
|
||||
high_availability:
|
||||
sentinel_name: test
|
||||
|
||||
regulation:
|
||||
max_retries: 3
|
||||
find_time: 120
|
||||
ban_time: 300
|
||||
|
||||
storage:
|
||||
mysql:
|
||||
host: 127.0.0.1
|
||||
port: 3306
|
||||
database: authelia
|
||||
username: authelia
|
||||
|
||||
notifier:
|
||||
smtp:
|
||||
username: test
|
||||
host: 127.0.0.1
|
||||
port: 1025
|
||||
sender: admin@example.com
|
||||
disable_require_tls: true
|
||||
|
||||
identity_providers:
|
||||
oidc:
|
||||
cors:
|
||||
allowed_origins:
|
||||
- https://google.com
|
||||
- https://example.com
|
||||
...
|
|
@ -123,11 +123,15 @@ const (
|
|||
const (
|
||||
errFmtOIDCNoClientsConfigured = "identity_providers: oidc: option 'clients' must have one or " +
|
||||
"more clients configured"
|
||||
errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
|
||||
|
||||
errFmtOIDCNoPrivateKey = "identity_providers: oidc: option 'issuer_private_key' is required"
|
||||
errFmtOIDCEnforcePKCEInvalidValue = "identity_providers: oidc: option 'enforce_pkce' must be 'never', " +
|
||||
"'public_clients_only' or 'always', but it is configured as '%s'"
|
||||
|
||||
errFmtOIDCCORSInvalidOrigin = "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value '%s' as it has a %s: origins must only be scheme, hostname, and an optional port"
|
||||
errFmtOIDCCORSInvalidOriginWildcard = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself"
|
||||
errFmtOIDCCORSInvalidOriginWildcardWithClients = "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled"
|
||||
errFmtOIDCCORSInvalidEndpoint = "identity_providers: oidc: cors: option 'endpoints' contains an invalid value '%s': must be one of '%s'"
|
||||
|
||||
errFmtOIDCClientsDuplicateID = "identity_providers: oidc: one or more clients have the same id but all client" +
|
||||
"id's must be unique"
|
||||
errFmtOIDCClientsWithEmptyID = "identity_providers: oidc: one or more clients have been configured with " +
|
||||
|
@ -275,6 +279,7 @@ var validOIDCScopes = []string{oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeProf
|
|||
var validOIDCGrantTypes = []string{"implicit", "refresh_token", "authorization_code", "password", "client_credentials"}
|
||||
var validOIDCResponseModes = []string{"form_post", "query", "fragment"}
|
||||
var validOIDCUserinfoAlgorithms = []string{"none", "RS256"}
|
||||
var validOIDCCORSEndpoints = []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint}
|
||||
|
||||
var reKeyReplacer = regexp.MustCompile(`\[\d+]`)
|
||||
|
||||
|
@ -471,6 +476,9 @@ var ValidKeys = []string{
|
|||
"identity_providers.oidc.enable_pkce_plain_challenge",
|
||||
"identity_providers.oidc.enable_client_debug_messages",
|
||||
"identity_providers.oidc.minimum_parameter_entropy",
|
||||
"identity_providers.oidc.cors.endpoints",
|
||||
"identity_providers.oidc.cors.allowed_origins",
|
||||
"identity_providers.oidc.cors.enable_origins_from_clients",
|
||||
"identity_providers.oidc.clients",
|
||||
"identity_providers.oidc.clients[].id",
|
||||
"identity_providers.oidc.clients[].description",
|
||||
|
|
|
@ -49,6 +49,7 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
|
|||
validator.Push(fmt.Errorf(errFmtOIDCEnforcePKCEInvalidValue, config.EnforcePKCE))
|
||||
}
|
||||
|
||||
validateOIDCOptionsCORS(config, validator)
|
||||
validateOIDCClients(config, validator)
|
||||
|
||||
if len(config.Clients) == 0 {
|
||||
|
@ -57,6 +58,64 @@ func validateOIDC(config *schema.OpenIDConnectConfiguration, validator *schema.S
|
|||
}
|
||||
}
|
||||
|
||||
func validateOIDCOptionsCORS(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||
validateOIDCOptionsCORSAllowedOrigins(config, validator)
|
||||
|
||||
if config.CORS.AllowedOriginsFromClientRedirectURIs {
|
||||
validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config)
|
||||
}
|
||||
|
||||
validateOIDCOptionsCORSEndpoints(config, validator)
|
||||
}
|
||||
|
||||
func validateOIDCOptionsCORSAllowedOrigins(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||
for _, origin := range config.CORS.AllowedOrigins {
|
||||
if origin.String() == "*" {
|
||||
if len(config.CORS.AllowedOrigins) != 1 {
|
||||
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcard))
|
||||
}
|
||||
|
||||
if config.CORS.AllowedOriginsFromClientRedirectURIs {
|
||||
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOriginWildcardWithClients))
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if origin.Path != "" {
|
||||
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "path"))
|
||||
}
|
||||
|
||||
if origin.RawQuery != "" {
|
||||
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidOrigin, origin.String(), "query string"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateOIDCOptionsCORSAllowedOriginsFromClientRedirectURIs(config *schema.OpenIDConnectConfiguration) {
|
||||
for _, client := range config.Clients {
|
||||
for _, redirectURI := range client.RedirectURIs {
|
||||
uri, err := url.Parse(redirectURI)
|
||||
if err != nil || (uri.Scheme != schemeHTTP && uri.Scheme != schemeHTTPS) || uri.Hostname() == "localhost" {
|
||||
continue
|
||||
}
|
||||
|
||||
origin := utils.OriginFromURL(*uri)
|
||||
|
||||
if !utils.IsURLInSlice(origin, config.CORS.AllowedOrigins) {
|
||||
config.CORS.AllowedOrigins = append(config.CORS.AllowedOrigins, origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func validateOIDCOptionsCORSEndpoints(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||
for _, endpoint := range config.CORS.Endpoints {
|
||||
if !utils.IsStringInSlice(endpoint, validOIDCCORSEndpoints) {
|
||||
validator.Push(fmt.Errorf(errFmtOIDCCORSInvalidEndpoint, endpoint, strings.Join(validOIDCCORSEndpoints, "', '")))
|
||||
}
|
||||
}
|
||||
}
|
||||
func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *schema.StructValidator) {
|
||||
invalidID, duplicateIDs := false, false
|
||||
|
||||
|
@ -97,7 +156,6 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s
|
|||
validateOIDCClientResponseTypes(c, config, validator)
|
||||
validateOIDCClientResponseModes(c, config, validator)
|
||||
validateOIDDClientUserinfoAlgorithm(c, config, validator)
|
||||
|
||||
validateOIDCClientRedirectURIs(client, validator)
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
|
||||
|
@ -29,6 +31,54 @@ func TestShouldRaiseErrorWhenInvalidOIDCServerConfiguration(t *testing.T) {
|
|||
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
|
||||
}
|
||||
|
||||
func TestShouldNotRaiseErrorWhenCORSEndpointsValid(t *testing.T) {
|
||||
validator := schema.NewStructValidator()
|
||||
config := &schema.IdentityProvidersConfiguration{
|
||||
OIDC: &schema.OpenIDConnectConfiguration{
|
||||
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
|
||||
IssuerPrivateKey: "key-material",
|
||||
CORS: schema.OpenIDConnectCORSConfiguration{
|
||||
Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint},
|
||||
},
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||
{
|
||||
ID: "example",
|
||||
Secret: "example",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ValidateIdentityProviders(config, validator)
|
||||
|
||||
assert.Len(t, validator.Errors(), 0)
|
||||
}
|
||||
|
||||
func TestShouldRaiseErrorWhenCORSEndpointsNotValid(t *testing.T) {
|
||||
validator := schema.NewStructValidator()
|
||||
config := &schema.IdentityProvidersConfiguration{
|
||||
OIDC: &schema.OpenIDConnectConfiguration{
|
||||
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
|
||||
IssuerPrivateKey: "key-material",
|
||||
CORS: schema.OpenIDConnectCORSConfiguration{
|
||||
Endpoints: []string{oidc.AuthorizationEndpoint, oidc.TokenEndpoint, oidc.IntrospectionEndpoint, oidc.RevocationEndpoint, oidc.UserinfoEndpoint, "invalid_endpoint"},
|
||||
},
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||
{
|
||||
ID: "example",
|
||||
Secret: "example",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ValidateIdentityProviders(config, validator)
|
||||
|
||||
require.Len(t, validator.Errors(), 1)
|
||||
|
||||
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'endpoints' contains an invalid value 'invalid_endpoint': must be one of 'authorization', 'token', 'introspection', 'revocation', 'userinfo'")
|
||||
}
|
||||
|
||||
func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
|
||||
validator := schema.NewStructValidator()
|
||||
config := &schema.IdentityProvidersConfiguration{
|
||||
|
@ -47,7 +97,44 @@ func TestShouldRaiseErrorWhenOIDCPKCEEnforceValueInvalid(t *testing.T) {
|
|||
assert.EqualError(t, validator.Errors()[1], errFmtOIDCNoClientsConfigured)
|
||||
}
|
||||
|
||||
func TestShouldRaiseErrorWhenOIDCServerIssuerPrivateKeyPathInvalid(t *testing.T) {
|
||||
func TestShouldRaiseErrorWhenOIDCCORSOriginsHasInvalidValues(t *testing.T) {
|
||||
validator := schema.NewStructValidator()
|
||||
|
||||
config := &schema.IdentityProvidersConfiguration{
|
||||
OIDC: &schema.OpenIDConnectConfiguration{
|
||||
HMACSecret: "rLABDrx87et5KvRHVUgTm3pezWWd8LMN",
|
||||
IssuerPrivateKey: "key-material",
|
||||
CORS: schema.OpenIDConnectCORSConfiguration{
|
||||
AllowedOrigins: utils.URLsFromStringSlice([]string{"https://example.com/", "https://site.example.com/subpath", "https://site.example.com?example=true", "*"}),
|
||||
AllowedOriginsFromClientRedirectURIs: true,
|
||||
},
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||
{
|
||||
ID: "myclient",
|
||||
Secret: "jk12nb3klqwmnelqkwenm",
|
||||
Policy: "two_factor",
|
||||
RedirectURIs: []string{"https://example.com/oauth2_callback", "https://localhost:566/callback", "http://an.example.com/callback", "file://a/file"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ValidateIdentityProviders(config, validator)
|
||||
|
||||
require.Len(t, validator.Errors(), 6)
|
||||
assert.EqualError(t, validator.Errors()[0], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://example.com/' as it has a path: origins must only be scheme, hostname, and an optional port")
|
||||
assert.EqualError(t, validator.Errors()[1], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com/subpath' as it has a path: origins must only be scheme, hostname, and an optional port")
|
||||
assert.EqualError(t, validator.Errors()[2], "identity_providers: oidc: cors: option 'allowed_origins' contains an invalid value 'https://site.example.com?example=true' as it has a query string: origins must only be scheme, hostname, and an optional port")
|
||||
assert.EqualError(t, validator.Errors()[3], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' with more than one origin but the wildcard origin must be defined by itself")
|
||||
assert.EqualError(t, validator.Errors()[4], "identity_providers: oidc: cors: option 'allowed_origins' contains the wildcard origin '*' cannot be specified with option 'allowed_origins_from_client_redirect_uris' enabled")
|
||||
assert.EqualError(t, validator.Errors()[5], "identity_providers: oidc: client 'myclient': option 'redirect_uris' has an invalid value: redirect uri 'file://a/file' must have a scheme of 'http' or 'https' but 'file' is configured")
|
||||
|
||||
require.Len(t, config.OIDC.CORS.AllowedOrigins, 6)
|
||||
assert.Equal(t, "*", config.OIDC.CORS.AllowedOrigins[3].String())
|
||||
assert.Equal(t, "https://example.com", config.OIDC.CORS.AllowedOrigins[4].String())
|
||||
}
|
||||
|
||||
func TestShouldRaiseErrorWhenOIDCServerNoClients(t *testing.T) {
|
||||
validator := schema.NewStructValidator()
|
||||
config := &schema.IdentityProvidersConfiguration{
|
||||
OIDC: &schema.OpenIDConnectConfiguration{
|
||||
|
|
|
@ -72,16 +72,6 @@ const (
|
|||
auth = "auth"
|
||||
)
|
||||
|
||||
// OIDC constants.
|
||||
const (
|
||||
pathLegacyOpenIDConnectAuthorization = "/api/oidc/authorize"
|
||||
pathLegacyOpenIDConnectIntrospection = "/api/oidc/introspect"
|
||||
pathLegacyOpenIDConnectRevocation = "/api/oidc/revoke"
|
||||
|
||||
// Note: If you change this const you must also do so in the frontend at web/src/services/Api.ts.
|
||||
pathOpenIDConnectConsent = "/api/oidc/consent"
|
||||
)
|
||||
|
||||
const (
|
||||
accept = "accept"
|
||||
reject = "reject"
|
||||
|
|
|
@ -6,7 +6,8 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
)
|
||||
|
||||
func oidcJWKs(ctx *middlewares.AutheliaCtx) {
|
||||
// JSONWebKeySetGET returns the JSON Web Key Set. Used in OAuth 2.0 and OpenID Connect 1.0.
|
||||
func JSONWebKeySetGET(ctx *middlewares.AutheliaCtx) {
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
if err := json.NewEncoder(ctx).Encode(ctx.Providers.OpenIDConnect.KeyManager.GetKeySet()); err != nil {
|
|
@ -9,7 +9,10 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
)
|
||||
|
||||
func oidcIntrospection(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
// OAuthIntrospectionPOST handles POST requests to the OAuth 2.0 Introspection endpoint.
|
||||
//
|
||||
// https://datatracker.ietf.org/doc/html/rfc7662
|
||||
func OAuthIntrospectionPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
var (
|
||||
responder fosite.IntrospectionResponder
|
||||
err error
|
|
@ -8,7 +8,10 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
)
|
||||
|
||||
func oidcRevocation(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
// OAuthRevocationPOST handles POST requests to the OAuth 2.0 Revocation endpoint.
|
||||
//
|
||||
// https://datatracker.ietf.org/doc/html/rfc7009
|
||||
func OAuthRevocationPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
var err error
|
||||
|
||||
if err = ctx.Providers.OpenIDConnect.Fosite.NewRevocationRequest(ctx, req); err != nil {
|
|
@ -16,7 +16,10 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
func oidcAuthorization(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
|
||||
// OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
|
||||
func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
requester fosite.AuthorizeRequester
|
||||
responder fosite.AuthorizeResponder
|
||||
|
|
|
@ -7,7 +7,8 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
)
|
||||
|
||||
func oidcConsent(ctx *middlewares.AutheliaCtx) {
|
||||
// OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect.
|
||||
func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) {
|
||||
userSession := ctx.GetSession()
|
||||
|
||||
if userSession.OIDCWorkflowSession == nil {
|
||||
|
@ -39,7 +40,8 @@ func oidcConsent(ctx *middlewares.AutheliaCtx) {
|
|||
}
|
||||
}
|
||||
|
||||
func oidcConsentPOST(ctx *middlewares.AutheliaCtx) {
|
||||
// OpenIDConnectConsentPOST handles consent responses for OpenID Connect.
|
||||
func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
|
||||
userSession := ctx.GetSession()
|
||||
|
||||
if userSession.OIDCWorkflowSession == nil {
|
||||
|
|
|
@ -9,7 +9,10 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
)
|
||||
|
||||
func oidcToken(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
// OpenIDConnectTokenPOST handles POST requests to the OpenID Connect 1.0 Token endpoint.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
|
||||
func OpenIDConnectTokenPOST(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
var (
|
||||
requester fosite.AccessRequester
|
||||
responder fosite.AccessResponder
|
||||
|
|
|
@ -14,7 +14,10 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
)
|
||||
|
||||
func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
// OpenIDConnectUserinfo handles GET/POST requests to the OpenID Connect 1.0 UserInfo endpoint.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *http.Request) {
|
||||
var (
|
||||
tokenType fosite.TokenType
|
||||
requester fosite.AccessRequester
|
||||
|
@ -97,7 +100,7 @@ func oidcUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, req *htt
|
|||
var jti uuid.UUID
|
||||
|
||||
if jti, err = uuid.NewRandom(); err != nil {
|
||||
ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JWT ID."))
|
||||
ctx.Providers.OpenIDConnect.WriteError(rw, req, fosite.ErrServerError.WithHintf("Could not generate JTI."))
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -8,7 +8,13 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
)
|
||||
|
||||
func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
|
||||
// OpenIDConnectConfigurationWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
||||
// OpenID Connect Discovery 1.0 metadata.
|
||||
//
|
||||
// https://datatracker.ietf.org/doc/html/rfc5785
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
func OpenIDConnectConfigurationWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||
issuer, err := ctx.ExternalRootURL()
|
||||
if err != nil {
|
||||
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
|
||||
|
@ -30,7 +36,13 @@ func wellKnownOpenIDConnectConfigurationGET(ctx *middlewares.AutheliaCtx) {
|
|||
}
|
||||
}
|
||||
|
||||
func wellKnownOAuthAuthorizationServerGET(ctx *middlewares.AutheliaCtx) {
|
||||
// OAuthAuthorizationServerWellKnownGET handles requests to a .well-known endpoint (RFC5785) which returns the
|
||||
// OAuth 2.0 Authorization Server Metadata (RFC8414).
|
||||
//
|
||||
// https://datatracker.ietf.org/doc/html/rfc5785
|
||||
//
|
||||
// https://datatracker.ietf.org/doc/html/rfc8414
|
||||
func OAuthAuthorizationServerWellKnownGET(ctx *middlewares.AutheliaCtx) {
|
||||
issuer, err := ctx.ExternalRootURL()
|
||||
if err != nil {
|
||||
ctx.Logger.Errorf("Error occurred determining OpenID Connect issuer details: %+v", err)
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/fasthttp/router"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
)
|
||||
|
||||
// RegisterOIDC registers the handlers with the fasthttp *router.Router. TODO: Add paths for Flush, Logout.
|
||||
func RegisterOIDC(router *router.Router, middleware middlewares.RequestHandlerBridge) {
|
||||
// TODO: Add OPTIONS handler.
|
||||
router.GET(oidc.WellKnownOpenIDConfigurationPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOpenIDConnectConfigurationGET)))
|
||||
router.GET(oidc.WellKnownOAuthAuthorizationServerPath, middleware(middlewares.CORSApplyAutomaticAllowAllPolicy(wellKnownOAuthAuthorizationServerGET)))
|
||||
|
||||
router.GET(pathOpenIDConnectConsent, middleware(oidcConsent))
|
||||
|
||||
router.POST(pathOpenIDConnectConsent, middleware(oidcConsentPOST))
|
||||
|
||||
router.GET(oidc.JWKsPath, middleware(oidcJWKs))
|
||||
|
||||
router.GET(oidc.AuthorizationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
|
||||
router.GET(pathLegacyOpenIDConnectAuthorization, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcAuthorization)))
|
||||
|
||||
// TODO: Add OPTIONS handler.
|
||||
router.POST(oidc.TokenPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcToken)))
|
||||
|
||||
router.POST(oidc.IntrospectionPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
|
||||
router.GET(pathLegacyOpenIDConnectIntrospection, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcIntrospection)))
|
||||
|
||||
router.GET(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
|
||||
router.POST(oidc.UserinfoPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcUserinfo)))
|
||||
|
||||
// TODO: Add OPTIONS handler.
|
||||
router.POST(oidc.RevocationPath, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
|
||||
router.POST(pathLegacyOpenIDConnectRevocation, middleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(oidcRevocation)))
|
||||
}
|
|
@ -7,18 +7,22 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
headerAccept = []byte(fasthttp.HeaderAccept)
|
||||
headerContentLength = []byte(fasthttp.HeaderContentLength)
|
||||
|
||||
headerXForwardedProto = []byte(fasthttp.HeaderXForwardedProto)
|
||||
headerXForwardedHost = []byte(fasthttp.HeaderXForwardedHost)
|
||||
headerXForwardedFor = []byte(fasthttp.HeaderXForwardedFor)
|
||||
headerXRequestedWith = []byte(fasthttp.HeaderXRequestedWith)
|
||||
headerAccept = []byte(fasthttp.HeaderAccept)
|
||||
|
||||
headerXForwardedURI = []byte("X-Forwarded-URI")
|
||||
headerXOriginalURL = []byte("X-Original-URL")
|
||||
headerXForwardedMethod = []byte("X-Forwarded-Method")
|
||||
|
||||
headerVary = []byte(fasthttp.HeaderVary)
|
||||
headerOrigin = []byte(fasthttp.HeaderOrigin)
|
||||
headerVary = []byte(fasthttp.HeaderVary)
|
||||
headerAllow = []byte(fasthttp.HeaderAllow)
|
||||
headerOrigin = []byte(fasthttp.HeaderOrigin)
|
||||
|
||||
headerAccessControlAllowCredentials = []byte(fasthttp.HeaderAccessControlAllowCredentials)
|
||||
headerAccessControlAllowHeaders = []byte(fasthttp.HeaderAccessControlAllowHeaders)
|
||||
headerAccessControlAllowMethods = []byte(fasthttp.HeaderAccessControlAllowMethods)
|
||||
|
@ -29,9 +33,13 @@ var (
|
|||
)
|
||||
|
||||
var (
|
||||
headerValueFalse = []byte("false")
|
||||
headerValueMaxAge = []byte("100")
|
||||
headerValueVary = []byte("Accept-Encoding, Origin")
|
||||
headerValueFalse = []byte("false")
|
||||
headerValueTrue = []byte("true")
|
||||
headerValueMaxAge = []byte("100")
|
||||
headerValueVary = []byte("Accept-Encoding, Origin")
|
||||
headerValueVaryWildcard = []byte("Accept-Encoding")
|
||||
headerValueOriginWildcard = []byte("*")
|
||||
headerValueZero = []byte("0")
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -40,6 +48,8 @@ var (
|
|||
|
||||
// UserValueKeyBaseURL is the User Value key where we store the Base URL.
|
||||
UserValueKeyBaseURL = []byte("base_url")
|
||||
|
||||
headerSeparator = []byte(", ")
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -1,53 +1,347 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// CORSApplyAutomaticAllowAllPolicy applies a CORS policy that automatically grants all Origins as well
|
||||
// as all Request Headers other than Cookie and *. It does not allow credentials, and has a max age of 100. Vary is applied
|
||||
// to both Accept-Encoding and Origin. It grants the GET Request Method only.
|
||||
func CORSApplyAutomaticAllowAllPolicy(next RequestHandler) RequestHandler {
|
||||
return func(ctx *AutheliaCtx) {
|
||||
if origin := ctx.Request.Header.PeekBytes(headerOrigin); origin != nil {
|
||||
corsApplyAutomaticAllowAllPolicy(&ctx.Request, &ctx.Response, origin)
|
||||
// NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header
|
||||
// with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme
|
||||
// it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header
|
||||
// effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS
|
||||
// requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by
|
||||
// Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special *
|
||||
// values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets
|
||||
// the Access-Control-Max-Age header to 100.
|
||||
//
|
||||
// These behaviours can be overridden by the With methods on the returned policy.
|
||||
func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) {
|
||||
return &CORSPolicyBuilder{
|
||||
enabled: true,
|
||||
maxAge: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be
|
||||
// configured. It aims to simplify CORS configurations.
|
||||
type CORSPolicyBuilder struct {
|
||||
enabled bool
|
||||
varyOnly bool
|
||||
varySet bool
|
||||
methods []string
|
||||
headers []string
|
||||
origins []string
|
||||
credentials bool
|
||||
vary []string
|
||||
maxAge int
|
||||
}
|
||||
|
||||
// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy.
|
||||
func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) {
|
||||
policy = &CORSPolicy{
|
||||
enabled: b.enabled,
|
||||
varyOnly: b.varyOnly,
|
||||
credentials: []byte(strconv.FormatBool(b.credentials)),
|
||||
origins: b.buildOrigins(),
|
||||
headers: b.buildHeaders(),
|
||||
vary: b.buildVary(),
|
||||
}
|
||||
|
||||
if len(b.methods) != 0 {
|
||||
policy.methods = []byte(strings.Join(b.methods, ", "))
|
||||
}
|
||||
|
||||
if b.maxAge <= 0 {
|
||||
policy.maxAge = headerValueMaxAge
|
||||
} else {
|
||||
policy.maxAge = []byte(strconv.Itoa(b.maxAge))
|
||||
}
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) {
|
||||
if len(b.origins) != 0 {
|
||||
if len(b.origins) == 1 && b.origins[0] == "*" {
|
||||
origins = append(origins, []byte(b.origins[0]))
|
||||
} else {
|
||||
for _, origin := range b.origins {
|
||||
origins = append(origins, []byte(origin))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return origins
|
||||
}
|
||||
|
||||
func (b CORSPolicyBuilder) buildHeaders() (headers []byte) {
|
||||
if len(b.headers) != 0 {
|
||||
h := b.headers
|
||||
|
||||
if b.credentials {
|
||||
if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) {
|
||||
h = append(h, fasthttp.HeaderCookie)
|
||||
}
|
||||
|
||||
if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) {
|
||||
h = append(h, fasthttp.HeaderAuthorization)
|
||||
}
|
||||
|
||||
if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) {
|
||||
h = append(h, fasthttp.HeaderProxyAuthorization)
|
||||
}
|
||||
}
|
||||
|
||||
headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...)
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
func (b CORSPolicyBuilder) buildVary() (vary []byte) {
|
||||
if b.varySet {
|
||||
if len(b.vary) != 0 {
|
||||
vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...)
|
||||
}
|
||||
} else {
|
||||
if len(b.origins) == 1 && b.origins[0] == "*" {
|
||||
vary = headerValueVaryWildcard
|
||||
} else {
|
||||
vary = headerValueVary
|
||||
}
|
||||
}
|
||||
|
||||
return vary
|
||||
}
|
||||
|
||||
// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this
|
||||
// value will be true but this function can override the value. Setting it to false prevents the middleware from adding
|
||||
// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS
|
||||
// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP
|
||||
// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the
|
||||
// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method.
|
||||
func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) {
|
||||
b.enabled = enabled
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with
|
||||
// that value.
|
||||
func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) {
|
||||
b.methods = methods
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these.
|
||||
func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) {
|
||||
b.origins = origins
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header.
|
||||
func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) {
|
||||
b.headers = headers
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header.
|
||||
func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) {
|
||||
b.credentials = allow
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithVary takes a list of header strings and alters the default Vary header.
|
||||
func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) {
|
||||
b.vary = headers
|
||||
b.varySet = true
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithVaryOnly just adds the Vary header.
|
||||
func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) {
|
||||
b.varyOnly = varyOnly
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header.
|
||||
func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) {
|
||||
b.maxAge = age
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// CORSPolicy is a middleware that handles adding CORS headers.
|
||||
type CORSPolicy struct {
|
||||
enabled bool
|
||||
varyOnly bool
|
||||
methods []byte
|
||||
headers []byte
|
||||
origins [][]byte
|
||||
credentials []byte
|
||||
vary []byte
|
||||
maxAge []byte
|
||||
}
|
||||
|
||||
// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204
|
||||
// without a body. This handler should generally not be used without using WithAllowedMethods.
|
||||
func (p CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) {
|
||||
p.handleOPTIONS(ctx)
|
||||
p.handle(ctx)
|
||||
}
|
||||
|
||||
// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204
|
||||
// without a body. This handler should generally not be used without using WithAllowedMethods.
|
||||
func (p CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) {
|
||||
p.handleOPTIONS(ctx)
|
||||
}
|
||||
|
||||
// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder.
|
||||
func (p CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
p.handle(ctx)
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func corsApplyAutomaticAllowAllPolicy(req *fasthttp.Request, resp *fasthttp.Response, origin []byte) {
|
||||
originURL, err := url.Parse(string(origin))
|
||||
if err != nil || originURL.Scheme != "https" {
|
||||
func (p CORSPolicy) handle(ctx *fasthttp.RequestCtx) {
|
||||
if !p.enabled {
|
||||
return
|
||||
}
|
||||
|
||||
resp.Header.SetBytesKV(headerVary, headerValueVary)
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowOrigin, origin)
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowCredentials, headerValueFalse)
|
||||
resp.Header.SetBytesKV(headerAccessControlMaxAge, headerValueMaxAge)
|
||||
p.handleVary(ctx)
|
||||
|
||||
if headers := req.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
|
||||
requestedHeaders := strings.Split(string(headers), ",")
|
||||
allowHeaders := make([]string, len(requestedHeaders))
|
||||
if !p.varyOnly {
|
||||
p.handleCORS(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
for i, header := range requestedHeaders {
|
||||
headerTrimmed := strings.Trim(header, " ")
|
||||
if !strings.EqualFold("*", headerTrimmed) && !strings.EqualFold("Cookie", headerTrimmed) {
|
||||
allowHeaders[i] = headerTrimmed
|
||||
func (p CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) {
|
||||
ctx.Response.ResetBody()
|
||||
|
||||
/* The OPTIONS method should not return a 204 as per the following specifications when read together:
|
||||
|
||||
RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7):
|
||||
A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in
|
||||
the response.
|
||||
|
||||
RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2):
|
||||
A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational)
|
||||
or 204 (No Content).
|
||||
*/
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero)
|
||||
|
||||
if len(p.methods) != 0 {
|
||||
ctx.Response.Header.SetBytesKV(headerAllow, p.methods)
|
||||
}
|
||||
}
|
||||
|
||||
func (p CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) {
|
||||
if len(p.vary) != 0 {
|
||||
ctx.Response.Header.SetBytesKV(headerVary, p.vary)
|
||||
}
|
||||
}
|
||||
|
||||
func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
|
||||
var (
|
||||
originURL *url.URL
|
||||
err error
|
||||
)
|
||||
|
||||
origin := ctx.Request.Header.PeekBytes(headerOrigin)
|
||||
|
||||
// Skip processing of any `https` scheme URL that has not expressly been configured.
|
||||
if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
|
||||
return
|
||||
}
|
||||
|
||||
var allowedOrigin []byte
|
||||
|
||||
switch len(p.origins) {
|
||||
case 0:
|
||||
allowedOrigin = origin
|
||||
default:
|
||||
for i := 0; i < len(p.origins); i++ {
|
||||
if bytes.Equal(p.origins[i], headerValueOriginWildcard) {
|
||||
allowedOrigin = headerValueOriginWildcard
|
||||
} else if bytes.Equal(p.origins[i], origin) {
|
||||
allowedOrigin = origin
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowHeaders) != 0 {
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
|
||||
if len(allowedOrigin) == 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if requestMethods := req.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
|
||||
resp.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin)
|
||||
|
||||
if len(p.credentials) != 0 {
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials)
|
||||
}
|
||||
|
||||
if len(p.maxAge) != 0 {
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge)
|
||||
}
|
||||
|
||||
p.handleAllowedHeaders(ctx)
|
||||
p.handleAllowedMethods(ctx)
|
||||
}
|
||||
|
||||
func (p CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) {
|
||||
switch len(p.methods) {
|
||||
case 0:
|
||||
// TODO: It may be beneficial to be able to control this automatic behaviour.
|
||||
if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
|
||||
}
|
||||
default:
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods)
|
||||
}
|
||||
}
|
||||
|
||||
func (p CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) {
|
||||
switch len(p.headers) {
|
||||
case 0:
|
||||
// TODO: It may be beneficial to be able to control this automatic behaviour.
|
||||
if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
|
||||
requestedHeaders := strings.Split(string(headers), ",")
|
||||
allowHeaders := make([]string, 0, len(requestedHeaders))
|
||||
|
||||
for i := 0; i < len(requestedHeaders); i++ {
|
||||
headerTrimmed := strings.Trim(requestedHeaders[i], " ")
|
||||
|
||||
if headerTrimmed == "*" {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(p.credentials, headerValueTrue) ||
|
||||
(!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) &&
|
||||
!strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) &&
|
||||
!strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) {
|
||||
allowHeaders = append(allowHeaders, headerTrimmed)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allowHeaders) != 0 {
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
|
||||
}
|
||||
}
|
||||
default:
|
||||
ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,61 +5,587 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.Response{}
|
||||
func TestNewCORSMiddleware(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Equal(t, 100, cors.maxAge)
|
||||
assert.Equal(t, false, cors.credentials)
|
||||
|
||||
assert.Nil(t, cors.methods)
|
||||
assert.Nil(t, cors.origins)
|
||||
assert.Nil(t, cors.headers)
|
||||
assert.Nil(t, cors.vary)
|
||||
assert.False(t, cors.varyOnly)
|
||||
assert.False(t, cors.varySet)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithEnabled(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.True(t, cors.enabled)
|
||||
|
||||
cors.WithEnabled(false)
|
||||
assert.False(t, cors.enabled)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithVary(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Nil(t, cors.vary)
|
||||
assert.False(t, cors.varyOnly)
|
||||
assert.False(t, cors.varySet)
|
||||
|
||||
cors.WithVary()
|
||||
assert.Nil(t, cors.vary)
|
||||
assert.False(t, cors.varyOnly)
|
||||
assert.True(t, cors.varySet)
|
||||
|
||||
cors.WithVary("Origin", "Example", "Test")
|
||||
|
||||
assert.Equal(t, []string{"Origin", "Example", "Test"}, cors.vary)
|
||||
assert.False(t, cors.varyOnly)
|
||||
assert.True(t, cors.varySet)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithAllowedMethods(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Nil(t, cors.methods)
|
||||
|
||||
cors.WithAllowedMethods("GET")
|
||||
|
||||
assert.Equal(t, []string{"GET"}, cors.methods)
|
||||
|
||||
cors.WithAllowedMethods("POST", "PATCH")
|
||||
|
||||
assert.Equal(t, []string{"POST", "PATCH"}, cors.methods)
|
||||
|
||||
cors.WithAllowedMethods()
|
||||
|
||||
assert.Nil(t, cors.methods)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithAllowedOrigins(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Nil(t, cors.origins)
|
||||
|
||||
cors.WithAllowedOrigins("https://google.com", "http://localhost")
|
||||
|
||||
assert.Equal(t, []string{"https://google.com", "http://localhost"}, cors.origins)
|
||||
|
||||
cors.WithAllowedOrigins()
|
||||
|
||||
assert.Nil(t, cors.origins)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithAllowedHeaders(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Nil(t, cors.headers)
|
||||
|
||||
cors.WithAllowedHeaders("Example", "Another")
|
||||
|
||||
assert.Equal(t, []string{"Example", "Another"}, cors.headers)
|
||||
|
||||
cors.WithAllowedHeaders()
|
||||
|
||||
assert.Nil(t, cors.headers)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithAllowCredentials(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Equal(t, false, cors.credentials)
|
||||
|
||||
cors.WithAllowCredentials(false)
|
||||
|
||||
assert.Equal(t, false, cors.credentials)
|
||||
|
||||
cors.WithAllowCredentials(true)
|
||||
|
||||
assert.Equal(t, true, cors.credentials)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithVaryOnly(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.False(t, cors.varyOnly)
|
||||
|
||||
cors.WithVaryOnly(false)
|
||||
|
||||
assert.False(t, cors.varyOnly)
|
||||
|
||||
cors.WithVaryOnly(true)
|
||||
|
||||
cors.WithVaryOnly(true)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithMaxAge(t *testing.T) {
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Equal(t, 100, cors.maxAge)
|
||||
|
||||
cors.WithMaxAge(20)
|
||||
|
||||
assert.Equal(t, 20, cors.maxAge)
|
||||
|
||||
cors.WithMaxAge(0)
|
||||
|
||||
assert.Equal(t, 0, cors.maxAge)
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_HandleOPTIONS(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
||||
cors := NewCORSPolicyBuilder()
|
||||
policy := cors.Build()
|
||||
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOnlyOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithEnabled(false)
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_HandleOPTIONS_WithoutOrigin(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
policy := cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
|
||||
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedOrigins(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
cors.WithAllowedOrigins("https://myapp.example.com")
|
||||
|
||||
policy := cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithAllowedOrigins("https://anotherapp.example.com")
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithAllowedOrigins("*")
|
||||
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_WithAllowedOrigins_DoesntOverrideVary(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
cors.WithVary("Accept-Encoding", "Origin", "Test")
|
||||
cors.WithAllowedOrigins("*")
|
||||
|
||||
policy := cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin, Test"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, headerValueOriginWildcard, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_HandleOPTIONSWithVaryOnly(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
cors.WithVaryOnly(true)
|
||||
|
||||
policy := cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_HandleOPTIONSWithAllowedHeaders(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
cors.WithAllowedHeaders("Example", "Test")
|
||||
|
||||
policy := cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithAllowedMethods("GET", "OPTIONS")
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("Example, Test"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
|
||||
ctx = newFastHTTPRequestCtx()
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors.WithAllowCredentials(true)
|
||||
|
||||
policy = cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueTrue, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("Example, Test, Cookie, Authorization, Proxy-Authorization"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func TestCORSPolicyBuilder_HandleOPTIONS_ShouldNotAllowWildcardInRequestedHeaders(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "*")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
policy := cors.Build()
|
||||
policy.HandleOPTIONS(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, headerValueZero, ctx.Response.Header.PeekBytes(headerContentLength))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAllow))
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithoutRequestMethod(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
policy := cors.Build()
|
||||
policy.handle(ctx)
|
||||
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_WithRequestMethod(t *testing.T) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.Response{}
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
|
||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
||||
cors := NewCORSPolicyBuilder()
|
||||
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), resp.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, resp.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET"), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
policy := cors.Build()
|
||||
policy.handle(ctx)
|
||||
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func Test_CORSApplyAutomaticAllowAllPolicy_ShouldNotModifyFotNonHTTPSRequests(t *testing.T) {
|
||||
req := fasthttp.AcquireRequest()
|
||||
|
||||
resp := fasthttp.Response{}
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("http://myapp.example.com")
|
||||
|
||||
req.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
req.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
|
||||
corsApplyAutomaticAllowAllPolicy(req, &resp, origin)
|
||||
cors := NewCORSPolicyBuilder().WithVary()
|
||||
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), resp.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
policy := cors.Build()
|
||||
policy.handle(ctx)
|
||||
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte(nil), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func Test_CORSMiddleware_AsMiddleware(t *testing.T) {
|
||||
ctx := newFastHTTPRequestCtx()
|
||||
|
||||
origin := []byte("https://myapp.example.com")
|
||||
|
||||
ctx.Request.Header.SetBytesKV(headerOrigin, origin)
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestHeaders, "X-Example-Header")
|
||||
ctx.Request.Header.SetBytesK(headerAccessControlRequestMethod, "GET")
|
||||
|
||||
autheliaMiddleware := AutheliaMiddleware(schema.Configuration{}, Providers{})
|
||||
|
||||
cors := NewCORSPolicyBuilder().WithAllowedMethods("GET", "OPTIONS")
|
||||
|
||||
policy := cors.Build()
|
||||
|
||||
route := policy.Middleware(autheliaMiddleware(testNilHandler))
|
||||
|
||||
route(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, []byte("Accept-Encoding, Origin"), ctx.Response.Header.PeekBytes(headerVary))
|
||||
assert.Equal(t, origin, ctx.Response.Header.PeekBytes(headerAccessControlAllowOrigin))
|
||||
assert.Equal(t, headerValueFalse, ctx.Response.Header.PeekBytes(headerAccessControlAllowCredentials))
|
||||
assert.Equal(t, headerValueMaxAge, ctx.Response.Header.PeekBytes(headerAccessControlMaxAge))
|
||||
assert.Equal(t, []byte("X-Example-Header"), ctx.Response.Header.PeekBytes(headerAccessControlAllowHeaders))
|
||||
assert.Equal(t, []byte("GET, OPTIONS"), ctx.Response.Header.PeekBytes(headerAccessControlAllowMethods))
|
||||
}
|
||||
|
||||
func testNilHandler(_ *AutheliaCtx) {}
|
||||
|
||||
func newFastHTTPRequestCtx() (ctx *fasthttp.RequestCtx) {
|
||||
return &fasthttp.RequestCtx{
|
||||
Request: fasthttp.Request{},
|
||||
Response: fasthttp.Response{},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,17 +19,28 @@ const (
|
|||
ClaimEmailAlts = "alt_emails"
|
||||
)
|
||||
|
||||
// Endpoints.
|
||||
const (
|
||||
AuthorizationEndpoint = "authorization"
|
||||
TokenEndpoint = "token"
|
||||
UserinfoEndpoint = "userinfo"
|
||||
IntrospectionEndpoint = "introspection"
|
||||
RevocationEndpoint = "revocation"
|
||||
)
|
||||
|
||||
// Paths.
|
||||
const (
|
||||
WellKnownOpenIDConfigurationPath = "/.well-known/openid-configuration"
|
||||
WellKnownOAuthAuthorizationServerPath = "/.well-known/oauth-authorization-server"
|
||||
JWKsPath = "/jwks.json"
|
||||
|
||||
JWKsPath = "/api/oidc/jwks"
|
||||
AuthorizationPath = "/api/oidc/authorization"
|
||||
TokenPath = "/api/oidc/token" //nolint:gosec // This is not a hard coded credential, it's a path.
|
||||
IntrospectionPath = "/api/oidc/introspection"
|
||||
RevocationPath = "/api/oidc/revocation"
|
||||
UserinfoPath = "/api/oidc/userinfo"
|
||||
RootPath = "/api/oidc"
|
||||
|
||||
AuthorizationPath = RootPath + "/" + AuthorizationEndpoint
|
||||
TokenPath = RootPath + "/" + TokenEndpoint
|
||||
UserinfoPath = RootPath + "/" + UserinfoEndpoint
|
||||
IntrospectionPath = RootPath + "/" + IntrospectionEndpoint
|
||||
RevocationPath = RootPath + "/" + RevocationEndpoint
|
||||
)
|
||||
|
||||
// Authentication Method Reference Values https://datatracker.ietf.org/doc/html/rfc8176
|
||||
|
|
|
@ -87,7 +87,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
|
|||
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
|
||||
|
||||
assert.Equal(t, "https://example.com", disco.Issuer)
|
||||
assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI)
|
||||
assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
|
||||
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
|
||||
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
|
||||
assert.Equal(t, "https://example.com/api/oidc/userinfo", disco.UserinfoEndpoint)
|
||||
|
@ -173,7 +173,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig
|
|||
disco := provider.GetOAuth2WellKnownConfiguration("https://example.com")
|
||||
|
||||
assert.Equal(t, "https://example.com", disco.Issuer)
|
||||
assert.Equal(t, "https://example.com/api/oidc/jwks", disco.JWKSURI)
|
||||
assert.Equal(t, "https://example.com/jwks.json", disco.JWKSURI)
|
||||
assert.Equal(t, "https://example.com/api/oidc/authorization", disco.AuthorizationEndpoint)
|
||||
assert.Equal(t, "https://example.com/api/oidc/token", disco.TokenEndpoint)
|
||||
assert.Equal(t, "https://example.com/api/oidc/introspection", disco.IntrospectionEndpoint)
|
||||
|
|
|
@ -37,16 +37,17 @@ var (
|
|||
{name: "/api", prefix: "/api/"},
|
||||
{name: "/.well-known", prefix: "/.well-known/"},
|
||||
{name: "/static", prefix: "/static/"},
|
||||
{name: "/locales", prefix: "/locales/"},
|
||||
}
|
||||
)
|
||||
|
||||
const schemeHTTP = "http"
|
||||
const schemeHTTPS = "https"
|
||||
|
||||
const (
|
||||
dev = "dev"
|
||||
f = "false"
|
||||
t = "true"
|
||||
dev = "dev"
|
||||
f = "false"
|
||||
t = "true"
|
||||
localhost = "localhost"
|
||||
schemeHTTP = "http"
|
||||
schemeHTTPS = "https"
|
||||
)
|
||||
|
||||
const healthCheckEnv = `# Written by Authelia Process
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
// Replacement for the default error handler in fasthttp.
|
||||
func autheliaErrorHandler(ctx *fasthttp.RequestCtx, err error) {
|
||||
logger := logging.Logger()
|
||||
|
||||
if _, ok := err.(*fasthttp.ErrSmallBuffer); ok {
|
||||
// Note: Getting X-Forwarded-For or Request URI is impossible for ths error.
|
||||
logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
|
||||
ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
|
||||
} else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() {
|
||||
// TODO: Add X-Forwarded-For Check here.
|
||||
logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
|
||||
ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
|
||||
} else {
|
||||
// TODO: Add X-Forwarded-For Check here.
|
||||
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
|
||||
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/handlers"
|
||||
)
|
||||
|
||||
func handleNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
path := strings.ToLower(string(ctx.Path()))
|
||||
|
||||
for i := 0; i < len(httpServerDirs); i++ {
|
||||
if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
|
||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
56
internal/server/handlers.go
Normal file
56
internal/server/handlers.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/handlers"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
// Replacement for the default error handler in fasthttp.
|
||||
func handlerErrors(ctx *fasthttp.RequestCtx, err error) {
|
||||
logger := logging.Logger()
|
||||
|
||||
switch e := err.(type) {
|
||||
case *fasthttp.ErrSmallBuffer:
|
||||
logger.Tracef("Request was too large to handle from client %s. Response Code %d.", ctx.RemoteIP().String(), fasthttp.StatusRequestHeaderFieldsTooLarge)
|
||||
ctx.Error("request header too large", fasthttp.StatusRequestHeaderFieldsTooLarge)
|
||||
case *net.OpError:
|
||||
if e.Timeout() {
|
||||
// TODO: Add X-Forwarded-For Check here.
|
||||
logger.Tracef("Request timeout occurred while handling from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusRequestTimeout)
|
||||
ctx.Error("request timeout", fasthttp.StatusRequestTimeout)
|
||||
} else {
|
||||
// TODO: Add X-Forwarded-For Check here.
|
||||
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
|
||||
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
|
||||
}
|
||||
default:
|
||||
// TODO: Add X-Forwarded-For Check here.
|
||||
logger.Tracef("An unknown error occurred while handling a request from client %s: %s. Response Code %d.", ctx.RemoteIP().String(), ctx.RequestURI(), fasthttp.StatusBadRequest)
|
||||
ctx.Error("error when parsing request", fasthttp.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func handlerNotFound(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
path := strings.ToLower(string(ctx.Path()))
|
||||
|
||||
for i := 0; i < len(httpServerDirs); i++ {
|
||||
if path == httpServerDirs[i].name || strings.HasPrefix(path, httpServerDirs[i].prefix) {
|
||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusNotFound)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func handlerMethodNotAllowed(ctx *fasthttp.RequestCtx) {
|
||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
)
|
||||
|
||||
func handleOPTIONS(ctx *middlewares.AutheliaCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusNoContent)
|
||||
}
|
|
@ -19,10 +19,12 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/handlers"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// TODO: move to its own file and rename configuration -> config.
|
||||
func registerRoutes(configuration schema.Configuration, providers middlewares.Providers) fasthttp.RequestHandler {
|
||||
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
|
||||
rememberMe := strconv.FormatBool(configuration.Session.RememberMeDuration != schema.RememberMeDisabled)
|
||||
resetPassword := strconv.FormatBool(!configuration.AuthenticationBackend.DisableResetPassword)
|
||||
|
||||
|
@ -33,37 +35,49 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
|||
duoSelfEnrollment = strconv.FormatBool(configuration.DuoAPI.EnableSelfEnrollment)
|
||||
}
|
||||
|
||||
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
||||
handlerLocales := newLocalesEmbeddedHandler()
|
||||
|
||||
https := configuration.Server.TLS.Key != "" && configuration.Server.TLS.Certificate != ""
|
||||
|
||||
serveIndexHandler := ServeTemplatedFile(embeddedAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
||||
serveSwaggerHandler := ServeTemplatedFile(swaggerAssets, indexFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
||||
serveSwaggerAPIHandler := ServeTemplatedFile(swaggerAssets, apiFile, configuration.Server.AssetPath, duoSelfEnrollment, rememberMe, resetPassword, resetPasswordCustomURL, configuration.Session.Name, configuration.Theme, https)
|
||||
|
||||
handlerPublicHTML := newPublicHTMLEmbeddedHandler()
|
||||
handlerLocales := newLocalesEmbeddedHandler()
|
||||
|
||||
autheliaMiddleware := middlewares.AutheliaMiddleware(configuration, providers)
|
||||
|
||||
policyCORSPublicGET := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowedMethods("OPTIONS", "GET").
|
||||
WithAllowedOrigins("*").
|
||||
Build()
|
||||
|
||||
r := router.New()
|
||||
|
||||
// Static Assets.
|
||||
r.GET("/", autheliaMiddleware(serveIndexHandler))
|
||||
r.OPTIONS("/", autheliaMiddleware(handleOPTIONS))
|
||||
|
||||
for _, f := range rootFiles {
|
||||
r.GET("/"+f, handlerPublicHTML)
|
||||
}
|
||||
|
||||
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
|
||||
r.GET("/api/"+apiFile, autheliaMiddleware(serveSwaggerAPIHandler))
|
||||
|
||||
for _, file := range swaggerFiles {
|
||||
r.GET("/api/"+file, handlerPublicHTML)
|
||||
}
|
||||
|
||||
r.GET("/favicon.ico", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerPublicHTML))
|
||||
r.GET("/static/media/logo.png", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 2, handlerPublicHTML))
|
||||
r.GET("/static/{filepath:*}", handlerPublicHTML)
|
||||
|
||||
// Locales.
|
||||
r.GET("/locales/{language:[a-z]{1,3}}-{variant:[a-z0-9-]+}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
|
||||
r.GET("/locales/{language:[a-z]{1,3}}/{namespace:[a-z]+}.json", middlewares.AssetOverrideMiddleware(configuration.Server.AssetPath, 0, handlerLocales))
|
||||
|
||||
// Swagger.
|
||||
r.GET("/api/", autheliaMiddleware(serveSwaggerHandler))
|
||||
r.OPTIONS("/api/", policyCORSPublicGET.HandleOPTIONS)
|
||||
r.GET("/api/"+apiFile, policyCORSPublicGET.Middleware(autheliaMiddleware(serveSwaggerAPIHandler)))
|
||||
r.OPTIONS("/api/"+apiFile, policyCORSPublicGET.HandleOPTIONS)
|
||||
|
||||
for _, file := range swaggerFiles {
|
||||
r.GET("/api/"+file, handlerPublicHTML)
|
||||
}
|
||||
|
||||
r.GET("/api/health", autheliaMiddleware(handlers.HealthGet))
|
||||
r.GET("/api/state", autheliaMiddleware(handlers.StateGet))
|
||||
|
||||
|
@ -161,22 +175,98 @@ func registerRoutes(configuration schema.Configuration, providers middlewares.Pr
|
|||
r.GET("/debug/vars", expvarhandler.ExpvarHandler)
|
||||
}
|
||||
|
||||
r.NotFound = handleNotFound(autheliaMiddleware(serveIndexHandler))
|
||||
if providers.OpenIDConnect.Fosite != nil {
|
||||
r.GET("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentGET))
|
||||
r.POST("/api/oidc/consent", autheliaMiddleware(handlers.OpenIDConnectConsentPOST))
|
||||
|
||||
allowedOrigins := utils.StringSliceFromURLs(configuration.IdentityProviders.OIDC.CORS.AllowedOrigins)
|
||||
|
||||
r.OPTIONS(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.HandleOPTIONS)
|
||||
r.GET(oidc.WellKnownOpenIDConfigurationPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OpenIDConnectConfigurationWellKnownGET)))
|
||||
|
||||
r.OPTIONS(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.HandleOPTIONS)
|
||||
r.GET(oidc.WellKnownOAuthAuthorizationServerPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.OAuthAuthorizationServerWellKnownGET)))
|
||||
|
||||
r.OPTIONS(oidc.JWKsPath, policyCORSPublicGET.HandleOPTIONS)
|
||||
r.GET(oidc.JWKsPath, policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
|
||||
|
||||
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
|
||||
r.OPTIONS("/api/oidc/jwks", policyCORSPublicGET.HandleOPTIONS)
|
||||
r.GET("/api/oidc/jwks", policyCORSPublicGET.Middleware(autheliaMiddleware(handlers.JSONWebKeySetGET)))
|
||||
|
||||
policyCORSAuthorization := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowedMethods("OPTIONS", "GET").
|
||||
WithAllowedOrigins(allowedOrigins...).
|
||||
WithEnabled(utils.IsStringInSlice(oidc.AuthorizationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||
Build()
|
||||
|
||||
r.OPTIONS(oidc.AuthorizationPath, policyCORSAuthorization.HandleOnlyOPTIONS)
|
||||
r.GET(oidc.AuthorizationPath, autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
|
||||
|
||||
// TODO (james-d-elliott): Remove in GA. This is a legacy endpoint.
|
||||
r.OPTIONS("/api/oidc/authorize", policyCORSAuthorization.HandleOnlyOPTIONS)
|
||||
r.GET("/api/oidc/authorize", autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectAuthorizationGET)))
|
||||
|
||||
policyCORSToken := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowCredentials(true).
|
||||
WithAllowedMethods("OPTIONS", "POST").
|
||||
WithAllowedOrigins(allowedOrigins...).
|
||||
WithEnabled(utils.IsStringInSlice(oidc.TokenEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||
Build()
|
||||
|
||||
r.OPTIONS(oidc.TokenPath, policyCORSToken.HandleOPTIONS)
|
||||
r.POST(oidc.TokenPath, policyCORSToken.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectTokenPOST))))
|
||||
|
||||
policyCORSUserinfo := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowCredentials(true).
|
||||
WithAllowedMethods("OPTIONS", "GET", "POST").
|
||||
WithAllowedOrigins(allowedOrigins...).
|
||||
WithEnabled(utils.IsStringInSlice(oidc.UserinfoEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||
Build()
|
||||
|
||||
r.OPTIONS(oidc.UserinfoPath, policyCORSUserinfo.HandleOPTIONS)
|
||||
r.GET(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
|
||||
r.POST(oidc.UserinfoPath, policyCORSUserinfo.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OpenIDConnectUserinfo))))
|
||||
|
||||
policyCORSIntrospection := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowCredentials(true).
|
||||
WithAllowedMethods("OPTIONS", "POST").
|
||||
WithAllowedOrigins(allowedOrigins...).
|
||||
WithEnabled(utils.IsStringInSlice(oidc.IntrospectionEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||
Build()
|
||||
|
||||
r.OPTIONS(oidc.IntrospectionPath, policyCORSIntrospection.HandleOPTIONS)
|
||||
r.POST(oidc.IntrospectionPath, policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
|
||||
|
||||
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
|
||||
r.OPTIONS("/api/oidc/introspect", policyCORSIntrospection.HandleOPTIONS)
|
||||
r.POST("/api/oidc/introspect", policyCORSIntrospection.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthIntrospectionPOST))))
|
||||
|
||||
policyCORSRevocation := middlewares.NewCORSPolicyBuilder().
|
||||
WithAllowCredentials(true).
|
||||
WithAllowedMethods("OPTIONS", "POST").
|
||||
WithAllowedOrigins(allowedOrigins...).
|
||||
WithEnabled(utils.IsStringInSlice(oidc.RevocationEndpoint, configuration.IdentityProviders.OIDC.CORS.Endpoints)).
|
||||
Build()
|
||||
|
||||
r.OPTIONS(oidc.RevocationPath, policyCORSRevocation.HandleOPTIONS)
|
||||
r.POST(oidc.RevocationPath, policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
|
||||
|
||||
// TODO (james-d-elliott): Remove in GA. This is a legacy implementation of the above endpoint.
|
||||
r.OPTIONS("/api/oidc/revoke", policyCORSRevocation.HandleOPTIONS)
|
||||
r.POST("/api/oidc/revoke", policyCORSRevocation.Middleware(autheliaMiddleware(middlewares.NewHTTPToAutheliaHandlerAdaptor(handlers.OAuthRevocationPOST))))
|
||||
}
|
||||
|
||||
r.NotFound = handlerNotFound(autheliaMiddleware(serveIndexHandler))
|
||||
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.MethodNotAllowed = func(ctx *fasthttp.RequestCtx) {
|
||||
handlers.SetStatusCodeResponse(ctx, fasthttp.StatusMethodNotAllowed)
|
||||
}
|
||||
r.MethodNotAllowed = handlerMethodNotAllowed
|
||||
|
||||
handler := middlewares.LogRequestMiddleware(r.Handler)
|
||||
if configuration.Server.Path != "" {
|
||||
handler = middlewares.StripPathMiddleware(configuration.Server.Path, handler)
|
||||
}
|
||||
|
||||
if providers.OpenIDConnect.Fosite != nil {
|
||||
handlers.RegisterOIDC(r, autheliaMiddleware)
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
|
@ -185,12 +275,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
|||
handler := registerRoutes(configuration, providers)
|
||||
|
||||
server := &fasthttp.Server{
|
||||
ErrorHandler: autheliaErrorHandler,
|
||||
ErrorHandler: handlerErrors,
|
||||
Handler: handler,
|
||||
NoDefaultServerHeader: true,
|
||||
ReadBufferSize: configuration.Server.ReadBufferSize,
|
||||
WriteBufferSize: configuration.Server.WriteBufferSize,
|
||||
}
|
||||
|
||||
logger := logging.Logger()
|
||||
|
||||
address := net.JoinHostPort(configuration.Server.Host, strconv.Itoa(configuration.Server.Port))
|
||||
|
@ -204,9 +295,8 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
|||
|
||||
if configuration.Server.TLS.Certificate != "" && configuration.Server.TLS.Key != "" {
|
||||
connectionType, connectionScheme = "TLS", schemeHTTPS
|
||||
err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key)
|
||||
|
||||
if err != nil {
|
||||
if err = server.AppendCert(configuration.Server.TLS.Certificate, configuration.Server.TLS.Key); err != nil {
|
||||
logger.Fatalf("unable to load certificate: %v", err)
|
||||
}
|
||||
|
||||
|
@ -228,14 +318,13 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
|||
server.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone())
|
||||
if err != nil {
|
||||
if listener, err = tls.Listen("tcp", address, server.TLSConfig.Clone()); err != nil {
|
||||
logger.Fatalf("Error initializing listener: %s", err)
|
||||
}
|
||||
} else {
|
||||
connectionType, connectionScheme = "non-TLS", schemeHTTP
|
||||
listener, err = net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
|
||||
if listener, err = net.Listen("tcp", address); err != nil {
|
||||
logger.Fatalf("Error initializing listener: %s", err)
|
||||
}
|
||||
}
|
||||
|
@ -245,11 +334,10 @@ func CreateServer(configuration schema.Configuration, providers middlewares.Prov
|
|||
logger.Fatalf("Could not configure healthcheck: %v", err)
|
||||
}
|
||||
|
||||
actualAddress := listener.Addr().String()
|
||||
if configuration.Server.Path == "" {
|
||||
logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, actualAddress)
|
||||
logger.Infof("Initializing server for %s connections on '%s' path '/'", connectionType, listener.Addr().String())
|
||||
} else {
|
||||
logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, actualAddress, configuration.Server.Path)
|
||||
logger.Infof("Initializing server for %s connections on '%s' paths '/' and '%s'", connectionType, listener.Addr().String(), configuration.Server.Path)
|
||||
}
|
||||
|
||||
return server, listener
|
||||
|
|
|
@ -48,14 +48,14 @@ func ServeTemplatedFile(publicDir, file, assetPath, duoSelfEnrollment, rememberM
|
|||
}
|
||||
}
|
||||
|
||||
var scheme = "https"
|
||||
var scheme = schemeHTTPS
|
||||
|
||||
if !https {
|
||||
proto := string(ctx.XForwardedProto())
|
||||
switch proto {
|
||||
case "":
|
||||
break
|
||||
case "http", "https":
|
||||
case schemeHTTP, schemeHTTPS:
|
||||
scheme = proto
|
||||
}
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ func writeHealthCheckEnv(disabled bool, scheme, host, path string, port int) (er
|
|||
}()
|
||||
|
||||
if host == "0.0.0.0" {
|
||||
host = "localhost"
|
||||
host = localhost
|
||||
} else if strings.Contains(host, ":") {
|
||||
host = "[" + host + "]"
|
||||
}
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// IsStringAbsURL checks a string can be parsed as a URL and that is IsAbs and if it can't it returns an error
|
||||
|
@ -145,6 +147,52 @@ func IsStringSlicesDifferentFold(a, b []string) (different bool) {
|
|||
return isStringSlicesDifferent(a, b, IsStringInSliceFold)
|
||||
}
|
||||
|
||||
// IsURLInSlice returns true if the needle url.URL is in the []url.URL haystack.
|
||||
func IsURLInSlice(needle url.URL, haystack []url.URL) (has bool) {
|
||||
for i := 0; i < len(haystack); i++ {
|
||||
if strings.EqualFold(needle.String(), haystack[i].String()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// StringSliceFromURLs returns a []string from a []url.URL.
|
||||
func StringSliceFromURLs(urls []url.URL) []string {
|
||||
result := make([]string, len(urls))
|
||||
|
||||
for i := 0; i < len(urls); i++ {
|
||||
result[i] = urls[i].String()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// URLsFromStringSlice returns a []url.URL from a []string.
|
||||
func URLsFromStringSlice(urls []string) []url.URL {
|
||||
var result []url.URL
|
||||
|
||||
for i := 0; i < len(urls); i++ {
|
||||
u, err := url.Parse(urls[i])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, *u)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// OriginFromURL returns an origin url.URL given another url.URL.
|
||||
func OriginFromURL(u url.URL) (origin url.URL) {
|
||||
return url.URL{
|
||||
Scheme: u.Scheme,
|
||||
Host: u.Host,
|
||||
}
|
||||
}
|
||||
|
||||
// StringSlicesDelta takes a before and after []string and compares them returning a added and removed []string.
|
||||
func StringSlicesDelta(before, after []string) (added, removed []string) {
|
||||
for _, s := range before {
|
||||
|
@ -193,6 +241,19 @@ func StringHTMLEscape(input string) (output string) {
|
|||
return htmlEscaper.Replace(input)
|
||||
}
|
||||
|
||||
// JoinAndCanonicalizeHeaders join header strings by a given sep.
|
||||
func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) {
|
||||
for i, header := range headers {
|
||||
if i != 0 {
|
||||
joined = append(joined, sep...)
|
||||
}
|
||||
|
||||
joined = fasthttp.AppendNormalizedHeaderKey(joined, header)
|
||||
}
|
||||
|
||||
return joined
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -171,3 +172,48 @@ func TestIsStringSliceContainsAny(t *testing.T) {
|
|||
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
|
||||
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
|
||||
}
|
||||
|
||||
func TestStringSliceURLConversionFuncs(t *testing.T) {
|
||||
urls := URLsFromStringSlice([]string{"https://google.com", "abc", "%*()@#$J(@*#$J@#($H"})
|
||||
|
||||
require.Len(t, urls, 2)
|
||||
assert.Equal(t, "https://google.com", urls[0].String())
|
||||
assert.Equal(t, "abc", urls[1].String())
|
||||
|
||||
strs := StringSliceFromURLs(urls)
|
||||
|
||||
require.Len(t, strs, 2)
|
||||
assert.Equal(t, "https://google.com", strs[0])
|
||||
assert.Equal(t, "abc", strs[1])
|
||||
}
|
||||
|
||||
func TestIsURLInSlice(t *testing.T) {
|
||||
urls := URLsFromStringSlice([]string{"https://google.com", "https://example.com"})
|
||||
|
||||
google, err := url.Parse("https://google.com")
|
||||
assert.NoError(t, err)
|
||||
|
||||
microsoft, err := url.Parse("https://microsoft.com")
|
||||
assert.NoError(t, err)
|
||||
|
||||
example, err := url.Parse("https://example.com")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, IsURLInSlice(*google, urls))
|
||||
assert.False(t, IsURLInSlice(*microsoft, urls))
|
||||
assert.True(t, IsURLInSlice(*example, urls))
|
||||
}
|
||||
|
||||
func TestOriginFromURL(t *testing.T) {
|
||||
google, err := url.Parse("https://google.com/abc?a=123#five")
|
||||
assert.NoError(t, err)
|
||||
|
||||
origin := OriginFromURL(*google)
|
||||
assert.Equal(t, "https://google.com", origin.String())
|
||||
}
|
||||
|
||||
func TestJoinAndCanonicalizeHeaders(t *testing.T) {
|
||||
result := JoinAndCanonicalizeHeaders([]byte(", "), "x-example-ONE", "X-EGG-Two")
|
||||
|
||||
assert.Equal(t, []byte("X-Example-One, X-Egg-Two"), result)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user