From 0a970aef8a2494f24ca8c340e91189a415d43bf2 Mon Sep 17 00:00:00 2001 From: James Elliott Date: Thu, 7 Apr 2022 15:33:53 +1000 Subject: [PATCH] feat(oidc): persistent storage (#2965) This moves the OpenID Connect storage from memory into the SQL storage, making it persistent and allowing it to be used with clustered deployments like the rest of Authelia. --- docs/configuration/identity-providers/oidc.md | 35 +- docs/configuration/storage/migrations.md | 1 + docs/roadmap/oidc.md | 22 +- internal/commands/helpers.go | 2 +- .../schema/identity_providers.go | 4 +- internal/handlers/handler_firstfactor.go | 4 +- .../handlers/handler_oidc_authorization.go | 70 +-- .../handler_oidc_authorization_consent.go | 168 +++++++ internal/handlers/handler_oidc_consent.go | 145 +++--- internal/handlers/handler_oidc_userinfo.go | 7 +- internal/handlers/handler_sign_duo.go | 2 +- internal/handlers/handler_sign_totp.go | 2 +- internal/handlers/handler_sign_webauthn.go | 2 +- internal/handlers/oidc.go | 22 +- internal/handlers/oidc_test.go | 52 +-- internal/handlers/response.go | 46 +- internal/handlers/types_oidc.go | 12 - internal/mocks/storage.go | 275 +++++++++++ internal/model/oidc.go | 236 +++++++++- internal/model/types.go | 26 ++ internal/model/types_test.go | 10 + internal/model/user_opaque_identifier.go | 33 ++ internal/oidc/client.go | 41 +- internal/oidc/client_test.go | 26 +- internal/oidc/discovery.go | 79 ++++ internal/oidc/discovery_test.go | 65 +++ internal/oidc/hasher.go | 4 +- internal/oidc/hasher_test.go | 6 +- internal/oidc/provider.go | 113 ++--- internal/oidc/provider_test.go | 18 +- internal/oidc/store.go | 334 +++++++++----- internal/oidc/store_test.go | 18 +- internal/oidc/types.go | 94 ++-- internal/oidc/types_test.go | 19 +- internal/session/types.go | 6 +- internal/storage/const.go | 40 +- .../V0004.OpenIDConenct.mysql.up.sql | 188 ++++++++ .../V0004.OpenIDConenct.postgres.up.sql | 188 ++++++++ .../V0004.OpenIDConenct.sqlite.up.sql | 188 ++++++++ .../V0004.OpenIDConnect.all.down.sql | 8 + internal/storage/provider.go | 25 + internal/storage/sql_provider.go | 427 ++++++++++++++++++ .../storage/sql_provider_backend_postgres.go | 66 ++- internal/storage/sql_provider_queries.go | 104 ++++- internal/storage/sql_rows.go | 47 ++ internal/suites/OIDC/configuration.yml | 1 + internal/suites/OIDCTraefik/configuration.yml | 1 + internal/suites/action_2fa_methods.go | 14 +- internal/suites/action_login.go | 8 +- internal/suites/action_reset_password.go | 12 +- internal/suites/action_totp.go | 4 +- .../compose/oidc-client/docker-compose.yml | 2 +- .../example/compose/oidc-client/entrypoint.sh | 2 +- .../suites/scenario_available_methods_test.go | 4 +- internal/suites/scenario_oidc_test.go | 55 ++- internal/suites/scenario_regulation_test.go | 12 +- .../suites/scenario_user_preferences_test.go | 10 +- internal/suites/suite_duo_push_test.go | 12 +- .../suites/verify_is_authenticated_page.go | 2 +- internal/suites/verify_is_consent_page.go | 2 +- .../suites/verify_is_first_factor_page.go | 2 +- internal/suites/verify_is_oidc.go | 24 + .../suites/verify_is_second_factor_page.go | 2 +- internal/suites/verify_secret_authorized.go | 2 +- internal/suites/webdriver.go | 8 +- internal/utils/strings.go | 33 ++ internal/utils/strings_test.go | 45 ++ 67 files changed, 2946 insertions(+), 591 deletions(-) create mode 100644 internal/handlers/handler_oidc_authorization_consent.go delete mode 100644 internal/handlers/types_oidc.go create mode 100644 internal/model/user_opaque_identifier.go create mode 100644 internal/oidc/discovery.go create mode 100644 internal/oidc/discovery_test.go create mode 100644 internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql create mode 100644 internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql create mode 100644 internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql create mode 100644 internal/storage/migrations/V0004.OpenIDConnect.all.down.sql create mode 100644 internal/storage/sql_rows.go diff --git a/docs/configuration/identity-providers/oidc.md b/docs/configuration/identity-providers/oidc.md index f2ffd276..32f825af 100644 --- a/docs/configuration/identity-providers/oidc.md +++ b/docs/configuration/identity-providers/oidc.md @@ -232,7 +232,7 @@ Allows PKCE `plain` challenges when set to `true`. Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows you to configure the optional parts. We reply with CORS headers when the request includes the Origin header. -##### endpoints +#### endpoints
type: list(string) {: .label .label-config .label-purple } @@ -522,8 +522,8 @@ individual user. Please use the claim `preferred_username` instead._ This scope includes the groups the authentication backend reports the user is a member of in the token. -| Claim | JWT Type | Authelia Attribute | Description | -|:------:|:-------------:|:------------------:|:----------------------:| +| Claim | JWT Type | Authelia Attribute | Description | +|:------:|:-------------:|:------------------:|:------------------------------------------------------------------------------------------------------------------:| | groups | array[string] | groups | List of user's groups discovered via [authentication](https://www.authelia.com/docs/configuration/authentication/) | ### email @@ -585,30 +585,33 @@ an example of the Authelia root URL which is also the OpenID Connect issuer. These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP. -| Endpoint | Path | -|:-------------:|:---------------------------------------------------------------:| -| Discovery | https://auth.example.com/.well-known/openid-configuration | -| Metadata | https://auth.example.com/.well-known/oauth-authorization-server | +| Endpoint | Path | +|:-----------------------------------------:|:---------------------------------------------------------------:| +| [OpenID Connect Discovery] | https://auth.example.com/.well-known/openid-configuration | +| [OAuth 2.0 Authorization Server Metadata] | https://auth.example.com/.well-known/oauth-authorization-server | ### Discoverable Endpoints These endpoints implement OpenID Connect elements. -| Endpoint | Path | Discovery Attribute | -|:---------------:|:-----------------------------------------------:|:----------------------:| -| JWKS | https://auth.example.com/jwks.json | jwks_uri | -| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint | -| [Token] | https://auth.example.com/api/oidc/token | token_endpoint | -| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint | -| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint | -| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint | +| Endpoint | Path | Discovery Attribute | +|:-------------------:|:-----------------------------------------------:|:----------------------:| +| [JSON Web Key Sets] | https://auth.example.com/jwks.json | jwks_uri | +| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint | +| [Token] | https://auth.example.com/api/oidc/token | token_endpoint | +| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint | +| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint | +| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint | +[JSON Web Key Sets]: https://datatracker.ietf.org/doc/html/rfc7517#section-5 [OpenID Connect]: https://openid.net/connect/ -[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration +[OpenID Connect Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html +[OAuth 2.0 Authorization Server Metadata]: https://datatracker.ietf.org/doc/html/rfc8414 [Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint [Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint [Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo [Introspection]: https://datatracker.ietf.org/doc/html/rfc7662 [Revocation]: https://datatracker.ietf.org/doc/html/rfc7009 [RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176 +[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration diff --git a/docs/configuration/storage/migrations.md b/docs/configuration/storage/migrations.md index a66c6e30..4d9787f8 100644 --- a/docs/configuration/storage/migrations.md +++ b/docs/configuration/storage/migrations.md @@ -24,3 +24,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel | 1 | 4.33.0 | Initial migration managed version | | 2 | 4.34.0 | Webauthn - added webauthn_devices table, altered totp_config to include device created/used dates | | 3 | 4.34.2 | Webauthn - fix V2 migration kid column length and provide migration path for anyone on V2 | +| 4 | 4.35.0 | Added OpenID Connect storage tables and opaque user identifier tables | diff --git a/docs/roadmap/oidc.md b/docs/roadmap/oidc.md index 26d2d936..910517a8 100644 --- a/docs/roadmap/oidc.md +++ b/docs/roadmap/oidc.md @@ -74,7 +74,7 @@ for which stage will have each feature, and may evolve over time: Proof Key for Code Exchange (PKCE) for Authorization Code Flow - beta4 1 + beta5 (4.35.0) Token Storage @@ -83,6 +83,21 @@ for which stage will have each feature, and may evolve over time: Subject Storage + + Pairwise Subject Identifier Type + + + Per-Client Consent Pre-Configuration + + + Cross-Origin Resource Sharing Configuration + + + Authentication Methods References Claim + + + UUID v4 sub claim + beta5 1 Prompt Handling @@ -91,7 +106,7 @@ for which stage will have each feature, and may evolve over time: Display Handling - beta6 1 + beta6 1 Back-Channel Logout @@ -103,9 +118,6 @@ for which stage will have each feature, and may evolve over time: Client Secrets Hashed in Configuration - - UUID or Random String for sub claim - GA 1 General Availability after previous stages are vetted for bug fixes diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go index ee09f743..d5c3990b 100644 --- a/internal/commands/helpers.go +++ b/internal/commands/helpers.go @@ -64,7 +64,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [ sessionProvider := session.NewProvider(config.Session, autheliaCertPool) regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock) - oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC) + oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC, storageProvider) if err != nil { errors = append(errors, err) } diff --git a/internal/configuration/schema/identity_providers.go b/internal/configuration/schema/identity_providers.go index 3828172c..bb2c35bd 100644 --- a/internal/configuration/schema/identity_providers.go +++ b/internal/configuration/schema/identity_providers.go @@ -12,7 +12,6 @@ type IdentityProvidersConfiguration struct { // OpenIDConnectConfiguration configuration for OpenID Connect. type OpenIDConnectConfiguration struct { - // This secret must be 32 bytes long. HMACSecret string `koanf:"hmac_secret"` IssuerPrivateKey string `koanf:"issuer_private_key"` @@ -49,9 +48,10 @@ type OpenIDConnectClientConfiguration struct { Policy string `koanf:"authorization_policy"` + RedirectURIs []string `koanf:"redirect_uris"` + Audience []string `koanf:"audience"` Scopes []string `koanf:"scopes"` - RedirectURIs []string `koanf:"redirect_uris"` GrantTypes []string `koanf:"grant_types"` ResponseTypes []string `koanf:"response_types"` ResponseModes []string `koanf:"response_modes"` diff --git a/internal/handlers/handler_firstfactor.go b/internal/handlers/handler_firstfactor.go index 1a48b8cb..ca5b0569 100644 --- a/internal/handlers/handler_firstfactor.go +++ b/internal/handlers/handler_firstfactor.go @@ -73,7 +73,7 @@ func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re userSession := ctx.GetSession() newSession := session.NewDefaultUserSession() - newSession.OIDCWorkflowSession = userSession.OIDCWorkflowSession + newSession.ConsentChallengeID = userSession.ConsentChallengeID // Reset all values from previous session except OIDC workflow before regenerating the cookie. if err = ctx.SaveSession(newSession); err != nil { @@ -135,7 +135,7 @@ func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re successful = true - if userSession.OIDCWorkflowSession != nil { + if userSession.ConsentChallengeID != nil { handleOIDCWorkflowResponse(ctx) } else { Handle1FAResponse(ctx, bodyJSON.TargetURL, bodyJSON.RequestMethod, userSession.Username, userSession.Groups) diff --git a/internal/handlers/handler_oidc_authorization.go b/internal/handlers/handler_oidc_authorization.go index c5940410..c1d34c95 100644 --- a/internal/handlers/handler_oidc_authorization.go +++ b/internal/handlers/handler_oidc_authorization.go @@ -2,18 +2,15 @@ package handlers import ( "errors" - "fmt" "net/http" - "strings" "time" + "github.com/google/uuid" "github.com/ory/fosite" - "github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/oidc" - "github.com/authelia/authelia/v4/internal/session" ) // OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint. @@ -23,7 +20,7 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons var ( requester fosite.AuthorizeRequester responder fosite.AuthorizeResponder - client *oidc.InternalClient + client *oidc.Client authTime time.Time issuer string err error @@ -43,7 +40,7 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' is being processed", requester.GetID(), clientID) - if client, err = ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID); err != nil { + if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(clientID); err != nil { if errors.Is(err, fosite.ErrNotFound) { ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: client was not found", requester.GetID(), clientID) } else { @@ -65,31 +62,27 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons userSession := ctx.GetSession() - requestedScopes := requester.GetRequestedScopes() - requestedAudience := requester.GetRequestedAudience() + var subject uuid.UUID - isAuthInsufficient := !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) + if subject, err = ctx.Providers.OpenIDConnect.Store.GetSubject(ctx, client.GetSectorIdentifier(), userSession.Username); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred retrieving subject for user '%s': %+v", requester.GetID(), client.GetID(), userSession.Username, err) - if isAuthInsufficient || (isConsentMissing(userSession.OIDCWorkflowSession, requestedScopes, requestedAudience)) { - oidcAuthorizeHandleAuthorizationOrConsentInsufficient(ctx, userSession, client, isAuthInsufficient, rw, r, requester, issuer) + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not retrieve the subject.")) return } - extraClaims := oidcGrantRequests(requester, requestedScopes, requestedAudience, &userSession) - - workflowCreated := time.Unix(userSession.OIDCWorkflowSession.CreatedTimestamp, 0) - - userSession.OIDCWorkflowSession = nil - - if err = ctx.SaveSession(userSession); err != nil { - ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session: %+v", requester.GetID(), client.GetID(), err) - - ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session.")) + var ( + consent *model.OAuth2ConsentSession + handled bool + ) + if consent, handled = handleOIDCAuthorizationConsent(ctx, issuer, client, userSession, subject, rw, r, requester); handled { return } + extraClaims := oidcGrantRequests(requester, consent, &userSession) + if authTime, err = userSession.AuthenticatedTime(client.Policy); err != nil { ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred checking authentication time: %+v", requester.GetID(), client.GetID(), err) @@ -100,9 +93,8 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID) - subject := userSession.Username oidcSession := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(), - subject, userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, workflowCreated, requester) + userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, consent, requester) ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v", requester.GetID(), oidcSession.ClientID, oidcSession.Subject, oidcSession.Username, oidcSession.Claims) @@ -119,39 +111,13 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons return } - ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeResponse(rw, requester, responder) -} - -func oidcAuthorizeHandleAuthorizationOrConsentInsufficient( - ctx *middlewares.AutheliaCtx, userSession session.UserSession, client *oidc.InternalClient, isAuthInsufficient bool, - rw http.ResponseWriter, r *http.Request, - requester fosite.AuthorizeRequester, issuer string) { - redirectURL := fmt.Sprintf("%s%s", issuer, string(ctx.Request.RequestURI())) - - ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' requires user '%s' provides consent for scopes '%s'", - requester.GetID(), client.GetID(), userSession.Username, strings.Join(requester.GetRequestedScopes(), "', '")) - - userSession.OIDCWorkflowSession = &model.OIDCWorkflowSession{ - ClientID: client.GetID(), - RequestedScopes: requester.GetRequestedScopes(), - RequestedAudience: requester.GetRequestedAudience(), - AuthURI: redirectURL, - TargetURI: requester.GetRedirectURI().String(), - Require2FA: client.Policy == authorization.TwoFactor, - CreatedTimestamp: time.Now().Unix(), - } - - if err := ctx.SaveSession(userSession); err != nil { - ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session for consent: %+v", requester.GetID(), client.GetID(), err) + if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSessionGranted(ctx, consent.ID); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving consent session: %+v", requester.GetID(), client.GetID(), err) ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session.")) return } - if isAuthInsufficient { - http.Redirect(rw, r, issuer, http.StatusFound) - } else { - http.Redirect(rw, r, fmt.Sprintf("%s/consent", issuer), http.StatusFound) - } + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeResponse(rw, requester, responder) } diff --git a/internal/handlers/handler_oidc_authorization_consent.go b/internal/handlers/handler_oidc_authorization_consent.go new file mode 100644 index 00000000..7c94f8ea --- /dev/null +++ b/internal/handlers/handler_oidc_authorization_consent.go @@ -0,0 +1,168 @@ +package handlers + +import ( + "fmt" + "net/http" + "strings" + + "github.com/google/uuid" + "github.com/ory/fosite" + + "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/session" + "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/utils" +) + +func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client, + userSession session.UserSession, subject uuid.UUID, + rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) { + if userSession.ConsentChallengeID != nil { + return handleOIDCAuthorizationConsentWithChallengeID(ctx, rootURI, client, userSession, rw, r, requester) + } + + return handleOIDCAuthorizationConsentOrGenerate(ctx, rootURI, client, userSession, subject, rw, r, requester) +} + +func handleOIDCAuthorizationConsentWithChallengeID(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client, + userSession session.UserSession, + rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) { + var ( + err error + ) + + if consent, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred during consent session lookup: %+v", requester.GetID(), requester.GetClient().GetID(), err) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Failed to lookup consent session.")) + + userSession.ConsentChallengeID = nil + + if err = ctx.SaveSession(userSession); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred unlinking consent session challenge id: %+v", requester.GetID(), requester.GetClient().GetID(), err) + } + + return nil, true + } + + if consent.Responded() { + userSession.ConsentChallengeID = nil + + if err = ctx.SaveSession(userSession); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session: %+v", requester.GetID(), client.GetID(), err) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session.")) + + return nil, true + } + + if consent.Granted { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: this consent session with challenge id '%s' was already granted", requester.GetID(), client.GetID(), consent.ChallengeID.String()) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Authorization already granted.")) + + return nil, true + } + + ctx.Logger.Debugf("Authorization Request with id '%s' loaded consent session with id '%d' and challenge id '%s' for client id '%s' and subject '%s' and scopes '%s'", requester.GetID(), consent.ID, consent.ChallengeID.String(), client.GetID(), consent.Subject.String(), strings.Join(requester.GetRequestedScopes(), " ")) + + if consent.IsDenied() { + ctx.Logger.Warnf("Authorization Request with id '%s' and challenge id '%s' for client id '%s' and subject '%s' and scopes '%s' was not denied by the user durng the consent session", requester.GetID(), consent.ChallengeID.String(), client.GetID(), consent.Subject.String(), strings.Join(requester.GetRequestedScopes(), " ")) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrAccessDenied) + + return nil, true + } + + return consent, false + } + + handleOIDCAuthorizationConsentRedirect(rootURI, client, userSession, rw, r) + + return consent, true +} + +func handleOIDCAuthorizationConsentOrGenerate(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client, + userSession session.UserSession, subject uuid.UUID, + rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) { + var ( + rows *storage.ConsentSessionRows + scopes, audience []string + err error + ) + + if rows, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionsPreConfigured(ctx, client.GetID(), subject); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' had error looking up pre-configured consent sessions: %+v", requester.GetID(), requester.GetClient().GetID(), err) + } + + defer rows.Close() + + for rows.Next() { + if consent, err = rows.Get(); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' had error looking up pre-configured consent sessions: %+v", requester.GetID(), requester.GetClient().GetID(), err) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not lookup pre-configured consent sessions.")) + + return nil, true + } + + scopes, audience = getExpectedScopesAndAudience(requester) + + if consent.HasExactGrants(scopes, audience) && consent.CanGrant() { + break + } + } + + if consent != nil && consent.HasExactGrants(scopes, audience) && consent.CanGrant() { + return consent, false + } + + if consent, err = model.NewOAuth2ConsentSession(subject, requester); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred generating consent: %+v", requester.GetID(), requester.GetClient().GetID(), err) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not generate the consent session.")) + + return nil, true + } + + if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSession(ctx, *consent); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving consent session: %+v", requester.GetID(), client.GetID(), err) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the consent session.")) + + return nil, true + } + + userSession.ConsentChallengeID = &consent.ChallengeID + + if err = ctx.SaveSession(userSession); err != nil { + ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving user session for consent: %+v", requester.GetID(), client.GetID(), err) + + ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the user session.")) + + return nil, true + } + + handleOIDCAuthorizationConsentRedirect(rootURI, client, userSession, rw, r) + + return consent, true +} + +func handleOIDCAuthorizationConsentRedirect(destination string, client *oidc.Client, userSession session.UserSession, rw http.ResponseWriter, r *http.Request) { + if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { + destination = fmt.Sprintf("%s/consent", destination) + } + + http.Redirect(rw, r, destination, http.StatusFound) +} + +func getExpectedScopesAndAudience(requester fosite.Requester) (scopes, audience []string) { + audience = requester.GetRequestedAudience() + if !utils.IsStringInSlice(requester.GetClient().GetID(), audience) { + audience = append(audience, requester.GetClient().GetID()) + } + + return requester.GetRequestedScopes(), audience +} diff --git a/internal/handlers/handler_oidc_consent.go b/internal/handlers/handler_oidc_consent.go index b403ecb0..af41dfb5 100644 --- a/internal/handlers/handler_oidc_consent.go +++ b/internal/handlers/handler_oidc_consent.go @@ -5,116 +5,135 @@ import ( "fmt" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/oidc" + "github.com/authelia/authelia/v4/internal/session" + "github.com/authelia/authelia/v4/internal/utils" ) // OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect. func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) { - userSession := ctx.GetSession() - - if userSession.OIDCWorkflowSession == nil { - ctx.Logger.Debugf("Cannot consent for user %s when OIDC workflow has not been initiated", userSession.Username) - ctx.ReplyForbidden() - - return - } - - clientID := userSession.OIDCWorkflowSession.ClientID - client, err := ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID) - - if err != nil { - ctx.Logger.Debugf("Unable to find related client configuration with name '%s': %v", clientID, err) - ctx.ReplyForbidden() - + userSession, consent, client, handled := oidcConsentGetSessionsAndClient(ctx) + if handled { return } if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { - ctx.Logger.Debugf("Insufficient permissions to give consent during GET current level: %d, require 2FA: %t", userSession.AuthenticationLevel, userSession.OIDCWorkflowSession.Require2FA) + ctx.Logger.Errorf("Unable to perform consent without sufficient authentication for user '%s' and client id '%s'", userSession.Username, consent.ClientID) ctx.ReplyForbidden() return } - if err := ctx.SetJSONBody(client.GetConsentResponseBody(userSession.OIDCWorkflowSession)); err != nil { + if err := ctx.SetJSONBody(client.GetConsentResponseBody(consent)); err != nil { ctx.Error(fmt.Errorf("unable to set JSON body: %v", err), "Operation failed") } } // OpenIDConnectConsentPOST handles consent responses for OpenID Connect. func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) { - userSession := ctx.GetSession() + var ( + body oidc.ConsentPostRequestBody + err error + ) - if userSession.OIDCWorkflowSession == nil { - ctx.Logger.Debugf("Cannot consent for user %s when OIDC workflow has not been initiated", userSession.Username) - ctx.ReplyForbidden() + if err = json.Unmarshal(ctx.Request.Body(), &body); err != nil { + ctx.Logger.Errorf("Failed to parse JSON body in consent POST: %+v", err) + ctx.SetJSONError(messageOperationFailed) return } - client, err := ctx.Providers.OpenIDConnect.Store.GetInternalClient(userSession.OIDCWorkflowSession.ClientID) - - if err != nil { - ctx.Logger.Debugf("Unable to find related client configuration with name '%s': %v", userSession.OIDCWorkflowSession.ClientID, err) - ctx.ReplyForbidden() - + userSession, consent, client, handled := oidcConsentGetSessionsAndClient(ctx) + if handled { return } if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { - ctx.Logger.Debugf("Insufficient permissions to give consent during POST current level: %d, require 2FA: %t", userSession.AuthenticationLevel, userSession.OIDCWorkflowSession.Require2FA) + ctx.Logger.Debugf("Insufficient permissions to give consent during POST current level: %d, require 2FA: %d", userSession.AuthenticationLevel, client.Policy) ctx.ReplyForbidden() return } - var body ConsentPostRequestBody - err = json.Unmarshal(ctx.Request.Body(), &body) + if consent.ClientID != body.ClientID { + ctx.Logger.Errorf("User '%s' consented to scopes of another client (%s) than expected (%s). Beware this can be a sign of attack", + userSession.Username, body.ClientID, consent.ClientID) + ctx.SetJSONError(messageOperationFailed) - if err != nil { - ctx.Error(fmt.Errorf("unable to unmarshal body: %v", err), "Operation failed") return } - if body.AcceptOrReject != accept && body.AcceptOrReject != reject { - ctx.Logger.Infof("User %s tried to reply to consent with an unexpected verb", userSession.Username) + var ( + externalRootURL string + authorized = true + ) + + switch body.AcceptOrReject { + case accept: + if externalRootURL, err = ctx.ExternalRootURL(); err != nil { + ctx.Logger.Errorf("Could not determine the external URL during consent session processing with challenge id '%s' for user '%s': %v", consent.ChallengeID.String(), userSession.Username, err) + ctx.SetJSONError(messageOperationFailed) + + return + } + + consent.GrantedScopes = consent.RequestedScopes + consent.GrantedAudience = consent.RequestedAudience + + if !utils.IsStringInSlice(consent.ClientID, consent.GrantedAudience) { + consent.GrantedAudience = append(consent.GrantedAudience, consent.ClientID) + } + case reject: + authorized = false + default: + ctx.Logger.Warnf("User '%s' tried to reply to consent with an unexpected verb", userSession.Username) ctx.ReplyBadRequest() return } - if userSession.OIDCWorkflowSession.ClientID != body.ClientID { - ctx.Logger.Infof("User %s consented to scopes of another client (%s) than expected (%s). Beware this can be a sign of attack", - userSession.Username, body.ClientID, userSession.OIDCWorkflowSession.ClientID) - ctx.ReplyBadRequest() + if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSessionResponse(ctx, *consent, authorized); err != nil { + ctx.Logger.Errorf("Failed to save the consent session response to the database: %+v", err) + ctx.SetJSONError(messageOperationFailed) return } - var redirectionURL string + response := oidc.ConsentPostResponseBody{RedirectURI: fmt.Sprintf("%s%s?%s", externalRootURL, oidc.AuthorizationPath, consent.Form)} - if body.AcceptOrReject == accept { - redirectionURL = userSession.OIDCWorkflowSession.AuthURI - userSession.OIDCWorkflowSession.GrantedScopes = userSession.OIDCWorkflowSession.RequestedScopes - userSession.OIDCWorkflowSession.GrantedAudience = userSession.OIDCWorkflowSession.RequestedAudience - - if err := ctx.SaveSession(userSession); err != nil { - ctx.Error(fmt.Errorf("unable to write session: %v", err), "Operation failed") - return - } - } else if body.AcceptOrReject == reject { - redirectionURL = fmt.Sprintf("%s?error=access_denied&error_description=%s", - userSession.OIDCWorkflowSession.TargetURI, "User has rejected the scopes") - userSession.OIDCWorkflowSession = nil - - if err := ctx.SaveSession(userSession); err != nil { - ctx.Error(fmt.Errorf("unable to write session: %v", err), "Operation failed") - return - } - } - - response := ConsentPostResponseBody{RedirectURI: redirectionURL} - - if err := ctx.SetJSONBody(response); err != nil { + if err = ctx.SetJSONBody(response); err != nil { ctx.Error(fmt.Errorf("unable to set JSON body in response"), "Operation failed") } } + +func oidcConsentGetSessionsAndClient(ctx *middlewares.AutheliaCtx) (userSession session.UserSession, consent *model.OAuth2ConsentSession, client *oidc.Client, handled bool) { + var ( + err error + ) + + userSession = ctx.GetSession() + + if userSession.ConsentChallengeID == nil { + ctx.Logger.Errorf("Cannot consent for user '%s' when OIDC consent session has not been initiated", userSession.Username) + ctx.ReplyForbidden() + + return userSession, nil, nil, true + } + + if consent, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID); err != nil { + ctx.Logger.Errorf("Unable to load consent session with challenge id '%s': %v", userSession.ConsentChallengeID.String(), err) + ctx.ReplyForbidden() + + return userSession, nil, nil, true + } + + if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(consent.ClientID); err != nil { + ctx.Logger.Errorf("Unable to find related client configuration with name '%s': %v", consent.ClientID, err) + ctx.ReplyForbidden() + + return userSession, nil, nil, true + } + + return userSession, consent, client, false +} diff --git a/internal/handlers/handler_oidc_userinfo.go b/internal/handlers/handler_oidc_userinfo.go index 1a46ec39..0e4c5914 100644 --- a/internal/handlers/handler_oidc_userinfo.go +++ b/internal/handlers/handler_oidc_userinfo.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/oidc" ) @@ -21,7 +22,7 @@ func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, var ( tokenType fosite.TokenType requester fosite.AccessRequester - client *oidc.InternalClient + client *oidc.Client err error ) @@ -54,13 +55,13 @@ func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter, return } - if client, err = ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID); err != nil { + if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(clientID); err != nil { ctx.Providers.OpenIDConnect.WriteError(rw, req, errors.WithStack(fosite.ErrServerError.WithHint("Unable to assert type of client"))) return } - claims := requester.GetSession().(*oidc.OpenIDSession).IDTokenClaims().ToMap() + claims := requester.GetSession().(*model.OpenIDSession).IDTokenClaims().ToMap() delete(claims, "jti") delete(claims, "sid") delete(claims, "at_hash") diff --git a/internal/handlers/handler_sign_duo.go b/internal/handlers/handler_sign_duo.go index 48aebbc6..880ffdb2 100644 --- a/internal/handlers/handler_sign_duo.go +++ b/internal/handlers/handler_sign_duo.go @@ -266,7 +266,7 @@ func HandleAllow(ctx *middlewares.AutheliaCtx, targetURL string) { return } - if userSession.OIDCWorkflowSession != nil { + if userSession.ConsentChallengeID != nil { handleOIDCWorkflowResponse(ctx) } else { Handle2FAResponse(ctx, targetURL) diff --git a/internal/handlers/handler_sign_totp.go b/internal/handlers/handler_sign_totp.go index f21c2d21..348925b8 100644 --- a/internal/handlers/handler_sign_totp.go +++ b/internal/handlers/handler_sign_totp.go @@ -78,7 +78,7 @@ func SecondFactorTOTPPost(ctx *middlewares.AutheliaCtx) { return } - if userSession.OIDCWorkflowSession != nil { + if userSession.ConsentChallengeID != nil { handleOIDCWorkflowResponse(ctx) } else { Handle2FAResponse(ctx, requestBody.TargetURL) diff --git a/internal/handlers/handler_sign_webauthn.go b/internal/handlers/handler_sign_webauthn.go index 53a8ba9c..5564851a 100644 --- a/internal/handlers/handler_sign_webauthn.go +++ b/internal/handlers/handler_sign_webauthn.go @@ -197,7 +197,7 @@ func SecondFactorWebauthnAssertionPOST(ctx *middlewares.AutheliaCtx) { return } - if userSession.OIDCWorkflowSession != nil { + if userSession.ConsentChallengeID != nil { handleOIDCWorkflowResponse(ctx) } else { Handle2FAResponse(ctx, requestBody.TargetURL) diff --git a/internal/handlers/oidc.go b/internal/handlers/oidc.go index 3244ed04..20c51d2c 100644 --- a/internal/handlers/oidc.go +++ b/internal/handlers/oidc.go @@ -6,24 +6,12 @@ import ( "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/session" - "github.com/authelia/authelia/v4/internal/utils" ) -// isConsentMissing compares the requestedScopes and requestedAudience to the workflows -// GrantedScopes and GrantedAudience and returns true if they do not match or the workflow is nil. -func isConsentMissing(workflow *model.OIDCWorkflowSession, requestedScopes, requestedAudience []string) (isMissing bool) { - if workflow == nil { - return true - } - - return len(requestedScopes) > 0 && utils.IsStringSlicesDifferent(requestedScopes, workflow.GrantedScopes) || - len(requestedAudience) > 0 && utils.IsStringSlicesDifferentFold(requestedAudience, workflow.GrantedAudience) -} - -func oidcGrantRequests(ar fosite.AuthorizeRequester, scopes, audiences []string, userSession *session.UserSession) (extraClaims map[string]interface{}) { +func oidcGrantRequests(ar fosite.AuthorizeRequester, consent *model.OAuth2ConsentSession, userSession *session.UserSession) (extraClaims map[string]interface{}) { extraClaims = map[string]interface{}{} - for _, scope := range scopes { + for _, scope := range consent.GrantedScopes { if ar != nil { ar.GrantScope(scope) } @@ -47,13 +35,9 @@ func oidcGrantRequests(ar fosite.AuthorizeRequester, scopes, audiences []string, } if ar != nil { - for _, audience := range audiences { + for _, audience := range consent.GrantedAudience { ar.GrantAudience(audience) } - - if !utils.IsStringInSlice(ar.GetClient().GetID(), ar.GetGrantedAudience()) { - ar.GrantAudience(ar.GetClient().GetID()) - } } return extraClaims diff --git a/internal/handlers/oidc_test.go b/internal/handlers/oidc_test.go index 8a5b21b0..95ff6e34 100644 --- a/internal/handlers/oidc_test.go +++ b/internal/handlers/oidc_test.go @@ -11,32 +11,12 @@ import ( "github.com/authelia/authelia/v4/internal/session" ) -func TestShouldDetectIfConsentIsMissing(t *testing.T) { - var workflow *model.OIDCWorkflowSession - - requestedScopes := []string{"openid", "profile"} - requestedAudience := []string{"https://authelia.com"} - - assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience)) - - workflow = &model.OIDCWorkflowSession{ - GrantedScopes: []string{"openid", "profile"}, - GrantedAudience: []string{"https://authelia.com"}, +func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) { + consent := &model.OAuth2ConsentSession{ + GrantedScopes: []string{oidc.ScopeProfile}, } - assert.False(t, isConsentMissing(workflow, requestedScopes, requestedAudience)) - - requestedScopes = []string{"openid", "profile", "group"} - - assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience)) - - requestedScopes = []string{"openid", "profile"} - requestedAudience = []string{"https://not.authelia.com"} - assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience)) -} - -func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) { - extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeProfile}, []string{}, &oidcUserSessionJohn) + extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn) assert.Len(t, extraClaims, 2) @@ -48,7 +28,11 @@ func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) { } func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) { - extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeGroups}, []string{}, &oidcUserSessionJohn) + consent := &model.OAuth2ConsentSession{ + GrantedScopes: []string{oidc.ScopeGroups}, + } + + extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn) assert.Len(t, extraClaims, 1) @@ -57,7 +41,7 @@ func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) { assert.Contains(t, extraClaims[oidc.ClaimGroups], "admin") assert.Contains(t, extraClaims[oidc.ClaimGroups], "dev") - extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeGroups}, []string{}, &oidcUserSessionFred) + extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred) assert.Len(t, extraClaims, 1) @@ -67,7 +51,11 @@ func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) { } func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) { - extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeEmail}, []string{}, &oidcUserSessionJohn) + consent := &model.OAuth2ConsentSession{ + GrantedScopes: []string{oidc.ScopeEmail}, + } + + extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn) assert.Len(t, extraClaims, 3) @@ -81,7 +69,7 @@ func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) { require.Contains(t, extraClaims, oidc.ClaimEmailVerified) assert.Equal(t, true, extraClaims[oidc.ClaimEmailVerified]) - extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeEmail}, []string{}, &oidcUserSessionFred) + extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred) assert.Len(t, extraClaims, 2) @@ -93,7 +81,11 @@ func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) { } func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) { - extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile}, []string{}, &oidcUserSessionJohn) + consent := &model.OAuth2ConsentSession{ + GrantedScopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile}, + } + + extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn) assert.Len(t, extraClaims, 2) @@ -103,7 +95,7 @@ func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) { require.Contains(t, extraClaims, oidc.ClaimDisplayName) assert.Equal(t, "John Smith", extraClaims[oidc.ClaimDisplayName]) - extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile}, []string{}, &oidcUserSessionFred) + extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred) assert.Len(t, extraClaims, 2) diff --git a/internal/handlers/response.go b/internal/handlers/response.go index 96301c35..a9f3bb74 100644 --- a/internal/handlers/response.go +++ b/internal/handlers/response.go @@ -7,9 +7,9 @@ import ( "github.com/valyala/fasthttp" - "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/middlewares" + "github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/utils" ) @@ -17,14 +17,15 @@ import ( func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) { userSession := ctx.GetSession() - if userSession.OIDCWorkflowSession.Require2FA && userSession.AuthenticationLevel != authentication.TwoFactor { - ctx.Logger.Warnf("OpenID Connect client '%s' requires 2FA, cannot be redirected yet", userSession.OIDCWorkflowSession.ClientID) - ctx.ReplyOK() + if userSession.ConsentChallengeID == nil { + ctx.Logger.Errorf("Unable to handle OIDC workflow response because the user session doesn't contain a consent challenge id") + + respondUnauthorized(ctx, messageOperationFailed) return } - uri, err := ctx.ExternalRootURL() + externalRootURL, err := ctx.ExternalRootURL() if err != nil { ctx.Logger.Errorf("Unable to determine external Base URL: %v", err) @@ -33,18 +34,37 @@ func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) { return } - if isConsentMissing( - userSession.OIDCWorkflowSession, - userSession.OIDCWorkflowSession.RequestedScopes, - userSession.OIDCWorkflowSession.RequestedAudience) { - err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s/consent", uri)}) + consent, err := ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID) + if err != nil { + ctx.Logger.Errorf("Unable to load consent session from database: %v", err) - if err != nil { + respondUnauthorized(ctx, messageOperationFailed) + + return + } + + client, err := ctx.Providers.OpenIDConnect.Store.GetFullClient(consent.ClientID) + if err != nil { + ctx.Logger.Errorf("Unable to find client for the consent session: %v", err) + + respondUnauthorized(ctx, messageOperationFailed) + + return + } + + if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { + ctx.Logger.Warnf("OpenID Connect client '%s' requires 2FA, cannot be redirected yet", client.ID) + ctx.ReplyOK() + + return + } + + if userSession.ConsentChallengeID != nil { + if err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s/consent", externalRootURL)}); err != nil { ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err) } } else { - err = ctx.SetJSONBody(redirectResponse{Redirect: userSession.OIDCWorkflowSession.AuthURI}) - if err != nil { + if err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s%s?%s", externalRootURL, oidc.AuthorizationPath, consent.Form)}); err != nil { ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err) } } diff --git a/internal/handlers/types_oidc.go b/internal/handlers/types_oidc.go deleted file mode 100644 index 0464bde6..00000000 --- a/internal/handlers/types_oidc.go +++ /dev/null @@ -1,12 +0,0 @@ -package handlers - -// ConsentPostRequestBody schema of the request body of the consent POST endpoint. -type ConsentPostRequestBody struct { - ClientID string `json:"client_id"` - AcceptOrReject string `json:"accept_or_reject"` -} - -// ConsentPostResponseBody schema of the response body of the consent POST endpoint. -type ConsentPostResponseBody struct { - RedirectURI string `json:"redirect_uri"` -} diff --git a/internal/mocks/storage.go b/internal/mocks/storage.go index 9c1ccce4..4bcea9ad 100644 --- a/internal/mocks/storage.go +++ b/internal/mocks/storage.go @@ -10,8 +10,10 @@ import ( time "time" gomock "github.com/golang/mock/gomock" + uuid "github.com/google/uuid" model "github.com/authelia/authelia/v4/internal/model" + storage "github.com/authelia/authelia/v4/internal/storage" ) // MockStorage is a mock of Provider interface. @@ -51,6 +53,21 @@ func (mr *MockStorageMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorage)(nil).AppendAuthenticationLog), arg0, arg1) } +// BeginTX mocks base method. +func (m *MockStorage) BeginTX(arg0 context.Context) (context.Context, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BeginTX", arg0) + ret0, _ := ret[0].(context.Context) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTX indicates an expected call of BeginTX. +func (mr *MockStorageMockRecorder) BeginTX(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTX", reflect.TypeOf((*MockStorage)(nil).BeginTX), arg0) +} + // Close mocks base method. func (m *MockStorage) Close() error { m.ctrl.T.Helper() @@ -65,6 +82,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close)) } +// Commit mocks base method. +func (m *MockStorage) Commit(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockStorageMockRecorder) Commit(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockStorage)(nil).Commit), arg0) +} + // ConsumeIdentityVerification mocks base method. func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 model.NullIP) error { m.ctrl.T.Helper() @@ -79,6 +110,34 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2) } +// DeactivateOAuth2Session mocks base method. +func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeactivateOAuth2Session", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeactivateOAuth2Session indicates an expected call of DeactivateOAuth2Session. +func (mr *MockStorageMockRecorder) DeactivateOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeactivateOAuth2Session", reflect.TypeOf((*MockStorage)(nil).DeactivateOAuth2Session), arg0, arg1, arg2) +} + +// DeactivateOAuth2SessionByRequestID mocks base method. +func (m *MockStorage) DeactivateOAuth2SessionByRequestID(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeactivateOAuth2SessionByRequestID", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeactivateOAuth2SessionByRequestID indicates an expected call of DeactivateOAuth2SessionByRequestID. +func (mr *MockStorageMockRecorder) DeactivateOAuth2SessionByRequestID(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeactivateOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).DeactivateOAuth2SessionByRequestID), arg0, arg1, arg2) +} + // DeletePreferredDuoDevice mocks base method. func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -137,6 +196,66 @@ func (mr *MockStorageMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4) } +// LoadOAuth2BlacklistedJTI mocks base method. +func (m *MockStorage) LoadOAuth2BlacklistedJTI(arg0 context.Context, arg1 string) (*model.OAuth2BlacklistedJTI, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadOAuth2BlacklistedJTI", arg0, arg1) + ret0, _ := ret[0].(*model.OAuth2BlacklistedJTI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadOAuth2BlacklistedJTI indicates an expected call of LoadOAuth2BlacklistedJTI. +func (mr *MockStorageMockRecorder) LoadOAuth2BlacklistedJTI(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2BlacklistedJTI", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2BlacklistedJTI), arg0, arg1) +} + +// LoadOAuth2ConsentSessionByChallengeID mocks base method. +func (m *MockStorage) LoadOAuth2ConsentSessionByChallengeID(arg0 context.Context, arg1 uuid.UUID) (*model.OAuth2ConsentSession, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadOAuth2ConsentSessionByChallengeID", arg0, arg1) + ret0, _ := ret[0].(*model.OAuth2ConsentSession) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadOAuth2ConsentSessionByChallengeID indicates an expected call of LoadOAuth2ConsentSessionByChallengeID. +func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionByChallengeID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionByChallengeID", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionByChallengeID), arg0, arg1) +} + +// LoadOAuth2ConsentSessionsPreConfigured mocks base method. +func (m *MockStorage) LoadOAuth2ConsentSessionsPreConfigured(arg0 context.Context, arg1 string, arg2 uuid.UUID) (*storage.ConsentSessionRows, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadOAuth2ConsentSessionsPreConfigured", arg0, arg1, arg2) + ret0, _ := ret[0].(*storage.ConsentSessionRows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadOAuth2ConsentSessionsPreConfigured indicates an expected call of LoadOAuth2ConsentSessionsPreConfigured. +func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionsPreConfigured(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionsPreConfigured", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionsPreConfigured), arg0, arg1, arg2) +} + +// LoadOAuth2Session mocks base method. +func (m *MockStorage) LoadOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) (*model.OAuth2Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadOAuth2Session", arg0, arg1, arg2) + ret0, _ := ret[0].(*model.OAuth2Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadOAuth2Session indicates an expected call of LoadOAuth2Session. +func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2) +} + // LoadPreferred2FAMethod mocks base method. func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { m.ctrl.T.Helper() @@ -212,6 +331,36 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1) } +// LoadUserOpaqueIdentifier mocks base method. +func (m *MockStorage) LoadUserOpaqueIdentifier(arg0 context.Context, arg1 uuid.UUID) (*model.UserOpaqueIdentifier, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadUserOpaqueIdentifier", arg0, arg1) + ret0, _ := ret[0].(*model.UserOpaqueIdentifier) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadUserOpaqueIdentifier indicates an expected call of LoadUserOpaqueIdentifier. +func (mr *MockStorageMockRecorder) LoadUserOpaqueIdentifier(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserOpaqueIdentifier", reflect.TypeOf((*MockStorage)(nil).LoadUserOpaqueIdentifier), arg0, arg1) +} + +// LoadUserOpaqueIdentifierBySignature mocks base method. +func (m *MockStorage) LoadUserOpaqueIdentifierBySignature(arg0 context.Context, arg1, arg2, arg3 string) (*model.UserOpaqueIdentifier, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadUserOpaqueIdentifierBySignature", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*model.UserOpaqueIdentifier) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LoadUserOpaqueIdentifierBySignature indicates an expected call of LoadUserOpaqueIdentifierBySignature. +func (mr *MockStorageMockRecorder) LoadUserOpaqueIdentifierBySignature(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserOpaqueIdentifierBySignature", reflect.TypeOf((*MockStorage)(nil).LoadUserOpaqueIdentifierBySignature), arg0, arg1, arg2, arg3) +} + // LoadWebauthnDevices mocks base method. func (m *MockStorage) LoadWebauthnDevices(arg0 context.Context, arg1, arg2 int) ([]model.WebauthnDevice, error) { m.ctrl.T.Helper() @@ -242,6 +391,48 @@ func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1 inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1) } +// RevokeOAuth2Session mocks base method. +func (m *MockStorage) RevokeOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeOAuth2Session", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeOAuth2Session indicates an expected call of RevokeOAuth2Session. +func (mr *MockStorageMockRecorder) RevokeOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2Session", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2Session), arg0, arg1, arg2) +} + +// RevokeOAuth2SessionByRequestID mocks base method. +func (m *MockStorage) RevokeOAuth2SessionByRequestID(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RevokeOAuth2SessionByRequestID", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// RevokeOAuth2SessionByRequestID indicates an expected call of RevokeOAuth2SessionByRequestID. +func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2) +} + +// Rollback mocks base method. +func (m *MockStorage) Rollback(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockStorageMockRecorder) Rollback(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockStorage)(nil).Rollback), arg0) +} + // SaveIdentityVerification mocks base method. func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 model.IdentityVerification) error { m.ctrl.T.Helper() @@ -256,6 +447,76 @@ func (mr *MockStorageMockRecorder) SaveIdentityVerification(arg0, arg1 interface return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).SaveIdentityVerification), arg0, arg1) } +// SaveOAuth2BlacklistedJTI mocks base method. +func (m *MockStorage) SaveOAuth2BlacklistedJTI(arg0 context.Context, arg1 model.OAuth2BlacklistedJTI) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOAuth2BlacklistedJTI", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveOAuth2BlacklistedJTI indicates an expected call of SaveOAuth2BlacklistedJTI. +func (mr *MockStorageMockRecorder) SaveOAuth2BlacklistedJTI(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2BlacklistedJTI", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2BlacklistedJTI), arg0, arg1) +} + +// SaveOAuth2ConsentSession mocks base method. +func (m *MockStorage) SaveOAuth2ConsentSession(arg0 context.Context, arg1 model.OAuth2ConsentSession) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOAuth2ConsentSession", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveOAuth2ConsentSession indicates an expected call of SaveOAuth2ConsentSession. +func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSession(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSession", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSession), arg0, arg1) +} + +// SaveOAuth2ConsentSessionGranted mocks base method. +func (m *MockStorage) SaveOAuth2ConsentSessionGranted(arg0 context.Context, arg1 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOAuth2ConsentSessionGranted", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveOAuth2ConsentSessionGranted indicates an expected call of SaveOAuth2ConsentSessionGranted. +func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionGranted(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionGranted", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionGranted), arg0, arg1) +} + +// SaveOAuth2ConsentSessionResponse mocks base method. +func (m *MockStorage) SaveOAuth2ConsentSessionResponse(arg0 context.Context, arg1 model.OAuth2ConsentSession, arg2 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOAuth2ConsentSessionResponse", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveOAuth2ConsentSessionResponse indicates an expected call of SaveOAuth2ConsentSessionResponse. +func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionResponse(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionResponse", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionResponse), arg0, arg1, arg2) +} + +// SaveOAuth2Session mocks base method. +func (m *MockStorage) SaveOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 model.OAuth2Session) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveOAuth2Session", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveOAuth2Session indicates an expected call of SaveOAuth2Session. +func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2) +} + // SavePreferred2FAMethod mocks base method. func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() @@ -298,6 +559,20 @@ func (mr *MockStorageMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).SaveTOTPConfiguration), arg0, arg1) } +// SaveUserOpaqueIdentifier mocks base method. +func (m *MockStorage) SaveUserOpaqueIdentifier(arg0 context.Context, arg1 model.UserOpaqueIdentifier) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveUserOpaqueIdentifier", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveUserOpaqueIdentifier indicates an expected call of SaveUserOpaqueIdentifier. +func (mr *MockStorageMockRecorder) SaveUserOpaqueIdentifier(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveUserOpaqueIdentifier", reflect.TypeOf((*MockStorage)(nil).SaveUserOpaqueIdentifier), arg0, arg1) +} + // SaveWebauthnDevice mocks base method. func (m *MockStorage) SaveWebauthnDevice(arg0 context.Context, arg1 model.WebauthnDevice) error { m.ctrl.T.Helper() diff --git a/internal/model/oidc.go b/internal/model/oidc.go index d599d9f7..2c132556 100644 --- a/internal/model/oidc.go +++ b/internal/model/oidc.go @@ -1,14 +1,228 @@ package model -// OIDCWorkflowSession represent an OIDC workflow session. -type OIDCWorkflowSession struct { - ClientID string - RequestedScopes []string - GrantedScopes []string - RequestedAudience []string - GrantedAudience []string - TargetURI string - AuthURI string - Require2FA bool - CreatedTimestamp int64 +import ( + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "net/url" + "time" + + "github.com/google/uuid" + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + + "github.com/authelia/authelia/v4/internal/utils" +) + +// NewOAuth2ConsentSession creates a new OAuth2ConsentSession. +func NewOAuth2ConsentSession(subject uuid.UUID, r fosite.Requester) (consent *OAuth2ConsentSession, err error) { + consent = &OAuth2ConsentSession{ + ClientID: r.GetClient().GetID(), + Subject: subject, + Form: r.GetRequestForm().Encode(), + RequestedAt: r.GetRequestedAt(), + RequestedScopes: StringSlicePipeDelimited(r.GetRequestedScopes()), + RequestedAudience: StringSlicePipeDelimited(r.GetRequestedAudience()), + GrantedScopes: StringSlicePipeDelimited(r.GetGrantedScopes()), + GrantedAudience: StringSlicePipeDelimited(r.GetGrantedAudience()), + } + + if consent.ChallengeID, err = uuid.NewRandom(); err != nil { + return nil, err + } + + return consent, nil +} + +// NewOAuth2SessionFromRequest creates a new OAuth2Session from a signature and fosite.Requester. +func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session *OAuth2Session, err error) { + var ( + subject string + openidSession *OpenIDSession + sessData []byte + ) + + openidSession = r.GetSession().(*OpenIDSession) + if openidSession == nil { + return nil, errors.New("unexpected session type") + } + + subject = openidSession.GetSubject() + + if sessData, err = json.Marshal(openidSession); err != nil { + return nil, err + } + + return &OAuth2Session{ + ChallengeID: openidSession.ChallengeID, + RequestID: r.GetID(), + ClientID: r.GetClient().GetID(), + Signature: signature, + RequestedAt: r.GetRequestedAt(), + Subject: subject, + RequestedScopes: StringSlicePipeDelimited(r.GetRequestedScopes()), + GrantedScopes: StringSlicePipeDelimited(r.GetGrantedScopes()), + RequestedAudience: StringSlicePipeDelimited(r.GetRequestedAudience()), + GrantedAudience: StringSlicePipeDelimited(r.GetGrantedAudience()), + Active: true, + Revoked: false, + Form: r.GetRequestForm().Encode(), + Session: sessData, + }, nil +} + +// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI. +func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) { + return OAuth2BlacklistedJTI{ + Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))), + ExpiresAt: exp, + } +} + +// OAuth2ConsentSession stores information about an OAuth2.0 Consent. +type OAuth2ConsentSession struct { + ID int `db:"id"` + ChallengeID uuid.UUID `db:"challenge_id"` + ClientID string `db:"client_id"` + Subject uuid.UUID `db:"subject"` + + Authorized bool `db:"authorized"` + Granted bool `db:"granted"` + + RequestedAt time.Time `db:"requested_at"` + RespondedAt *time.Time `db:"responded_at"` + ExpiresAt *time.Time `db:"expires_at"` + + Form string `db:"form_data"` + + RequestedScopes StringSlicePipeDelimited `db:"requested_scopes"` + GrantedScopes StringSlicePipeDelimited `db:"granted_scopes"` + RequestedAudience StringSlicePipeDelimited `db:"requested_audience"` + GrantedAudience StringSlicePipeDelimited `db:"granted_audience"` +} + +// HasExactGrants returns true if the granted audience and scopes of this consent matches exactly with another +// audience and set of scopes. +func (s OAuth2ConsentSession) HasExactGrants(scopes, audience []string) (has bool) { + return s.HasExactGrantedScopes(scopes) && s.HasExactGrantedAudience(audience) +} + +// HasExactGrantedAudience returns true if the granted audience of this consent matches exactly with another audience. +func (s OAuth2ConsentSession) HasExactGrantedAudience(audience []string) (has bool) { + return !utils.IsStringSlicesDifferent(s.GrantedAudience, audience) +} + +// HasExactGrantedScopes returns true if the granted scopes of this consent matches exactly with another set of scopes. +func (s OAuth2ConsentSession) HasExactGrantedScopes(scopes []string) (has bool) { + return !utils.IsStringSlicesDifferent(s.GrantedScopes, scopes) +} + +// IsAuthorized returns true if the user has responded to the consent session and it was authorized. +func (s OAuth2ConsentSession) IsAuthorized() bool { + return s.Responded() && s.Authorized +} + +// CanGrant returns true if the user has responded to the consent session, it was authorized, and it either hast not +// previously been granted or the ability to grant has not expired. +func (s OAuth2ConsentSession) CanGrant() bool { + if !s.Responded() { + return false + } + + if s.Granted && (s.ExpiresAt == nil || s.ExpiresAt.Before(time.Now())) { + return false + } + + return true +} + +// IsDenied returns true if the user has responded to the consent session and it was not authorized. +func (s OAuth2ConsentSession) IsDenied() bool { + return s.Responded() && !s.Authorized +} + +// Responded returns true if the user has responded to the consent session. +func (s OAuth2ConsentSession) Responded() bool { + return s.RespondedAt != nil +} + +// GetForm returns the form. +func (s OAuth2ConsentSession) GetForm() (form url.Values, err error) { + return url.ParseQuery(s.Form) +} + +// OAuth2BlacklistedJTI represents a blacklisted JTI used with OAuth2.0. +type OAuth2BlacklistedJTI struct { + ID int `db:"id"` + Signature string `db:"signature"` + ExpiresAt time.Time `db:"expires_at"` +} + +// OpenIDSession holds OIDC Session information. +type OpenIDSession struct { + *openid.DefaultSession `json:"id_token"` + + ChallengeID uuid.UUID `db:"challenge_id"` + ClientID string + + Extra map[string]interface{} `json:"extra"` +} + +// OAuth2Session represents a OAuth2.0 session. +type OAuth2Session struct { + ID int `db:"id"` + ChallengeID uuid.UUID `db:"challenge_id"` + RequestID string `db:"request_id"` + ClientID string `db:"client_id"` + Signature string `db:"signature"` + RequestedAt time.Time `db:"requested_at"` + Subject string `db:"subject"` + RequestedScopes StringSlicePipeDelimited `db:"requested_scopes"` + GrantedScopes StringSlicePipeDelimited `db:"granted_scopes"` + RequestedAudience StringSlicePipeDelimited `db:"requested_audience"` + GrantedAudience StringSlicePipeDelimited `db:"granted_audience"` + Active bool `db:"active"` + Revoked bool `db:"revoked"` + Form string `db:"form_data"` + Session []byte `db:"session_data"` +} + +// SetSubject implements an interface required for RFC7523. +func (s *OAuth2Session) SetSubject(subject string) { + s.Subject = subject +} + +// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage. +func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) { + sessionData := s.Session + + if session != nil { + if err = json.Unmarshal(sessionData, session); err != nil { + return nil, err + } + } + + client, err := store.GetClient(ctx, s.ClientID) + if err != nil { + return nil, err + } + + values, err := url.ParseQuery(s.Form) + if err != nil { + return nil, err + } + + return &fosite.Request{ + ID: s.RequestID, + RequestedAt: s.RequestedAt, + Client: client, + RequestedScope: fosite.Arguments(s.RequestedScopes), + GrantedScope: fosite.Arguments(s.GrantedScopes), + RequestedAudience: fosite.Arguments(s.RequestedAudience), + GrantedAudience: fosite.Arguments(s.GrantedAudience), + Form: values, + Session: session, + }, nil } diff --git a/internal/model/types.go b/internal/model/types.go index de1658b2..674e6aa4 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -1,10 +1,13 @@ package model import ( + "database/sql" "database/sql/driver" "encoding/base64" "fmt" "net" + + "github.com/authelia/authelia/v4/internal/utils" ) // NewIP easily constructs a new IP. @@ -150,3 +153,26 @@ func (b *Base64) Scan(src interface{}) (err error) { type StartupCheck interface { StartupCheck() (err error) } + +// StringSlicePipeDelimited is a string slice that is stored in the database delimited by pipes. +type StringSlicePipeDelimited []string + +// Scan is the StringSlicePipeDelimited implementation of the sql.Scanner. +func (s *StringSlicePipeDelimited) Scan(value interface{}) (err error) { + var nullStr sql.NullString + + if err = nullStr.Scan(value); err != nil { + return err + } + + if nullStr.Valid { + *s = utils.StringSplitDelimitedEscaped(nullStr.String, '|') + } + + return nil +} + +// Value is the StringSlicePipeDelimited implementation of the databases/sql driver.Valuer. +func (s StringSlicePipeDelimited) Value() (driver.Value, error) { + return utils.StringJoinDelimitedEscaped(s, '|'), nil +} diff --git a/internal/model/types_test.go b/internal/model/types_test.go index ada1e9bb..50d8bf38 100644 --- a/internal/model/types_test.go +++ b/internal/model/types_test.go @@ -1,11 +1,21 @@ package model import ( + "fmt" "testing" + "github.com/ory/fosite" "github.com/stretchr/testify/assert" ) +func Test(t *testing.T) { + args := fosite.Arguments{"abc", "123"} + + x := StringSlicePipeDelimited(args) + + fmt.Println(x) +} + func TestDatabaseModelTypeIP(t *testing.T) { ip := IP{} diff --git a/internal/model/user_opaque_identifier.go b/internal/model/user_opaque_identifier.go new file mode 100644 index 00000000..c89b8038 --- /dev/null +++ b/internal/model/user_opaque_identifier.go @@ -0,0 +1,33 @@ +package model + +import ( + "fmt" + + "github.com/google/uuid" +) + +// NewUserOpaqueIdentifier either creates a new UserOpaqueIdentifier or returns an error. +func NewUserOpaqueIdentifier(service, sectorID, username string) (id *UserOpaqueIdentifier, err error) { + var opaqueID uuid.UUID + + if opaqueID, err = uuid.NewRandom(); err != nil { + return nil, fmt.Errorf("unable to generate uuid: %w", err) + } + + return &UserOpaqueIdentifier{ + Service: service, + SectorID: sectorID, + Username: username, + Identifier: opaqueID, + }, nil +} + +// UserOpaqueIdentifier represents an opaque identifier for a user. Commonly used with OAuth 2.0 and OpenID Connect. +type UserOpaqueIdentifier struct { + ID int `db:"id"` + Service string `db:"service"` + SectorID string `db:"sector_id"` + Username string `db:"username"` + + Identifier uuid.UUID `db:"identifier"` +} diff --git a/internal/oidc/client.go b/internal/oidc/client.go index 22f8dc98..28a35af6 100644 --- a/internal/oidc/client.go +++ b/internal/oidc/client.go @@ -9,9 +9,9 @@ import ( "github.com/authelia/authelia/v4/internal/model" ) -// NewClient creates a new InternalClient. -func NewClient(config schema.OpenIDConnectClientConfiguration) (client *InternalClient) { - client = &InternalClient{ +// NewClient creates a new Client. +func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client) { + client = &Client{ ID: config.ID, Description: config.Description, Secret: []byte(config.Secret), @@ -37,42 +37,47 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Internal } // IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient. -func (c InternalClient) IsAuthenticationLevelSufficient(level authentication.Level) bool { +func (c Client) IsAuthenticationLevelSufficient(level authentication.Level) bool { return authorization.IsAuthLevelSufficient(level, c.Policy) } // GetID returns the ID. -func (c InternalClient) GetID() string { +func (c Client) GetID() string { return c.ID } -// GetConsentResponseBody returns the proper consent response body for this model.OIDCWorkflowSession. -func (c InternalClient) GetConsentResponseBody(session *model.OIDCWorkflowSession) ConsentGetResponseBody { +// GetSectorIdentifier returns the SectorIdentifier for this client. +func (c Client) GetSectorIdentifier() string { + return c.SectorIdentifier +} + +// GetConsentResponseBody returns the proper consent response body for this session.OIDCWorkflowSession. +func (c Client) GetConsentResponseBody(consent *model.OAuth2ConsentSession) ConsentGetResponseBody { body := ConsentGetResponseBody{ ClientID: c.ID, ClientDescription: c.Description, } - if session != nil { - body.Scopes = session.RequestedScopes - body.Audience = session.RequestedAudience + if consent != nil { + body.Scopes = consent.RequestedScopes + body.Audience = consent.RequestedAudience } return body } // GetHashedSecret returns the Secret. -func (c InternalClient) GetHashedSecret() []byte { +func (c Client) GetHashedSecret() []byte { return c.Secret } // GetRedirectURIs returns the RedirectURIs. -func (c InternalClient) GetRedirectURIs() []string { +func (c Client) GetRedirectURIs() []string { return c.RedirectURIs } // GetGrantTypes returns the GrantTypes. -func (c InternalClient) GetGrantTypes() fosite.Arguments { +func (c Client) GetGrantTypes() fosite.Arguments { if len(c.GrantTypes) == 0 { return fosite.Arguments{"authorization_code"} } @@ -81,7 +86,7 @@ func (c InternalClient) GetGrantTypes() fosite.Arguments { } // GetResponseTypes returns the ResponseTypes. -func (c InternalClient) GetResponseTypes() fosite.Arguments { +func (c Client) GetResponseTypes() fosite.Arguments { if len(c.ResponseTypes) == 0 { return fosite.Arguments{"code"} } @@ -90,23 +95,23 @@ func (c InternalClient) GetResponseTypes() fosite.Arguments { } // GetScopes returns the Scopes. -func (c InternalClient) GetScopes() fosite.Arguments { +func (c Client) GetScopes() fosite.Arguments { return c.Scopes } // IsPublic returns the value of the Public property. -func (c InternalClient) IsPublic() bool { +func (c Client) IsPublic() bool { return c.Public } // GetAudience returns the Audience. -func (c InternalClient) GetAudience() fosite.Arguments { +func (c Client) GetAudience() fosite.Arguments { return c.Audience } // GetResponseModes returns the valid response modes for this client. // // Implements the fosite.ResponseModeClient. -func (c InternalClient) GetResponseModes() []fosite.ResponseModeType { +func (c Client) GetResponseModes() []fosite.ResponseModeType { return c.ResponseModes } diff --git a/internal/oidc/client_test.go b/internal/oidc/client_test.go index fdad40f3..e208c082 100644 --- a/internal/oidc/client_test.go +++ b/internal/oidc/client_test.go @@ -44,7 +44,7 @@ func TestNewClient(t *testing.T) { } func TestIsAuthenticationLevelSufficient(t *testing.T) { - c := InternalClient{} + c := Client{} c.Policy = authorization.Bypass assert.True(t, c.IsAuthenticationLevelSufficient(authentication.NotAuthenticated)) @@ -68,7 +68,7 @@ func TestIsAuthenticationLevelSufficient(t *testing.T) { } func TestInternalClient_GetConsentResponseBody(t *testing.T) { - c := InternalClient{} + c := Client{} consentRequestBody := c.GetConsentResponseBody(nil) assert.Equal(t, "", consentRequestBody.ClientID) @@ -79,7 +79,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) { c.ID = "myclient" c.Description = "My Client" - workflow := &model.OIDCWorkflowSession{ + consent := &model.OAuth2ConsentSession{ RequestedAudience: []string{"https://example.com"}, RequestedScopes: []string{"openid", "groups"}, } @@ -87,7 +87,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) { expectedScopes := []string{"openid", "groups"} expectedAudiences := []string{"https://example.com"} - consentRequestBody = c.GetConsentResponseBody(workflow) + consentRequestBody = c.GetConsentResponseBody(consent) assert.Equal(t, "myclient", consentRequestBody.ClientID) assert.Equal(t, "My Client", consentRequestBody.ClientDescription) assert.Equal(t, expectedScopes, consentRequestBody.Scopes) @@ -95,7 +95,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) { } func TestInternalClient_GetAudience(t *testing.T) { - c := InternalClient{} + c := Client{} audience := c.GetAudience() assert.Len(t, audience, 0) @@ -108,7 +108,7 @@ func TestInternalClient_GetAudience(t *testing.T) { } func TestInternalClient_GetScopes(t *testing.T) { - c := InternalClient{} + c := Client{} scopes := c.GetScopes() assert.Len(t, scopes, 0) @@ -121,7 +121,7 @@ func TestInternalClient_GetScopes(t *testing.T) { } func TestInternalClient_GetGrantTypes(t *testing.T) { - c := InternalClient{} + c := Client{} grantTypes := c.GetGrantTypes() require.Len(t, grantTypes, 1) @@ -135,7 +135,7 @@ func TestInternalClient_GetGrantTypes(t *testing.T) { } func TestInternalClient_GetHashedSecret(t *testing.T) { - c := InternalClient{} + c := Client{} hashedSecret := c.GetHashedSecret() assert.Equal(t, []byte(nil), hashedSecret) @@ -147,7 +147,7 @@ func TestInternalClient_GetHashedSecret(t *testing.T) { } func TestInternalClient_GetID(t *testing.T) { - c := InternalClient{} + c := Client{} id := c.GetID() assert.Equal(t, "", id) @@ -159,7 +159,7 @@ func TestInternalClient_GetID(t *testing.T) { } func TestInternalClient_GetRedirectURIs(t *testing.T) { - c := InternalClient{} + c := Client{} redirectURIs := c.GetRedirectURIs() require.Len(t, redirectURIs, 0) @@ -172,7 +172,7 @@ func TestInternalClient_GetRedirectURIs(t *testing.T) { } func TestInternalClient_GetResponseModes(t *testing.T) { - c := InternalClient{} + c := Client{} responseModes := c.GetResponseModes() require.Len(t, responseModes, 0) @@ -191,7 +191,7 @@ func TestInternalClient_GetResponseModes(t *testing.T) { } func TestInternalClient_GetResponseTypes(t *testing.T) { - c := InternalClient{} + c := Client{} responseTypes := c.GetResponseTypes() require.Len(t, responseTypes, 1) @@ -206,7 +206,7 @@ func TestInternalClient_GetResponseTypes(t *testing.T) { } func TestInternalClient_IsPublic(t *testing.T) { - c := InternalClient{} + c := Client{} assert.False(t, c.IsPublic()) diff --git a/internal/oidc/discovery.go b/internal/oidc/discovery.go new file mode 100644 index 00000000..f2619969 --- /dev/null +++ b/internal/oidc/discovery.go @@ -0,0 +1,79 @@ +package oidc + +// NewOpenIDConnectWellKnownConfiguration generates a new OpenIDConnectWellKnownConfiguration. +func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge, pairwise bool) (config OpenIDConnectWellKnownConfiguration) { + config = OpenIDConnectWellKnownConfiguration{ + CommonDiscoveryOptions: CommonDiscoveryOptions{ + SubjectTypesSupported: []string{ + "public", + }, + ResponseTypesSupported: []string{ + "code", + "token", + "id_token", + "code token", + "code id_token", + "token id_token", + "code token id_token", + "none", + }, + ResponseModesSupported: []string{ + "form_post", + "query", + "fragment", + }, + ScopesSupported: []string{ + ScopeOfflineAccess, + ScopeOpenID, + ScopeProfile, + ScopeGroups, + ScopeEmail, + }, + ClaimsSupported: []string{ + "aud", + "exp", + "iat", + "iss", + "jti", + "rat", + "sub", + "auth_time", + "nonce", + ClaimEmail, + ClaimEmailVerified, + ClaimEmailAlts, + ClaimGroups, + ClaimPreferredUsername, + ClaimDisplayName, + }, + }, + OAuth2DiscoveryOptions: OAuth2DiscoveryOptions{ + CodeChallengeMethodsSupported: []string{ + "S256", + }, + }, + OpenIDConnectDiscoveryOptions: OpenIDConnectDiscoveryOptions{ + IDTokenSigningAlgValuesSupported: []string{ + "RS256", + }, + UserinfoSigningAlgValuesSupported: []string{ + "none", + "RS256", + }, + RequestObjectSigningAlgValuesSupported: []string{ + "none", + "RS256", + }, + }, + } + + if pairwise { + config.SubjectTypesSupported = append(config.SubjectTypesSupported, "pairwise") + } + + if enablePKCEPlainChallenge { + config.CodeChallengeMethodsSupported = append(config.CodeChallengeMethodsSupported, "plain") + } + + return config +} diff --git a/internal/oidc/discovery_test.go b/internal/oidc/discovery_test.go new file mode 100644 index 00000000..256f5f70 --- /dev/null +++ b/internal/oidc/discovery_test.go @@ -0,0 +1,65 @@ +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) { + testCases := []struct { + desc string + pkcePlainChallenge, pairwise bool + expectCodeChallengeMethodsSupported, expectSubjectTypesSupported []string + }{ + { + desc: "ShouldHaveChallengeMethodsS256ANDSubjectTypesSupportedPublic", + pkcePlainChallenge: false, + pairwise: false, + expectCodeChallengeMethodsSupported: []string{"S256"}, + expectSubjectTypesSupported: []string{"public"}, + }, + { + desc: "ShouldHaveChallengeMethodsS256PlainANDSubjectTypesSupportedPublic", + pkcePlainChallenge: true, + pairwise: false, + expectCodeChallengeMethodsSupported: []string{"S256", "plain"}, + expectSubjectTypesSupported: []string{"public"}, + }, + { + desc: "ShouldHaveChallengeMethodsS256ANDSubjectTypesSupportedPublicPairwise", + pkcePlainChallenge: false, + pairwise: true, + expectCodeChallengeMethodsSupported: []string{"S256"}, + expectSubjectTypesSupported: []string{"public", "pairwise"}, + }, + { + desc: "ShouldHaveChallengeMethodsS256PlainANDSubjectTypesSupportedPublicPairwise", + pkcePlainChallenge: true, + pairwise: true, + expectCodeChallengeMethodsSupported: []string{"S256", "plain"}, + expectSubjectTypesSupported: []string{"public", "pairwise"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + actual := NewOpenIDConnectWellKnownConfiguration(tc.pkcePlainChallenge, tc.pairwise) + for _, codeChallengeMethod := range tc.expectCodeChallengeMethodsSupported { + assert.Contains(t, actual.CodeChallengeMethodsSupported, codeChallengeMethod) + } + + for _, subjectType := range tc.expectSubjectTypesSupported { + assert.Contains(t, actual.SubjectTypesSupported, subjectType) + } + + for _, codeChallengeMethod := range actual.CodeChallengeMethodsSupported { + assert.Contains(t, tc.expectCodeChallengeMethodsSupported, codeChallengeMethod) + } + + for _, subjectType := range actual.SubjectTypesSupported { + assert.Contains(t, tc.expectSubjectTypesSupported, subjectType) + } + }) + } +} diff --git a/internal/oidc/hasher.go b/internal/oidc/hasher.go index 58f84a93..0bacbe98 100644 --- a/internal/oidc/hasher.go +++ b/internal/oidc/hasher.go @@ -6,7 +6,7 @@ import ( ) // Compare compares the hash with the data and returns an error if they don't match. -func (h AutheliaHasher) Compare(_ context.Context, hash, data []byte) (err error) { +func (h PlainTextHasher) Compare(_ context.Context, hash, data []byte) (err error) { if subtle.ConstantTimeCompare(hash, data) == 0 { return errPasswordsDoNotMatch } @@ -15,6 +15,6 @@ func (h AutheliaHasher) Compare(_ context.Context, hash, data []byte) (err error } // Hash creates a new hash from data. -func (h AutheliaHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) { +func (h PlainTextHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) { return data, nil } diff --git a/internal/oidc/hasher_test.go b/internal/oidc/hasher_test.go index bc93a235..a86123bf 100644 --- a/internal/oidc/hasher_test.go +++ b/internal/oidc/hasher_test.go @@ -8,7 +8,7 @@ import ( ) func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) { - hasher := AutheliaHasher{} + hasher := PlainTextHasher{} a := []byte("abc") b := []byte("abc") @@ -21,7 +21,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) { } func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) { - hasher := AutheliaHasher{} + hasher := PlainTextHasher{} a := []byte("abc") b := []byte("abcd") @@ -34,7 +34,7 @@ func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) { } func TestShouldHashPassword(t *testing.T) { - hasher := AutheliaHasher{} + hasher := PlainTextHasher{} data := []byte("abc") diff --git a/internal/oidc/provider.go b/internal/oidc/provider.go index 862ddce6..670d1e42 100644 --- a/internal/oidc/provider.go +++ b/internal/oidc/provider.go @@ -8,34 +8,35 @@ import ( "github.com/ory/herodot" "github.com/authelia/authelia/v4/internal/configuration/schema" + "github.com/authelia/authelia/v4/internal/storage" "github.com/authelia/authelia/v4/internal/utils" ) // NewOpenIDConnectProvider new-ups a OpenIDConnectProvider. -func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) (provider OpenIDConnectProvider, err error) { +func NewOpenIDConnectProvider(config *schema.OpenIDConnectConfiguration, storageProvider storage.Provider) (provider OpenIDConnectProvider, err error) { provider = OpenIDConnectProvider{ Fosite: nil, } - if configuration == nil { + if config == nil { return provider, nil } - provider.Store = NewOpenIDConnectStore(configuration) + provider.Store = NewOpenIDConnectStore(config, storageProvider) composeConfiguration := &compose.Config{ - AccessTokenLifespan: configuration.AccessTokenLifespan, - AuthorizeCodeLifespan: configuration.AuthorizeCodeLifespan, - IDTokenLifespan: configuration.IDTokenLifespan, - RefreshTokenLifespan: configuration.RefreshTokenLifespan, - SendDebugMessagesToClients: configuration.EnableClientDebugMessages, - MinParameterEntropy: configuration.MinimumParameterEntropy, - EnforcePKCE: configuration.EnforcePKCE == "always", - EnforcePKCEForPublicClients: configuration.EnforcePKCE != "never", - EnablePKCEPlainChallengeMethod: configuration.EnablePKCEPlainChallenge, + AccessTokenLifespan: config.AccessTokenLifespan, + AuthorizeCodeLifespan: config.AuthorizeCodeLifespan, + IDTokenLifespan: config.IDTokenLifespan, + RefreshTokenLifespan: config.RefreshTokenLifespan, + SendDebugMessagesToClients: config.EnableClientDebugMessages, + MinParameterEntropy: config.MinimumParameterEntropy, + EnforcePKCE: config.EnforcePKCE == "always", + EnforcePKCEForPublicClients: config.EnforcePKCE != "never", + EnablePKCEPlainChallengeMethod: config.EnablePKCEPlainChallenge, } - keyManager, err := NewKeyManagerWithConfiguration(configuration) + keyManager, err := NewKeyManagerWithConfiguration(config) if err != nil { return provider, err } @@ -50,7 +51,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) strategy := &compose.CommonStrategy{ CoreStrategy: compose.NewOAuth2HMACStrategy( composeConfiguration, - []byte(utils.HashSHA256FromString(configuration.HMACSecret)), + []byte(utils.HashSHA256FromString(config.HMACSecret)), nil, ), OpenIDConnectTokenStrategy: compose.NewOpenIDConnectStrategy( @@ -64,7 +65,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) composeConfiguration, provider.Store, strategy, - AutheliaHasher{}, + PlainTextHasher{}, /* These are the OAuth2 and OpenIDConnect factories. Order is important (the OAuth2 factories at the top must @@ -75,7 +76,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) compose.OAuth2AuthorizeImplicitFactory, compose.OAuth2ClientCredentialsGrantFactory, compose.OAuth2RefreshTokenGrantFactory, - compose.OAuth2ResourceOwnerPasswordCredentialsFactory, + // compose.OAuth2ResourceOwnerPasswordCredentialsFactory, // compose.RFC7523AssertionGrantFactory,. compose.OpenIDConnectExplicitFactory, @@ -89,80 +90,24 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) compose.OAuth2PKCEFactory, ) - provider.discovery = OpenIDConnectWellKnownConfiguration{ - CommonDiscoveryOptions: CommonDiscoveryOptions{ - SubjectTypesSupported: []string{ - "public", - }, - ResponseTypesSupported: []string{ - "code", - "token", - "id_token", - "code token", - "code id_token", - "token id_token", - "code token id_token", - "none", - }, - ResponseModesSupported: []string{ - "form_post", - "query", - "fragment", - }, - ScopesSupported: []string{ - ScopeOfflineAccess, - ScopeOpenID, - ScopeProfile, - ScopeGroups, - ScopeEmail, - }, - ClaimsSupported: []string{ - "aud", - "exp", - "iat", - "iss", - "jti", - "rat", - "sub", - "auth_time", - "nonce", - ClaimEmail, - ClaimEmailVerified, - ClaimEmailAlts, - ClaimGroups, - ClaimPreferredUsername, - ClaimDisplayName, - }, - }, - OAuth2DiscoveryOptions: OAuth2DiscoveryOptions{ - CodeChallengeMethodsSupported: []string{ - "S256", - }, - }, - OpenIDConnectDiscoveryOptions: OpenIDConnectDiscoveryOptions{ - IDTokenSigningAlgValuesSupported: []string{ - "RS256", - }, - UserinfoSigningAlgValuesSupported: []string{ - "none", - "RS256", - }, - RequestObjectSigningAlgValuesSupported: []string{ - "none", - "RS256", - }, - }, - } - - if configuration.EnablePKCEPlainChallenge { - provider.discovery.CodeChallengeMethodsSupported = append(provider.discovery.CodeChallengeMethodsSupported, "plain") - } + provider.discovery = NewOpenIDConnectWellKnownConfiguration(config.EnablePKCEPlainChallenge, provider.Pairwise()) provider.herodot = herodot.NewJSONWriter(nil) return provider, nil } +// Pairwise returns true if this provider is configured with clients that require pairwise. +func (p OpenIDConnectProvider) Pairwise() bool { + for _, c := range p.Store.clients { + if c.SectorIdentifier != "" { + return true + } + } + + return false +} + // Write writes data with herodot.JSONWriter. func (p OpenIDConnectProvider) Write(w http.ResponseWriter, r *http.Request, e interface{}, opts ...herodot.EncoderOptions) { p.herodot.Write(w, r, e, opts...) diff --git a/internal/oidc/provider_test.go b/internal/oidc/provider_test.go index a6f6e0cb..0a07d78e 100644 --- a/internal/oidc/provider_test.go +++ b/internal/oidc/provider_test.go @@ -12,7 +12,7 @@ import ( var exampleIssuerPrivateKey = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAvcMVMB2vEbqI6PlSNJ4HmUyMxBDJ5iY7FS+zDDAHOZBg9S3S\nKcAn1CZcnyL0VvJ7wcdhR6oTnOwR94eKvzUyJZ+GL2hTMm27dubEYsNdhoCl6N3X\nyEEohNfoxiiCYraVauX8X3M9jFzbEz9+pacaDbHB2syaJ1qFmMNR+HSu2jPzOo7M\nlqKIOgUzA0741MaYNt47AEVg4XU5ORLdolbAkItmYg1QbyFndg9H5IvwKkYaXTGE\nlgDBcPUC0yVjAC15Mguquq+jZeQay+6PSbHTD8PQMOkLjyChI2xEhVNbdCXe676R\ncMW2R/gjrcK23zmtmTWRfdC1iZLSlHO+bJj9vQIDAQABAoIBAEZvkP/JJOCJwqPn\nV3IcbmmilmV4bdi1vByDFgyiDyx4wOSA24+PubjvfFW9XcCgRPuKjDtTj/AhWBHv\nB7stfa2lZuNV7/u562mZArA+IAr62Zp0LdIxDV8x3T8gbjVB3HhPYbv0RJZDKTYd\nzV6jhfIrVu9mHpoY6ZnodhapCPYIyk/d49KBIHZuAc25CUjMXgTeaVtf0c996036\nUxW6ef33wAOJAvW0RCvbXAJfmBeEq2qQlkjTIlpYx71fhZWexHifi8Ouv3Zonc+1\n/P2Adq5uzYVBT92f9RKHg9QxxNzVrLjSMaxyvUtWQCAQfW0tFIRdqBGsHYsQrFtI\nF4yzv8ECgYEA7ntpyN9HD9Z9lYQzPCR73sFCLM+ID99aVij0wHuxK97bkSyyvkLd\n7MyTaym3lg1UEqWNWBCLvFULZx7F0Ah6qCzD4ymm3Bj/ADpWWPgljBI0AFml+HHs\nhcATmXUrj5QbLyhiP2gmJjajp1o/rgATx6ED66seSynD6JOH8wUhhZUCgYEAy7OA\n06PF8GfseNsTqlDjNF0K7lOqd21S0prdwrsJLiVzUlfMM25MLE0XLDUutCnRheeh\nIlcuDoBsVTxz6rkvFGD74N+pgXlN4CicsBq5ofK060PbqCQhSII3fmHobrZ9Cr75\nHmBjAxHx998SKaAAGbBbcYGUAp521i1pH5CEPYkCgYEAkUd1Zf0+2RMdZhwm6hh/\nrW+l1I6IoMK70YkZsLipccRNld7Y9LbfYwYtODcts6di9AkOVfueZJiaXbONZfIE\nZrb+jkAteh9wGL9xIrnohbABJcV3Kiaco84jInUSmGDtPokncOENfHIEuEpuSJ2b\nbx1TuhmAVuGWivR0+ULC7RECgYEAgS0cDRpWc9Xzh9Cl7+PLsXEvdWNpPsL9OsEq\n0Ep7z9+/+f/jZtoTRCS/BTHUpDvAuwHglT5j3p5iFMt5VuiIiovWLwynGYwrbnNS\nqfrIrYKUaH1n1oDS+oBZYLQGCe9/7EifAjxtjYzbvSyg//SPG7tSwfBCREbpZXj2\nqSWkNsECgYA/mCDzCTlrrWPuiepo6kTmN+4TnFA+hJI6NccDVQ+jvbqEdoJ4SW4L\nzqfZSZRFJMNpSgIqkQNRPJqMP0jQ5KRtJrjMWBnYxktwKz9fDg2R2MxdFgMF2LH2\nHEMMhFHlv8NDjVOXh1KwRoltNGVWYsSrD9wKU9GhRCEfmNCGrvBcEg==\n-----END RSA PRIVATE KEY-----" func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing.T) { - provider, err := NewOpenIDConnectProvider(nil) + provider, err := NewOpenIDConnectProvider(nil, nil) assert.NoError(t, err) assert.Nil(t, provider.Fosite) @@ -22,7 +22,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing func TestOpenIDConnectProvider_NewOpenIDConnectProvider_BadIssuerKey(t *testing.T) { _, err := NewOpenIDConnectProvider(&schema.OpenIDConnectConfiguration{ IssuerPrivateKey: "BAD KEY", - }) + }, nil) assert.Error(t, err, "abc") } @@ -60,7 +60,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GoodConfiguration(t *tes }, }, }, - }) + }, nil) assert.NotNil(t, provider) assert.NoError(t, err) @@ -80,10 +80,12 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow }, }, }, - }) + }, nil) assert.NoError(t, err) + assert.False(t, provider.Pairwise()) + disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com") assert.Equal(t, "https://example.com", disco.Issuer) @@ -95,8 +97,8 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow assert.Equal(t, "https://example.com/api/oidc/revocation", disco.RevocationEndpoint) assert.Equal(t, "", disco.RegistrationEndpoint) - require.Len(t, disco.CodeChallengeMethodsSupported, 1) - assert.Equal(t, "S256", disco.CodeChallengeMethodsSupported[0]) + assert.Len(t, disco.CodeChallengeMethodsSupported, 1) + assert.Contains(t, disco.CodeChallengeMethodsSupported, "S256") assert.Len(t, disco.ScopesSupported, 5) assert.Contains(t, disco.ScopesSupported, ScopeOpenID) @@ -166,7 +168,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig }, }, }, - }) + }, nil) assert.NoError(t, err) @@ -241,7 +243,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow }, }, }, - }) + }, nil) assert.NoError(t, err) diff --git a/internal/oidc/store.go b/internal/oidc/store.go index 0b9d9a86..4a66ee9b 100644 --- a/internal/oidc/store.go +++ b/internal/oidc/store.go @@ -2,38 +2,32 @@ package oidc import ( "context" + "crypto/sha256" + "database/sql" + "errors" + "fmt" "time" + "github.com/google/uuid" "github.com/ory/fosite" - "github.com/ory/fosite/storage" - "gopkg.in/square/go-jose.v2" "github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/logging" + "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/storage" ) -// NewOpenIDConnectStore returns a new OpenIDConnectStore using the provided schema.OpenIDConnectConfiguration. -func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore) { +// NewOpenIDConnectStore returns a OpenIDConnectStore when provided with a schema.OpenIDConnectConfiguration and storage.Provider. +func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider storage.Provider) (store *OpenIDConnectStore) { logger := logging.Logger() store = &OpenIDConnectStore{ - memory: &storage.MemoryStore{ - IDSessions: map[string]fosite.Requester{}, - Users: map[string]storage.MemoryUserRelation{}, - AuthorizeCodes: map[string]storage.StoreAuthorizeCode{}, - AccessTokens: map[string]fosite.Requester{}, - RefreshTokens: map[string]storage.StoreRefreshToken{}, - PKCES: map[string]fosite.Requester{}, - BlacklistedJTIs: map[string]time.Time{}, - AccessTokenRequestIDs: map[string]string{}, - RefreshTokenRequestIDs: map[string]string{}, - }, + provider: provider, + clients: map[string]*Client{}, } - store.clients = make(map[string]*InternalClient) - - for _, client := range configuration.Clients { + for _, client := range config.Clients { policy := authorization.PolicyToLevel(client.Policy) logger.Debugf("Registering client %s with policy %s (%v)", client.ID, client.Policy, policy) @@ -43,9 +37,37 @@ func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (st return store } +// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username. +func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { + if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil { + return nil, err + } else if opaqueID == nil { + if opaqueID, err = model.NewUserOpaqueIdentifier("openid", sectorID, username); err != nil { + return nil, err + } + + if err = s.provider.SaveUserOpaqueIdentifier(ctx, *opaqueID); err != nil { + return nil, err + } + } + + return opaqueID, nil +} + +// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it. +func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) { + var opaqueID *model.UserOpaqueIdentifier + + if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil { + return uuid.UUID{}, err + } + + return opaqueID.Identifier, nil +} + // GetClientPolicy retrieves the policy from the client with the matching provided id. func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) { - client, err := s.GetInternalClient(id) + client, err := s.GetFullClient(id) if err != nil { return authorization.TwoFactor } @@ -53,8 +75,8 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve return client.Policy } -// GetInternalClient returns a fosite.Client asserted as an InternalClient matching the provided id. -func (s OpenIDConnectStore) GetInternalClient(id string) (client *InternalClient, err error) { +// GetFullClient returns a fosite.Client asserted as an Client matching the provided id. +func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) { client, ok := s.clients[id] if !ok { return nil, fosite.ErrNotFound @@ -65,142 +87,248 @@ func (s OpenIDConnectStore) GetInternalClient(id string) (client *InternalClient // IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map. func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) { - _, err := s.GetInternalClient(id) + _, err := s.GetFullClient(id) return err == nil } -// CreateOpenIDConnectSession decorates fosite's storage.MemoryStore CreateOpenIDConnectSession method. -func (s *OpenIDConnectStore) CreateOpenIDConnectSession(ctx context.Context, authorizeCode string, requester fosite.Requester) error { - return s.memory.CreateOpenIDConnectSession(ctx, authorizeCode, requester) +// BeginTX starts a transaction. +// This implements a portion of fosite storage.Transactional interface. +func (s *OpenIDConnectStore) BeginTX(ctx context.Context) (c context.Context, err error) { + return s.provider.BeginTX(ctx) } -// GetOpenIDConnectSession decorates fosite's storage.MemoryStore GetOpenIDConnectSession method. -func (s *OpenIDConnectStore) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, requester fosite.Requester) (fosite.Requester, error) { - return s.memory.GetOpenIDConnectSession(ctx, authorizeCode, requester) +// Commit completes a transaction. +// This implements a portion of fosite storage.Transactional interface. +func (s *OpenIDConnectStore) Commit(ctx context.Context) (err error) { + return s.provider.Commit(ctx) } -// DeleteOpenIDConnectSession decorates fosite's storage.MemoryStore DeleteOpenIDConnectSession method. -func (s *OpenIDConnectStore) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error { - return s.memory.DeleteOpenIDConnectSession(ctx, authorizeCode) +// Rollback rolls a transaction back. +// This implements a portion of fosite storage.Transactional interface. +func (s *OpenIDConnectStore) Rollback(ctx context.Context) (err error) { + return s.provider.Rollback(ctx) } -// GetClient decorates fosite's storage.MemoryStore GetClient method. -func (s *OpenIDConnectStore) GetClient(_ context.Context, id string) (fosite.Client, error) { - return s.GetInternalClient(id) +// GetClient loads the client by its ID or returns an error if the client does not exist or another error occurred. +// This implements a portion of fosite.ClientManager. +func (s *OpenIDConnectStore) GetClient(_ context.Context, id string) (client fosite.Client, err error) { + return s.GetFullClient(id) } -// ClientAssertionJWTValid decorates fosite's storage.MemoryStore ClientAssertionJWTValid method. -func (s *OpenIDConnectStore) ClientAssertionJWTValid(ctx context.Context, jti string) error { - return s.memory.ClientAssertionJWTValid(ctx, jti) +// ClientAssertionJWTValid returns an error if the JTI is known or the DB check failed and nil if the JTI is not known. +// This implements a portion of fosite.ClientManager. +func (s *OpenIDConnectStore) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) { + signature := fmt.Sprintf("%x", sha256.Sum256([]byte(jti))) + + blacklistedJTI, err := s.provider.LoadOAuth2BlacklistedJTI(ctx, signature) + + switch { + case errors.Is(sql.ErrNoRows, err): + return nil + case err != nil: + return err + case blacklistedJTI.ExpiresAt.After(time.Now()): + return fosite.ErrJTIKnown + default: + return nil + } } -// SetClientAssertionJWT decorates fosite's storage.MemoryStore SetClientAssertionJWT method. -func (s *OpenIDConnectStore) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) error { - return s.memory.SetClientAssertionJWT(ctx, jti, exp) +// SetClientAssertionJWT marks a JTI as known for the given expiry time. Before inserting the new JTI, it will clean +// up any existing JTIs that have expired as those tokens can not be replayed due to the expiry. +// This implements a portion of fosite.ClientManager. +func (s *OpenIDConnectStore) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) { + blacklistedJTI := model.NewOAuth2BlacklistedJTI(jti, exp) + + return s.provider.SaveOAuth2BlacklistedJTI(ctx, blacklistedJTI) } -// CreateAuthorizeCodeSession decorates fosite's storage.MemoryStore CreateAuthorizeCodeSession method. -func (s *OpenIDConnectStore) CreateAuthorizeCodeSession(ctx context.Context, code string, req fosite.Requester) error { - return s.memory.CreateAuthorizeCodeSession(ctx, code, req) +// CreateAuthorizeCodeSession stores the authorization request for a given authorization code. +// This implements a portion of oauth2.AuthorizeCodeStorage. +func (s *OpenIDConnectStore) CreateAuthorizeCodeSession(ctx context.Context, code string, request fosite.Requester) (err error) { + return s.saveSession(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, request) } -// GetAuthorizeCodeSession decorates fosite's storage.MemoryStore GetAuthorizeCodeSession method. -func (s *OpenIDConnectStore) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (fosite.Requester, error) { - return s.memory.GetAuthorizeCodeSession(ctx, code, session) +// InvalidateAuthorizeCodeSession is called when an authorize code is being used. The state of the authorization +// code should be set to invalid and consecutive requests to GetAuthorizeCodeSession should return the +// ErrInvalidatedAuthorizeCode error. +// This implements a portion of oauth2.AuthorizeCodeStorage. +func (s *OpenIDConnectStore) InvalidateAuthorizeCodeSession(ctx context.Context, code string) (err error) { + return s.provider.DeactivateOAuth2Session(ctx, storage.OAuth2SessionTypeAuthorizeCode, code) } -// InvalidateAuthorizeCodeSession decorates fosite's storage.MemoryStore InvalidateAuthorizeCodeSession method. -func (s *OpenIDConnectStore) InvalidateAuthorizeCodeSession(ctx context.Context, code string) error { - return s.memory.InvalidateAuthorizeCodeSession(ctx, code) +// GetAuthorizeCodeSession hydrates the session based on the given code and returns the authorization request. +// If the authorization code has been invalidated with `InvalidateAuthorizeCodeSession`, this +// method should return the ErrInvalidatedAuthorizeCode error. +// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedAuthorizeCode error! +// This implements a portion of oauth2.AuthorizeCodeStorage. +func (s *OpenIDConnectStore) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) { + // TODO: Implement the fosite.ErrInvalidatedAuthorizeCode error above. This requires splitting the invalidated sessions and deleted sessions. + return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session) } -// CreatePKCERequestSession decorates fosite's storage.MemoryStore CreatePKCERequestSession method. -func (s *OpenIDConnectStore) CreatePKCERequestSession(ctx context.Context, code string, req fosite.Requester) error { - return s.memory.CreatePKCERequestSession(ctx, code, req) +// CreateAccessTokenSession stores the authorization request for a given access token. +// This implements a portion of oauth2.AccessTokenStorage. +func (s *OpenIDConnectStore) CreateAccessTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) { + return s.saveSession(ctx, storage.OAuth2SessionTypeAccessToken, signature, request) } -// GetPKCERequestSession decorates fosite's storage.MemoryStore GetPKCERequestSession method. -func (s *OpenIDConnectStore) GetPKCERequestSession(ctx context.Context, code string, session fosite.Session) (fosite.Requester, error) { - return s.memory.GetPKCERequestSession(ctx, code, session) +// DeleteAccessTokenSession marks an access token session as deleted. +// This implements a portion of oauth2.AccessTokenStorage. +func (s *OpenIDConnectStore) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) { + return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature) } -// DeletePKCERequestSession decorates fosite's storage.MemoryStore DeletePKCERequestSession method. -func (s *OpenIDConnectStore) DeletePKCERequestSession(ctx context.Context, code string) error { - return s.memory.DeletePKCERequestSession(ctx, code) +// RevokeAccessToken revokes an access token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1 +// If the token passed to the request is an access token, the server MAY revoke the respective refresh token as well. +// This implements a portion of oauth2.TokenRevocationStorage. +func (s *OpenIDConnectStore) RevokeAccessToken(ctx context.Context, requestID string) (err error) { + return s.revokeSessionByRequestID(ctx, storage.OAuth2SessionTypeAccessToken, requestID) } -// CreateAccessTokenSession decorates fosite's storage.MemoryStore CreateAccessTokenSession method. -func (s *OpenIDConnectStore) CreateAccessTokenSession(ctx context.Context, signature string, req fosite.Requester) error { - return s.memory.CreateAccessTokenSession(ctx, signature, req) +// GetAccessTokenSession gets the authorization request for a given access token. +// This implements a portion of oauth2.AccessTokenStorage. +func (s *OpenIDConnectStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session) } -// GetAccessTokenSession decorates fosite's storage.MemoryStore GetAccessTokenSession method. -func (s *OpenIDConnectStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { - return s.memory.GetAccessTokenSession(ctx, signature, session) +// CreateRefreshTokenSession stores the authorization request for a given refresh token. +// This implements a portion of oauth2.RefreshTokenStorage. +func (s *OpenIDConnectStore) CreateRefreshTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) { + return s.saveSession(ctx, storage.OAuth2SessionTypeRefreshToken, signature, request) } -// DeleteAccessTokenSession decorates fosite's storage.MemoryStore DeleteAccessTokenSession method. -func (s *OpenIDConnectStore) DeleteAccessTokenSession(ctx context.Context, signature string) error { - return s.memory.DeleteAccessTokenSession(ctx, signature) +// DeleteRefreshTokenSession marks the authorization request for a given refresh token as deleted. +// This implements a portion of oauth2.RefreshTokenStorage. +func (s *OpenIDConnectStore) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) { + return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature) } -// CreateRefreshTokenSession decorates fosite's storage.MemoryStore CreateRefreshTokenSession method. -func (s *OpenIDConnectStore) CreateRefreshTokenSession(ctx context.Context, signature string, req fosite.Requester) error { - return s.memory.CreateRefreshTokenSession(ctx, signature, req) +// RevokeRefreshToken revokes a refresh token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1 +// If the particular token is a refresh token and the authorization server supports the revocation of access tokens, +// then the authorization server SHOULD also invalidate all access tokens based on the same authorization grant (see Implementation Note). +// This implements a portion of oauth2.TokenRevocationStorage. +func (s *OpenIDConnectStore) RevokeRefreshToken(ctx context.Context, requestID string) (err error) { + return s.provider.DeactivateOAuth2SessionByRequestID(ctx, storage.OAuth2SessionTypeRefreshToken, requestID) } -// GetRefreshTokenSession decorates fosite's storage.MemoryStore GetRefreshTokenSession method. -func (s *OpenIDConnectStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { - return s.memory.GetRefreshTokenSession(ctx, signature, session) +// RevokeRefreshTokenMaybeGracePeriod revokes an access token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1 +// If the token passed to the request is an access token, the server MAY revoke the respective refresh token as well. +// This implements a portion of oauth2.TokenRevocationStorage. +func (s *OpenIDConnectStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) (err error) { + return s.RevokeRefreshToken(ctx, requestID) } -// DeleteRefreshTokenSession decorates fosite's storage.MemoryStore DeleteRefreshTokenSession method. -func (s *OpenIDConnectStore) DeleteRefreshTokenSession(ctx context.Context, signature string) error { - return s.memory.DeleteRefreshTokenSession(ctx, signature) +// GetRefreshTokenSession gets the authorization request for a given refresh token. +// This implements a portion of oauth2.RefreshTokenStorage. +func (s *OpenIDConnectStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) { + return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session) } -// Authenticate decorates fosite's storage.MemoryStore Authenticate method. -func (s *OpenIDConnectStore) Authenticate(ctx context.Context, name string, secret string) error { - return s.memory.Authenticate(ctx, name, secret) +// CreatePKCERequestSession stores the authorization request for a given PKCE request. +// This implements a portion of pkce.PKCERequestStorage. +func (s *OpenIDConnectStore) CreatePKCERequestSession(ctx context.Context, signature string, request fosite.Requester) (err error) { + return s.saveSession(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, request) } -// RevokeRefreshToken decorates fosite's storage.MemoryStore RevokeRefreshToken method. -func (s *OpenIDConnectStore) RevokeRefreshToken(ctx context.Context, requestID string) error { - return s.memory.RevokeRefreshToken(ctx, requestID) +// DeletePKCERequestSession marks the authorization request for a given PKCE request as deleted. +// This implements a portion of pkce.PKCERequestStorage. +func (s *OpenIDConnectStore) DeletePKCERequestSession(ctx context.Context, signature string) (err error) { + return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature) } -// RevokeRefreshTokenMaybeGracePeriod decorates fosite's storage.MemoryStore RevokeRefreshTokenMaybeGracePeriod method. -func (s OpenIDConnectStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) error { - return s.memory.RevokeRefreshTokenMaybeGracePeriod(ctx, requestID, signature) +// GetPKCERequestSession gets the authorization request for a given PKCE request. +// This implements a portion of pkce.PKCERequestStorage. +func (s *OpenIDConnectStore) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (requester fosite.Requester, err error) { + return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session) } -// RevokeAccessToken decorates fosite's storage.MemoryStore RevokeAccessToken method. -func (s *OpenIDConnectStore) RevokeAccessToken(ctx context.Context, requestID string) error { - return s.memory.RevokeAccessToken(ctx, requestID) +// CreateOpenIDConnectSession creates an open id connect session for a given authorize code. +// This is relevant for explicit open id connect flow. +// This implements a portion of openid.OpenIDConnectRequestStorage. +func (s *OpenIDConnectStore) CreateOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (err error) { + return s.saveSession(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request) } -// GetPublicKey decorates fosite's storage.MemoryStore GetPublicKey method. -func (s *OpenIDConnectStore) GetPublicKey(ctx context.Context, issuer string, subject string, keyID string) (*jose.JSONWebKey, error) { - return s.memory.GetPublicKey(ctx, issuer, subject, keyID) +// DeleteOpenIDConnectSession just implements the method required by fosite even though it's unused. +// This implements a portion of openid.OpenIDConnectRequestStorage. +func (s *OpenIDConnectStore) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) (err error) { + return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, authorizeCode) } -// GetPublicKeys decorates fosite's storage.MemoryStore GetPublicKeys method. -func (s *OpenIDConnectStore) GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) { - return s.memory.GetPublicKeys(ctx, issuer, subject) +// GetOpenIDConnectSession returns error: +// - nil if a session was found, +// - ErrNoSessionFound if no session was found +// - or an arbitrary error if an error occurred. +// This implements a portion of openid.OpenIDConnectRequestStorage. +func (s *OpenIDConnectStore) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (r fosite.Requester, err error) { + return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession()) } -// GetPublicKeyScopes decorates fosite's storage.MemoryStore GetPublicKeyScopes method. -func (s *OpenIDConnectStore) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyID string) ([]string, error) { - return s.memory.GetPublicKeyScopes(ctx, issuer, subject, keyID) +// IsJWTUsed implements an interface required for RFC7523. +func (s *OpenIDConnectStore) IsJWTUsed(ctx context.Context, jti string) (used bool, err error) { + if err = s.ClientAssertionJWTValid(ctx, jti); err != nil { + return true, err + } + + return false, nil } -// IsJWTUsed decorates fosite's storage.MemoryStore IsJWTUsed method. -func (s *OpenIDConnectStore) IsJWTUsed(ctx context.Context, jti string) (bool, error) { - return s.memory.IsJWTUsed(ctx, jti) +// MarkJWTUsedForTime implements an interface required for rfc7523.RFC7523KeyStorage. +func (s *OpenIDConnectStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) { + return s.SetClientAssertionJWT(ctx, jti, exp) } -// MarkJWTUsedForTime decorates fosite's storage.MemoryStore MarkJWTUsedForTime method. -func (s *OpenIDConnectStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error { - return s.memory.MarkJWTUsedForTime(ctx, jti, exp) +func (s *OpenIDConnectStore) loadSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) { + var ( + sessionModel *model.OAuth2Session + ) + + sessionModel, err = s.provider.LoadOAuth2Session(ctx, sessionType, signature) + if err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, fosite.ErrNotFound + default: + return nil, err + } + } + + if r, err = sessionModel.ToRequest(ctx, session, s); err != nil { + return nil, err + } + + if !sessionModel.Active && sessionType == storage.OAuth2SessionTypeAuthorizeCode { + return r, fosite.ErrInvalidatedAuthorizeCode + } + + return r, nil +} + +func (s *OpenIDConnectStore) saveSession(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, r fosite.Requester) (err error) { + var session *model.OAuth2Session + + if session, err = model.NewOAuth2SessionFromRequest(signature, r); err != nil { + return err + } + + return s.provider.SaveOAuth2Session(ctx, sessionType, *session) +} + +func (s *OpenIDConnectStore) revokeSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string) (err error) { + return s.provider.RevokeOAuth2Session(ctx, sessionType, signature) +} + +func (s *OpenIDConnectStore) revokeSessionByRequestID(ctx context.Context, sessionType storage.OAuth2SessionType, requestID string) (err error) { + if err = s.provider.RevokeOAuth2SessionByRequestID(ctx, sessionType, requestID); err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return fosite.ErrNotFound + default: + return err + } + } + + return nil } diff --git a/internal/oidc/store_test.go b/internal/oidc/store_test.go index d69df0a8..f73e2d8c 100644 --- a/internal/oidc/store_test.go +++ b/internal/oidc/store_test.go @@ -1,15 +1,6 @@ package oidc -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/authelia/authelia/v4/internal/authorization" - "github.com/authelia/authelia/v4/internal/configuration/schema" -) +/* func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) { s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ @@ -80,7 +71,7 @@ func TestOpenIDConnectStore_GetInternalClient_ValidClient(t *testing.T) { Clients: []schema.OpenIDConnectClientConfiguration{c1}, }) - client, err := s.GetInternalClient(c1.ID) + client, err := s.GetFullClient(c1.ID) require.NoError(t, err) require.NotNil(t, client) assert.Equal(t, client.ID, c1.ID) @@ -107,7 +98,7 @@ func TestOpenIDConnectStore_GetInternalClient_InvalidClient(t *testing.T) { Clients: []schema.OpenIDConnectClientConfiguration{c1}, }) - client, err := s.GetInternalClient("another-client") + client, err := s.GetFullClient("another-client") assert.Nil(t, client) assert.EqualError(t, err, "not_found") } @@ -131,4 +122,5 @@ func TestOpenIDConnectStore_IsValidClientID(t *testing.T) { assert.True(t, validClient) assert.False(t, invalidClient) -} +}. +*/ diff --git a/internal/oidc/types.go b/internal/oidc/types.go index 9129a38a..e8859d75 100644 --- a/internal/oidc/types.go +++ b/internal/oidc/types.go @@ -6,17 +6,19 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" - "github.com/ory/fosite/storage" "github.com/ory/fosite/token/jwt" "github.com/ory/herodot" "gopkg.in/square/go-jose.v2" "github.com/authelia/authelia/v4/internal/authorization" + "github.com/authelia/authelia/v4/internal/model" + "github.com/authelia/authelia/v4/internal/storage" + "github.com/authelia/authelia/v4/internal/utils" ) -// NewSession creates a new OpenIDSession struct. -func NewSession() (session *OpenIDSession) { - return &OpenIDSession{ +// NewSession creates a new empty OpenIDSession struct. +func NewSession() (session *model.OpenIDSession) { + return &model.OpenIDSession{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Extra: map[string]interface{}{}, @@ -30,19 +32,19 @@ func NewSession() (session *OpenIDSession) { } // NewSessionWithAuthorizeRequest uses details from an AuthorizeRequester to generate an OpenIDSession. -func NewSessionWithAuthorizeRequest(issuer, kid, subject, username string, amr []string, extra map[string]interface{}, - authTime, requestedAt time.Time, requester fosite.AuthorizeRequester) (session *OpenIDSession) { +func NewSessionWithAuthorizeRequest(issuer, kid, username string, amr []string, extra map[string]interface{}, + authTime time.Time, consent *model.OAuth2ConsentSession, requester fosite.AuthorizeRequester) (session *model.OpenIDSession) { if extra == nil { extra = make(map[string]interface{}) } - return &OpenIDSession{ + session = &model.OpenIDSession{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ - Subject: subject, + Subject: consent.Subject.String(), Issuer: issuer, AuthTime: authTime, - RequestedAt: requestedAt, + RequestedAt: consent.RequestedAt, IssuedAt: time.Now(), Nonce: requester.GetRequestForm().Get("nonce"), Audience: requester.GetGrantedAudience(), @@ -55,12 +57,20 @@ func NewSessionWithAuthorizeRequest(issuer, kid, subject, username string, amr [ "kid": kid, }, }, - Subject: subject, + Subject: consent.Subject.String(), Username: username, }, - Extra: map[string]interface{}{}, - ClientID: requester.GetClient().GetID(), + Extra: map[string]interface{}{}, + ClientID: requester.GetClient().GetID(), + ChallengeID: consent.ChallengeID, } + + // Ensure required audience value of the client_id exists. + if !utils.IsStringInSlice(requester.GetClient().GetID(), session.Claims.Audience) { + session.Claims.Audience = append(session.Claims.Audience, requester.GetClient().GetID()) + } + + return session } // OpenIDConnectProvider for OpenID Connect. @@ -74,33 +84,34 @@ type OpenIDConnectProvider struct { discovery OpenIDConnectWellKnownConfiguration } -// OpenIDConnectStore is Authelia's internal representation of the fosite.Storage interface. -// -// Currently it is mostly just implementing a decorator pattern other then GetInternalClient. -// The long term plan is to have these methods interact with the Authelia storage and -// session providers where applicable. +// OpenIDConnectStore is Authelia's internal representation of the fosite.Storage interface. It maps the following +// interfaces to the storage.Provider interface: +// fosite.Storage, fosite.ClientManager, storage.Transactional, oauth2.AuthorizeCodeStorage, oauth2.AccessTokenStorage, +// oauth2.RefreshTokenStorage, oauth2.TokenRevocationStorage, pkce.PKCERequestStorage, +// openid.OpenIDConnectRequestStorage, and partially implements rfc7523.RFC7523KeyStorage. type OpenIDConnectStore struct { - clients map[string]*InternalClient - memory *storage.MemoryStore + provider storage.Provider + clients map[string]*Client } -// InternalClient represents the client internally. -type InternalClient struct { - ID string `json:"id"` - Description string `json:"-"` - Secret []byte `json:"client_secret,omitempty"` - Public bool `json:"public"` +// Client represents the client internally. +type Client struct { + ID string + SectorIdentifier string + Description string + Secret []byte + Public bool - Policy authorization.Level `json:"-"` + Policy authorization.Level - Audience []string `json:"audience"` - Scopes []string `json:"scopes"` - RedirectURIs []string `json:"redirect_uris"` - GrantTypes []string `json:"grant_types"` - ResponseTypes []string `json:"response_types"` - ResponseModes []fosite.ResponseModeType `json:"response_modes"` + Audience []string + Scopes []string + RedirectURIs []string + GrantTypes []string + ResponseTypes []string + ResponseModes []fosite.ResponseModeType - UserinfoSigningAlgorithm string `json:"userinfo_signed_response_alg,omitempty"` + UserinfoSigningAlgorithm string } // KeyManager keeps track of all of the active/inactive rsa keys and provides them to services requiring them. @@ -112,8 +123,8 @@ type KeyManager struct { strategy *RS256JWTStrategy } -// AutheliaHasher implements the fosite.Hasher interface without an actual hashing algo. -type AutheliaHasher struct{} +// PlainTextHasher implements the fosite.Hasher interface without an actual hashing algo. +type PlainTextHasher struct{} // ConsentGetResponseBody schema of the response body of the consent GET endpoint. type ConsentGetResponseBody struct { @@ -123,12 +134,15 @@ type ConsentGetResponseBody struct { Audience []string `json:"audience"` } -// OpenIDSession holds OIDC Session information. -type OpenIDSession struct { - *openid.DefaultSession `json:"idToken"` +// ConsentPostRequestBody schema of the request body of the consent POST endpoint. +type ConsentPostRequestBody struct { + ClientID string `json:"client_id"` + AcceptOrReject string `json:"accept_or_reject"` +} - Extra map[string]interface{} `json:"extra"` - ClientID string +// ConsentPostResponseBody schema of the response body of the consent POST endpoint. +type ConsentPostResponseBody struct { + RedirectURI string `json:"redirect_uri"` } /* diff --git a/internal/oidc/types_test.go b/internal/oidc/types_test.go index f8d31dac..7a96558a 100644 --- a/internal/oidc/types_test.go +++ b/internal/oidc/types_test.go @@ -9,6 +9,8 @@ import ( "github.com/ory/fosite" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/authelia/authelia/v4/internal/model" ) func TestNewSession(t *testing.T) { @@ -38,7 +40,7 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) { Request: fosite.Request{ ID: requestID.String(), Form: formValues, - Client: &InternalClient{ID: "example"}, + Client: &Client{ID: "example"}, }, } @@ -51,7 +53,13 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) { issuer := "https://example.com" amr := []string{AMRPasswordBasedAuthentication} - session := NewSessionWithAuthorizeRequest(issuer, "primary", subject.String(), "john", amr, extra, authAt, requested, request) + consent := &model.OAuth2ConsentSession{ + ChallengeID: uuid.New(), + RequestedAt: requested, + Subject: subject, + } + + session := NewSessionWithAuthorizeRequest(issuer, "primary", "john", amr, extra, authAt, consent, request) require.NotNil(t, session) require.NotNil(t, session.Extra) @@ -78,7 +86,12 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) { require.Contains(t, session.Claims.Extra, "preferred_username") - session = NewSessionWithAuthorizeRequest(issuer, "primary", subject.String(), "john", nil, nil, authAt, requested, request) + consent = &model.OAuth2ConsentSession{ + ChallengeID: uuid.New(), + RequestedAt: requested, + } + + session = NewSessionWithAuthorizeRequest(issuer, "primary", "john", nil, nil, authAt, consent, request) require.NotNil(t, session) require.NotNil(t, session.Claims) diff --git a/internal/session/types.go b/internal/session/types.go index 235df041..b73937d7 100644 --- a/internal/session/types.go +++ b/internal/session/types.go @@ -7,11 +7,11 @@ import ( "github.com/fasthttp/session/v2" "github.com/fasthttp/session/v2/providers/redis" "github.com/go-webauthn/webauthn/webauthn" + "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/logging" - "github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/oidc" ) @@ -43,8 +43,8 @@ type UserSession struct { // Webauthn holds the session registration data for this session. Webauthn *webauthn.SessionData - // Represent an OIDC workflow session initiated by the client if not null. - OIDCWorkflowSession *model.OIDCWorkflowSession + // ConsentChallengeID is the OpenID Connect Consent Session challenge ID. + ConsentChallengeID *uuid.UUID // This boolean is set to true after identity verification and checked // while doing the query actually updating the password. diff --git a/internal/storage/const.go b/internal/storage/const.go index affa0bca..cf188b10 100644 --- a/internal/storage/const.go +++ b/internal/storage/const.go @@ -5,18 +5,40 @@ import ( ) const ( - tableUserPreferences = "user_preferences" + tableAuthenticationLogs = "authentication_logs" + tableDuoDevices = "duo_devices" tableIdentityVerification = "identity_verification" tableTOTPConfigurations = "totp_configurations" + tableUserOpaqueIdentifier = "user_opaque_identifier" + tableUserPreferences = "user_preferences" tableWebauthnDevices = "webauthn_devices" - tableDuoDevices = "duo_devices" - tableAuthenticationLogs = "authentication_logs" - tableMigrations = "migrations" - tableEncryption = "encryption" + + tableOAuth2ConsentSession = "oauth2_consent_session" + tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session" + tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential. + tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential. + tableOAuth2PKCERequestSession = "oauth2_pkce_request_session" + tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session" + tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti" + + tableMigrations = "migrations" + tableEncryption = "encryption" tablePrefixBackup = "_bkp_" ) +// OAuth2SessionType represents the potential OAuth 2.0 session types. +type OAuth2SessionType string + +// Representation of specific OAuth 2.0 session types. +const ( + OAuth2SessionTypeAuthorizeCode OAuth2SessionType = "authorization code" + OAuth2SessionTypeAccessToken OAuth2SessionType = "access token" + OAuth2SessionTypeRefreshToken OAuth2SessionType = "refresh token" + OAuth2SessionTypePKCEChallenge OAuth2SessionType = "pkce challenge" + OAuth2SessionTypeOpenIDConnect OAuth2SessionType = "openid connect" +) + const ( encryptionNameCheck = "check" ) @@ -56,7 +78,7 @@ const ( const ( // This is the latest schema version for the purpose of tests. - testLatestVersion = 3 + testLatestVersion = 4 ) const ( @@ -64,6 +86,12 @@ const ( SchemaLatest = 2147483647 ) +type ctxKey int + +const ( + ctxKeyTransaction ctxKey = iota +) + var ( reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`) ) diff --git a/internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql b/internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql new file mode 100644 index 00000000..4fc3adc7 --- /dev/null +++ b/internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql @@ -0,0 +1,188 @@ +CREATE TABLE IF NOT EXISTS user_opaque_identifier ( + id INTEGER AUTO_INCREMENT, + service VARCHAR(20) NOT NULL, + sector_id VARCHAR(255) NOT NULL, + username VARCHAR(100) NOT NULL, + identifier CHAR(36) NOT NULL, + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username); +CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier); + +CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti ( + id INTEGER AUTO_INCREMENT, + signature VARCHAR(64) NOT NULL, + expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature); + +CREATE TABLE IF NOT EXISTS oauth2_consent_session ( + id INTEGER AUTO_INCREMENT, + challenge_id CHAR(36) NOT NULL, + client_id VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + authorized BOOLEAN NOT NULL DEFAULT FALSE, + granted BOOLEAN NOT NULL DEFAULT FALSE, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + responded_at TIMESTAMP NULL DEFAULT NULL, + expires_at TIMESTAMP NULL DEFAULT NULL, + form_data TEXT NOT NULL, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL, + granted_audience TEXT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_consent_subject_fkey + FOREIGN KEY (subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id); + +CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session ( + id INTEGER AUTO_INCREMENT, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL, + granted_audience TEXT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey + FOREIGN KEY (challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_authorization_code_session_subject_fkey + FOREIGN KEY (subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id); +CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id); +CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_access_token_session ( + id INTEGER AUTO_INCREMENT, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL, + granted_audience TEXT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_access_token_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_access_token_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id); +CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id); +CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session ( + id INTEGER AUTO_INCREMENT, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL, + granted_audience TEXT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_refresh_token_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id); +CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id); +CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session ( + id INTEGER AUTO_INCREMENT, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL, + granted_audience TEXT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_pkce_request_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id); +CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id); +CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session ( + id INTEGER AUTO_INCREMENT, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL, + granted_audience TEXT NULL, + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_openid_connect_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id); +CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id); +CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject); \ No newline at end of file diff --git a/internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql b/internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql new file mode 100644 index 00000000..c5685dd2 --- /dev/null +++ b/internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql @@ -0,0 +1,188 @@ +CREATE TABLE IF NOT EXISTS user_opaque_identifier ( + id SERIAL, + service VARCHAR(20) NOT NULL, + sector_id VARCHAR(255) NOT NULL, + username VARCHAR(100) NOT NULL, + identifier CHAR(36) NOT NULL, + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username); +CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier); + +CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti ( + id SERIAL, + signature VARCHAR(64) NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature); + +CREATE TABLE IF NOT EXISTS oauth2_consent_session ( + id SERIAL, + challenge_id CHAR(36) NOT NULL, + client_id VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + authorized BOOLEAN NOT NULL DEFAULT FALSE, + granted BOOLEAN NOT NULL DEFAULT FALSE, + requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + responded_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, + expires_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL, + form_data TEXT NOT NULL, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + PRIMARY KEY (id), + CONSTRAINT oauth2_consent_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id); + +CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session ( + id SERIAL, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36), + requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BYTEA NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_authorization_code_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id); +CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id); +CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_access_token_session ( + id SERIAL, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BYTEA NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_access_token_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_access_token_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id); +CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id); +CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session ( + id SERIAL, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BYTEA NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_refresh_token_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id); +CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id); +CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session ( + id SERIAL, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BYTEA NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_pkce_request_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id); +CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id); +CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session ( + id SERIAL, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BYTEA NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_openid_connect_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id); +CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id); +CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject); diff --git a/internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql b/internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql new file mode 100644 index 00000000..372fbeb5 --- /dev/null +++ b/internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql @@ -0,0 +1,188 @@ +CREATE TABLE IF NOT EXISTS user_opaque_identifier ( + id INTEGER, + service VARCHAR(20) NOT NULL, + sector_id VARCHAR(255) NOT NULL, + username VARCHAR(100) NOT NULL, + identifier CHAR(36) NOT NULL, + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username); +CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier); + +CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti ( + id INTEGER, + signature VARCHAR(64) NOT NULL, + expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (id) +); + +CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature); + +CREATE TABLE IF NOT EXISTS oauth2_consent_session ( + id INTEGER, + challenge_id CHAR(36) NOT NULL, + client_id VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + authorized BOOLEAN NOT NULL DEFAULT FALSE, + granted BOOLEAN NOT NULL DEFAULT FALSE, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + responded_at TIMESTAMP NULL DEFAULT NULL, + expires_at TIMESTAMP NULL DEFAULT NULL, + form_data TEXT NOT NULL, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + PRIMARY KEY (id), + CONSTRAINT oauth2_consent_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id); + +CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session ( + id INTEGER, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_authorization_code_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id); +CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id); +CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_access_token_session ( + id INTEGER, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_access_token_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_access_token_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id); +CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id); +CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session ( + id INTEGER, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_refresh_token_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id); +CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id); +CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session ( + id INTEGER, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_pkce_request_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id); +CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id); +CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject); + +CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session ( + id INTEGER, + challenge_id CHAR(36) NOT NULL, + request_id VARCHAR(40) NOT NULL, + client_id VARCHAR(255) NOT NULL, + signature VARCHAR(255) NOT NULL, + subject CHAR(36) NOT NULL, + requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + requested_scopes TEXT NOT NULL, + granted_scopes TEXT NOT NULL, + requested_audience TEXT NULL DEFAULT '', + granted_audience TEXT NULL DEFAULT '', + active BOOLEAN NOT NULL DEFAULT FALSE, + revoked BOOLEAN NOT NULL DEFAULT FALSE, + form_data TEXT NOT NULL, + session_data BLOB NOT NULL, + PRIMARY KEY (id), + CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey + FOREIGN KEY(challenge_id) + REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT oauth2_openid_connect_session_subject_fkey + FOREIGN KEY(subject) + REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT +); + +CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id); +CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id); +CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject); diff --git a/internal/storage/migrations/V0004.OpenIDConnect.all.down.sql b/internal/storage/migrations/V0004.OpenIDConnect.all.down.sql new file mode 100644 index 00000000..481f502f --- /dev/null +++ b/internal/storage/migrations/V0004.OpenIDConnect.all.down.sql @@ -0,0 +1,8 @@ +DROP TABLE IF EXISTS oauth2_blacklisted_jti; +DROP TABLE IF EXISTS oauth2_authorization_code_session; +DROP TABLE IF EXISTS oauth2_access_token_session; +DROP TABLE IF EXISTS oauth2_refresh_token_session; +DROP TABLE IF EXISTS oauth2_pkce_request_session; +DROP TABLE IF EXISTS oauth2_openid_connect_session; +DROP TABLE IF EXISTS oauth2_consent_session; +DROP TABLE IF EXISTS user_opaque_identifier; \ No newline at end of file diff --git a/internal/storage/provider.go b/internal/storage/provider.go index 7805e15c..742b3600 100644 --- a/internal/storage/provider.go +++ b/internal/storage/provider.go @@ -4,6 +4,9 @@ import ( "context" "time" + "github.com/google/uuid" + "github.com/ory/fosite/storage" + "github.com/authelia/authelia/v4/internal/model" ) @@ -13,10 +16,16 @@ type Provider interface { RegulatorProvider + storage.Transactional + SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) + SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error) + LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (subject *model.UserOpaqueIdentifier, err error) + LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error) + SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) @@ -36,6 +45,22 @@ type Provider interface { DeletePreferredDuoDevice(ctx context.Context, username string) (err error) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) + SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) + SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, rejection bool) (err error) + SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) + LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) + LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error) + + SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) + RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) + RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) + DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) + DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) + LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) + + SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) + LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) + SchemaTables(ctx context.Context) (tables []string, err error) SchemaVersion(ctx context.Context) (version int, err error) SchemaLatestVersion() (version int, err error) diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go index 126c4ddb..94b9c839 100644 --- a/internal/storage/sql_provider.go +++ b/internal/storage/sql_provider.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" @@ -63,6 +64,54 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebauthnDevices, tableDuoDevices, tableUserPreferences), + sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier), + sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier), + sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier), + + sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), + sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession), + sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession), + + sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession), + sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession), + sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession), + sqlRevokeOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AccessTokenSession), + sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession), + sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession), + + sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession), + sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession), + sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession), + sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), + sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession), + sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession), + + sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession), + sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession), + sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession), + sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), + sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession), + sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession), + + sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession), + sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession), + sqlRevokeOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2OpenIDConnectSession), + sqlRevokeOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession), + sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession), + sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession), + + sqlInsertOAuth2ConsentSession: fmt.Sprintf(queryFmtInsertOAuth2ConsentSession, tableOAuth2ConsentSession), + sqlUpdateOAuth2ConsentSessionResponse: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionResponse, tableOAuth2ConsentSession), + sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession), + sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession), + sqlSelectOAuth2ConsentSessionsPreConfigured: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionsPreConfigured, tableOAuth2ConsentSession), + + sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), + sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI), + sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations), @@ -128,6 +177,11 @@ type SQLProvider struct { sqlSelectPreferred2FAMethod string sqlSelectUserInfo string + // Table: user_opaque_identifier. + sqlInsertUserOpaqueIdentifier string + sqlSelectUserOpaqueIdentifier string + sqlSelectUserOpaqueIdentifierBySignature string + // Table: migrations. sqlInsertMigration string sqlSelectMigrations string @@ -137,6 +191,56 @@ type SQLProvider struct { sqlUpsertEncryptionValue string sqlSelectEncryptionValue string + // Table: oauth2_authorization_code_session. + sqlInsertOAuth2AuthorizeCodeSession string + sqlSelectOAuth2AuthorizeCodeSession string + sqlRevokeOAuth2AuthorizeCodeSession string + sqlRevokeOAuth2AuthorizeCodeSessionByRequestID string + sqlDeactivateOAuth2AuthorizeCodeSession string + sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID string + + // Table: oauth2_access_token_session. + sqlInsertOAuth2AccessTokenSession string + sqlSelectOAuth2AccessTokenSession string + sqlRevokeOAuth2AccessTokenSession string + sqlRevokeOAuth2AccessTokenSessionByRequestID string + sqlDeactivateOAuth2AccessTokenSession string + sqlDeactivateOAuth2AccessTokenSessionByRequestID string + + // Table: oauth2_refresh_token_session. + sqlInsertOAuth2RefreshTokenSession string + sqlSelectOAuth2RefreshTokenSession string + sqlRevokeOAuth2RefreshTokenSession string + sqlRevokeOAuth2RefreshTokenSessionByRequestID string + sqlDeactivateOAuth2RefreshTokenSession string + sqlDeactivateOAuth2RefreshTokenSessionByRequestID string + + // Table: oauth2_pkce_request_session. + sqlInsertOAuth2PKCERequestSession string + sqlSelectOAuth2PKCERequestSession string + sqlRevokeOAuth2PKCERequestSession string + sqlRevokeOAuth2PKCERequestSessionByRequestID string + sqlDeactivateOAuth2PKCERequestSession string + sqlDeactivateOAuth2PKCERequestSessionByRequestID string + + // Table: oauth2_openid_connect_session. + sqlInsertOAuth2OpenIDConnectSession string + sqlSelectOAuth2OpenIDConnectSession string + sqlRevokeOAuth2OpenIDConnectSession string + sqlRevokeOAuth2OpenIDConnectSessionByRequestID string + sqlDeactivateOAuth2OpenIDConnectSession string + sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string + + // Table: oauth2_consent_session. + sqlInsertOAuth2ConsentSession string + sqlUpdateOAuth2ConsentSessionResponse string + sqlUpdateOAuth2ConsentSessionGranted string + sqlSelectOAuth2ConsentSessionByChallengeID string + sqlSelectOAuth2ConsentSessionsPreConfigured string + + sqlUpsertOAuth2BlacklistedJTI string + sqlSelectOAuth2BlacklistedJTI string + // Utility. sqlSelectExistingTables string sqlFmtRenameTable string @@ -187,6 +291,329 @@ func (p *SQLProvider) StartupCheck() (err error) { } } +// BeginTX begins a transaction. +func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) { + var tx *sql.Tx + + if tx, err = p.db.Begin(); err != nil { + return nil, err + } + + return context.WithValue(ctx, ctxKeyTransaction, tx), nil +} + +// Commit performs a database commit. +func (p *SQLProvider) Commit(ctx context.Context) (err error) { + tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx) + + if !ok { + return errors.New("could not retrieve tx") + } + + return tx.Commit() +} + +// Rollback performs a database rollback. +func (p *SQLProvider) Rollback(ctx context.Context) (err error) { + tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx) + + if !ok { + return errors.New("could not retrieve tx") + } + + return tx.Rollback() +} + +// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database. +func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil { + return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err) + } + + return nil +} + +// LoadUserOpaqueIdentifier selects an opaque user identifier from the database. +func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) { + opaqueID = &model.UserOpaqueIdentifier{} + + if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, nil + default: + return nil, err + } + } + + return opaqueID, nil +} + +// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username. +func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) { + opaqueID = &model.UserOpaqueIdentifier{} + + if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil { + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, nil + default: + return nil, err + } + } + + return opaqueID, nil +} + +// SaveOAuth2ConsentSession inserts an OAuth2.0 consent. +func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession, + consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted, + consent.RequestedAt, consent.RespondedAt, consent.ExpiresAt, consent.Form, + consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience); err != nil { + return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.String(), err) + } + + return nil +} + +// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent with the consent response. +func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) { + _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.ExpiresAt, consent.GrantedScopes, consent.GrantedAudience, consent.ID) + if err != nil { + return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject, err) + } + + return nil +} + +// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint. +func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil { + return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err) + } + + return nil +} + +// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID. +func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) { + consent = &model.OAuth2ConsentSession{} + + if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil { + return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err) + } + + return consent, nil +} + +// LoadOAuth2ConsentSessionsPreConfigured returns an OAuth2.0 consents that are pre-configured given the consent signature. +func (p *SQLProvider) LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error) { + var r *sqlx.Rows + + if r, err = p.db.QueryxContext(ctx, p.sqlSelectOAuth2ConsentSessionsPreConfigured, clientID, subject); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return &ConsentSessionRows{}, nil + } + + return &ConsentSessionRows{}, fmt.Errorf("error selecting oauth2 consent session by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err) + } + + return &ConsentSessionRows{rows: r}, nil +} + +// SaveOAuth2Session saves a OAuth2Session to the database. +func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) { + var query string + + switch sessionType { + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlInsertOAuth2AuthorizeCodeSession + case OAuth2SessionTypeAccessToken: + query = p.sqlInsertOAuth2AccessTokenSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlInsertOAuth2RefreshTokenSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlInsertOAuth2PKCERequestSession + case OAuth2SessionTypeOpenIDConnect: + query = p.sqlInsertOAuth2OpenIDConnectSession + default: + return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType) + } + + if session.Session, err = p.encrypt(session.Session); err != nil { + return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) + } + + _, err = p.db.ExecContext(ctx, query, + session.ChallengeID, session.RequestID, session.ClientID, session.Signature, + session.Subject, session.RequestedAt, session.RequestedScopes, session.GrantedScopes, + session.RequestedAudience, session.GrantedAudience, + session.Active, session.Revoked, session.Form, session.Session) + + if err != nil { + return fmt.Errorf("error inserting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err) + } + + return nil +} + +// RevokeOAuth2Session marks a OAuth2Session as revoked in the database. +func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) { + var query string + + switch sessionType { + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlRevokeOAuth2AuthorizeCodeSession + case OAuth2SessionTypeAccessToken: + query = p.sqlRevokeOAuth2AccessTokenSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlRevokeOAuth2RefreshTokenSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlRevokeOAuth2PKCERequestSession + case OAuth2SessionTypeOpenIDConnect: + query = p.sqlRevokeOAuth2OpenIDConnectSession + default: + return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType) + } + + if _, err = p.db.ExecContext(ctx, query, signature); err != nil { + return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err) + } + + return nil +} + +// RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database. +func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) { + var query string + + switch sessionType { + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID + case OAuth2SessionTypeAccessToken: + query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID + case OAuth2SessionTypeRefreshToken: + query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID + case OAuth2SessionTypePKCEChallenge: + query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID + case OAuth2SessionTypeOpenIDConnect: + query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID + default: + return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType) + } + + if _, err = p.db.ExecContext(ctx, query, requestID); err != nil { + return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err) + } + + return nil +} + +// DeactivateOAuth2Session marks a OAuth2Session as inactive in the database. +func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) { + var query string + + switch sessionType { + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlDeactivateOAuth2AuthorizeCodeSession + case OAuth2SessionTypeAccessToken: + query = p.sqlDeactivateOAuth2AccessTokenSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlDeactivateOAuth2RefreshTokenSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlDeactivateOAuth2PKCERequestSession + case OAuth2SessionTypeOpenIDConnect: + query = p.sqlDeactivateOAuth2OpenIDConnectSession + default: + return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType) + } + + if _, err = p.db.ExecContext(ctx, query, signature); err != nil { + return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err) + } + + return nil +} + +// DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database. +func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) { + var query string + + switch sessionType { + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlDeactivateOAuth2AuthorizeCodeSession + case OAuth2SessionTypeAccessToken: + query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID + case OAuth2SessionTypeRefreshToken: + query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID + case OAuth2SessionTypePKCEChallenge: + query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID + case OAuth2SessionTypeOpenIDConnect: + query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID + default: + return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType) + } + + if _, err = p.db.ExecContext(ctx, query, requestID); err != nil { + return fmt.Errorf("error deactivating oauth2 %s session with request id '%s': %w", sessionType, requestID, err) + } + + return nil +} + +// LoadOAuth2Session saves a OAuth2Session from the database. +func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) { + var query string + + switch sessionType { + case OAuth2SessionTypeAuthorizeCode: + query = p.sqlSelectOAuth2AuthorizeCodeSession + case OAuth2SessionTypeAccessToken: + query = p.sqlSelectOAuth2AccessTokenSession + case OAuth2SessionTypeRefreshToken: + query = p.sqlSelectOAuth2RefreshTokenSession + case OAuth2SessionTypePKCEChallenge: + query = p.sqlSelectOAuth2PKCERequestSession + case OAuth2SessionTypeOpenIDConnect: + query = p.sqlSelectOAuth2OpenIDConnectSession + default: + return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType) + } + + session = &model.OAuth2Session{} + + if err = p.db.GetContext(ctx, session, query, signature); err != nil { + return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err) + } + + if session.Session, err = p.decrypt(session.Session); err != nil { + return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject, session.RequestID, err) + } + + return session, nil +} + +// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database. +func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) { + if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil { + return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) + } + + return nil +} + +// LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database. +func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) { + blacklistedJTI = &model.OAuth2BlacklistedJTI{} + + if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil { + return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err) + } + + return blacklistedJTI, nil +} + // SavePreferred2FAMethod save the preferred method for 2FA to the database. func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil { diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go index 36c959e7..e1fdf630 100644 --- a/internal/storage/sql_provider_backend_postgres.go +++ b/internal/storage/sql_provider_backend_postgres.go @@ -26,19 +26,27 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr // Specific alterations to this provider. // PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead. - provider.sqlUpsertWebauthnDevice = fmt.Sprintf(queryFmtPostgresUpsertWebauthnDevice, tableWebauthnDevices) - provider.sqlUpsertDuoDevice = fmt.Sprintf(queryFmtPostgresUpsertDuoDevice, tableDuoDevices) - provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations) - provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences) - provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtPostgresUpsertEncryptionValue, tableEncryption) + provider.sqlUpsertWebauthnDevice = fmt.Sprintf(queryFmtUpsertWebauthnDevicePostgreSQL, tableWebauthnDevices) + provider.sqlUpsertDuoDevice = fmt.Sprintf(queryFmtUpsertDuoDevicePostgreSQL, tableDuoDevices) + provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtUpsertTOTPConfigurationPostgreSQL, tableTOTPConfigurations) + provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtUpsertPreferred2FAMethodPostgreSQL, tableUserPreferences) + provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtUpsertEncryptionValuePostgreSQL, tableEncryption) + provider.sqlUpsertOAuth2BlacklistedJTI = fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTIPostgreSQL, tableOAuth2BlacklistedJTI) // PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders. provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable) + provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod) provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo) + + provider.sqlInsertUserOpaqueIdentifier = provider.db.Rebind(provider.sqlInsertUserOpaqueIdentifier) + provider.sqlSelectUserOpaqueIdentifier = provider.db.Rebind(provider.sqlSelectUserOpaqueIdentifier) + provider.sqlSelectUserOpaqueIdentifierBySignature = provider.db.Rebind(provider.sqlSelectUserOpaqueIdentifierBySignature) + provider.sqlSelectIdentityVerification = provider.db.Rebind(provider.sqlSelectIdentityVerification) provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification) provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification) + provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig) provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn) provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername) @@ -46,21 +54,69 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs) provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret) provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername) + provider.sqlSelectWebauthnDevices = provider.db.Rebind(provider.sqlSelectWebauthnDevices) provider.sqlSelectWebauthnDevicesByUsername = provider.db.Rebind(provider.sqlSelectWebauthnDevicesByUsername) provider.sqlUpdateWebauthnDevicePublicKey = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKey) provider.sqlUpdateWebauthnDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKeyByUsername) provider.sqlUpdateWebauthnDeviceRecordSignIn = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignIn) provider.sqlUpdateWebauthnDeviceRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignInByUsername) + provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice) provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice) + provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt) provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername) + provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration) provider.sqlSelectMigrations = provider.db.Rebind(provider.sqlSelectMigrations) provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration) + provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue) + provider.sqlInsertOAuth2ConsentSession = provider.db.Rebind(provider.sqlInsertOAuth2ConsentSession) + provider.sqlUpdateOAuth2ConsentSessionResponse = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionResponse) + provider.sqlUpdateOAuth2ConsentSessionGranted = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionGranted) + provider.sqlSelectOAuth2ConsentSessionByChallengeID = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionByChallengeID) + provider.sqlSelectOAuth2ConsentSessionsPreConfigured = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionsPreConfigured) + + provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession) + provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession) + provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID) + provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession) + provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID) + provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession) + + provider.sqlInsertOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2AccessTokenSession) + provider.sqlRevokeOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSession) + provider.sqlRevokeOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSessionByRequestID) + provider.sqlDeactivateOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSession) + provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID) + provider.sqlSelectOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2AccessTokenSession) + + provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession) + provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession) + provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID) + provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession) + provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID) + provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession) + + provider.sqlInsertOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlInsertOAuth2PKCERequestSession) + provider.sqlRevokeOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSession) + provider.sqlRevokeOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSessionByRequestID) + provider.sqlDeactivateOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSession) + provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID) + provider.sqlSelectOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlSelectOAuth2PKCERequestSession) + + provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession) + provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession) + provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID) + provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession) + provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID) + provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession) + + provider.sqlSelectOAuth2BlacklistedJTI = provider.db.Rebind(provider.sqlSelectOAuth2BlacklistedJTI) + provider.schema = config.Storage.PostgreSQL.Schema return provider diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go index 43009413..57ecbc66 100644 --- a/internal/storage/sql_provider_queries.go +++ b/internal/storage/sql_provider_queries.go @@ -48,7 +48,7 @@ const ( REPLACE INTO %s (username, second_factor_method) VALUES (?, ?);` - queryFmtPostgresUpsertPreferred2FAMethod = ` + queryFmtUpsertPreferred2FAMethodPostgreSQL = ` INSERT INTO %s (username, second_factor_method) VALUES ($1, $2) ON CONFLICT (username) @@ -99,7 +99,7 @@ const ( REPLACE INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret) VALUES (?, ?, ?, ?, ?, ?, ?, ?);` - queryFmtPostgresUpsertTOTPConfiguration = ` + queryFmtUpsertTOTPConfigurationPostgreSQL = ` INSERT INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (username) @@ -160,7 +160,7 @@ const ( REPLACE INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` - queryFmtPostgresUpsertWebauthnDevice = ` + queryFmtUpsertWebauthnDevicePostgreSQL = ` INSERT INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (username, description) @@ -172,7 +172,7 @@ const ( REPLACE INTO %s (username, device, method) VALUES (?, ?, ?);` - queryFmtPostgresUpsertDuoDevice = ` + queryFmtUpsertDuoDevicePostgreSQL = ` INSERT INTO %s (username, device, method) VALUES ($1, $2, $3) ON CONFLICT (username) @@ -214,9 +214,103 @@ const ( REPLACE INTO %s (name, value) VALUES (?, ?);` - queryFmtPostgresUpsertEncryptionValue = ` + queryFmtUpsertEncryptionValuePostgreSQL = ` INSERT INTO %s (name, value) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET value = $2;` ) + +const ( + queryFmtSelectOAuth2ConsentSessionByChallengeID = ` + SELECT id, challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at, + form_data, requested_scopes, granted_scopes, requested_audience, granted_audience + FROM %s + WHERE challenge_id = ?;` + + queryFmtSelectOAuth2ConsentSessionsPreConfigured = ` + SELECT id, challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at, + form_data, requested_scopes, granted_scopes, requested_audience, granted_audience + FROM %s + WHERE client_id = ? AND subject = ? AND + authorized = TRUE AND granted = TRUE AND expires_at IS NOT NULL AND expires_at >= CURRENT_TIMESTAMP;` + + queryFmtInsertOAuth2ConsentSession = ` + INSERT INTO %s (challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at, + form_data, requested_scopes, granted_scopes, requested_audience, granted_audience) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` + + queryFmtUpdateOAuth2ConsentSessionResponse = ` + UPDATE %s + SET authorized = ?, responded_at = CURRENT_TIMESTAMP, expires_at = ?, granted_scopes = ?, granted_audience = ? + WHERE id = ? AND responded_at IS NULL;` + + queryFmtUpdateOAuth2ConsentSessionGranted = ` + UPDATE %s + SET granted = TRUE + WHERE id = ? AND responded_at IS NOT NULL;` + + queryFmtSelectOAuth2Session = ` + SELECT id, challenge_id, request_id, client_id, signature, subject, requested_at, + requested_scopes, granted_scopes, requested_audience, granted_audience, + active, revoked, form_data, session_data + FROM %s + WHERE signature = ? AND revoked = FALSE;` + + queryFmtInsertOAuth2Session = ` + INSERT INTO %s (challenge_id, request_id, client_id, signature, subject, requested_at, + requested_scopes, granted_scopes, requested_audience, granted_audience, + active, revoked, form_data, session_data) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` + + queryFmtRevokeOAuth2Session = ` + UPDATE %s + SET revoked = TRUE + WHERE signature = ?;` + + queryFmtRevokeOAuth2SessionByRequestID = ` + UPDATE %s + SET revoked = TRUE + WHERE request_id = ?;` + + queryFmtDeactivateOAuth2Session = ` + UPDATE %s + SET active = FALSE + WHERE signature = ?;` + + queryFmtDeactivateOAuth2SessionByRequestID = ` + UPDATE %s + SET active = FALSE + WHERE request_id = ?;"` + + queryFmtSelectOAuth2BlacklistedJTI = ` + SELECT id, signature, expires_at + FROM %s + WHERE signature = ?;` + + queryFmtUpsertOAuth2BlacklistedJTI = ` + REPLACE INTO %s (signature, expires_at) + VALUES(?, ?);` + + queryFmtUpsertOAuth2BlacklistedJTIPostgreSQL = ` + INSERT INTO %s (signature, expires_at) + VALUES ($1, $2) + ON CONFLICT (signature) + DO UPDATE SET expires_at = $2;` +) + +const ( + queryFmtInsertUserOpaqueIdentifier = ` + INSERT INTO %s (service, sector_id, username, identifier) + VALUES(?, ?, ?, ?);` + + queryFmtSelectUserOpaqueIdentifier = ` + SELECT id, sector_id, username, identifier + FROM %s + WHERE identifier = ?;` + + queryFmtSelectUserOpaqueIdentifierBySignature = ` + SELECT id, service, sector_id, username, identifier + FROM %s + WHERE service = ? AND sector_id = ? AND username = ?;` +) diff --git a/internal/storage/sql_rows.go b/internal/storage/sql_rows.go new file mode 100644 index 00000000..148dbf38 --- /dev/null +++ b/internal/storage/sql_rows.go @@ -0,0 +1,47 @@ +package storage + +import ( + "database/sql" + + "github.com/jmoiron/sqlx" + + "github.com/authelia/authelia/v4/internal/model" +) + +// ConsentSessionRows holds and assists with retrieving multiple model.OAuth2ConsentSession rows. +type ConsentSessionRows struct { + rows *sqlx.Rows +} + +// Next is the row iterator. +func (r *ConsentSessionRows) Next() bool { + if r.rows == nil { + return false + } + + return r.rows.Next() +} + +// Close the rows. +func (r *ConsentSessionRows) Close() (err error) { + if r.rows == nil { + return nil + } + + return r.rows.Close() +} + +// Get returns the *model.OAuth2ConsentSession or scan error. +func (r *ConsentSessionRows) Get() (consent *model.OAuth2ConsentSession, err error) { + if r.rows == nil { + return nil, sql.ErrNoRows + } + + consent = &model.OAuth2ConsentSession{} + + if err = r.rows.StructScan(consent); err != nil { + return nil, err + } + + return consent, nil +} diff --git a/internal/suites/OIDC/configuration.yml b/internal/suites/OIDC/configuration.yml index c63892dc..a68a82a6 100644 --- a/internal/suites/OIDC/configuration.yml +++ b/internal/suites/OIDC/configuration.yml @@ -58,6 +58,7 @@ notifier: identity_providers: oidc: + enable_client_debug_messages: true hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm issuer_private_key: | -----BEGIN RSA PRIVATE KEY----- diff --git a/internal/suites/OIDCTraefik/configuration.yml b/internal/suites/OIDCTraefik/configuration.yml index cc949e56..11fb3e73 100644 --- a/internal/suites/OIDCTraefik/configuration.yml +++ b/internal/suites/OIDCTraefik/configuration.yml @@ -60,6 +60,7 @@ notifier: identity_providers: oidc: + enable_client_debug_messages: true hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm issuer_private_key: | -----BEGIN RSA PRIVATE KEY----- diff --git a/internal/suites/action_2fa_methods.go b/internal/suites/action_2fa_methods.go index 89a1f241..b012c110 100644 --- a/internal/suites/action_2fa_methods.go +++ b/internal/suites/action_2fa_methods.go @@ -9,26 +9,26 @@ import ( ) func (rs *RodSession) doChangeMethod(t *testing.T, page *rod.Page, method string) { - err := rs.WaitElementLocatedByCSSSelector(t, page, "methods-button").Click("left") + err := rs.WaitElementLocatedByID(t, page, "methods-button").Click("left") require.NoError(t, err) - rs.WaitElementLocatedByCSSSelector(t, page, "methods-dialog") - err = rs.WaitElementLocatedByCSSSelector(t, page, fmt.Sprintf("%s-option", method)).Click("left") + rs.WaitElementLocatedByID(t, page, "methods-dialog") + err = rs.WaitElementLocatedByID(t, page, fmt.Sprintf("%s-option", method)).Click("left") require.NoError(t, err) } func (rs *RodSession) doChangeDevice(t *testing.T, page *rod.Page, deviceID string) { - err := rs.WaitElementLocatedByCSSSelector(t, page, "selection-link").Click("left") + err := rs.WaitElementLocatedByID(t, page, "selection-link").Click("left") require.NoError(t, err) rs.doSelectDevice(t, page, deviceID) } func (rs *RodSession) doSelectDevice(t *testing.T, page *rod.Page, deviceID string) { - rs.WaitElementLocatedByCSSSelector(t, page, "device-selection") - err := rs.WaitElementLocatedByCSSSelector(t, page, fmt.Sprintf("device-%s", deviceID)).Click("left") + rs.WaitElementLocatedByID(t, page, "device-selection") + err := rs.WaitElementLocatedByID(t, page, fmt.Sprintf("device-%s", deviceID)).Click("left") require.NoError(t, err) } func (rs *RodSession) doClickButton(t *testing.T, page *rod.Page, buttonID string) { - err := rs.WaitElementLocatedByCSSSelector(t, page, buttonID).Click("left") + err := rs.WaitElementLocatedByID(t, page, buttonID).Click("left") require.NoError(t, err) } diff --git a/internal/suites/action_login.go b/internal/suites/action_login.go index fd4af24b..73e7b807 100644 --- a/internal/suites/action_login.go +++ b/internal/suites/action_login.go @@ -9,21 +9,21 @@ import ( ) func (rs *RodSession) doFillLoginPageAndClick(t *testing.T, page *rod.Page, username, password string, keepMeLoggedIn bool) { - usernameElement := rs.WaitElementLocatedByCSSSelector(t, page, "username-textfield") + usernameElement := rs.WaitElementLocatedByID(t, page, "username-textfield") err := usernameElement.Input(username) require.NoError(t, err) - passwordElement := rs.WaitElementLocatedByCSSSelector(t, page, "password-textfield") + passwordElement := rs.WaitElementLocatedByID(t, page, "password-textfield") err = passwordElement.Input(password) require.NoError(t, err) if keepMeLoggedIn { - keepMeLoggedInElement := rs.WaitElementLocatedByCSSSelector(t, page, "remember-checkbox") + keepMeLoggedInElement := rs.WaitElementLocatedByID(t, page, "remember-checkbox") err = keepMeLoggedInElement.Click("left") require.NoError(t, err) } - buttonElement := rs.WaitElementLocatedByCSSSelector(t, page, "sign-in-button") + buttonElement := rs.WaitElementLocatedByID(t, page, "sign-in-button") err = buttonElement.Click("left") require.NoError(t, err) } diff --git a/internal/suites/action_reset_password.go b/internal/suites/action_reset_password.go index 2dcf0a25..5a3ca5d5 100644 --- a/internal/suites/action_reset_password.go +++ b/internal/suites/action_reset_password.go @@ -9,13 +9,13 @@ import ( ) func (rs *RodSession) doInitiatePasswordReset(t *testing.T, page *rod.Page, username string) { - err := rs.WaitElementLocatedByCSSSelector(t, page, "reset-password-button").Click("left") + err := rs.WaitElementLocatedByID(t, page, "reset-password-button").Click("left") require.NoError(t, err) // Fill in username. - err = rs.WaitElementLocatedByCSSSelector(t, page, "username-textfield").Input(username) + err = rs.WaitElementLocatedByID(t, page, "username-textfield").Input(username) require.NoError(t, err) // And click on the reset button. - err = rs.WaitElementLocatedByCSSSelector(t, page, "reset-button").Click("left") + err = rs.WaitElementLocatedByID(t, page, "reset-button").Click("left") require.NoError(t, err) } @@ -25,15 +25,15 @@ func (rs *RodSession) doCompletePasswordReset(t *testing.T, page *rod.Page, newP time.Sleep(1 * time.Second) - err := rs.WaitElementLocatedByCSSSelector(t, page, "password1-textfield").Input(newPassword1) + err := rs.WaitElementLocatedByID(t, page, "password1-textfield").Input(newPassword1) require.NoError(t, err) time.Sleep(1 * time.Second) - err = rs.WaitElementLocatedByCSSSelector(t, page, "password2-textfield").Input(newPassword2) + err = rs.WaitElementLocatedByID(t, page, "password2-textfield").Input(newPassword2) require.NoError(t, err) - err = rs.WaitElementLocatedByCSSSelector(t, page, "reset-button").Click("left") + err = rs.WaitElementLocatedByID(t, page, "reset-button").Click("left") require.NoError(t, err) } diff --git a/internal/suites/action_totp.go b/internal/suites/action_totp.go index 07d00736..bf3d16bc 100644 --- a/internal/suites/action_totp.go +++ b/internal/suites/action_totp.go @@ -12,7 +12,7 @@ import ( ) func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string { - err := rs.WaitElementLocatedByCSSSelector(t, page, "register-link").Click("left") + err := rs.WaitElementLocatedByID(t, page, "register-link").Click("left") require.NoError(t, err) rs.verifyMailNotificationDisplayed(t, page) link := doGetLinkFromLastMail(t) @@ -28,7 +28,7 @@ func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string { } func (rs *RodSession) doEnterOTP(t *testing.T, page *rod.Page, code string) { - inputs := rs.WaitElementsLocatedByCSSSelector(t, page, "otp-input input") + inputs := rs.WaitElementsLocatedByID(t, page, "otp-input input") for i := 0; i < len(code); i++ { _ = inputs[i].Input(string(code[i])) diff --git a/internal/suites/example/compose/oidc-client/docker-compose.yml b/internal/suites/example/compose/oidc-client/docker-compose.yml index 40b211a5..0a150c36 100644 --- a/internal/suites/example/compose/oidc-client/docker-compose.yml +++ b/internal/suites/example/compose/oidc-client/docker-compose.yml @@ -2,7 +2,7 @@ version: '3' services: oidc-client: - image: ghcr.io/authelia/oidc-tester-app:master-89622a8 + image: ghcr.io/authelia/oidc-tester-app:master-01ff268 command: /entrypoint.sh depends_on: - authelia-backend diff --git a/internal/suites/example/compose/oidc-client/entrypoint.sh b/internal/suites/example/compose/oidc-client/entrypoint.sh index 6cf8792e..0fe5729f 100755 --- a/internal/suites/example/compose/oidc-client/entrypoint.sh +++ b/internal/suites/example/compose/oidc-client/entrypoint.sh @@ -2,6 +2,6 @@ while true; do - oidc-tester-app --issuer https://login.example.com:8080 --id oidc-tester-app --secret foobar --scopes openid,profile,email --public-url https://oidc.example.com:8080 + oidc-tester-app --issuer=https://login.example.com:8080 --id=oidc-tester-app --secret=foobar --scopes=openid,profile,email,groups --public-url=https://oidc.example.com:8080 sleep 5 done \ No newline at end of file diff --git a/internal/suites/scenario_available_methods_test.go b/internal/suites/scenario_available_methods_test.go index 976fb70d..7135f7af 100644 --- a/internal/suites/scenario_available_methods_test.go +++ b/internal/suites/scenario_available_methods_test.go @@ -58,11 +58,11 @@ func (s *AvailableMethodsScenario) TestShouldCheckAvailableMethods() { s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") - methodsButton := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "methods-button") + methodsButton := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "methods-button") err := methodsButton.Click("left") s.Assert().NoError(err) - methodsDialog := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "methods-dialog") + methodsDialog := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "methods-dialog") options, err := methodsDialog.Elements(".method-option") s.Assert().NoError(err) s.Assert().Len(options, len(s.methods)) diff --git a/internal/suites/scenario_oidc_test.go b/internal/suites/scenario_oidc_test.go index 6adc04a6..7add41b4 100644 --- a/internal/suites/scenario_oidc_test.go +++ b/internal/suites/scenario_oidc_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "regexp" "testing" "time" @@ -89,11 +90,50 @@ func (s *OIDCScenario) TestShouldAuthorizeAccessToOIDCApp() { assert.NoError(s.T(), err) s.verifyIsConsentPage(s.T(), s.Context(ctx)) - err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "accept-button").Click("left") + err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "accept-button").Click("left") assert.NoError(s.T(), err) // Verify that the app is showing the info related to the user stored in the JWT token. - s.waitBodyContains(s.T(), s.Context(ctx), "Logged in as john!") + + rUUID := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + rInteger := regexp.MustCompile(`^\d+$`) + rBoolean := regexp.MustCompile(`^(true|false)$`) + rBase64 := regexp.MustCompile(`^[-_A-Za-z0-9+\\/]+([=]{0,3})$`) + + testCases := []struct { + desc, elementID, elementText string + pattern *regexp.Regexp + }{ + {"welcome", "welcome", "Logged in as john!", nil}, + {"at_hash", "claim-at_hash", "", rBase64}, + {"jti", "claim-jti", "", rUUID}, + {"iat", "claim-iat", "", rInteger}, + {"nbf", "claim-nbf", "", rInteger}, + {"rat", "claim-rat", "", rInteger}, + {"expires", "claim-exp", "", rInteger}, + {"amr", "claim-amr", "pwd, otp, mfa", nil}, + {"acr", "claim-acr", "", nil}, + {"issuer", "claim-iss", "https://login.example.com:8080", nil}, + {"name", "claim-name", "John Doe", nil}, + {"preferred_username", "claim-preferred_username", "john", nil}, + {"groups", "claim-groups", "admins, dev", nil}, + {"email", "claim-email", "john.doe@authelia.com", nil}, + {"email_verified", "claim-email_verified", "", rBoolean}, + } + + var text string + + for _, tc := range testCases { + s.T().Run(fmt.Sprintf("check_claims/%s", tc.desc), func(t *testing.T) { + text, err = s.WaitElementLocatedByID(t, s.Context(ctx), tc.elementID).Text() + assert.NoError(t, err) + if tc.pattern == nil { + assert.Equal(t, tc.elementText, text) + } else { + assert.Regexp(t, tc.pattern, text) + } + }) + } } func (s *OIDCScenario) TestShouldDenyConsent() { @@ -117,10 +157,17 @@ func (s *OIDCScenario) TestShouldDenyConsent() { s.verifyIsConsentPage(s.T(), s.Context(ctx)) - err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "deny-button").Click("left") + err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "deny-button").Click("left") assert.NoError(s.T(), err) - s.verifyIsOIDC(s.T(), s.Context(ctx), "oauth2:", "https://oidc.example.com:8080/oauth2/callback?error=access_denied&error_description=User%20has%20rejected%20the%20scopes") + s.verifyIsOIDC(s.T(), s.Context(ctx), "access_denied", "https://oidc.example.com:8080/error?error=access_denied&error_description=The+resource+owner+or+authorization+server+denied+the+request.+Make+sure+that+the+request+you+are+making+is+valid.+Maybe+the+credential+or+request+parameters+you+are+using+are+limited+in+scope+or+otherwise+restricted.&state=random-string-here") + + errorDescription := "The resource owner or authorization server denied the request. Make sure that the request " + + "you are making is valid. Maybe the credential or request parameters you are using are limited in scope or " + + "otherwise restricted." + + s.verifyIsOIDCErrorPage(s.T(), s.Context(ctx), "access_denied", errorDescription, "", + "random-string-here") } func TestRunOIDCScenario(t *testing.T) { diff --git a/internal/suites/scenario_regulation_test.go b/internal/suites/scenario_regulation_test.go index 4d7e20b0..0eba4661 100644 --- a/internal/suites/scenario_regulation_test.go +++ b/internal/suites/scenario_regulation_test.go @@ -60,16 +60,16 @@ func (s *RegulationScenario) TestShouldBanUserAfterTooManyAttempt() { s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.") for i := 0; i < 3; i++ { - err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("bad-password") + err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("bad-password") require.NoError(s.T(), err) - err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left") + err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left") require.NoError(s.T(), err) } // Enter the correct password and test the regulation lock out. - err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("password") + err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("password") require.NoError(s.T(), err) - err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left") + err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left") require.NoError(s.T(), err) s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.") @@ -77,9 +77,9 @@ func (s *RegulationScenario) TestShouldBanUserAfterTooManyAttempt() { time.Sleep(10 * time.Second) // Enter the correct password and test a successful login. - err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("password") + err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("password") require.NoError(s.T(), err) - err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left") + err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left") require.NoError(s.T(), err) s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) } diff --git a/internal/suites/scenario_user_preferences_test.go b/internal/suites/scenario_user_preferences_test.go index 6b35e805..f8e76b23 100644 --- a/internal/suites/scenario_user_preferences_test.go +++ b/internal/suites/scenario_user_preferences_test.go @@ -60,7 +60,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() { // Then switch to push notification method. s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method") // Switch context to clean up state in portal. s.doVisit(s.T(), s.Context(ctx), HomeBaseURL) @@ -70,7 +70,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() { s.doVisit(s.T(), s.Context(ctx), GetLoginBaseURL()) s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) // And check the latest method is still used. - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method") // Meaning the authentication is successful. s.verifyIsHome(s.T(), s.Context(ctx)) @@ -78,7 +78,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() { s.doLogout(s.T(), s.Context(ctx)) s.doLoginOneFactor(s.T(), s.Context(ctx), "harry", "password", false, "") s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "one-time-password-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "one-time-password-method") s.doLogout(s.T(), s.Context(ctx)) s.verifyIsFirstFactorPage(s.T(), s.Context(ctx)) @@ -86,7 +86,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() { // Then log back as previous user and verify the push notification is still the default method. s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method") s.verifyIsHome(s.T(), s.Context(ctx)) s.doLogout(s.T(), s.Context(ctx)) @@ -94,7 +94,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() { // Eventually restore the default method. s.doChangeMethod(s.T(), s.Context(ctx), "one-time-password") - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "one-time-password-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "one-time-password-method") } func TestUserPreferencesScenario(t *testing.T) { diff --git a/internal/suites/suite_duo_push_test.go b/internal/suites/suite_duo_push_test.go index 79ca2dc6..4573c5b0 100644 --- a/internal/suites/suite_duo_push_test.go +++ b/internal/suites/suite_duo_push_test.go @@ -110,7 +110,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAskUserToRegister() { s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "state-not-registered") s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "No compatible device found") enrollPage := s.Page.MustWaitOpen() - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "register-link").MustClick() + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "register-link").MustClick() s.Page = enrollPage() assert.Contains(s.T(), s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "description").MustText(), "This enrollment code has expired. Contact your administrator to get a new enrollment code.") @@ -142,7 +142,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAutoSelectDevice() { s.doLogout(s.T(), s.Context(ctx)) s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") // And check the latest method and device is still used. - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method") // Meaning the authentication is successful. s.verifyIsHome(s.T(), s.Context(ctx)) } @@ -176,7 +176,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() { // Switch Method where Device Selection should open automatically. s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") // Check for available Device 1. - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "device-12345ABCDEFGHIJ67890") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "device-12345ABCDEFGHIJ67890") // Test Back button. s.doClickButton(s.T(), s.Context(ctx), "device-selection-back") // then select Device 2 for further use and be redirected. @@ -187,7 +187,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() { s.doLogout(s.T(), s.Context(ctx)) s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") // And check the latest method and device is still used. - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method") // Meaning the authentication is successful. s.verifyIsHome(s.T(), s.Context(ctx)) } @@ -238,7 +238,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectNewDeviceAfterSavedDeviceMethodI s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") - s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "device-selection") + s.WaitElementLocatedByID(s.T(), s.Context(ctx), "device-selection") s.doSelectDevice(s.T(), s.Context(ctx), "12345ABCDEFGHIJ67890") s.verifyIsHome(s.T(), s.Context(ctx)) } @@ -303,7 +303,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailSelectionBecauseOfSelectionDenied( s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") - err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "selection-link").Click("left") + err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "selection-link").Click("left") require.NoError(s.T(), err) s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Device selection was denied by Duo policy") } diff --git a/internal/suites/verify_is_authenticated_page.go b/internal/suites/verify_is_authenticated_page.go index 5cb24d4b..1e9bf7dd 100644 --- a/internal/suites/verify_is_authenticated_page.go +++ b/internal/suites/verify_is_authenticated_page.go @@ -7,5 +7,5 @@ import ( ) func (rs *RodSession) verifyIsAuthenticatedPage(t *testing.T, page *rod.Page) { - rs.WaitElementLocatedByCSSSelector(t, page, "authenticated-stage") + rs.WaitElementLocatedByID(t, page, "authenticated-stage") } diff --git a/internal/suites/verify_is_consent_page.go b/internal/suites/verify_is_consent_page.go index de143a83..0862b611 100644 --- a/internal/suites/verify_is_consent_page.go +++ b/internal/suites/verify_is_consent_page.go @@ -7,5 +7,5 @@ import ( ) func (rs *RodSession) verifyIsConsentPage(t *testing.T, page *rod.Page) { - rs.WaitElementLocatedByCSSSelector(t, page, "consent-stage") + rs.WaitElementLocatedByID(t, page, "consent-stage") } diff --git a/internal/suites/verify_is_first_factor_page.go b/internal/suites/verify_is_first_factor_page.go index 208194e3..40a50895 100644 --- a/internal/suites/verify_is_first_factor_page.go +++ b/internal/suites/verify_is_first_factor_page.go @@ -7,5 +7,5 @@ import ( ) func (rs *RodSession) verifyIsFirstFactorPage(t *testing.T, page *rod.Page) { - rs.WaitElementLocatedByCSSSelector(t, page, "first-factor-stage") + rs.WaitElementLocatedByID(t, page, "first-factor-stage") } diff --git a/internal/suites/verify_is_oidc.go b/internal/suites/verify_is_oidc.go index 9e8e2219..a6579bea 100644 --- a/internal/suites/verify_is_oidc.go +++ b/internal/suites/verify_is_oidc.go @@ -4,9 +4,33 @@ import ( "testing" "github.com/go-rod/rod" + "github.com/stretchr/testify/assert" ) func (rs *RodSession) verifyIsOIDC(t *testing.T, page *rod.Page, pattern, url string) { page.MustElementR("body", pattern) rs.verifyURLIs(t, page, url) } + +func (rs *RodSession) verifyIsOIDCErrorPage(t *testing.T, page *rod.Page, errorCode, errorDescription, errorURI, state string) { + testCases := []struct { + ElementID, ElementText string + }{ + {"error", errorCode}, + {"error_description", errorDescription}, + {"error_uri", errorURI}, + {"state", state}, + } + + for _, tc := range testCases { + t.Run(tc.ElementID, func(t *testing.T) { + if tc.ElementText == "" { + t.Skip("Test Skipped as the element is not expected.") + } + + text, err := rs.WaitElementLocatedByID(t, page, tc.ElementID).Text() + assert.NoError(t, err) + assert.Equal(t, tc.ElementText, text) + }) + } +} diff --git a/internal/suites/verify_is_second_factor_page.go b/internal/suites/verify_is_second_factor_page.go index a37ffc27..8813afc6 100644 --- a/internal/suites/verify_is_second_factor_page.go +++ b/internal/suites/verify_is_second_factor_page.go @@ -7,5 +7,5 @@ import ( ) func (rs *RodSession) verifyIsSecondFactorPage(t *testing.T, page *rod.Page) { - rs.WaitElementLocatedByCSSSelector(t, page, "second-factor-stage") + rs.WaitElementLocatedByID(t, page, "second-factor-stage") } diff --git a/internal/suites/verify_secret_authorized.go b/internal/suites/verify_secret_authorized.go index 2c6c94f4..18f6bbe8 100644 --- a/internal/suites/verify_secret_authorized.go +++ b/internal/suites/verify_secret_authorized.go @@ -7,5 +7,5 @@ import ( ) func (rs *RodSession) verifySecretAuthorized(t *testing.T, page *rod.Page) { - rs.WaitElementLocatedByCSSSelector(t, page, "secret") + rs.WaitElementLocatedByID(t, page, "secret") } diff --git a/internal/suites/webdriver.go b/internal/suites/webdriver.go index e93b2353..28b7750b 100644 --- a/internal/suites/webdriver.go +++ b/internal/suites/webdriver.go @@ -82,8 +82,8 @@ func (rs *RodSession) WaitElementLocatedByClassName(t *testing.T, page *rod.Page return e } -// WaitElementLocatedByCSSSelector wait an element is located by class name. -func (rs *RodSession) WaitElementLocatedByCSSSelector(t *testing.T, page *rod.Page, cssSelector string) *rod.Element { +// WaitElementLocatedByID waits for an element located by an id. +func (rs *RodSession) WaitElementLocatedByID(t *testing.T, page *rod.Page, cssSelector string) *rod.Element { e, err := page.Element("#" + cssSelector) require.NoError(t, err) require.NotNil(t, e) @@ -91,8 +91,8 @@ func (rs *RodSession) WaitElementLocatedByCSSSelector(t *testing.T, page *rod.Pa return e } -// WaitElementsLocatedByCSSSelector wait an element is located by CSS selector. -func (rs *RodSession) WaitElementsLocatedByCSSSelector(t *testing.T, page *rod.Page, cssSelector string) rod.Elements { +// WaitElementsLocatedByID waits for an elements located by an id. +func (rs *RodSession) WaitElementsLocatedByID(t *testing.T, page *rod.Page, cssSelector string) rod.Elements { e, err := page.Elements("#" + cssSelector) require.NoError(t, err) require.NotNil(t, e) diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 231621f2..7fb6b8ee 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -241,6 +241,39 @@ func StringHTMLEscape(input string) (output string) { return htmlEscaper.Replace(input) } +// StringJoinDelimitedEscaped joins a string with a specified rune delimiter after escaping any instance of that string +// in the string slice. Used with StringSplitDelimitedEscaped. +func StringJoinDelimitedEscaped(value []string, delimiter rune) string { + escaped := make([]string, len(value)) + for k, v := range value { + escaped[k] = strings.ReplaceAll(v, string(delimiter), "\\"+string(delimiter)) + } + + return strings.Join(escaped, string(delimiter)) +} + +// StringSplitDelimitedEscaped splits a string with a specified rune delimiter after unescaping any instance of that +// string in the string slice that has been escaped. Used with StringJoinDelimitedEscaped. +func StringSplitDelimitedEscaped(value string, delimiter rune) (out []string) { + var escape bool + + split := strings.FieldsFunc(value, func(r rune) bool { + if r == '\\' { + escape = !escape + } else if escape && r != delimiter { + escape = false + } + + return !escape && r == delimiter + }) + + for k, v := range split { + split[k] = strings.ReplaceAll(v, "\\"+string(delimiter), string(delimiter)) + } + + return split +} + // JoinAndCanonicalizeHeaders join header strings by a given sep. func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) { for i, header := range headers { diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go index 2c9ad981..6b1b5e7c 100644 --- a/internal/utils/strings_test.go +++ b/internal/utils/strings_test.go @@ -8,6 +8,51 @@ import ( "github.com/stretchr/testify/require" ) +func TestStringSplitDelimitedEscaped(t *testing.T) { + testCases := []struct { + desc, have string + delimiter rune + want []string + }{ + {desc: "ShouldSplitNormalString", have: "abc,123,456", delimiter: ',', want: []string{"abc", "123", "456"}}, + {desc: "ShouldSplitEscapedString", have: "a\\,bc,123,456", delimiter: ',', want: []string{"a,bc", "123", "456"}}, + {desc: "ShouldSplitEscapedStringPipe", have: "a\\|bc|123|456", delimiter: '|', want: []string{"a|bc", "123", "456"}}, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + actual := StringSplitDelimitedEscaped(tc.have, tc.delimiter) + + assert.Equal(t, tc.want, actual) + }) + } +} + +func TestStringJoinDelimitedEscaped(t *testing.T) { + testCases := []struct { + desc, want string + delimiter rune + have []string + }{ + {desc: "ShouldJoinNormalStringSlice", have: []string{"abc", "123", "456"}, delimiter: ',', want: "abc,123,456"}, + {desc: "ShouldJoinEscapeNeededStringSlice", have: []string{"abc", "1,23", "456"}, delimiter: ',', want: "abc,1\\,23,456"}, + {desc: "ShouldJoinEscapeNeededStringSlicePipe", have: []string{"abc", "1|23", "456"}, delimiter: '|', want: "abc|1\\|23|456"}, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + actual := StringJoinDelimitedEscaped(tc.have, tc.delimiter) + + assert.Equal(t, tc.want, actual) + + // Ensure splitting again also works fine. + split := StringSplitDelimitedEscaped(actual, tc.delimiter) + + assert.Equal(t, tc.have, split) + }) + } +} + func TestShouldNotGenerateSameRandomString(t *testing.T) { randomStringOne := RandomString(10, AlphaNumericCharacters, false) randomStringTwo := RandomString(10, AlphaNumericCharacters, false)