diff --git a/internal/configuration/validator/identity_providers.go b/internal/configuration/validator/identity_providers.go index 85f4b9db..a51c820e 100644 --- a/internal/configuration/validator/identity_providers.go +++ b/internal/configuration/validator/identity_providers.go @@ -171,6 +171,10 @@ func validateOIDCClients(config *schema.OpenIDConnectConfiguration, validator *s func validateOIDCClientSectorIdentifier(client schema.OpenIDConnectClientConfiguration, validator *schema.StructValidator) { if client.SectorIdentifier.String() != "" { + if utils.IsURLHostComponent(client.SectorIdentifier) || utils.IsURLHostComponentWithPort(client.SectorIdentifier) { + return + } + if client.SectorIdentifier.Scheme != "" { validator.Push(fmt.Errorf(errFmtOIDCClientInvalidSectorIdentifier, client.ID, client.SectorIdentifier.String(), client.SectorIdentifier.Host, "scheme", client.SectorIdentifier.Scheme)) diff --git a/internal/configuration/validator/identity_providers_test.go b/internal/configuration/validator/identity_providers_test.go index 1b524ff2..13d138bb 100644 --- a/internal/configuration/validator/identity_providers_test.go +++ b/internal/configuration/validator/identity_providers_test.go @@ -245,6 +245,34 @@ func TestShouldRaiseErrorWhenOIDCServerClientBadValues(t *testing.T) { fmt.Sprintf(errFmtOIDCClientRedirectURIAbsolute, "client-check-uri-abs", "google.com"), }, }, + { + Name: "ValidSectorIdentifier", + Clients: []schema.OpenIDConnectClientConfiguration{ + { + ID: "client-valid-sector", + Secret: "a-secret", + Policy: policyTwoFactor, + RedirectURIs: []string{ + "https://google.com", + }, + SectorIdentifier: mustParseURL("example.com"), + }, + }, + }, + { + Name: "ValidSectorIdentifierWithPort", + Clients: []schema.OpenIDConnectClientConfiguration{ + { + ID: "client-valid-sector", + Secret: "a-secret", + Policy: policyTwoFactor, + RedirectURIs: []string{ + "https://google.com", + }, + SectorIdentifier: mustParseURL("example.com:2000"), + }, + }, + }, { Name: "InvalidSectorIdentifierInvalidURL", Clients: []schema.OpenIDConnectClientConfiguration{ diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 7fb6b8ee..d776eb7e 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "net/url" + "strconv" "strings" "time" "unicode" @@ -287,6 +288,26 @@ func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) { return joined } +// IsURLHostComponent returns true if the provided url.URL that was parsed from a string to a url.URL via url.Parse is +// just a hostname. This is needed because of the way this function parses such strings. +func IsURLHostComponent(u url.URL) (isHostComponent bool) { + return u.Path != "" && u.Scheme == "" && u.Host == "" && u.RawPath == "" && u.Opaque == "" && + u.RawQuery == "" && u.Fragment == "" && u.RawFragment == "" +} + +// IsURLHostComponentWithPort returns true if the provided url.URL that was parsed from a string to a url.URL via +// url.Parse is just a hostname with a port. This is needed because of the way this function parses such strings. +func IsURLHostComponentWithPort(u url.URL) (isHostComponentWithPort bool) { + if u.Opaque != "" && u.Scheme != "" && u.Host == "" && u.Path == "" && u.RawPath == "" && + u.RawQuery == "" && u.Fragment == "" && u.RawFragment == "" { + _, err := strconv.Atoi(u.Opaque) + + return err == nil + } + + return false +} + func init() { rand.Seed(time.Now().UnixNano()) } diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index 6b1b5e7c..5d8e5f01 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -262,3 +262,53 @@ func TestJoinAndCanonicalizeHeaders(t *testing.T) { assert.Equal(t, []byte("X-Example-One, X-Egg-Two"), result) } + +func TestIsURLHostComponent(t *testing.T) { + testCases := []struct { + desc, have string + expectedA, expectedB bool + }{ + { + desc: "ShouldBeFalseWithScheme", + have: "https://google.com", + expectedA: false, expectedB: false, + }, + { + desc: "ShouldBeTrueForHostComponentButFalseForWithPort", + have: "google.com", + expectedA: true, expectedB: false, + }, + { + desc: "ShouldBeFalseForHostComponentButTrueForWithPort", + have: "google.com:8000", + expectedA: false, expectedB: true, + }, + { + desc: "ShouldBeFalseWithPath", + have: "google.com:8000/path", + expectedA: false, expectedB: false, + }, + { + desc: "ShouldBeFalseWithFragment", + have: "google.com:8000#test", + expectedA: false, expectedB: false, + }, + { + desc: "ShouldBeFalseWithQuery", + have: "google.com:8000?test=1", + expectedA: false, expectedB: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + u, err := url.Parse(tc.have) + + require.NoError(t, err) + require.NotNil(t, u) + + assert.Equal(t, tc.expectedA, IsURLHostComponent(*u)) + assert.Equal(t, tc.expectedB, IsURLHostComponentWithPort(*u)) + }) + } +}