diff --git a/docs/configuration/identity-providers/oidc.md b/docs/configuration/identity-providers/oidc.md
index f2ffd276..32f825af 100644
--- a/docs/configuration/identity-providers/oidc.md
+++ b/docs/configuration/identity-providers/oidc.md
@@ -232,7 +232,7 @@ Allows PKCE `plain` challenges when set to `true`.
Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
you to configure the optional parts. We reply with CORS headers when the request includes the Origin header.
-##### endpoints
+#### endpoints
type: list(string)
{: .label .label-config .label-purple }
@@ -522,8 +522,8 @@ individual user. Please use the claim `preferred_username` instead._
This scope includes the groups the authentication backend reports the user is a member of in the token.
-| Claim | JWT Type | Authelia Attribute | Description |
-|:------:|:-------------:|:------------------:|:----------------------:|
+| Claim | JWT Type | Authelia Attribute | Description |
+|:------:|:-------------:|:------------------:|:------------------------------------------------------------------------------------------------------------------:|
| groups | array[string] | groups | List of user's groups discovered via [authentication](https://www.authelia.com/docs/configuration/authentication/) |
### email
@@ -585,30 +585,33 @@ an example of the Authelia root URL which is also the OpenID Connect issuer.
These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP.
-| Endpoint | Path |
-|:-------------:|:---------------------------------------------------------------:|
-| Discovery | https://auth.example.com/.well-known/openid-configuration |
-| Metadata | https://auth.example.com/.well-known/oauth-authorization-server |
+| Endpoint | Path |
+|:-----------------------------------------:|:---------------------------------------------------------------:|
+| [OpenID Connect Discovery] | https://auth.example.com/.well-known/openid-configuration |
+| [OAuth 2.0 Authorization Server Metadata] | https://auth.example.com/.well-known/oauth-authorization-server |
### Discoverable Endpoints
These endpoints implement OpenID Connect elements.
-| Endpoint | Path | Discovery Attribute |
-|:---------------:|:-----------------------------------------------:|:----------------------:|
-| JWKS | https://auth.example.com/jwks.json | jwks_uri |
-| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
-| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
-| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
-| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
-| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
+| Endpoint | Path | Discovery Attribute |
+|:-------------------:|:-----------------------------------------------:|:----------------------:|
+| [JSON Web Key Sets] | https://auth.example.com/jwks.json | jwks_uri |
+| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
+| [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
+| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
+| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
+| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
+[JSON Web Key Sets]: https://datatracker.ietf.org/doc/html/rfc7517#section-5
[OpenID Connect]: https://openid.net/connect/
-[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
+[OpenID Connect Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
+[OAuth 2.0 Authorization Server Metadata]: https://datatracker.ietf.org/doc/html/rfc8414
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
+[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration
diff --git a/docs/configuration/storage/migrations.md b/docs/configuration/storage/migrations.md
index a66c6e30..4d9787f8 100644
--- a/docs/configuration/storage/migrations.md
+++ b/docs/configuration/storage/migrations.md
@@ -24,3 +24,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
| 1 | 4.33.0 | Initial migration managed version |
| 2 | 4.34.0 | Webauthn - added webauthn_devices table, altered totp_config to include device created/used dates |
| 3 | 4.34.2 | Webauthn - fix V2 migration kid column length and provide migration path for anyone on V2 |
+| 4 | 4.35.0 | Added OpenID Connect storage tables and opaque user identifier tables |
diff --git a/docs/roadmap/oidc.md b/docs/roadmap/oidc.md
index 26d2d936..910517a8 100644
--- a/docs/roadmap/oidc.md
+++ b/docs/roadmap/oidc.md
@@ -74,7 +74,7 @@ for which stage will have each feature, and may evolve over time:
Proof Key for Code Exchange (PKCE) for Authorization Code Flow |
-
+
Token Storage |
@@ -83,6 +83,21 @@ for which stage will have each feature, and may evolve over time:
Subject Storage |
+
+ Pairwise Subject Identifier Type |
+
+
+ Per-Client Consent Pre-Configuration |
+
+
+ Cross-Origin Resource Sharing Configuration |
+
+
+ Authentication Methods References Claim |
+
+
+ UUID v4 sub claim |
+
Prompt Handling |
@@ -91,7 +106,7 @@ for which stage will have each feature, and may evolve over time:
Display Handling |
-
+
Back-Channel Logout |
@@ -103,9 +118,6 @@ for which stage will have each feature, and may evolve over time:
Client Secrets Hashed in Configuration |
-
- UUID or Random String for sub claim |
-
General Availability after previous stages are vetted for bug fixes |
diff --git a/internal/commands/helpers.go b/internal/commands/helpers.go
index ee09f743..d5c3990b 100644
--- a/internal/commands/helpers.go
+++ b/internal/commands/helpers.go
@@ -64,7 +64,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [
sessionProvider := session.NewProvider(config.Session, autheliaCertPool)
regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock)
- oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC)
+ oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC, storageProvider)
if err != nil {
errors = append(errors, err)
}
diff --git a/internal/configuration/schema/identity_providers.go b/internal/configuration/schema/identity_providers.go
index 3828172c..bb2c35bd 100644
--- a/internal/configuration/schema/identity_providers.go
+++ b/internal/configuration/schema/identity_providers.go
@@ -12,7 +12,6 @@ type IdentityProvidersConfiguration struct {
// OpenIDConnectConfiguration configuration for OpenID Connect.
type OpenIDConnectConfiguration struct {
- // This secret must be 32 bytes long.
HMACSecret string `koanf:"hmac_secret"`
IssuerPrivateKey string `koanf:"issuer_private_key"`
@@ -49,9 +48,10 @@ type OpenIDConnectClientConfiguration struct {
Policy string `koanf:"authorization_policy"`
+ RedirectURIs []string `koanf:"redirect_uris"`
+
Audience []string `koanf:"audience"`
Scopes []string `koanf:"scopes"`
- RedirectURIs []string `koanf:"redirect_uris"`
GrantTypes []string `koanf:"grant_types"`
ResponseTypes []string `koanf:"response_types"`
ResponseModes []string `koanf:"response_modes"`
diff --git a/internal/handlers/handler_firstfactor.go b/internal/handlers/handler_firstfactor.go
index 1a48b8cb..ca5b0569 100644
--- a/internal/handlers/handler_firstfactor.go
+++ b/internal/handlers/handler_firstfactor.go
@@ -73,7 +73,7 @@ func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re
userSession := ctx.GetSession()
newSession := session.NewDefaultUserSession()
- newSession.OIDCWorkflowSession = userSession.OIDCWorkflowSession
+ newSession.ConsentChallengeID = userSession.ConsentChallengeID
// Reset all values from previous session except OIDC workflow before regenerating the cookie.
if err = ctx.SaveSession(newSession); err != nil {
@@ -135,7 +135,7 @@ func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re
successful = true
- if userSession.OIDCWorkflowSession != nil {
+ if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx)
} else {
Handle1FAResponse(ctx, bodyJSON.TargetURL, bodyJSON.RequestMethod, userSession.Username, userSession.Groups)
diff --git a/internal/handlers/handler_oidc_authorization.go b/internal/handlers/handler_oidc_authorization.go
index c5940410..c1d34c95 100644
--- a/internal/handlers/handler_oidc_authorization.go
+++ b/internal/handlers/handler_oidc_authorization.go
@@ -2,18 +2,15 @@ package handlers
import (
"errors"
- "fmt"
"net/http"
- "strings"
"time"
+ "github.com/google/uuid"
"github.com/ory/fosite"
- "github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc"
- "github.com/authelia/authelia/v4/internal/session"
)
// OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint.
@@ -23,7 +20,7 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
var (
requester fosite.AuthorizeRequester
responder fosite.AuthorizeResponder
- client *oidc.InternalClient
+ client *oidc.Client
authTime time.Time
issuer string
err error
@@ -43,7 +40,7 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' is being processed", requester.GetID(), clientID)
- if client, err = ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID); err != nil {
+ if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(clientID); err != nil {
if errors.Is(err, fosite.ErrNotFound) {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: client was not found", requester.GetID(), clientID)
} else {
@@ -65,31 +62,27 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
userSession := ctx.GetSession()
- requestedScopes := requester.GetRequestedScopes()
- requestedAudience := requester.GetRequestedAudience()
+ var subject uuid.UUID
- isAuthInsufficient := !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel)
+ if subject, err = ctx.Providers.OpenIDConnect.Store.GetSubject(ctx, client.GetSectorIdentifier(), userSession.Username); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred retrieving subject for user '%s': %+v", requester.GetID(), client.GetID(), userSession.Username, err)
- if isAuthInsufficient || (isConsentMissing(userSession.OIDCWorkflowSession, requestedScopes, requestedAudience)) {
- oidcAuthorizeHandleAuthorizationOrConsentInsufficient(ctx, userSession, client, isAuthInsufficient, rw, r, requester, issuer)
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not retrieve the subject."))
return
}
- extraClaims := oidcGrantRequests(requester, requestedScopes, requestedAudience, &userSession)
-
- workflowCreated := time.Unix(userSession.OIDCWorkflowSession.CreatedTimestamp, 0)
-
- userSession.OIDCWorkflowSession = nil
-
- if err = ctx.SaveSession(userSession); err != nil {
- ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session: %+v", requester.GetID(), client.GetID(), err)
-
- ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session."))
+ var (
+ consent *model.OAuth2ConsentSession
+ handled bool
+ )
+ if consent, handled = handleOIDCAuthorizationConsent(ctx, issuer, client, userSession, subject, rw, r, requester); handled {
return
}
+ extraClaims := oidcGrantRequests(requester, consent, &userSession)
+
if authTime, err = userSession.AuthenticatedTime(client.Policy); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred checking authentication time: %+v", requester.GetID(), client.GetID(), err)
@@ -100,9 +93,8 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID)
- subject := userSession.Username
oidcSession := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(),
- subject, userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, workflowCreated, requester)
+ userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, consent, requester)
ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v",
requester.GetID(), oidcSession.ClientID, oidcSession.Subject, oidcSession.Username, oidcSession.Claims)
@@ -119,39 +111,13 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
return
}
- ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeResponse(rw, requester, responder)
-}
-
-func oidcAuthorizeHandleAuthorizationOrConsentInsufficient(
- ctx *middlewares.AutheliaCtx, userSession session.UserSession, client *oidc.InternalClient, isAuthInsufficient bool,
- rw http.ResponseWriter, r *http.Request,
- requester fosite.AuthorizeRequester, issuer string) {
- redirectURL := fmt.Sprintf("%s%s", issuer, string(ctx.Request.RequestURI()))
-
- ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' requires user '%s' provides consent for scopes '%s'",
- requester.GetID(), client.GetID(), userSession.Username, strings.Join(requester.GetRequestedScopes(), "', '"))
-
- userSession.OIDCWorkflowSession = &model.OIDCWorkflowSession{
- ClientID: client.GetID(),
- RequestedScopes: requester.GetRequestedScopes(),
- RequestedAudience: requester.GetRequestedAudience(),
- AuthURI: redirectURL,
- TargetURI: requester.GetRedirectURI().String(),
- Require2FA: client.Policy == authorization.TwoFactor,
- CreatedTimestamp: time.Now().Unix(),
- }
-
- if err := ctx.SaveSession(userSession); err != nil {
- ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session for consent: %+v", requester.GetID(), client.GetID(), err)
+ if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSessionGranted(ctx, consent.ID); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving consent session: %+v", requester.GetID(), client.GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session."))
return
}
- if isAuthInsufficient {
- http.Redirect(rw, r, issuer, http.StatusFound)
- } else {
- http.Redirect(rw, r, fmt.Sprintf("%s/consent", issuer), http.StatusFound)
- }
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeResponse(rw, requester, responder)
}
diff --git a/internal/handlers/handler_oidc_authorization_consent.go b/internal/handlers/handler_oidc_authorization_consent.go
new file mode 100644
index 00000000..7c94f8ea
--- /dev/null
+++ b/internal/handlers/handler_oidc_authorization_consent.go
@@ -0,0 +1,168 @@
+package handlers
+
+import (
+ "fmt"
+ "net/http"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/ory/fosite"
+
+ "github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/model"
+ "github.com/authelia/authelia/v4/internal/oidc"
+ "github.com/authelia/authelia/v4/internal/session"
+ "github.com/authelia/authelia/v4/internal/storage"
+ "github.com/authelia/authelia/v4/internal/utils"
+)
+
+func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client,
+ userSession session.UserSession, subject uuid.UUID,
+ rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) {
+ if userSession.ConsentChallengeID != nil {
+ return handleOIDCAuthorizationConsentWithChallengeID(ctx, rootURI, client, userSession, rw, r, requester)
+ }
+
+ return handleOIDCAuthorizationConsentOrGenerate(ctx, rootURI, client, userSession, subject, rw, r, requester)
+}
+
+func handleOIDCAuthorizationConsentWithChallengeID(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client,
+ userSession session.UserSession,
+ rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) {
+ var (
+ err error
+ )
+
+ if consent, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred during consent session lookup: %+v", requester.GetID(), requester.GetClient().GetID(), err)
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Failed to lookup consent session."))
+
+ userSession.ConsentChallengeID = nil
+
+ if err = ctx.SaveSession(userSession); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred unlinking consent session challenge id: %+v", requester.GetID(), requester.GetClient().GetID(), err)
+ }
+
+ return nil, true
+ }
+
+ if consent.Responded() {
+ userSession.ConsentChallengeID = nil
+
+ if err = ctx.SaveSession(userSession); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session: %+v", requester.GetID(), client.GetID(), err)
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session."))
+
+ return nil, true
+ }
+
+ if consent.Granted {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: this consent session with challenge id '%s' was already granted", requester.GetID(), client.GetID(), consent.ChallengeID.String())
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Authorization already granted."))
+
+ return nil, true
+ }
+
+ ctx.Logger.Debugf("Authorization Request with id '%s' loaded consent session with id '%d' and challenge id '%s' for client id '%s' and subject '%s' and scopes '%s'", requester.GetID(), consent.ID, consent.ChallengeID.String(), client.GetID(), consent.Subject.String(), strings.Join(requester.GetRequestedScopes(), " "))
+
+ if consent.IsDenied() {
+ ctx.Logger.Warnf("Authorization Request with id '%s' and challenge id '%s' for client id '%s' and subject '%s' and scopes '%s' was not denied by the user durng the consent session", requester.GetID(), consent.ChallengeID.String(), client.GetID(), consent.Subject.String(), strings.Join(requester.GetRequestedScopes(), " "))
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrAccessDenied)
+
+ return nil, true
+ }
+
+ return consent, false
+ }
+
+ handleOIDCAuthorizationConsentRedirect(rootURI, client, userSession, rw, r)
+
+ return consent, true
+}
+
+func handleOIDCAuthorizationConsentOrGenerate(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client,
+ userSession session.UserSession, subject uuid.UUID,
+ rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) {
+ var (
+ rows *storage.ConsentSessionRows
+ scopes, audience []string
+ err error
+ )
+
+ if rows, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionsPreConfigured(ctx, client.GetID(), subject); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' had error looking up pre-configured consent sessions: %+v", requester.GetID(), requester.GetClient().GetID(), err)
+ }
+
+ defer rows.Close()
+
+ for rows.Next() {
+ if consent, err = rows.Get(); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' had error looking up pre-configured consent sessions: %+v", requester.GetID(), requester.GetClient().GetID(), err)
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not lookup pre-configured consent sessions."))
+
+ return nil, true
+ }
+
+ scopes, audience = getExpectedScopesAndAudience(requester)
+
+ if consent.HasExactGrants(scopes, audience) && consent.CanGrant() {
+ break
+ }
+ }
+
+ if consent != nil && consent.HasExactGrants(scopes, audience) && consent.CanGrant() {
+ return consent, false
+ }
+
+ if consent, err = model.NewOAuth2ConsentSession(subject, requester); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred generating consent: %+v", requester.GetID(), requester.GetClient().GetID(), err)
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not generate the consent session."))
+
+ return nil, true
+ }
+
+ if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSession(ctx, *consent); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving consent session: %+v", requester.GetID(), client.GetID(), err)
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the consent session."))
+
+ return nil, true
+ }
+
+ userSession.ConsentChallengeID = &consent.ChallengeID
+
+ if err = ctx.SaveSession(userSession); err != nil {
+ ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving user session for consent: %+v", requester.GetID(), client.GetID(), err)
+
+ ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the user session."))
+
+ return nil, true
+ }
+
+ handleOIDCAuthorizationConsentRedirect(rootURI, client, userSession, rw, r)
+
+ return consent, true
+}
+
+func handleOIDCAuthorizationConsentRedirect(destination string, client *oidc.Client, userSession session.UserSession, rw http.ResponseWriter, r *http.Request) {
+ if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
+ destination = fmt.Sprintf("%s/consent", destination)
+ }
+
+ http.Redirect(rw, r, destination, http.StatusFound)
+}
+
+func getExpectedScopesAndAudience(requester fosite.Requester) (scopes, audience []string) {
+ audience = requester.GetRequestedAudience()
+ if !utils.IsStringInSlice(requester.GetClient().GetID(), audience) {
+ audience = append(audience, requester.GetClient().GetID())
+ }
+
+ return requester.GetRequestedScopes(), audience
+}
diff --git a/internal/handlers/handler_oidc_consent.go b/internal/handlers/handler_oidc_consent.go
index b403ecb0..af41dfb5 100644
--- a/internal/handlers/handler_oidc_consent.go
+++ b/internal/handlers/handler_oidc_consent.go
@@ -5,116 +5,135 @@ import (
"fmt"
"github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/model"
+ "github.com/authelia/authelia/v4/internal/oidc"
+ "github.com/authelia/authelia/v4/internal/session"
+ "github.com/authelia/authelia/v4/internal/utils"
)
// OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect.
func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) {
- userSession := ctx.GetSession()
-
- if userSession.OIDCWorkflowSession == nil {
- ctx.Logger.Debugf("Cannot consent for user %s when OIDC workflow has not been initiated", userSession.Username)
- ctx.ReplyForbidden()
-
- return
- }
-
- clientID := userSession.OIDCWorkflowSession.ClientID
- client, err := ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID)
-
- if err != nil {
- ctx.Logger.Debugf("Unable to find related client configuration with name '%s': %v", clientID, err)
- ctx.ReplyForbidden()
-
+ userSession, consent, client, handled := oidcConsentGetSessionsAndClient(ctx)
+ if handled {
return
}
if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
- ctx.Logger.Debugf("Insufficient permissions to give consent during GET current level: %d, require 2FA: %t", userSession.AuthenticationLevel, userSession.OIDCWorkflowSession.Require2FA)
+ ctx.Logger.Errorf("Unable to perform consent without sufficient authentication for user '%s' and client id '%s'", userSession.Username, consent.ClientID)
ctx.ReplyForbidden()
return
}
- if err := ctx.SetJSONBody(client.GetConsentResponseBody(userSession.OIDCWorkflowSession)); err != nil {
+ if err := ctx.SetJSONBody(client.GetConsentResponseBody(consent)); err != nil {
ctx.Error(fmt.Errorf("unable to set JSON body: %v", err), "Operation failed")
}
}
// OpenIDConnectConsentPOST handles consent responses for OpenID Connect.
func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
- userSession := ctx.GetSession()
+ var (
+ body oidc.ConsentPostRequestBody
+ err error
+ )
- if userSession.OIDCWorkflowSession == nil {
- ctx.Logger.Debugf("Cannot consent for user %s when OIDC workflow has not been initiated", userSession.Username)
- ctx.ReplyForbidden()
+ if err = json.Unmarshal(ctx.Request.Body(), &body); err != nil {
+ ctx.Logger.Errorf("Failed to parse JSON body in consent POST: %+v", err)
+ ctx.SetJSONError(messageOperationFailed)
return
}
- client, err := ctx.Providers.OpenIDConnect.Store.GetInternalClient(userSession.OIDCWorkflowSession.ClientID)
-
- if err != nil {
- ctx.Logger.Debugf("Unable to find related client configuration with name '%s': %v", userSession.OIDCWorkflowSession.ClientID, err)
- ctx.ReplyForbidden()
-
+ userSession, consent, client, handled := oidcConsentGetSessionsAndClient(ctx)
+ if handled {
return
}
if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
- ctx.Logger.Debugf("Insufficient permissions to give consent during POST current level: %d, require 2FA: %t", userSession.AuthenticationLevel, userSession.OIDCWorkflowSession.Require2FA)
+ ctx.Logger.Debugf("Insufficient permissions to give consent during POST current level: %d, require 2FA: %d", userSession.AuthenticationLevel, client.Policy)
ctx.ReplyForbidden()
return
}
- var body ConsentPostRequestBody
- err = json.Unmarshal(ctx.Request.Body(), &body)
+ if consent.ClientID != body.ClientID {
+ ctx.Logger.Errorf("User '%s' consented to scopes of another client (%s) than expected (%s). Beware this can be a sign of attack",
+ userSession.Username, body.ClientID, consent.ClientID)
+ ctx.SetJSONError(messageOperationFailed)
- if err != nil {
- ctx.Error(fmt.Errorf("unable to unmarshal body: %v", err), "Operation failed")
return
}
- if body.AcceptOrReject != accept && body.AcceptOrReject != reject {
- ctx.Logger.Infof("User %s tried to reply to consent with an unexpected verb", userSession.Username)
+ var (
+ externalRootURL string
+ authorized = true
+ )
+
+ switch body.AcceptOrReject {
+ case accept:
+ if externalRootURL, err = ctx.ExternalRootURL(); err != nil {
+ ctx.Logger.Errorf("Could not determine the external URL during consent session processing with challenge id '%s' for user '%s': %v", consent.ChallengeID.String(), userSession.Username, err)
+ ctx.SetJSONError(messageOperationFailed)
+
+ return
+ }
+
+ consent.GrantedScopes = consent.RequestedScopes
+ consent.GrantedAudience = consent.RequestedAudience
+
+ if !utils.IsStringInSlice(consent.ClientID, consent.GrantedAudience) {
+ consent.GrantedAudience = append(consent.GrantedAudience, consent.ClientID)
+ }
+ case reject:
+ authorized = false
+ default:
+ ctx.Logger.Warnf("User '%s' tried to reply to consent with an unexpected verb", userSession.Username)
ctx.ReplyBadRequest()
return
}
- if userSession.OIDCWorkflowSession.ClientID != body.ClientID {
- ctx.Logger.Infof("User %s consented to scopes of another client (%s) than expected (%s). Beware this can be a sign of attack",
- userSession.Username, body.ClientID, userSession.OIDCWorkflowSession.ClientID)
- ctx.ReplyBadRequest()
+ if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSessionResponse(ctx, *consent, authorized); err != nil {
+ ctx.Logger.Errorf("Failed to save the consent session response to the database: %+v", err)
+ ctx.SetJSONError(messageOperationFailed)
return
}
- var redirectionURL string
+ response := oidc.ConsentPostResponseBody{RedirectURI: fmt.Sprintf("%s%s?%s", externalRootURL, oidc.AuthorizationPath, consent.Form)}
- if body.AcceptOrReject == accept {
- redirectionURL = userSession.OIDCWorkflowSession.AuthURI
- userSession.OIDCWorkflowSession.GrantedScopes = userSession.OIDCWorkflowSession.RequestedScopes
- userSession.OIDCWorkflowSession.GrantedAudience = userSession.OIDCWorkflowSession.RequestedAudience
-
- if err := ctx.SaveSession(userSession); err != nil {
- ctx.Error(fmt.Errorf("unable to write session: %v", err), "Operation failed")
- return
- }
- } else if body.AcceptOrReject == reject {
- redirectionURL = fmt.Sprintf("%s?error=access_denied&error_description=%s",
- userSession.OIDCWorkflowSession.TargetURI, "User has rejected the scopes")
- userSession.OIDCWorkflowSession = nil
-
- if err := ctx.SaveSession(userSession); err != nil {
- ctx.Error(fmt.Errorf("unable to write session: %v", err), "Operation failed")
- return
- }
- }
-
- response := ConsentPostResponseBody{RedirectURI: redirectionURL}
-
- if err := ctx.SetJSONBody(response); err != nil {
+ if err = ctx.SetJSONBody(response); err != nil {
ctx.Error(fmt.Errorf("unable to set JSON body in response"), "Operation failed")
}
}
+
+func oidcConsentGetSessionsAndClient(ctx *middlewares.AutheliaCtx) (userSession session.UserSession, consent *model.OAuth2ConsentSession, client *oidc.Client, handled bool) {
+ var (
+ err error
+ )
+
+ userSession = ctx.GetSession()
+
+ if userSession.ConsentChallengeID == nil {
+ ctx.Logger.Errorf("Cannot consent for user '%s' when OIDC consent session has not been initiated", userSession.Username)
+ ctx.ReplyForbidden()
+
+ return userSession, nil, nil, true
+ }
+
+ if consent, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID); err != nil {
+ ctx.Logger.Errorf("Unable to load consent session with challenge id '%s': %v", userSession.ConsentChallengeID.String(), err)
+ ctx.ReplyForbidden()
+
+ return userSession, nil, nil, true
+ }
+
+ if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(consent.ClientID); err != nil {
+ ctx.Logger.Errorf("Unable to find related client configuration with name '%s': %v", consent.ClientID, err)
+ ctx.ReplyForbidden()
+
+ return userSession, nil, nil, true
+ }
+
+ return userSession, consent, client, false
+}
diff --git a/internal/handlers/handler_oidc_userinfo.go b/internal/handlers/handler_oidc_userinfo.go
index 1a46ec39..0e4c5914 100644
--- a/internal/handlers/handler_oidc_userinfo.go
+++ b/internal/handlers/handler_oidc_userinfo.go
@@ -11,6 +11,7 @@ import (
"github.com/pkg/errors"
"github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc"
)
@@ -21,7 +22,7 @@ func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter,
var (
tokenType fosite.TokenType
requester fosite.AccessRequester
- client *oidc.InternalClient
+ client *oidc.Client
err error
)
@@ -54,13 +55,13 @@ func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter,
return
}
- if client, err = ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID); err != nil {
+ if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(clientID); err != nil {
ctx.Providers.OpenIDConnect.WriteError(rw, req, errors.WithStack(fosite.ErrServerError.WithHint("Unable to assert type of client")))
return
}
- claims := requester.GetSession().(*oidc.OpenIDSession).IDTokenClaims().ToMap()
+ claims := requester.GetSession().(*model.OpenIDSession).IDTokenClaims().ToMap()
delete(claims, "jti")
delete(claims, "sid")
delete(claims, "at_hash")
diff --git a/internal/handlers/handler_sign_duo.go b/internal/handlers/handler_sign_duo.go
index 48aebbc6..880ffdb2 100644
--- a/internal/handlers/handler_sign_duo.go
+++ b/internal/handlers/handler_sign_duo.go
@@ -266,7 +266,7 @@ func HandleAllow(ctx *middlewares.AutheliaCtx, targetURL string) {
return
}
- if userSession.OIDCWorkflowSession != nil {
+ if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx)
} else {
Handle2FAResponse(ctx, targetURL)
diff --git a/internal/handlers/handler_sign_totp.go b/internal/handlers/handler_sign_totp.go
index f21c2d21..348925b8 100644
--- a/internal/handlers/handler_sign_totp.go
+++ b/internal/handlers/handler_sign_totp.go
@@ -78,7 +78,7 @@ func SecondFactorTOTPPost(ctx *middlewares.AutheliaCtx) {
return
}
- if userSession.OIDCWorkflowSession != nil {
+ if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx)
} else {
Handle2FAResponse(ctx, requestBody.TargetURL)
diff --git a/internal/handlers/handler_sign_webauthn.go b/internal/handlers/handler_sign_webauthn.go
index 53a8ba9c..5564851a 100644
--- a/internal/handlers/handler_sign_webauthn.go
+++ b/internal/handlers/handler_sign_webauthn.go
@@ -197,7 +197,7 @@ func SecondFactorWebauthnAssertionPOST(ctx *middlewares.AutheliaCtx) {
return
}
- if userSession.OIDCWorkflowSession != nil {
+ if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx)
} else {
Handle2FAResponse(ctx, requestBody.TargetURL)
diff --git a/internal/handlers/oidc.go b/internal/handlers/oidc.go
index 3244ed04..20c51d2c 100644
--- a/internal/handlers/oidc.go
+++ b/internal/handlers/oidc.go
@@ -6,24 +6,12 @@ import (
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/session"
- "github.com/authelia/authelia/v4/internal/utils"
)
-// isConsentMissing compares the requestedScopes and requestedAudience to the workflows
-// GrantedScopes and GrantedAudience and returns true if they do not match or the workflow is nil.
-func isConsentMissing(workflow *model.OIDCWorkflowSession, requestedScopes, requestedAudience []string) (isMissing bool) {
- if workflow == nil {
- return true
- }
-
- return len(requestedScopes) > 0 && utils.IsStringSlicesDifferent(requestedScopes, workflow.GrantedScopes) ||
- len(requestedAudience) > 0 && utils.IsStringSlicesDifferentFold(requestedAudience, workflow.GrantedAudience)
-}
-
-func oidcGrantRequests(ar fosite.AuthorizeRequester, scopes, audiences []string, userSession *session.UserSession) (extraClaims map[string]interface{}) {
+func oidcGrantRequests(ar fosite.AuthorizeRequester, consent *model.OAuth2ConsentSession, userSession *session.UserSession) (extraClaims map[string]interface{}) {
extraClaims = map[string]interface{}{}
- for _, scope := range scopes {
+ for _, scope := range consent.GrantedScopes {
if ar != nil {
ar.GrantScope(scope)
}
@@ -47,13 +35,9 @@ func oidcGrantRequests(ar fosite.AuthorizeRequester, scopes, audiences []string,
}
if ar != nil {
- for _, audience := range audiences {
+ for _, audience := range consent.GrantedAudience {
ar.GrantAudience(audience)
}
-
- if !utils.IsStringInSlice(ar.GetClient().GetID(), ar.GetGrantedAudience()) {
- ar.GrantAudience(ar.GetClient().GetID())
- }
}
return extraClaims
diff --git a/internal/handlers/oidc_test.go b/internal/handlers/oidc_test.go
index 8a5b21b0..95ff6e34 100644
--- a/internal/handlers/oidc_test.go
+++ b/internal/handlers/oidc_test.go
@@ -11,32 +11,12 @@ import (
"github.com/authelia/authelia/v4/internal/session"
)
-func TestShouldDetectIfConsentIsMissing(t *testing.T) {
- var workflow *model.OIDCWorkflowSession
-
- requestedScopes := []string{"openid", "profile"}
- requestedAudience := []string{"https://authelia.com"}
-
- assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
-
- workflow = &model.OIDCWorkflowSession{
- GrantedScopes: []string{"openid", "profile"},
- GrantedAudience: []string{"https://authelia.com"},
+func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) {
+ consent := &model.OAuth2ConsentSession{
+ GrantedScopes: []string{oidc.ScopeProfile},
}
- assert.False(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
-
- requestedScopes = []string{"openid", "profile", "group"}
-
- assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
-
- requestedScopes = []string{"openid", "profile"}
- requestedAudience = []string{"https://not.authelia.com"}
- assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
-}
-
-func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) {
- extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeProfile}, []string{}, &oidcUserSessionJohn)
+ extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 2)
@@ -48,7 +28,11 @@ func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) {
}
func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) {
- extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeGroups}, []string{}, &oidcUserSessionJohn)
+ consent := &model.OAuth2ConsentSession{
+ GrantedScopes: []string{oidc.ScopeGroups},
+ }
+
+ extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 1)
@@ -57,7 +41,7 @@ func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) {
assert.Contains(t, extraClaims[oidc.ClaimGroups], "admin")
assert.Contains(t, extraClaims[oidc.ClaimGroups], "dev")
- extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeGroups}, []string{}, &oidcUserSessionFred)
+ extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred)
assert.Len(t, extraClaims, 1)
@@ -67,7 +51,11 @@ func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) {
}
func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) {
- extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeEmail}, []string{}, &oidcUserSessionJohn)
+ consent := &model.OAuth2ConsentSession{
+ GrantedScopes: []string{oidc.ScopeEmail},
+ }
+
+ extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 3)
@@ -81,7 +69,7 @@ func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) {
require.Contains(t, extraClaims, oidc.ClaimEmailVerified)
assert.Equal(t, true, extraClaims[oidc.ClaimEmailVerified])
- extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeEmail}, []string{}, &oidcUserSessionFred)
+ extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred)
assert.Len(t, extraClaims, 2)
@@ -93,7 +81,11 @@ func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) {
}
func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) {
- extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile}, []string{}, &oidcUserSessionJohn)
+ consent := &model.OAuth2ConsentSession{
+ GrantedScopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile},
+ }
+
+ extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 2)
@@ -103,7 +95,7 @@ func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) {
require.Contains(t, extraClaims, oidc.ClaimDisplayName)
assert.Equal(t, "John Smith", extraClaims[oidc.ClaimDisplayName])
- extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile}, []string{}, &oidcUserSessionFred)
+ extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred)
assert.Len(t, extraClaims, 2)
diff --git a/internal/handlers/response.go b/internal/handlers/response.go
index 96301c35..a9f3bb74 100644
--- a/internal/handlers/response.go
+++ b/internal/handlers/response.go
@@ -7,9 +7,9 @@ import (
"github.com/valyala/fasthttp"
- "github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/middlewares"
+ "github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/utils"
)
@@ -17,14 +17,15 @@ import (
func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession()
- if userSession.OIDCWorkflowSession.Require2FA && userSession.AuthenticationLevel != authentication.TwoFactor {
- ctx.Logger.Warnf("OpenID Connect client '%s' requires 2FA, cannot be redirected yet", userSession.OIDCWorkflowSession.ClientID)
- ctx.ReplyOK()
+ if userSession.ConsentChallengeID == nil {
+ ctx.Logger.Errorf("Unable to handle OIDC workflow response because the user session doesn't contain a consent challenge id")
+
+ respondUnauthorized(ctx, messageOperationFailed)
return
}
- uri, err := ctx.ExternalRootURL()
+ externalRootURL, err := ctx.ExternalRootURL()
if err != nil {
ctx.Logger.Errorf("Unable to determine external Base URL: %v", err)
@@ -33,18 +34,37 @@ func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) {
return
}
- if isConsentMissing(
- userSession.OIDCWorkflowSession,
- userSession.OIDCWorkflowSession.RequestedScopes,
- userSession.OIDCWorkflowSession.RequestedAudience) {
- err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s/consent", uri)})
+ consent, err := ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID)
+ if err != nil {
+ ctx.Logger.Errorf("Unable to load consent session from database: %v", err)
- if err != nil {
+ respondUnauthorized(ctx, messageOperationFailed)
+
+ return
+ }
+
+ client, err := ctx.Providers.OpenIDConnect.Store.GetFullClient(consent.ClientID)
+ if err != nil {
+ ctx.Logger.Errorf("Unable to find client for the consent session: %v", err)
+
+ respondUnauthorized(ctx, messageOperationFailed)
+
+ return
+ }
+
+ if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
+ ctx.Logger.Warnf("OpenID Connect client '%s' requires 2FA, cannot be redirected yet", client.ID)
+ ctx.ReplyOK()
+
+ return
+ }
+
+ if userSession.ConsentChallengeID != nil {
+ if err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s/consent", externalRootURL)}); err != nil {
ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err)
}
} else {
- err = ctx.SetJSONBody(redirectResponse{Redirect: userSession.OIDCWorkflowSession.AuthURI})
- if err != nil {
+ if err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s%s?%s", externalRootURL, oidc.AuthorizationPath, consent.Form)}); err != nil {
ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err)
}
}
diff --git a/internal/handlers/types_oidc.go b/internal/handlers/types_oidc.go
deleted file mode 100644
index 0464bde6..00000000
--- a/internal/handlers/types_oidc.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package handlers
-
-// ConsentPostRequestBody schema of the request body of the consent POST endpoint.
-type ConsentPostRequestBody struct {
- ClientID string `json:"client_id"`
- AcceptOrReject string `json:"accept_or_reject"`
-}
-
-// ConsentPostResponseBody schema of the response body of the consent POST endpoint.
-type ConsentPostResponseBody struct {
- RedirectURI string `json:"redirect_uri"`
-}
diff --git a/internal/mocks/storage.go b/internal/mocks/storage.go
index 9c1ccce4..4bcea9ad 100644
--- a/internal/mocks/storage.go
+++ b/internal/mocks/storage.go
@@ -10,8 +10,10 @@ import (
time "time"
gomock "github.com/golang/mock/gomock"
+ uuid "github.com/google/uuid"
model "github.com/authelia/authelia/v4/internal/model"
+ storage "github.com/authelia/authelia/v4/internal/storage"
)
// MockStorage is a mock of Provider interface.
@@ -51,6 +53,21 @@ func (mr *MockStorageMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorage)(nil).AppendAuthenticationLog), arg0, arg1)
}
+// BeginTX mocks base method.
+func (m *MockStorage) BeginTX(arg0 context.Context) (context.Context, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "BeginTX", arg0)
+ ret0, _ := ret[0].(context.Context)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// BeginTX indicates an expected call of BeginTX.
+func (mr *MockStorageMockRecorder) BeginTX(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTX", reflect.TypeOf((*MockStorage)(nil).BeginTX), arg0)
+}
+
// Close mocks base method.
func (m *MockStorage) Close() error {
m.ctrl.T.Helper()
@@ -65,6 +82,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close))
}
+// Commit mocks base method.
+func (m *MockStorage) Commit(arg0 context.Context) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Commit", arg0)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Commit indicates an expected call of Commit.
+func (mr *MockStorageMockRecorder) Commit(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockStorage)(nil).Commit), arg0)
+}
+
// ConsumeIdentityVerification mocks base method.
func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 model.NullIP) error {
m.ctrl.T.Helper()
@@ -79,6 +110,34 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
}
+// DeactivateOAuth2Session mocks base method.
+func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeactivateOAuth2Session", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// DeactivateOAuth2Session indicates an expected call of DeactivateOAuth2Session.
+func (mr *MockStorageMockRecorder) DeactivateOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeactivateOAuth2Session", reflect.TypeOf((*MockStorage)(nil).DeactivateOAuth2Session), arg0, arg1, arg2)
+}
+
+// DeactivateOAuth2SessionByRequestID mocks base method.
+func (m *MockStorage) DeactivateOAuth2SessionByRequestID(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "DeactivateOAuth2SessionByRequestID", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// DeactivateOAuth2SessionByRequestID indicates an expected call of DeactivateOAuth2SessionByRequestID.
+func (mr *MockStorageMockRecorder) DeactivateOAuth2SessionByRequestID(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeactivateOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).DeactivateOAuth2SessionByRequestID), arg0, arg1, arg2)
+}
+
// DeletePreferredDuoDevice mocks base method.
func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
@@ -137,6 +196,66 @@ func (mr *MockStorageMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4)
}
+// LoadOAuth2BlacklistedJTI mocks base method.
+func (m *MockStorage) LoadOAuth2BlacklistedJTI(arg0 context.Context, arg1 string) (*model.OAuth2BlacklistedJTI, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadOAuth2BlacklistedJTI", arg0, arg1)
+ ret0, _ := ret[0].(*model.OAuth2BlacklistedJTI)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LoadOAuth2BlacklistedJTI indicates an expected call of LoadOAuth2BlacklistedJTI.
+func (mr *MockStorageMockRecorder) LoadOAuth2BlacklistedJTI(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2BlacklistedJTI", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2BlacklistedJTI), arg0, arg1)
+}
+
+// LoadOAuth2ConsentSessionByChallengeID mocks base method.
+func (m *MockStorage) LoadOAuth2ConsentSessionByChallengeID(arg0 context.Context, arg1 uuid.UUID) (*model.OAuth2ConsentSession, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadOAuth2ConsentSessionByChallengeID", arg0, arg1)
+ ret0, _ := ret[0].(*model.OAuth2ConsentSession)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LoadOAuth2ConsentSessionByChallengeID indicates an expected call of LoadOAuth2ConsentSessionByChallengeID.
+func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionByChallengeID(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionByChallengeID", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionByChallengeID), arg0, arg1)
+}
+
+// LoadOAuth2ConsentSessionsPreConfigured mocks base method.
+func (m *MockStorage) LoadOAuth2ConsentSessionsPreConfigured(arg0 context.Context, arg1 string, arg2 uuid.UUID) (*storage.ConsentSessionRows, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadOAuth2ConsentSessionsPreConfigured", arg0, arg1, arg2)
+ ret0, _ := ret[0].(*storage.ConsentSessionRows)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LoadOAuth2ConsentSessionsPreConfigured indicates an expected call of LoadOAuth2ConsentSessionsPreConfigured.
+func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionsPreConfigured(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionsPreConfigured", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionsPreConfigured), arg0, arg1, arg2)
+}
+
+// LoadOAuth2Session mocks base method.
+func (m *MockStorage) LoadOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) (*model.OAuth2Session, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadOAuth2Session", arg0, arg1, arg2)
+ ret0, _ := ret[0].(*model.OAuth2Session)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LoadOAuth2Session indicates an expected call of LoadOAuth2Session.
+func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
+}
+
// LoadPreferred2FAMethod mocks base method.
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
@@ -212,6 +331,36 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1)
}
+// LoadUserOpaqueIdentifier mocks base method.
+func (m *MockStorage) LoadUserOpaqueIdentifier(arg0 context.Context, arg1 uuid.UUID) (*model.UserOpaqueIdentifier, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadUserOpaqueIdentifier", arg0, arg1)
+ ret0, _ := ret[0].(*model.UserOpaqueIdentifier)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LoadUserOpaqueIdentifier indicates an expected call of LoadUserOpaqueIdentifier.
+func (mr *MockStorageMockRecorder) LoadUserOpaqueIdentifier(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserOpaqueIdentifier", reflect.TypeOf((*MockStorage)(nil).LoadUserOpaqueIdentifier), arg0, arg1)
+}
+
+// LoadUserOpaqueIdentifierBySignature mocks base method.
+func (m *MockStorage) LoadUserOpaqueIdentifierBySignature(arg0 context.Context, arg1, arg2, arg3 string) (*model.UserOpaqueIdentifier, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "LoadUserOpaqueIdentifierBySignature", arg0, arg1, arg2, arg3)
+ ret0, _ := ret[0].(*model.UserOpaqueIdentifier)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// LoadUserOpaqueIdentifierBySignature indicates an expected call of LoadUserOpaqueIdentifierBySignature.
+func (mr *MockStorageMockRecorder) LoadUserOpaqueIdentifierBySignature(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserOpaqueIdentifierBySignature", reflect.TypeOf((*MockStorage)(nil).LoadUserOpaqueIdentifierBySignature), arg0, arg1, arg2, arg3)
+}
+
// LoadWebauthnDevices mocks base method.
func (m *MockStorage) LoadWebauthnDevices(arg0 context.Context, arg1, arg2 int) ([]model.WebauthnDevice, error) {
m.ctrl.T.Helper()
@@ -242,6 +391,48 @@ func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1 inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1)
}
+// RevokeOAuth2Session mocks base method.
+func (m *MockStorage) RevokeOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RevokeOAuth2Session", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// RevokeOAuth2Session indicates an expected call of RevokeOAuth2Session.
+func (mr *MockStorageMockRecorder) RevokeOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2Session", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2Session), arg0, arg1, arg2)
+}
+
+// RevokeOAuth2SessionByRequestID mocks base method.
+func (m *MockStorage) RevokeOAuth2SessionByRequestID(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "RevokeOAuth2SessionByRequestID", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// RevokeOAuth2SessionByRequestID indicates an expected call of RevokeOAuth2SessionByRequestID.
+func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
+}
+
+// Rollback mocks base method.
+func (m *MockStorage) Rollback(arg0 context.Context) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Rollback", arg0)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Rollback indicates an expected call of Rollback.
+func (mr *MockStorageMockRecorder) Rollback(arg0 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockStorage)(nil).Rollback), arg0)
+}
+
// SaveIdentityVerification mocks base method.
func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 model.IdentityVerification) error {
m.ctrl.T.Helper()
@@ -256,6 +447,76 @@ func (mr *MockStorageMockRecorder) SaveIdentityVerification(arg0, arg1 interface
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).SaveIdentityVerification), arg0, arg1)
}
+// SaveOAuth2BlacklistedJTI mocks base method.
+func (m *MockStorage) SaveOAuth2BlacklistedJTI(arg0 context.Context, arg1 model.OAuth2BlacklistedJTI) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SaveOAuth2BlacklistedJTI", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SaveOAuth2BlacklistedJTI indicates an expected call of SaveOAuth2BlacklistedJTI.
+func (mr *MockStorageMockRecorder) SaveOAuth2BlacklistedJTI(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2BlacklistedJTI", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2BlacklistedJTI), arg0, arg1)
+}
+
+// SaveOAuth2ConsentSession mocks base method.
+func (m *MockStorage) SaveOAuth2ConsentSession(arg0 context.Context, arg1 model.OAuth2ConsentSession) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SaveOAuth2ConsentSession", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SaveOAuth2ConsentSession indicates an expected call of SaveOAuth2ConsentSession.
+func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSession(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSession", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSession), arg0, arg1)
+}
+
+// SaveOAuth2ConsentSessionGranted mocks base method.
+func (m *MockStorage) SaveOAuth2ConsentSessionGranted(arg0 context.Context, arg1 int) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SaveOAuth2ConsentSessionGranted", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SaveOAuth2ConsentSessionGranted indicates an expected call of SaveOAuth2ConsentSessionGranted.
+func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionGranted(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionGranted", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionGranted), arg0, arg1)
+}
+
+// SaveOAuth2ConsentSessionResponse mocks base method.
+func (m *MockStorage) SaveOAuth2ConsentSessionResponse(arg0 context.Context, arg1 model.OAuth2ConsentSession, arg2 bool) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SaveOAuth2ConsentSessionResponse", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SaveOAuth2ConsentSessionResponse indicates an expected call of SaveOAuth2ConsentSessionResponse.
+func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionResponse", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionResponse), arg0, arg1, arg2)
+}
+
+// SaveOAuth2Session mocks base method.
+func (m *MockStorage) SaveOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 model.OAuth2Session) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SaveOAuth2Session", arg0, arg1, arg2)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SaveOAuth2Session indicates an expected call of SaveOAuth2Session.
+func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
+}
+
// SavePreferred2FAMethod mocks base method.
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
@@ -298,6 +559,20 @@ func (mr *MockStorageMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).SaveTOTPConfiguration), arg0, arg1)
}
+// SaveUserOpaqueIdentifier mocks base method.
+func (m *MockStorage) SaveUserOpaqueIdentifier(arg0 context.Context, arg1 model.UserOpaqueIdentifier) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "SaveUserOpaqueIdentifier", arg0, arg1)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// SaveUserOpaqueIdentifier indicates an expected call of SaveUserOpaqueIdentifier.
+func (mr *MockStorageMockRecorder) SaveUserOpaqueIdentifier(arg0, arg1 interface{}) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveUserOpaqueIdentifier", reflect.TypeOf((*MockStorage)(nil).SaveUserOpaqueIdentifier), arg0, arg1)
+}
+
// SaveWebauthnDevice mocks base method.
func (m *MockStorage) SaveWebauthnDevice(arg0 context.Context, arg1 model.WebauthnDevice) error {
m.ctrl.T.Helper()
diff --git a/internal/model/oidc.go b/internal/model/oidc.go
index d599d9f7..2c132556 100644
--- a/internal/model/oidc.go
+++ b/internal/model/oidc.go
@@ -1,14 +1,228 @@
package model
-// OIDCWorkflowSession represent an OIDC workflow session.
-type OIDCWorkflowSession struct {
- ClientID string
- RequestedScopes []string
- GrantedScopes []string
- RequestedAudience []string
- GrantedAudience []string
- TargetURI string
- AuthURI string
- Require2FA bool
- CreatedTimestamp int64
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/url"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/ory/fosite"
+ "github.com/ory/fosite/handler/openid"
+
+ "github.com/authelia/authelia/v4/internal/utils"
+)
+
+// NewOAuth2ConsentSession creates a new OAuth2ConsentSession.
+func NewOAuth2ConsentSession(subject uuid.UUID, r fosite.Requester) (consent *OAuth2ConsentSession, err error) {
+ consent = &OAuth2ConsentSession{
+ ClientID: r.GetClient().GetID(),
+ Subject: subject,
+ Form: r.GetRequestForm().Encode(),
+ RequestedAt: r.GetRequestedAt(),
+ RequestedScopes: StringSlicePipeDelimited(r.GetRequestedScopes()),
+ RequestedAudience: StringSlicePipeDelimited(r.GetRequestedAudience()),
+ GrantedScopes: StringSlicePipeDelimited(r.GetGrantedScopes()),
+ GrantedAudience: StringSlicePipeDelimited(r.GetGrantedAudience()),
+ }
+
+ if consent.ChallengeID, err = uuid.NewRandom(); err != nil {
+ return nil, err
+ }
+
+ return consent, nil
+}
+
+// NewOAuth2SessionFromRequest creates a new OAuth2Session from a signature and fosite.Requester.
+func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session *OAuth2Session, err error) {
+ var (
+ subject string
+ openidSession *OpenIDSession
+ sessData []byte
+ )
+
+ openidSession = r.GetSession().(*OpenIDSession)
+ if openidSession == nil {
+ return nil, errors.New("unexpected session type")
+ }
+
+ subject = openidSession.GetSubject()
+
+ if sessData, err = json.Marshal(openidSession); err != nil {
+ return nil, err
+ }
+
+ return &OAuth2Session{
+ ChallengeID: openidSession.ChallengeID,
+ RequestID: r.GetID(),
+ ClientID: r.GetClient().GetID(),
+ Signature: signature,
+ RequestedAt: r.GetRequestedAt(),
+ Subject: subject,
+ RequestedScopes: StringSlicePipeDelimited(r.GetRequestedScopes()),
+ GrantedScopes: StringSlicePipeDelimited(r.GetGrantedScopes()),
+ RequestedAudience: StringSlicePipeDelimited(r.GetRequestedAudience()),
+ GrantedAudience: StringSlicePipeDelimited(r.GetGrantedAudience()),
+ Active: true,
+ Revoked: false,
+ Form: r.GetRequestForm().Encode(),
+ Session: sessData,
+ }, nil
+}
+
+// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI.
+func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) {
+ return OAuth2BlacklistedJTI{
+ Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))),
+ ExpiresAt: exp,
+ }
+}
+
+// OAuth2ConsentSession stores information about an OAuth2.0 Consent.
+type OAuth2ConsentSession struct {
+ ID int `db:"id"`
+ ChallengeID uuid.UUID `db:"challenge_id"`
+ ClientID string `db:"client_id"`
+ Subject uuid.UUID `db:"subject"`
+
+ Authorized bool `db:"authorized"`
+ Granted bool `db:"granted"`
+
+ RequestedAt time.Time `db:"requested_at"`
+ RespondedAt *time.Time `db:"responded_at"`
+ ExpiresAt *time.Time `db:"expires_at"`
+
+ Form string `db:"form_data"`
+
+ RequestedScopes StringSlicePipeDelimited `db:"requested_scopes"`
+ GrantedScopes StringSlicePipeDelimited `db:"granted_scopes"`
+ RequestedAudience StringSlicePipeDelimited `db:"requested_audience"`
+ GrantedAudience StringSlicePipeDelimited `db:"granted_audience"`
+}
+
+// HasExactGrants returns true if the granted audience and scopes of this consent matches exactly with another
+// audience and set of scopes.
+func (s OAuth2ConsentSession) HasExactGrants(scopes, audience []string) (has bool) {
+ return s.HasExactGrantedScopes(scopes) && s.HasExactGrantedAudience(audience)
+}
+
+// HasExactGrantedAudience returns true if the granted audience of this consent matches exactly with another audience.
+func (s OAuth2ConsentSession) HasExactGrantedAudience(audience []string) (has bool) {
+ return !utils.IsStringSlicesDifferent(s.GrantedAudience, audience)
+}
+
+// HasExactGrantedScopes returns true if the granted scopes of this consent matches exactly with another set of scopes.
+func (s OAuth2ConsentSession) HasExactGrantedScopes(scopes []string) (has bool) {
+ return !utils.IsStringSlicesDifferent(s.GrantedScopes, scopes)
+}
+
+// IsAuthorized returns true if the user has responded to the consent session and it was authorized.
+func (s OAuth2ConsentSession) IsAuthorized() bool {
+ return s.Responded() && s.Authorized
+}
+
+// CanGrant returns true if the user has responded to the consent session, it was authorized, and it either hast not
+// previously been granted or the ability to grant has not expired.
+func (s OAuth2ConsentSession) CanGrant() bool {
+ if !s.Responded() {
+ return false
+ }
+
+ if s.Granted && (s.ExpiresAt == nil || s.ExpiresAt.Before(time.Now())) {
+ return false
+ }
+
+ return true
+}
+
+// IsDenied returns true if the user has responded to the consent session and it was not authorized.
+func (s OAuth2ConsentSession) IsDenied() bool {
+ return s.Responded() && !s.Authorized
+}
+
+// Responded returns true if the user has responded to the consent session.
+func (s OAuth2ConsentSession) Responded() bool {
+ return s.RespondedAt != nil
+}
+
+// GetForm returns the form.
+func (s OAuth2ConsentSession) GetForm() (form url.Values, err error) {
+ return url.ParseQuery(s.Form)
+}
+
+// OAuth2BlacklistedJTI represents a blacklisted JTI used with OAuth2.0.
+type OAuth2BlacklistedJTI struct {
+ ID int `db:"id"`
+ Signature string `db:"signature"`
+ ExpiresAt time.Time `db:"expires_at"`
+}
+
+// OpenIDSession holds OIDC Session information.
+type OpenIDSession struct {
+ *openid.DefaultSession `json:"id_token"`
+
+ ChallengeID uuid.UUID `db:"challenge_id"`
+ ClientID string
+
+ Extra map[string]interface{} `json:"extra"`
+}
+
+// OAuth2Session represents a OAuth2.0 session.
+type OAuth2Session struct {
+ ID int `db:"id"`
+ ChallengeID uuid.UUID `db:"challenge_id"`
+ RequestID string `db:"request_id"`
+ ClientID string `db:"client_id"`
+ Signature string `db:"signature"`
+ RequestedAt time.Time `db:"requested_at"`
+ Subject string `db:"subject"`
+ RequestedScopes StringSlicePipeDelimited `db:"requested_scopes"`
+ GrantedScopes StringSlicePipeDelimited `db:"granted_scopes"`
+ RequestedAudience StringSlicePipeDelimited `db:"requested_audience"`
+ GrantedAudience StringSlicePipeDelimited `db:"granted_audience"`
+ Active bool `db:"active"`
+ Revoked bool `db:"revoked"`
+ Form string `db:"form_data"`
+ Session []byte `db:"session_data"`
+}
+
+// SetSubject implements an interface required for RFC7523.
+func (s *OAuth2Session) SetSubject(subject string) {
+ s.Subject = subject
+}
+
+// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage.
+func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
+ sessionData := s.Session
+
+ if session != nil {
+ if err = json.Unmarshal(sessionData, session); err != nil {
+ return nil, err
+ }
+ }
+
+ client, err := store.GetClient(ctx, s.ClientID)
+ if err != nil {
+ return nil, err
+ }
+
+ values, err := url.ParseQuery(s.Form)
+ if err != nil {
+ return nil, err
+ }
+
+ return &fosite.Request{
+ ID: s.RequestID,
+ RequestedAt: s.RequestedAt,
+ Client: client,
+ RequestedScope: fosite.Arguments(s.RequestedScopes),
+ GrantedScope: fosite.Arguments(s.GrantedScopes),
+ RequestedAudience: fosite.Arguments(s.RequestedAudience),
+ GrantedAudience: fosite.Arguments(s.GrantedAudience),
+ Form: values,
+ Session: session,
+ }, nil
}
diff --git a/internal/model/types.go b/internal/model/types.go
index de1658b2..674e6aa4 100644
--- a/internal/model/types.go
+++ b/internal/model/types.go
@@ -1,10 +1,13 @@
package model
import (
+ "database/sql"
"database/sql/driver"
"encoding/base64"
"fmt"
"net"
+
+ "github.com/authelia/authelia/v4/internal/utils"
)
// NewIP easily constructs a new IP.
@@ -150,3 +153,26 @@ func (b *Base64) Scan(src interface{}) (err error) {
type StartupCheck interface {
StartupCheck() (err error)
}
+
+// StringSlicePipeDelimited is a string slice that is stored in the database delimited by pipes.
+type StringSlicePipeDelimited []string
+
+// Scan is the StringSlicePipeDelimited implementation of the sql.Scanner.
+func (s *StringSlicePipeDelimited) Scan(value interface{}) (err error) {
+ var nullStr sql.NullString
+
+ if err = nullStr.Scan(value); err != nil {
+ return err
+ }
+
+ if nullStr.Valid {
+ *s = utils.StringSplitDelimitedEscaped(nullStr.String, '|')
+ }
+
+ return nil
+}
+
+// Value is the StringSlicePipeDelimited implementation of the databases/sql driver.Valuer.
+func (s StringSlicePipeDelimited) Value() (driver.Value, error) {
+ return utils.StringJoinDelimitedEscaped(s, '|'), nil
+}
diff --git a/internal/model/types_test.go b/internal/model/types_test.go
index ada1e9bb..50d8bf38 100644
--- a/internal/model/types_test.go
+++ b/internal/model/types_test.go
@@ -1,11 +1,21 @@
package model
import (
+ "fmt"
"testing"
+ "github.com/ory/fosite"
"github.com/stretchr/testify/assert"
)
+func Test(t *testing.T) {
+ args := fosite.Arguments{"abc", "123"}
+
+ x := StringSlicePipeDelimited(args)
+
+ fmt.Println(x)
+}
+
func TestDatabaseModelTypeIP(t *testing.T) {
ip := IP{}
diff --git a/internal/model/user_opaque_identifier.go b/internal/model/user_opaque_identifier.go
new file mode 100644
index 00000000..c89b8038
--- /dev/null
+++ b/internal/model/user_opaque_identifier.go
@@ -0,0 +1,33 @@
+package model
+
+import (
+ "fmt"
+
+ "github.com/google/uuid"
+)
+
+// NewUserOpaqueIdentifier either creates a new UserOpaqueIdentifier or returns an error.
+func NewUserOpaqueIdentifier(service, sectorID, username string) (id *UserOpaqueIdentifier, err error) {
+ var opaqueID uuid.UUID
+
+ if opaqueID, err = uuid.NewRandom(); err != nil {
+ return nil, fmt.Errorf("unable to generate uuid: %w", err)
+ }
+
+ return &UserOpaqueIdentifier{
+ Service: service,
+ SectorID: sectorID,
+ Username: username,
+ Identifier: opaqueID,
+ }, nil
+}
+
+// UserOpaqueIdentifier represents an opaque identifier for a user. Commonly used with OAuth 2.0 and OpenID Connect.
+type UserOpaqueIdentifier struct {
+ ID int `db:"id"`
+ Service string `db:"service"`
+ SectorID string `db:"sector_id"`
+ Username string `db:"username"`
+
+ Identifier uuid.UUID `db:"identifier"`
+}
diff --git a/internal/oidc/client.go b/internal/oidc/client.go
index 22f8dc98..28a35af6 100644
--- a/internal/oidc/client.go
+++ b/internal/oidc/client.go
@@ -9,9 +9,9 @@ import (
"github.com/authelia/authelia/v4/internal/model"
)
-// NewClient creates a new InternalClient.
-func NewClient(config schema.OpenIDConnectClientConfiguration) (client *InternalClient) {
- client = &InternalClient{
+// NewClient creates a new Client.
+func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client) {
+ client = &Client{
ID: config.ID,
Description: config.Description,
Secret: []byte(config.Secret),
@@ -37,42 +37,47 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Internal
}
// IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient.
-func (c InternalClient) IsAuthenticationLevelSufficient(level authentication.Level) bool {
+func (c Client) IsAuthenticationLevelSufficient(level authentication.Level) bool {
return authorization.IsAuthLevelSufficient(level, c.Policy)
}
// GetID returns the ID.
-func (c InternalClient) GetID() string {
+func (c Client) GetID() string {
return c.ID
}
-// GetConsentResponseBody returns the proper consent response body for this model.OIDCWorkflowSession.
-func (c InternalClient) GetConsentResponseBody(session *model.OIDCWorkflowSession) ConsentGetResponseBody {
+// GetSectorIdentifier returns the SectorIdentifier for this client.
+func (c Client) GetSectorIdentifier() string {
+ return c.SectorIdentifier
+}
+
+// GetConsentResponseBody returns the proper consent response body for this session.OIDCWorkflowSession.
+func (c Client) GetConsentResponseBody(consent *model.OAuth2ConsentSession) ConsentGetResponseBody {
body := ConsentGetResponseBody{
ClientID: c.ID,
ClientDescription: c.Description,
}
- if session != nil {
- body.Scopes = session.RequestedScopes
- body.Audience = session.RequestedAudience
+ if consent != nil {
+ body.Scopes = consent.RequestedScopes
+ body.Audience = consent.RequestedAudience
}
return body
}
// GetHashedSecret returns the Secret.
-func (c InternalClient) GetHashedSecret() []byte {
+func (c Client) GetHashedSecret() []byte {
return c.Secret
}
// GetRedirectURIs returns the RedirectURIs.
-func (c InternalClient) GetRedirectURIs() []string {
+func (c Client) GetRedirectURIs() []string {
return c.RedirectURIs
}
// GetGrantTypes returns the GrantTypes.
-func (c InternalClient) GetGrantTypes() fosite.Arguments {
+func (c Client) GetGrantTypes() fosite.Arguments {
if len(c.GrantTypes) == 0 {
return fosite.Arguments{"authorization_code"}
}
@@ -81,7 +86,7 @@ func (c InternalClient) GetGrantTypes() fosite.Arguments {
}
// GetResponseTypes returns the ResponseTypes.
-func (c InternalClient) GetResponseTypes() fosite.Arguments {
+func (c Client) GetResponseTypes() fosite.Arguments {
if len(c.ResponseTypes) == 0 {
return fosite.Arguments{"code"}
}
@@ -90,23 +95,23 @@ func (c InternalClient) GetResponseTypes() fosite.Arguments {
}
// GetScopes returns the Scopes.
-func (c InternalClient) GetScopes() fosite.Arguments {
+func (c Client) GetScopes() fosite.Arguments {
return c.Scopes
}
// IsPublic returns the value of the Public property.
-func (c InternalClient) IsPublic() bool {
+func (c Client) IsPublic() bool {
return c.Public
}
// GetAudience returns the Audience.
-func (c InternalClient) GetAudience() fosite.Arguments {
+func (c Client) GetAudience() fosite.Arguments {
return c.Audience
}
// GetResponseModes returns the valid response modes for this client.
//
// Implements the fosite.ResponseModeClient.
-func (c InternalClient) GetResponseModes() []fosite.ResponseModeType {
+func (c Client) GetResponseModes() []fosite.ResponseModeType {
return c.ResponseModes
}
diff --git a/internal/oidc/client_test.go b/internal/oidc/client_test.go
index fdad40f3..e208c082 100644
--- a/internal/oidc/client_test.go
+++ b/internal/oidc/client_test.go
@@ -44,7 +44,7 @@ func TestNewClient(t *testing.T) {
}
func TestIsAuthenticationLevelSufficient(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
c.Policy = authorization.Bypass
assert.True(t, c.IsAuthenticationLevelSufficient(authentication.NotAuthenticated))
@@ -68,7 +68,7 @@ func TestIsAuthenticationLevelSufficient(t *testing.T) {
}
func TestInternalClient_GetConsentResponseBody(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
consentRequestBody := c.GetConsentResponseBody(nil)
assert.Equal(t, "", consentRequestBody.ClientID)
@@ -79,7 +79,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) {
c.ID = "myclient"
c.Description = "My Client"
- workflow := &model.OIDCWorkflowSession{
+ consent := &model.OAuth2ConsentSession{
RequestedAudience: []string{"https://example.com"},
RequestedScopes: []string{"openid", "groups"},
}
@@ -87,7 +87,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) {
expectedScopes := []string{"openid", "groups"}
expectedAudiences := []string{"https://example.com"}
- consentRequestBody = c.GetConsentResponseBody(workflow)
+ consentRequestBody = c.GetConsentResponseBody(consent)
assert.Equal(t, "myclient", consentRequestBody.ClientID)
assert.Equal(t, "My Client", consentRequestBody.ClientDescription)
assert.Equal(t, expectedScopes, consentRequestBody.Scopes)
@@ -95,7 +95,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) {
}
func TestInternalClient_GetAudience(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
audience := c.GetAudience()
assert.Len(t, audience, 0)
@@ -108,7 +108,7 @@ func TestInternalClient_GetAudience(t *testing.T) {
}
func TestInternalClient_GetScopes(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
scopes := c.GetScopes()
assert.Len(t, scopes, 0)
@@ -121,7 +121,7 @@ func TestInternalClient_GetScopes(t *testing.T) {
}
func TestInternalClient_GetGrantTypes(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
grantTypes := c.GetGrantTypes()
require.Len(t, grantTypes, 1)
@@ -135,7 +135,7 @@ func TestInternalClient_GetGrantTypes(t *testing.T) {
}
func TestInternalClient_GetHashedSecret(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
hashedSecret := c.GetHashedSecret()
assert.Equal(t, []byte(nil), hashedSecret)
@@ -147,7 +147,7 @@ func TestInternalClient_GetHashedSecret(t *testing.T) {
}
func TestInternalClient_GetID(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
id := c.GetID()
assert.Equal(t, "", id)
@@ -159,7 +159,7 @@ func TestInternalClient_GetID(t *testing.T) {
}
func TestInternalClient_GetRedirectURIs(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
redirectURIs := c.GetRedirectURIs()
require.Len(t, redirectURIs, 0)
@@ -172,7 +172,7 @@ func TestInternalClient_GetRedirectURIs(t *testing.T) {
}
func TestInternalClient_GetResponseModes(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
responseModes := c.GetResponseModes()
require.Len(t, responseModes, 0)
@@ -191,7 +191,7 @@ func TestInternalClient_GetResponseModes(t *testing.T) {
}
func TestInternalClient_GetResponseTypes(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
responseTypes := c.GetResponseTypes()
require.Len(t, responseTypes, 1)
@@ -206,7 +206,7 @@ func TestInternalClient_GetResponseTypes(t *testing.T) {
}
func TestInternalClient_IsPublic(t *testing.T) {
- c := InternalClient{}
+ c := Client{}
assert.False(t, c.IsPublic())
diff --git a/internal/oidc/discovery.go b/internal/oidc/discovery.go
new file mode 100644
index 00000000..f2619969
--- /dev/null
+++ b/internal/oidc/discovery.go
@@ -0,0 +1,79 @@
+package oidc
+
+// NewOpenIDConnectWellKnownConfiguration generates a new OpenIDConnectWellKnownConfiguration.
+func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge, pairwise bool) (config OpenIDConnectWellKnownConfiguration) {
+ config = OpenIDConnectWellKnownConfiguration{
+ CommonDiscoveryOptions: CommonDiscoveryOptions{
+ SubjectTypesSupported: []string{
+ "public",
+ },
+ ResponseTypesSupported: []string{
+ "code",
+ "token",
+ "id_token",
+ "code token",
+ "code id_token",
+ "token id_token",
+ "code token id_token",
+ "none",
+ },
+ ResponseModesSupported: []string{
+ "form_post",
+ "query",
+ "fragment",
+ },
+ ScopesSupported: []string{
+ ScopeOfflineAccess,
+ ScopeOpenID,
+ ScopeProfile,
+ ScopeGroups,
+ ScopeEmail,
+ },
+ ClaimsSupported: []string{
+ "aud",
+ "exp",
+ "iat",
+ "iss",
+ "jti",
+ "rat",
+ "sub",
+ "auth_time",
+ "nonce",
+ ClaimEmail,
+ ClaimEmailVerified,
+ ClaimEmailAlts,
+ ClaimGroups,
+ ClaimPreferredUsername,
+ ClaimDisplayName,
+ },
+ },
+ OAuth2DiscoveryOptions: OAuth2DiscoveryOptions{
+ CodeChallengeMethodsSupported: []string{
+ "S256",
+ },
+ },
+ OpenIDConnectDiscoveryOptions: OpenIDConnectDiscoveryOptions{
+ IDTokenSigningAlgValuesSupported: []string{
+ "RS256",
+ },
+ UserinfoSigningAlgValuesSupported: []string{
+ "none",
+ "RS256",
+ },
+ RequestObjectSigningAlgValuesSupported: []string{
+ "none",
+ "RS256",
+ },
+ },
+ }
+
+ if pairwise {
+ config.SubjectTypesSupported = append(config.SubjectTypesSupported, "pairwise")
+ }
+
+ if enablePKCEPlainChallenge {
+ config.CodeChallengeMethodsSupported = append(config.CodeChallengeMethodsSupported, "plain")
+ }
+
+ return config
+}
diff --git a/internal/oidc/discovery_test.go b/internal/oidc/discovery_test.go
new file mode 100644
index 00000000..256f5f70
--- /dev/null
+++ b/internal/oidc/discovery_test.go
@@ -0,0 +1,65 @@
+package oidc
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) {
+ testCases := []struct {
+ desc string
+ pkcePlainChallenge, pairwise bool
+ expectCodeChallengeMethodsSupported, expectSubjectTypesSupported []string
+ }{
+ {
+ desc: "ShouldHaveChallengeMethodsS256ANDSubjectTypesSupportedPublic",
+ pkcePlainChallenge: false,
+ pairwise: false,
+ expectCodeChallengeMethodsSupported: []string{"S256"},
+ expectSubjectTypesSupported: []string{"public"},
+ },
+ {
+ desc: "ShouldHaveChallengeMethodsS256PlainANDSubjectTypesSupportedPublic",
+ pkcePlainChallenge: true,
+ pairwise: false,
+ expectCodeChallengeMethodsSupported: []string{"S256", "plain"},
+ expectSubjectTypesSupported: []string{"public"},
+ },
+ {
+ desc: "ShouldHaveChallengeMethodsS256ANDSubjectTypesSupportedPublicPairwise",
+ pkcePlainChallenge: false,
+ pairwise: true,
+ expectCodeChallengeMethodsSupported: []string{"S256"},
+ expectSubjectTypesSupported: []string{"public", "pairwise"},
+ },
+ {
+ desc: "ShouldHaveChallengeMethodsS256PlainANDSubjectTypesSupportedPublicPairwise",
+ pkcePlainChallenge: true,
+ pairwise: true,
+ expectCodeChallengeMethodsSupported: []string{"S256", "plain"},
+ expectSubjectTypesSupported: []string{"public", "pairwise"},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ actual := NewOpenIDConnectWellKnownConfiguration(tc.pkcePlainChallenge, tc.pairwise)
+ for _, codeChallengeMethod := range tc.expectCodeChallengeMethodsSupported {
+ assert.Contains(t, actual.CodeChallengeMethodsSupported, codeChallengeMethod)
+ }
+
+ for _, subjectType := range tc.expectSubjectTypesSupported {
+ assert.Contains(t, actual.SubjectTypesSupported, subjectType)
+ }
+
+ for _, codeChallengeMethod := range actual.CodeChallengeMethodsSupported {
+ assert.Contains(t, tc.expectCodeChallengeMethodsSupported, codeChallengeMethod)
+ }
+
+ for _, subjectType := range actual.SubjectTypesSupported {
+ assert.Contains(t, tc.expectSubjectTypesSupported, subjectType)
+ }
+ })
+ }
+}
diff --git a/internal/oidc/hasher.go b/internal/oidc/hasher.go
index 58f84a93..0bacbe98 100644
--- a/internal/oidc/hasher.go
+++ b/internal/oidc/hasher.go
@@ -6,7 +6,7 @@ import (
)
// Compare compares the hash with the data and returns an error if they don't match.
-func (h AutheliaHasher) Compare(_ context.Context, hash, data []byte) (err error) {
+func (h PlainTextHasher) Compare(_ context.Context, hash, data []byte) (err error) {
if subtle.ConstantTimeCompare(hash, data) == 0 {
return errPasswordsDoNotMatch
}
@@ -15,6 +15,6 @@ func (h AutheliaHasher) Compare(_ context.Context, hash, data []byte) (err error
}
// Hash creates a new hash from data.
-func (h AutheliaHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) {
+func (h PlainTextHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) {
return data, nil
}
diff --git a/internal/oidc/hasher_test.go b/internal/oidc/hasher_test.go
index bc93a235..a86123bf 100644
--- a/internal/oidc/hasher_test.go
+++ b/internal/oidc/hasher_test.go
@@ -8,7 +8,7 @@ import (
)
func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
- hasher := AutheliaHasher{}
+ hasher := PlainTextHasher{}
a := []byte("abc")
b := []byte("abc")
@@ -21,7 +21,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
}
func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
- hasher := AutheliaHasher{}
+ hasher := PlainTextHasher{}
a := []byte("abc")
b := []byte("abcd")
@@ -34,7 +34,7 @@ func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
}
func TestShouldHashPassword(t *testing.T) {
- hasher := AutheliaHasher{}
+ hasher := PlainTextHasher{}
data := []byte("abc")
diff --git a/internal/oidc/provider.go b/internal/oidc/provider.go
index 862ddce6..670d1e42 100644
--- a/internal/oidc/provider.go
+++ b/internal/oidc/provider.go
@@ -8,34 +8,35 @@ import (
"github.com/ory/herodot"
"github.com/authelia/authelia/v4/internal/configuration/schema"
+ "github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewOpenIDConnectProvider new-ups a OpenIDConnectProvider.
-func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) (provider OpenIDConnectProvider, err error) {
+func NewOpenIDConnectProvider(config *schema.OpenIDConnectConfiguration, storageProvider storage.Provider) (provider OpenIDConnectProvider, err error) {
provider = OpenIDConnectProvider{
Fosite: nil,
}
- if configuration == nil {
+ if config == nil {
return provider, nil
}
- provider.Store = NewOpenIDConnectStore(configuration)
+ provider.Store = NewOpenIDConnectStore(config, storageProvider)
composeConfiguration := &compose.Config{
- AccessTokenLifespan: configuration.AccessTokenLifespan,
- AuthorizeCodeLifespan: configuration.AuthorizeCodeLifespan,
- IDTokenLifespan: configuration.IDTokenLifespan,
- RefreshTokenLifespan: configuration.RefreshTokenLifespan,
- SendDebugMessagesToClients: configuration.EnableClientDebugMessages,
- MinParameterEntropy: configuration.MinimumParameterEntropy,
- EnforcePKCE: configuration.EnforcePKCE == "always",
- EnforcePKCEForPublicClients: configuration.EnforcePKCE != "never",
- EnablePKCEPlainChallengeMethod: configuration.EnablePKCEPlainChallenge,
+ AccessTokenLifespan: config.AccessTokenLifespan,
+ AuthorizeCodeLifespan: config.AuthorizeCodeLifespan,
+ IDTokenLifespan: config.IDTokenLifespan,
+ RefreshTokenLifespan: config.RefreshTokenLifespan,
+ SendDebugMessagesToClients: config.EnableClientDebugMessages,
+ MinParameterEntropy: config.MinimumParameterEntropy,
+ EnforcePKCE: config.EnforcePKCE == "always",
+ EnforcePKCEForPublicClients: config.EnforcePKCE != "never",
+ EnablePKCEPlainChallengeMethod: config.EnablePKCEPlainChallenge,
}
- keyManager, err := NewKeyManagerWithConfiguration(configuration)
+ keyManager, err := NewKeyManagerWithConfiguration(config)
if err != nil {
return provider, err
}
@@ -50,7 +51,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
strategy := &compose.CommonStrategy{
CoreStrategy: compose.NewOAuth2HMACStrategy(
composeConfiguration,
- []byte(utils.HashSHA256FromString(configuration.HMACSecret)),
+ []byte(utils.HashSHA256FromString(config.HMACSecret)),
nil,
),
OpenIDConnectTokenStrategy: compose.NewOpenIDConnectStrategy(
@@ -64,7 +65,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
composeConfiguration,
provider.Store,
strategy,
- AutheliaHasher{},
+ PlainTextHasher{},
/*
These are the OAuth2 and OpenIDConnect factories. Order is important (the OAuth2 factories at the top must
@@ -75,7 +76,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
compose.OAuth2AuthorizeImplicitFactory,
compose.OAuth2ClientCredentialsGrantFactory,
compose.OAuth2RefreshTokenGrantFactory,
- compose.OAuth2ResourceOwnerPasswordCredentialsFactory,
+ // compose.OAuth2ResourceOwnerPasswordCredentialsFactory,
// compose.RFC7523AssertionGrantFactory,.
compose.OpenIDConnectExplicitFactory,
@@ -89,80 +90,24 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
compose.OAuth2PKCEFactory,
)
- provider.discovery = OpenIDConnectWellKnownConfiguration{
- CommonDiscoveryOptions: CommonDiscoveryOptions{
- SubjectTypesSupported: []string{
- "public",
- },
- ResponseTypesSupported: []string{
- "code",
- "token",
- "id_token",
- "code token",
- "code id_token",
- "token id_token",
- "code token id_token",
- "none",
- },
- ResponseModesSupported: []string{
- "form_post",
- "query",
- "fragment",
- },
- ScopesSupported: []string{
- ScopeOfflineAccess,
- ScopeOpenID,
- ScopeProfile,
- ScopeGroups,
- ScopeEmail,
- },
- ClaimsSupported: []string{
- "aud",
- "exp",
- "iat",
- "iss",
- "jti",
- "rat",
- "sub",
- "auth_time",
- "nonce",
- ClaimEmail,
- ClaimEmailVerified,
- ClaimEmailAlts,
- ClaimGroups,
- ClaimPreferredUsername,
- ClaimDisplayName,
- },
- },
- OAuth2DiscoveryOptions: OAuth2DiscoveryOptions{
- CodeChallengeMethodsSupported: []string{
- "S256",
- },
- },
- OpenIDConnectDiscoveryOptions: OpenIDConnectDiscoveryOptions{
- IDTokenSigningAlgValuesSupported: []string{
- "RS256",
- },
- UserinfoSigningAlgValuesSupported: []string{
- "none",
- "RS256",
- },
- RequestObjectSigningAlgValuesSupported: []string{
- "none",
- "RS256",
- },
- },
- }
-
- if configuration.EnablePKCEPlainChallenge {
- provider.discovery.CodeChallengeMethodsSupported = append(provider.discovery.CodeChallengeMethodsSupported, "plain")
- }
+ provider.discovery = NewOpenIDConnectWellKnownConfiguration(config.EnablePKCEPlainChallenge, provider.Pairwise())
provider.herodot = herodot.NewJSONWriter(nil)
return provider, nil
}
+// Pairwise returns true if this provider is configured with clients that require pairwise.
+func (p OpenIDConnectProvider) Pairwise() bool {
+ for _, c := range p.Store.clients {
+ if c.SectorIdentifier != "" {
+ return true
+ }
+ }
+
+ return false
+}
+
// Write writes data with herodot.JSONWriter.
func (p OpenIDConnectProvider) Write(w http.ResponseWriter, r *http.Request, e interface{}, opts ...herodot.EncoderOptions) {
p.herodot.Write(w, r, e, opts...)
diff --git a/internal/oidc/provider_test.go b/internal/oidc/provider_test.go
index a6f6e0cb..0a07d78e 100644
--- a/internal/oidc/provider_test.go
+++ b/internal/oidc/provider_test.go
@@ -12,7 +12,7 @@ import (
var exampleIssuerPrivateKey = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAvcMVMB2vEbqI6PlSNJ4HmUyMxBDJ5iY7FS+zDDAHOZBg9S3S\nKcAn1CZcnyL0VvJ7wcdhR6oTnOwR94eKvzUyJZ+GL2hTMm27dubEYsNdhoCl6N3X\nyEEohNfoxiiCYraVauX8X3M9jFzbEz9+pacaDbHB2syaJ1qFmMNR+HSu2jPzOo7M\nlqKIOgUzA0741MaYNt47AEVg4XU5ORLdolbAkItmYg1QbyFndg9H5IvwKkYaXTGE\nlgDBcPUC0yVjAC15Mguquq+jZeQay+6PSbHTD8PQMOkLjyChI2xEhVNbdCXe676R\ncMW2R/gjrcK23zmtmTWRfdC1iZLSlHO+bJj9vQIDAQABAoIBAEZvkP/JJOCJwqPn\nV3IcbmmilmV4bdi1vByDFgyiDyx4wOSA24+PubjvfFW9XcCgRPuKjDtTj/AhWBHv\nB7stfa2lZuNV7/u562mZArA+IAr62Zp0LdIxDV8x3T8gbjVB3HhPYbv0RJZDKTYd\nzV6jhfIrVu9mHpoY6ZnodhapCPYIyk/d49KBIHZuAc25CUjMXgTeaVtf0c996036\nUxW6ef33wAOJAvW0RCvbXAJfmBeEq2qQlkjTIlpYx71fhZWexHifi8Ouv3Zonc+1\n/P2Adq5uzYVBT92f9RKHg9QxxNzVrLjSMaxyvUtWQCAQfW0tFIRdqBGsHYsQrFtI\nF4yzv8ECgYEA7ntpyN9HD9Z9lYQzPCR73sFCLM+ID99aVij0wHuxK97bkSyyvkLd\n7MyTaym3lg1UEqWNWBCLvFULZx7F0Ah6qCzD4ymm3Bj/ADpWWPgljBI0AFml+HHs\nhcATmXUrj5QbLyhiP2gmJjajp1o/rgATx6ED66seSynD6JOH8wUhhZUCgYEAy7OA\n06PF8GfseNsTqlDjNF0K7lOqd21S0prdwrsJLiVzUlfMM25MLE0XLDUutCnRheeh\nIlcuDoBsVTxz6rkvFGD74N+pgXlN4CicsBq5ofK060PbqCQhSII3fmHobrZ9Cr75\nHmBjAxHx998SKaAAGbBbcYGUAp521i1pH5CEPYkCgYEAkUd1Zf0+2RMdZhwm6hh/\nrW+l1I6IoMK70YkZsLipccRNld7Y9LbfYwYtODcts6di9AkOVfueZJiaXbONZfIE\nZrb+jkAteh9wGL9xIrnohbABJcV3Kiaco84jInUSmGDtPokncOENfHIEuEpuSJ2b\nbx1TuhmAVuGWivR0+ULC7RECgYEAgS0cDRpWc9Xzh9Cl7+PLsXEvdWNpPsL9OsEq\n0Ep7z9+/+f/jZtoTRCS/BTHUpDvAuwHglT5j3p5iFMt5VuiIiovWLwynGYwrbnNS\nqfrIrYKUaH1n1oDS+oBZYLQGCe9/7EifAjxtjYzbvSyg//SPG7tSwfBCREbpZXj2\nqSWkNsECgYA/mCDzCTlrrWPuiepo6kTmN+4TnFA+hJI6NccDVQ+jvbqEdoJ4SW4L\nzqfZSZRFJMNpSgIqkQNRPJqMP0jQ5KRtJrjMWBnYxktwKz9fDg2R2MxdFgMF2LH2\nHEMMhFHlv8NDjVOXh1KwRoltNGVWYsSrD9wKU9GhRCEfmNCGrvBcEg==\n-----END RSA PRIVATE KEY-----"
func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing.T) {
- provider, err := NewOpenIDConnectProvider(nil)
+ provider, err := NewOpenIDConnectProvider(nil, nil)
assert.NoError(t, err)
assert.Nil(t, provider.Fosite)
@@ -22,7 +22,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing
func TestOpenIDConnectProvider_NewOpenIDConnectProvider_BadIssuerKey(t *testing.T) {
_, err := NewOpenIDConnectProvider(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: "BAD KEY",
- })
+ }, nil)
assert.Error(t, err, "abc")
}
@@ -60,7 +60,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GoodConfiguration(t *tes
},
},
},
- })
+ }, nil)
assert.NotNil(t, provider)
assert.NoError(t, err)
@@ -80,10 +80,12 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
},
},
},
- })
+ }, nil)
assert.NoError(t, err)
+ assert.False(t, provider.Pairwise())
+
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
assert.Equal(t, "https://example.com", disco.Issuer)
@@ -95,8 +97,8 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
assert.Equal(t, "https://example.com/api/oidc/revocation", disco.RevocationEndpoint)
assert.Equal(t, "", disco.RegistrationEndpoint)
- require.Len(t, disco.CodeChallengeMethodsSupported, 1)
- assert.Equal(t, "S256", disco.CodeChallengeMethodsSupported[0])
+ assert.Len(t, disco.CodeChallengeMethodsSupported, 1)
+ assert.Contains(t, disco.CodeChallengeMethodsSupported, "S256")
assert.Len(t, disco.ScopesSupported, 5)
assert.Contains(t, disco.ScopesSupported, ScopeOpenID)
@@ -166,7 +168,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig
},
},
},
- })
+ }, nil)
assert.NoError(t, err)
@@ -241,7 +243,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
},
},
},
- })
+ }, nil)
assert.NoError(t, err)
diff --git a/internal/oidc/store.go b/internal/oidc/store.go
index 0b9d9a86..4a66ee9b 100644
--- a/internal/oidc/store.go
+++ b/internal/oidc/store.go
@@ -2,38 +2,32 @@ package oidc
import (
"context"
+ "crypto/sha256"
+ "database/sql"
+ "errors"
+ "fmt"
"time"
+ "github.com/google/uuid"
"github.com/ory/fosite"
- "github.com/ory/fosite/storage"
- "gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
+ "github.com/authelia/authelia/v4/internal/model"
+ "github.com/authelia/authelia/v4/internal/storage"
)
-// NewOpenIDConnectStore returns a new OpenIDConnectStore using the provided schema.OpenIDConnectConfiguration.
-func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore) {
+// NewOpenIDConnectStore returns a OpenIDConnectStore when provided with a schema.OpenIDConnectConfiguration and storage.Provider.
+func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider storage.Provider) (store *OpenIDConnectStore) {
logger := logging.Logger()
store = &OpenIDConnectStore{
- memory: &storage.MemoryStore{
- IDSessions: map[string]fosite.Requester{},
- Users: map[string]storage.MemoryUserRelation{},
- AuthorizeCodes: map[string]storage.StoreAuthorizeCode{},
- AccessTokens: map[string]fosite.Requester{},
- RefreshTokens: map[string]storage.StoreRefreshToken{},
- PKCES: map[string]fosite.Requester{},
- BlacklistedJTIs: map[string]time.Time{},
- AccessTokenRequestIDs: map[string]string{},
- RefreshTokenRequestIDs: map[string]string{},
- },
+ provider: provider,
+ clients: map[string]*Client{},
}
- store.clients = make(map[string]*InternalClient)
-
- for _, client := range configuration.Clients {
+ for _, client := range config.Clients {
policy := authorization.PolicyToLevel(client.Policy)
logger.Debugf("Registering client %s with policy %s (%v)", client.ID, client.Policy, policy)
@@ -43,9 +37,37 @@ func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (st
return store
}
+// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username.
+func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
+ if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil {
+ return nil, err
+ } else if opaqueID == nil {
+ if opaqueID, err = model.NewUserOpaqueIdentifier("openid", sectorID, username); err != nil {
+ return nil, err
+ }
+
+ if err = s.provider.SaveUserOpaqueIdentifier(ctx, *opaqueID); err != nil {
+ return nil, err
+ }
+ }
+
+ return opaqueID, nil
+}
+
+// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it.
+func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
+ var opaqueID *model.UserOpaqueIdentifier
+
+ if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil {
+ return uuid.UUID{}, err
+ }
+
+ return opaqueID.Identifier, nil
+}
+
// GetClientPolicy retrieves the policy from the client with the matching provided id.
func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
- client, err := s.GetInternalClient(id)
+ client, err := s.GetFullClient(id)
if err != nil {
return authorization.TwoFactor
}
@@ -53,8 +75,8 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve
return client.Policy
}
-// GetInternalClient returns a fosite.Client asserted as an InternalClient matching the provided id.
-func (s OpenIDConnectStore) GetInternalClient(id string) (client *InternalClient, err error) {
+// GetFullClient returns a fosite.Client asserted as an Client matching the provided id.
+func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
client, ok := s.clients[id]
if !ok {
return nil, fosite.ErrNotFound
@@ -65,142 +87,248 @@ func (s OpenIDConnectStore) GetInternalClient(id string) (client *InternalClient
// IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map.
func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
- _, err := s.GetInternalClient(id)
+ _, err := s.GetFullClient(id)
return err == nil
}
-// CreateOpenIDConnectSession decorates fosite's storage.MemoryStore CreateOpenIDConnectSession method.
-func (s *OpenIDConnectStore) CreateOpenIDConnectSession(ctx context.Context, authorizeCode string, requester fosite.Requester) error {
- return s.memory.CreateOpenIDConnectSession(ctx, authorizeCode, requester)
+// BeginTX starts a transaction.
+// This implements a portion of fosite storage.Transactional interface.
+func (s *OpenIDConnectStore) BeginTX(ctx context.Context) (c context.Context, err error) {
+ return s.provider.BeginTX(ctx)
}
-// GetOpenIDConnectSession decorates fosite's storage.MemoryStore GetOpenIDConnectSession method.
-func (s *OpenIDConnectStore) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, requester fosite.Requester) (fosite.Requester, error) {
- return s.memory.GetOpenIDConnectSession(ctx, authorizeCode, requester)
+// Commit completes a transaction.
+// This implements a portion of fosite storage.Transactional interface.
+func (s *OpenIDConnectStore) Commit(ctx context.Context) (err error) {
+ return s.provider.Commit(ctx)
}
-// DeleteOpenIDConnectSession decorates fosite's storage.MemoryStore DeleteOpenIDConnectSession method.
-func (s *OpenIDConnectStore) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error {
- return s.memory.DeleteOpenIDConnectSession(ctx, authorizeCode)
+// Rollback rolls a transaction back.
+// This implements a portion of fosite storage.Transactional interface.
+func (s *OpenIDConnectStore) Rollback(ctx context.Context) (err error) {
+ return s.provider.Rollback(ctx)
}
-// GetClient decorates fosite's storage.MemoryStore GetClient method.
-func (s *OpenIDConnectStore) GetClient(_ context.Context, id string) (fosite.Client, error) {
- return s.GetInternalClient(id)
+// GetClient loads the client by its ID or returns an error if the client does not exist or another error occurred.
+// This implements a portion of fosite.ClientManager.
+func (s *OpenIDConnectStore) GetClient(_ context.Context, id string) (client fosite.Client, err error) {
+ return s.GetFullClient(id)
}
-// ClientAssertionJWTValid decorates fosite's storage.MemoryStore ClientAssertionJWTValid method.
-func (s *OpenIDConnectStore) ClientAssertionJWTValid(ctx context.Context, jti string) error {
- return s.memory.ClientAssertionJWTValid(ctx, jti)
+// ClientAssertionJWTValid returns an error if the JTI is known or the DB check failed and nil if the JTI is not known.
+// This implements a portion of fosite.ClientManager.
+func (s *OpenIDConnectStore) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
+ signature := fmt.Sprintf("%x", sha256.Sum256([]byte(jti)))
+
+ blacklistedJTI, err := s.provider.LoadOAuth2BlacklistedJTI(ctx, signature)
+
+ switch {
+ case errors.Is(sql.ErrNoRows, err):
+ return nil
+ case err != nil:
+ return err
+ case blacklistedJTI.ExpiresAt.After(time.Now()):
+ return fosite.ErrJTIKnown
+ default:
+ return nil
+ }
}
-// SetClientAssertionJWT decorates fosite's storage.MemoryStore SetClientAssertionJWT method.
-func (s *OpenIDConnectStore) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) error {
- return s.memory.SetClientAssertionJWT(ctx, jti, exp)
+// SetClientAssertionJWT marks a JTI as known for the given expiry time. Before inserting the new JTI, it will clean
+// up any existing JTIs that have expired as those tokens can not be replayed due to the expiry.
+// This implements a portion of fosite.ClientManager.
+func (s *OpenIDConnectStore) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) {
+ blacklistedJTI := model.NewOAuth2BlacklistedJTI(jti, exp)
+
+ return s.provider.SaveOAuth2BlacklistedJTI(ctx, blacklistedJTI)
}
-// CreateAuthorizeCodeSession decorates fosite's storage.MemoryStore CreateAuthorizeCodeSession method.
-func (s *OpenIDConnectStore) CreateAuthorizeCodeSession(ctx context.Context, code string, req fosite.Requester) error {
- return s.memory.CreateAuthorizeCodeSession(ctx, code, req)
+// CreateAuthorizeCodeSession stores the authorization request for a given authorization code.
+// This implements a portion of oauth2.AuthorizeCodeStorage.
+func (s *OpenIDConnectStore) CreateAuthorizeCodeSession(ctx context.Context, code string, request fosite.Requester) (err error) {
+ return s.saveSession(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, request)
}
-// GetAuthorizeCodeSession decorates fosite's storage.MemoryStore GetAuthorizeCodeSession method.
-func (s *OpenIDConnectStore) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (fosite.Requester, error) {
- return s.memory.GetAuthorizeCodeSession(ctx, code, session)
+// InvalidateAuthorizeCodeSession is called when an authorize code is being used. The state of the authorization
+// code should be set to invalid and consecutive requests to GetAuthorizeCodeSession should return the
+// ErrInvalidatedAuthorizeCode error.
+// This implements a portion of oauth2.AuthorizeCodeStorage.
+func (s *OpenIDConnectStore) InvalidateAuthorizeCodeSession(ctx context.Context, code string) (err error) {
+ return s.provider.DeactivateOAuth2Session(ctx, storage.OAuth2SessionTypeAuthorizeCode, code)
}
-// InvalidateAuthorizeCodeSession decorates fosite's storage.MemoryStore InvalidateAuthorizeCodeSession method.
-func (s *OpenIDConnectStore) InvalidateAuthorizeCodeSession(ctx context.Context, code string) error {
- return s.memory.InvalidateAuthorizeCodeSession(ctx, code)
+// GetAuthorizeCodeSession hydrates the session based on the given code and returns the authorization request.
+// If the authorization code has been invalidated with `InvalidateAuthorizeCodeSession`, this
+// method should return the ErrInvalidatedAuthorizeCode error.
+// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedAuthorizeCode error!
+// This implements a portion of oauth2.AuthorizeCodeStorage.
+func (s *OpenIDConnectStore) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) {
+ // TODO: Implement the fosite.ErrInvalidatedAuthorizeCode error above. This requires splitting the invalidated sessions and deleted sessions.
+ return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session)
}
-// CreatePKCERequestSession decorates fosite's storage.MemoryStore CreatePKCERequestSession method.
-func (s *OpenIDConnectStore) CreatePKCERequestSession(ctx context.Context, code string, req fosite.Requester) error {
- return s.memory.CreatePKCERequestSession(ctx, code, req)
+// CreateAccessTokenSession stores the authorization request for a given access token.
+// This implements a portion of oauth2.AccessTokenStorage.
+func (s *OpenIDConnectStore) CreateAccessTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) {
+ return s.saveSession(ctx, storage.OAuth2SessionTypeAccessToken, signature, request)
}
-// GetPKCERequestSession decorates fosite's storage.MemoryStore GetPKCERequestSession method.
-func (s *OpenIDConnectStore) GetPKCERequestSession(ctx context.Context, code string, session fosite.Session) (fosite.Requester, error) {
- return s.memory.GetPKCERequestSession(ctx, code, session)
+// DeleteAccessTokenSession marks an access token session as deleted.
+// This implements a portion of oauth2.AccessTokenStorage.
+func (s *OpenIDConnectStore) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) {
+ return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature)
}
-// DeletePKCERequestSession decorates fosite's storage.MemoryStore DeletePKCERequestSession method.
-func (s *OpenIDConnectStore) DeletePKCERequestSession(ctx context.Context, code string) error {
- return s.memory.DeletePKCERequestSession(ctx, code)
+// RevokeAccessToken revokes an access token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1
+// If the token passed to the request is an access token, the server MAY revoke the respective refresh token as well.
+// This implements a portion of oauth2.TokenRevocationStorage.
+func (s *OpenIDConnectStore) RevokeAccessToken(ctx context.Context, requestID string) (err error) {
+ return s.revokeSessionByRequestID(ctx, storage.OAuth2SessionTypeAccessToken, requestID)
}
-// CreateAccessTokenSession decorates fosite's storage.MemoryStore CreateAccessTokenSession method.
-func (s *OpenIDConnectStore) CreateAccessTokenSession(ctx context.Context, signature string, req fosite.Requester) error {
- return s.memory.CreateAccessTokenSession(ctx, signature, req)
+// GetAccessTokenSession gets the authorization request for a given access token.
+// This implements a portion of oauth2.AccessTokenStorage.
+func (s *OpenIDConnectStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
+ return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session)
}
-// GetAccessTokenSession decorates fosite's storage.MemoryStore GetAccessTokenSession method.
-func (s *OpenIDConnectStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) {
- return s.memory.GetAccessTokenSession(ctx, signature, session)
+// CreateRefreshTokenSession stores the authorization request for a given refresh token.
+// This implements a portion of oauth2.RefreshTokenStorage.
+func (s *OpenIDConnectStore) CreateRefreshTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) {
+ return s.saveSession(ctx, storage.OAuth2SessionTypeRefreshToken, signature, request)
}
-// DeleteAccessTokenSession decorates fosite's storage.MemoryStore DeleteAccessTokenSession method.
-func (s *OpenIDConnectStore) DeleteAccessTokenSession(ctx context.Context, signature string) error {
- return s.memory.DeleteAccessTokenSession(ctx, signature)
+// DeleteRefreshTokenSession marks the authorization request for a given refresh token as deleted.
+// This implements a portion of oauth2.RefreshTokenStorage.
+func (s *OpenIDConnectStore) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) {
+ return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature)
}
-// CreateRefreshTokenSession decorates fosite's storage.MemoryStore CreateRefreshTokenSession method.
-func (s *OpenIDConnectStore) CreateRefreshTokenSession(ctx context.Context, signature string, req fosite.Requester) error {
- return s.memory.CreateRefreshTokenSession(ctx, signature, req)
+// RevokeRefreshToken revokes a refresh token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1
+// If the particular token is a refresh token and the authorization server supports the revocation of access tokens,
+// then the authorization server SHOULD also invalidate all access tokens based on the same authorization grant (see Implementation Note).
+// This implements a portion of oauth2.TokenRevocationStorage.
+func (s *OpenIDConnectStore) RevokeRefreshToken(ctx context.Context, requestID string) (err error) {
+ return s.provider.DeactivateOAuth2SessionByRequestID(ctx, storage.OAuth2SessionTypeRefreshToken, requestID)
}
-// GetRefreshTokenSession decorates fosite's storage.MemoryStore GetRefreshTokenSession method.
-func (s *OpenIDConnectStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) {
- return s.memory.GetRefreshTokenSession(ctx, signature, session)
+// RevokeRefreshTokenMaybeGracePeriod revokes an access token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1
+// If the token passed to the request is an access token, the server MAY revoke the respective refresh token as well.
+// This implements a portion of oauth2.TokenRevocationStorage.
+func (s *OpenIDConnectStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) (err error) {
+ return s.RevokeRefreshToken(ctx, requestID)
}
-// DeleteRefreshTokenSession decorates fosite's storage.MemoryStore DeleteRefreshTokenSession method.
-func (s *OpenIDConnectStore) DeleteRefreshTokenSession(ctx context.Context, signature string) error {
- return s.memory.DeleteRefreshTokenSession(ctx, signature)
+// GetRefreshTokenSession gets the authorization request for a given refresh token.
+// This implements a portion of oauth2.RefreshTokenStorage.
+func (s *OpenIDConnectStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
+ return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session)
}
-// Authenticate decorates fosite's storage.MemoryStore Authenticate method.
-func (s *OpenIDConnectStore) Authenticate(ctx context.Context, name string, secret string) error {
- return s.memory.Authenticate(ctx, name, secret)
+// CreatePKCERequestSession stores the authorization request for a given PKCE request.
+// This implements a portion of pkce.PKCERequestStorage.
+func (s *OpenIDConnectStore) CreatePKCERequestSession(ctx context.Context, signature string, request fosite.Requester) (err error) {
+ return s.saveSession(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, request)
}
-// RevokeRefreshToken decorates fosite's storage.MemoryStore RevokeRefreshToken method.
-func (s *OpenIDConnectStore) RevokeRefreshToken(ctx context.Context, requestID string) error {
- return s.memory.RevokeRefreshToken(ctx, requestID)
+// DeletePKCERequestSession marks the authorization request for a given PKCE request as deleted.
+// This implements a portion of pkce.PKCERequestStorage.
+func (s *OpenIDConnectStore) DeletePKCERequestSession(ctx context.Context, signature string) (err error) {
+ return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature)
}
-// RevokeRefreshTokenMaybeGracePeriod decorates fosite's storage.MemoryStore RevokeRefreshTokenMaybeGracePeriod method.
-func (s OpenIDConnectStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) error {
- return s.memory.RevokeRefreshTokenMaybeGracePeriod(ctx, requestID, signature)
+// GetPKCERequestSession gets the authorization request for a given PKCE request.
+// This implements a portion of pkce.PKCERequestStorage.
+func (s *OpenIDConnectStore) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (requester fosite.Requester, err error) {
+ return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session)
}
-// RevokeAccessToken decorates fosite's storage.MemoryStore RevokeAccessToken method.
-func (s *OpenIDConnectStore) RevokeAccessToken(ctx context.Context, requestID string) error {
- return s.memory.RevokeAccessToken(ctx, requestID)
+// CreateOpenIDConnectSession creates an open id connect session for a given authorize code.
+// This is relevant for explicit open id connect flow.
+// This implements a portion of openid.OpenIDConnectRequestStorage.
+func (s *OpenIDConnectStore) CreateOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (err error) {
+ return s.saveSession(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request)
}
-// GetPublicKey decorates fosite's storage.MemoryStore GetPublicKey method.
-func (s *OpenIDConnectStore) GetPublicKey(ctx context.Context, issuer string, subject string, keyID string) (*jose.JSONWebKey, error) {
- return s.memory.GetPublicKey(ctx, issuer, subject, keyID)
+// DeleteOpenIDConnectSession just implements the method required by fosite even though it's unused.
+// This implements a portion of openid.OpenIDConnectRequestStorage.
+func (s *OpenIDConnectStore) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) (err error) {
+ return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, authorizeCode)
}
-// GetPublicKeys decorates fosite's storage.MemoryStore GetPublicKeys method.
-func (s *OpenIDConnectStore) GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) {
- return s.memory.GetPublicKeys(ctx, issuer, subject)
+// GetOpenIDConnectSession returns error:
+// - nil if a session was found,
+// - ErrNoSessionFound if no session was found
+// - or an arbitrary error if an error occurred.
+// This implements a portion of openid.OpenIDConnectRequestStorage.
+func (s *OpenIDConnectStore) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (r fosite.Requester, err error) {
+ return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession())
}
-// GetPublicKeyScopes decorates fosite's storage.MemoryStore GetPublicKeyScopes method.
-func (s *OpenIDConnectStore) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyID string) ([]string, error) {
- return s.memory.GetPublicKeyScopes(ctx, issuer, subject, keyID)
+// IsJWTUsed implements an interface required for RFC7523.
+func (s *OpenIDConnectStore) IsJWTUsed(ctx context.Context, jti string) (used bool, err error) {
+ if err = s.ClientAssertionJWTValid(ctx, jti); err != nil {
+ return true, err
+ }
+
+ return false, nil
}
-// IsJWTUsed decorates fosite's storage.MemoryStore IsJWTUsed method.
-func (s *OpenIDConnectStore) IsJWTUsed(ctx context.Context, jti string) (bool, error) {
- return s.memory.IsJWTUsed(ctx, jti)
+// MarkJWTUsedForTime implements an interface required for rfc7523.RFC7523KeyStorage.
+func (s *OpenIDConnectStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) {
+ return s.SetClientAssertionJWT(ctx, jti, exp)
}
-// MarkJWTUsedForTime decorates fosite's storage.MemoryStore MarkJWTUsedForTime method.
-func (s *OpenIDConnectStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error {
- return s.memory.MarkJWTUsedForTime(ctx, jti, exp)
+func (s *OpenIDConnectStore) loadSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) {
+ var (
+ sessionModel *model.OAuth2Session
+ )
+
+ sessionModel, err = s.provider.LoadOAuth2Session(ctx, sessionType, signature)
+ if err != nil {
+ switch {
+ case errors.Is(err, sql.ErrNoRows):
+ return nil, fosite.ErrNotFound
+ default:
+ return nil, err
+ }
+ }
+
+ if r, err = sessionModel.ToRequest(ctx, session, s); err != nil {
+ return nil, err
+ }
+
+ if !sessionModel.Active && sessionType == storage.OAuth2SessionTypeAuthorizeCode {
+ return r, fosite.ErrInvalidatedAuthorizeCode
+ }
+
+ return r, nil
+}
+
+func (s *OpenIDConnectStore) saveSession(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, r fosite.Requester) (err error) {
+ var session *model.OAuth2Session
+
+ if session, err = model.NewOAuth2SessionFromRequest(signature, r); err != nil {
+ return err
+ }
+
+ return s.provider.SaveOAuth2Session(ctx, sessionType, *session)
+}
+
+func (s *OpenIDConnectStore) revokeSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string) (err error) {
+ return s.provider.RevokeOAuth2Session(ctx, sessionType, signature)
+}
+
+func (s *OpenIDConnectStore) revokeSessionByRequestID(ctx context.Context, sessionType storage.OAuth2SessionType, requestID string) (err error) {
+ if err = s.provider.RevokeOAuth2SessionByRequestID(ctx, sessionType, requestID); err != nil {
+ switch {
+ case errors.Is(err, sql.ErrNoRows):
+ return fosite.ErrNotFound
+ default:
+ return err
+ }
+ }
+
+ return nil
}
diff --git a/internal/oidc/store_test.go b/internal/oidc/store_test.go
index d69df0a8..f73e2d8c 100644
--- a/internal/oidc/store_test.go
+++ b/internal/oidc/store_test.go
@@ -1,15 +1,6 @@
package oidc
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "github.com/authelia/authelia/v4/internal/authorization"
- "github.com/authelia/authelia/v4/internal/configuration/schema"
-)
+/*
func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
@@ -80,7 +71,7 @@ func TestOpenIDConnectStore_GetInternalClient_ValidClient(t *testing.T) {
Clients: []schema.OpenIDConnectClientConfiguration{c1},
})
- client, err := s.GetInternalClient(c1.ID)
+ client, err := s.GetFullClient(c1.ID)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, client.ID, c1.ID)
@@ -107,7 +98,7 @@ func TestOpenIDConnectStore_GetInternalClient_InvalidClient(t *testing.T) {
Clients: []schema.OpenIDConnectClientConfiguration{c1},
})
- client, err := s.GetInternalClient("another-client")
+ client, err := s.GetFullClient("another-client")
assert.Nil(t, client)
assert.EqualError(t, err, "not_found")
}
@@ -131,4 +122,5 @@ func TestOpenIDConnectStore_IsValidClientID(t *testing.T) {
assert.True(t, validClient)
assert.False(t, invalidClient)
-}
+}.
+*/
diff --git a/internal/oidc/types.go b/internal/oidc/types.go
index 9129a38a..e8859d75 100644
--- a/internal/oidc/types.go
+++ b/internal/oidc/types.go
@@ -6,17 +6,19 @@ import (
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
- "github.com/ory/fosite/storage"
"github.com/ory/fosite/token/jwt"
"github.com/ory/herodot"
"gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/authorization"
+ "github.com/authelia/authelia/v4/internal/model"
+ "github.com/authelia/authelia/v4/internal/storage"
+ "github.com/authelia/authelia/v4/internal/utils"
)
-// NewSession creates a new OpenIDSession struct.
-func NewSession() (session *OpenIDSession) {
- return &OpenIDSession{
+// NewSession creates a new empty OpenIDSession struct.
+func NewSession() (session *model.OpenIDSession) {
+ return &model.OpenIDSession{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Extra: map[string]interface{}{},
@@ -30,19 +32,19 @@ func NewSession() (session *OpenIDSession) {
}
// NewSessionWithAuthorizeRequest uses details from an AuthorizeRequester to generate an OpenIDSession.
-func NewSessionWithAuthorizeRequest(issuer, kid, subject, username string, amr []string, extra map[string]interface{},
- authTime, requestedAt time.Time, requester fosite.AuthorizeRequester) (session *OpenIDSession) {
+func NewSessionWithAuthorizeRequest(issuer, kid, username string, amr []string, extra map[string]interface{},
+ authTime time.Time, consent *model.OAuth2ConsentSession, requester fosite.AuthorizeRequester) (session *model.OpenIDSession) {
if extra == nil {
extra = make(map[string]interface{})
}
- return &OpenIDSession{
+ session = &model.OpenIDSession{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
- Subject: subject,
+ Subject: consent.Subject.String(),
Issuer: issuer,
AuthTime: authTime,
- RequestedAt: requestedAt,
+ RequestedAt: consent.RequestedAt,
IssuedAt: time.Now(),
Nonce: requester.GetRequestForm().Get("nonce"),
Audience: requester.GetGrantedAudience(),
@@ -55,12 +57,20 @@ func NewSessionWithAuthorizeRequest(issuer, kid, subject, username string, amr [
"kid": kid,
},
},
- Subject: subject,
+ Subject: consent.Subject.String(),
Username: username,
},
- Extra: map[string]interface{}{},
- ClientID: requester.GetClient().GetID(),
+ Extra: map[string]interface{}{},
+ ClientID: requester.GetClient().GetID(),
+ ChallengeID: consent.ChallengeID,
}
+
+ // Ensure required audience value of the client_id exists.
+ if !utils.IsStringInSlice(requester.GetClient().GetID(), session.Claims.Audience) {
+ session.Claims.Audience = append(session.Claims.Audience, requester.GetClient().GetID())
+ }
+
+ return session
}
// OpenIDConnectProvider for OpenID Connect.
@@ -74,33 +84,34 @@ type OpenIDConnectProvider struct {
discovery OpenIDConnectWellKnownConfiguration
}
-// OpenIDConnectStore is Authelia's internal representation of the fosite.Storage interface.
-//
-// Currently it is mostly just implementing a decorator pattern other then GetInternalClient.
-// The long term plan is to have these methods interact with the Authelia storage and
-// session providers where applicable.
+// OpenIDConnectStore is Authelia's internal representation of the fosite.Storage interface. It maps the following
+// interfaces to the storage.Provider interface:
+// fosite.Storage, fosite.ClientManager, storage.Transactional, oauth2.AuthorizeCodeStorage, oauth2.AccessTokenStorage,
+// oauth2.RefreshTokenStorage, oauth2.TokenRevocationStorage, pkce.PKCERequestStorage,
+// openid.OpenIDConnectRequestStorage, and partially implements rfc7523.RFC7523KeyStorage.
type OpenIDConnectStore struct {
- clients map[string]*InternalClient
- memory *storage.MemoryStore
+ provider storage.Provider
+ clients map[string]*Client
}
-// InternalClient represents the client internally.
-type InternalClient struct {
- ID string `json:"id"`
- Description string `json:"-"`
- Secret []byte `json:"client_secret,omitempty"`
- Public bool `json:"public"`
+// Client represents the client internally.
+type Client struct {
+ ID string
+ SectorIdentifier string
+ Description string
+ Secret []byte
+ Public bool
- Policy authorization.Level `json:"-"`
+ Policy authorization.Level
- Audience []string `json:"audience"`
- Scopes []string `json:"scopes"`
- RedirectURIs []string `json:"redirect_uris"`
- GrantTypes []string `json:"grant_types"`
- ResponseTypes []string `json:"response_types"`
- ResponseModes []fosite.ResponseModeType `json:"response_modes"`
+ Audience []string
+ Scopes []string
+ RedirectURIs []string
+ GrantTypes []string
+ ResponseTypes []string
+ ResponseModes []fosite.ResponseModeType
- UserinfoSigningAlgorithm string `json:"userinfo_signed_response_alg,omitempty"`
+ UserinfoSigningAlgorithm string
}
// KeyManager keeps track of all of the active/inactive rsa keys and provides them to services requiring them.
@@ -112,8 +123,8 @@ type KeyManager struct {
strategy *RS256JWTStrategy
}
-// AutheliaHasher implements the fosite.Hasher interface without an actual hashing algo.
-type AutheliaHasher struct{}
+// PlainTextHasher implements the fosite.Hasher interface without an actual hashing algo.
+type PlainTextHasher struct{}
// ConsentGetResponseBody schema of the response body of the consent GET endpoint.
type ConsentGetResponseBody struct {
@@ -123,12 +134,15 @@ type ConsentGetResponseBody struct {
Audience []string `json:"audience"`
}
-// OpenIDSession holds OIDC Session information.
-type OpenIDSession struct {
- *openid.DefaultSession `json:"idToken"`
+// ConsentPostRequestBody schema of the request body of the consent POST endpoint.
+type ConsentPostRequestBody struct {
+ ClientID string `json:"client_id"`
+ AcceptOrReject string `json:"accept_or_reject"`
+}
- Extra map[string]interface{} `json:"extra"`
- ClientID string
+// ConsentPostResponseBody schema of the response body of the consent POST endpoint.
+type ConsentPostResponseBody struct {
+ RedirectURI string `json:"redirect_uri"`
}
/*
diff --git a/internal/oidc/types_test.go b/internal/oidc/types_test.go
index f8d31dac..7a96558a 100644
--- a/internal/oidc/types_test.go
+++ b/internal/oidc/types_test.go
@@ -9,6 +9,8 @@ import (
"github.com/ory/fosite"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+
+ "github.com/authelia/authelia/v4/internal/model"
)
func TestNewSession(t *testing.T) {
@@ -38,7 +40,7 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) {
Request: fosite.Request{
ID: requestID.String(),
Form: formValues,
- Client: &InternalClient{ID: "example"},
+ Client: &Client{ID: "example"},
},
}
@@ -51,7 +53,13 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) {
issuer := "https://example.com"
amr := []string{AMRPasswordBasedAuthentication}
- session := NewSessionWithAuthorizeRequest(issuer, "primary", subject.String(), "john", amr, extra, authAt, requested, request)
+ consent := &model.OAuth2ConsentSession{
+ ChallengeID: uuid.New(),
+ RequestedAt: requested,
+ Subject: subject,
+ }
+
+ session := NewSessionWithAuthorizeRequest(issuer, "primary", "john", amr, extra, authAt, consent, request)
require.NotNil(t, session)
require.NotNil(t, session.Extra)
@@ -78,7 +86,12 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) {
require.Contains(t, session.Claims.Extra, "preferred_username")
- session = NewSessionWithAuthorizeRequest(issuer, "primary", subject.String(), "john", nil, nil, authAt, requested, request)
+ consent = &model.OAuth2ConsentSession{
+ ChallengeID: uuid.New(),
+ RequestedAt: requested,
+ }
+
+ session = NewSessionWithAuthorizeRequest(issuer, "primary", "john", nil, nil, authAt, consent, request)
require.NotNil(t, session)
require.NotNil(t, session.Claims)
diff --git a/internal/session/types.go b/internal/session/types.go
index 235df041..b73937d7 100644
--- a/internal/session/types.go
+++ b/internal/session/types.go
@@ -7,11 +7,11 @@ import (
"github.com/fasthttp/session/v2"
"github.com/fasthttp/session/v2/providers/redis"
"github.com/go-webauthn/webauthn/webauthn"
+ "github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/logging"
- "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc"
)
@@ -43,8 +43,8 @@ type UserSession struct {
// Webauthn holds the session registration data for this session.
Webauthn *webauthn.SessionData
- // Represent an OIDC workflow session initiated by the client if not null.
- OIDCWorkflowSession *model.OIDCWorkflowSession
+ // ConsentChallengeID is the OpenID Connect Consent Session challenge ID.
+ ConsentChallengeID *uuid.UUID
// This boolean is set to true after identity verification and checked
// while doing the query actually updating the password.
diff --git a/internal/storage/const.go b/internal/storage/const.go
index affa0bca..cf188b10 100644
--- a/internal/storage/const.go
+++ b/internal/storage/const.go
@@ -5,18 +5,40 @@ import (
)
const (
- tableUserPreferences = "user_preferences"
+ tableAuthenticationLogs = "authentication_logs"
+ tableDuoDevices = "duo_devices"
tableIdentityVerification = "identity_verification"
tableTOTPConfigurations = "totp_configurations"
+ tableUserOpaqueIdentifier = "user_opaque_identifier"
+ tableUserPreferences = "user_preferences"
tableWebauthnDevices = "webauthn_devices"
- tableDuoDevices = "duo_devices"
- tableAuthenticationLogs = "authentication_logs"
- tableMigrations = "migrations"
- tableEncryption = "encryption"
+
+ tableOAuth2ConsentSession = "oauth2_consent_session"
+ tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
+ tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential.
+ tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
+ tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
+ tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session"
+ tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
+
+ tableMigrations = "migrations"
+ tableEncryption = "encryption"
tablePrefixBackup = "_bkp_"
)
+// OAuth2SessionType represents the potential OAuth 2.0 session types.
+type OAuth2SessionType string
+
+// Representation of specific OAuth 2.0 session types.
+const (
+ OAuth2SessionTypeAuthorizeCode OAuth2SessionType = "authorization code"
+ OAuth2SessionTypeAccessToken OAuth2SessionType = "access token"
+ OAuth2SessionTypeRefreshToken OAuth2SessionType = "refresh token"
+ OAuth2SessionTypePKCEChallenge OAuth2SessionType = "pkce challenge"
+ OAuth2SessionTypeOpenIDConnect OAuth2SessionType = "openid connect"
+)
+
const (
encryptionNameCheck = "check"
)
@@ -56,7 +78,7 @@ const (
const (
// This is the latest schema version for the purpose of tests.
- testLatestVersion = 3
+ testLatestVersion = 4
)
const (
@@ -64,6 +86,12 @@ const (
SchemaLatest = 2147483647
)
+type ctxKey int
+
+const (
+ ctxKeyTransaction ctxKey = iota
+)
+
var (
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`)
)
diff --git a/internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql b/internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql
new file mode 100644
index 00000000..4fc3adc7
--- /dev/null
+++ b/internal/storage/migrations/V0004.OpenIDConenct.mysql.up.sql
@@ -0,0 +1,188 @@
+CREATE TABLE IF NOT EXISTS user_opaque_identifier (
+ id INTEGER AUTO_INCREMENT,
+ service VARCHAR(20) NOT NULL,
+ sector_id VARCHAR(255) NOT NULL,
+ username VARCHAR(100) NOT NULL,
+ identifier CHAR(36) NOT NULL,
+ PRIMARY KEY (id)
+);
+
+CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username);
+CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier);
+
+CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti (
+ id INTEGER AUTO_INCREMENT,
+ signature VARCHAR(64) NOT NULL,
+ expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (id)
+);
+
+CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature);
+
+CREATE TABLE IF NOT EXISTS oauth2_consent_session (
+ id INTEGER AUTO_INCREMENT,
+ challenge_id CHAR(36) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ authorized BOOLEAN NOT NULL DEFAULT FALSE,
+ granted BOOLEAN NOT NULL DEFAULT FALSE,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ responded_at TIMESTAMP NULL DEFAULT NULL,
+ expires_at TIMESTAMP NULL DEFAULT NULL,
+ form_data TEXT NOT NULL,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL,
+ granted_audience TEXT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_consent_subject_fkey
+ FOREIGN KEY (subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id);
+
+CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session (
+ id INTEGER AUTO_INCREMENT,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL,
+ granted_audience TEXT NULL,
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey
+ FOREIGN KEY (challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_authorization_code_session_subject_fkey
+ FOREIGN KEY (subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id);
+CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id);
+CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_access_token_session (
+ id INTEGER AUTO_INCREMENT,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL,
+ granted_audience TEXT NULL,
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_access_token_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_access_token_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id);
+CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id);
+CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session (
+ id INTEGER AUTO_INCREMENT,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL,
+ granted_audience TEXT NULL,
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_refresh_token_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id);
+CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id);
+CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session (
+ id INTEGER AUTO_INCREMENT,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL,
+ granted_audience TEXT NULL,
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_pkce_request_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id);
+CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id);
+CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session (
+ id INTEGER AUTO_INCREMENT,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL,
+ granted_audience TEXT NULL,
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_openid_connect_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id);
+CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id);
+CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject);
\ No newline at end of file
diff --git a/internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql b/internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql
new file mode 100644
index 00000000..c5685dd2
--- /dev/null
+++ b/internal/storage/migrations/V0004.OpenIDConenct.postgres.up.sql
@@ -0,0 +1,188 @@
+CREATE TABLE IF NOT EXISTS user_opaque_identifier (
+ id SERIAL,
+ service VARCHAR(20) NOT NULL,
+ sector_id VARCHAR(255) NOT NULL,
+ username VARCHAR(100) NOT NULL,
+ identifier CHAR(36) NOT NULL,
+ PRIMARY KEY (id)
+);
+
+CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username);
+CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier);
+
+CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti (
+ id SERIAL,
+ signature VARCHAR(64) NOT NULL,
+ expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (id)
+);
+
+CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature);
+
+CREATE TABLE IF NOT EXISTS oauth2_consent_session (
+ id SERIAL,
+ challenge_id CHAR(36) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ authorized BOOLEAN NOT NULL DEFAULT FALSE,
+ granted BOOLEAN NOT NULL DEFAULT FALSE,
+ requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ responded_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
+ expires_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
+ form_data TEXT NOT NULL,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_consent_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id);
+
+CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session (
+ id SERIAL,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36),
+ requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BYTEA NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_authorization_code_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id);
+CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id);
+CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_access_token_session (
+ id SERIAL,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BYTEA NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_access_token_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_access_token_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id);
+CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id);
+CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session (
+ id SERIAL,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BYTEA NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_refresh_token_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id);
+CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id);
+CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session (
+ id SERIAL,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BYTEA NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_pkce_request_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id);
+CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id);
+CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session (
+ id SERIAL,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BYTEA NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_openid_connect_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id);
+CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id);
+CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject);
diff --git a/internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql b/internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql
new file mode 100644
index 00000000..372fbeb5
--- /dev/null
+++ b/internal/storage/migrations/V0004.OpenIDConenct.sqlite.up.sql
@@ -0,0 +1,188 @@
+CREATE TABLE IF NOT EXISTS user_opaque_identifier (
+ id INTEGER,
+ service VARCHAR(20) NOT NULL,
+ sector_id VARCHAR(255) NOT NULL,
+ username VARCHAR(100) NOT NULL,
+ identifier CHAR(36) NOT NULL,
+ PRIMARY KEY (id)
+);
+
+CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username);
+CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier);
+
+CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti (
+ id INTEGER,
+ signature VARCHAR(64) NOT NULL,
+ expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (id)
+);
+
+CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature);
+
+CREATE TABLE IF NOT EXISTS oauth2_consent_session (
+ id INTEGER,
+ challenge_id CHAR(36) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ authorized BOOLEAN NOT NULL DEFAULT FALSE,
+ granted BOOLEAN NOT NULL DEFAULT FALSE,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ responded_at TIMESTAMP NULL DEFAULT NULL,
+ expires_at TIMESTAMP NULL DEFAULT NULL,
+ form_data TEXT NOT NULL,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_consent_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id);
+
+CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session (
+ id INTEGER,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_authorization_code_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id);
+CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id);
+CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_access_token_session (
+ id INTEGER,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_access_token_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_access_token_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id);
+CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id);
+CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session (
+ id INTEGER,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_refresh_token_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id);
+CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id);
+CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session (
+ id INTEGER,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_pkce_request_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id);
+CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id);
+CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject);
+
+CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session (
+ id INTEGER,
+ challenge_id CHAR(36) NOT NULL,
+ request_id VARCHAR(40) NOT NULL,
+ client_id VARCHAR(255) NOT NULL,
+ signature VARCHAR(255) NOT NULL,
+ subject CHAR(36) NOT NULL,
+ requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ requested_scopes TEXT NOT NULL,
+ granted_scopes TEXT NOT NULL,
+ requested_audience TEXT NULL DEFAULT '',
+ granted_audience TEXT NULL DEFAULT '',
+ active BOOLEAN NOT NULL DEFAULT FALSE,
+ revoked BOOLEAN NOT NULL DEFAULT FALSE,
+ form_data TEXT NOT NULL,
+ session_data BLOB NOT NULL,
+ PRIMARY KEY (id),
+ CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey
+ FOREIGN KEY(challenge_id)
+ REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ CONSTRAINT oauth2_openid_connect_session_subject_fkey
+ FOREIGN KEY(subject)
+ REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
+);
+
+CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id);
+CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id);
+CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject);
diff --git a/internal/storage/migrations/V0004.OpenIDConnect.all.down.sql b/internal/storage/migrations/V0004.OpenIDConnect.all.down.sql
new file mode 100644
index 00000000..481f502f
--- /dev/null
+++ b/internal/storage/migrations/V0004.OpenIDConnect.all.down.sql
@@ -0,0 +1,8 @@
+DROP TABLE IF EXISTS oauth2_blacklisted_jti;
+DROP TABLE IF EXISTS oauth2_authorization_code_session;
+DROP TABLE IF EXISTS oauth2_access_token_session;
+DROP TABLE IF EXISTS oauth2_refresh_token_session;
+DROP TABLE IF EXISTS oauth2_pkce_request_session;
+DROP TABLE IF EXISTS oauth2_openid_connect_session;
+DROP TABLE IF EXISTS oauth2_consent_session;
+DROP TABLE IF EXISTS user_opaque_identifier;
\ No newline at end of file
diff --git a/internal/storage/provider.go b/internal/storage/provider.go
index 7805e15c..742b3600 100644
--- a/internal/storage/provider.go
+++ b/internal/storage/provider.go
@@ -4,6 +4,9 @@ import (
"context"
"time"
+ "github.com/google/uuid"
+ "github.com/ory/fosite/storage"
+
"github.com/authelia/authelia/v4/internal/model"
)
@@ -13,10 +16,16 @@ type Provider interface {
RegulatorProvider
+ storage.Transactional
+
SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error)
LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error)
LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error)
+ SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error)
+ LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (subject *model.UserOpaqueIdentifier, err error)
+ LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error)
+
SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error)
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
@@ -36,6 +45,22 @@ type Provider interface {
DeletePreferredDuoDevice(ctx context.Context, username string) (err error)
LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error)
+ SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error)
+ SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, rejection bool) (err error)
+ SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error)
+ LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error)
+ LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error)
+
+ SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error)
+ RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error)
+ RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error)
+ DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error)
+ DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error)
+ LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error)
+
+ SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error)
+ LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error)
+
SchemaTables(ctx context.Context) (tables []string, err error)
SchemaVersion(ctx context.Context) (version int, err error)
SchemaLatestVersion() (version int, err error)
diff --git a/internal/storage/sql_provider.go b/internal/storage/sql_provider.go
index 126c4ddb..94b9c839 100644
--- a/internal/storage/sql_provider.go
+++ b/internal/storage/sql_provider.go
@@ -8,6 +8,7 @@ import (
"fmt"
"time"
+ "github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus"
@@ -63,6 +64,54 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebauthnDevices, tableDuoDevices, tableUserPreferences),
+ sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier),
+ sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier),
+ sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier),
+
+ sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
+ sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
+ sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
+
+ sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession),
+ sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession),
+ sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession),
+ sqlRevokeOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
+ sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession),
+ sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
+
+ sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
+ sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
+ sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
+
+ sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
+ sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
+ sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
+
+ sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession),
+ sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession),
+ sqlRevokeOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2OpenIDConnectSession),
+ sqlRevokeOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
+ sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession),
+ sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
+
+ sqlInsertOAuth2ConsentSession: fmt.Sprintf(queryFmtInsertOAuth2ConsentSession, tableOAuth2ConsentSession),
+ sqlUpdateOAuth2ConsentSessionResponse: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionResponse, tableOAuth2ConsentSession),
+ sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession),
+ sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession),
+ sqlSelectOAuth2ConsentSessionsPreConfigured: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionsPreConfigured, tableOAuth2ConsentSession),
+
+ sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
+ sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
+
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
@@ -128,6 +177,11 @@ type SQLProvider struct {
sqlSelectPreferred2FAMethod string
sqlSelectUserInfo string
+ // Table: user_opaque_identifier.
+ sqlInsertUserOpaqueIdentifier string
+ sqlSelectUserOpaqueIdentifier string
+ sqlSelectUserOpaqueIdentifierBySignature string
+
// Table: migrations.
sqlInsertMigration string
sqlSelectMigrations string
@@ -137,6 +191,56 @@ type SQLProvider struct {
sqlUpsertEncryptionValue string
sqlSelectEncryptionValue string
+ // Table: oauth2_authorization_code_session.
+ sqlInsertOAuth2AuthorizeCodeSession string
+ sqlSelectOAuth2AuthorizeCodeSession string
+ sqlRevokeOAuth2AuthorizeCodeSession string
+ sqlRevokeOAuth2AuthorizeCodeSessionByRequestID string
+ sqlDeactivateOAuth2AuthorizeCodeSession string
+ sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID string
+
+ // Table: oauth2_access_token_session.
+ sqlInsertOAuth2AccessTokenSession string
+ sqlSelectOAuth2AccessTokenSession string
+ sqlRevokeOAuth2AccessTokenSession string
+ sqlRevokeOAuth2AccessTokenSessionByRequestID string
+ sqlDeactivateOAuth2AccessTokenSession string
+ sqlDeactivateOAuth2AccessTokenSessionByRequestID string
+
+ // Table: oauth2_refresh_token_session.
+ sqlInsertOAuth2RefreshTokenSession string
+ sqlSelectOAuth2RefreshTokenSession string
+ sqlRevokeOAuth2RefreshTokenSession string
+ sqlRevokeOAuth2RefreshTokenSessionByRequestID string
+ sqlDeactivateOAuth2RefreshTokenSession string
+ sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
+
+ // Table: oauth2_pkce_request_session.
+ sqlInsertOAuth2PKCERequestSession string
+ sqlSelectOAuth2PKCERequestSession string
+ sqlRevokeOAuth2PKCERequestSession string
+ sqlRevokeOAuth2PKCERequestSessionByRequestID string
+ sqlDeactivateOAuth2PKCERequestSession string
+ sqlDeactivateOAuth2PKCERequestSessionByRequestID string
+
+ // Table: oauth2_openid_connect_session.
+ sqlInsertOAuth2OpenIDConnectSession string
+ sqlSelectOAuth2OpenIDConnectSession string
+ sqlRevokeOAuth2OpenIDConnectSession string
+ sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
+ sqlDeactivateOAuth2OpenIDConnectSession string
+ sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
+
+ // Table: oauth2_consent_session.
+ sqlInsertOAuth2ConsentSession string
+ sqlUpdateOAuth2ConsentSessionResponse string
+ sqlUpdateOAuth2ConsentSessionGranted string
+ sqlSelectOAuth2ConsentSessionByChallengeID string
+ sqlSelectOAuth2ConsentSessionsPreConfigured string
+
+ sqlUpsertOAuth2BlacklistedJTI string
+ sqlSelectOAuth2BlacklistedJTI string
+
// Utility.
sqlSelectExistingTables string
sqlFmtRenameTable string
@@ -187,6 +291,329 @@ func (p *SQLProvider) StartupCheck() (err error) {
}
}
+// BeginTX begins a transaction.
+func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) {
+ var tx *sql.Tx
+
+ if tx, err = p.db.Begin(); err != nil {
+ return nil, err
+ }
+
+ return context.WithValue(ctx, ctxKeyTransaction, tx), nil
+}
+
+// Commit performs a database commit.
+func (p *SQLProvider) Commit(ctx context.Context) (err error) {
+ tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
+
+ if !ok {
+ return errors.New("could not retrieve tx")
+ }
+
+ return tx.Commit()
+}
+
+// Rollback performs a database rollback.
+func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
+ tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
+
+ if !ok {
+ return errors.New("could not retrieve tx")
+ }
+
+ return tx.Rollback()
+}
+
+// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
+func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil {
+ return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err)
+ }
+
+ return nil
+}
+
+// LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
+func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) {
+ opaqueID = &model.UserOpaqueIdentifier{}
+
+ if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil {
+ switch {
+ case errors.Is(err, sql.ErrNoRows):
+ return nil, nil
+ default:
+ return nil, err
+ }
+ }
+
+ return opaqueID, nil
+}
+
+// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
+func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
+ opaqueID = &model.UserOpaqueIdentifier{}
+
+ if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
+ switch {
+ case errors.Is(err, sql.ErrNoRows):
+ return nil, nil
+ default:
+ return nil, err
+ }
+ }
+
+ return opaqueID, nil
+}
+
+// SaveOAuth2ConsentSession inserts an OAuth2.0 consent.
+func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession,
+ consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted,
+ consent.RequestedAt, consent.RespondedAt, consent.ExpiresAt, consent.Form,
+ consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience); err != nil {
+ return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.String(), err)
+ }
+
+ return nil
+}
+
+// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent with the consent response.
+func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) {
+ _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.ExpiresAt, consent.GrantedScopes, consent.GrantedAudience, consent.ID)
+ if err != nil {
+ return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject, err)
+ }
+
+ return nil
+}
+
+// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint.
+func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil {
+ return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err)
+ }
+
+ return nil
+}
+
+// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID.
+func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) {
+ consent = &model.OAuth2ConsentSession{}
+
+ if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil {
+ return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err)
+ }
+
+ return consent, nil
+}
+
+// LoadOAuth2ConsentSessionsPreConfigured returns an OAuth2.0 consents that are pre-configured given the consent signature.
+func (p *SQLProvider) LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error) {
+ var r *sqlx.Rows
+
+ if r, err = p.db.QueryxContext(ctx, p.sqlSelectOAuth2ConsentSessionsPreConfigured, clientID, subject); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return &ConsentSessionRows{}, nil
+ }
+
+ return &ConsentSessionRows{}, fmt.Errorf("error selecting oauth2 consent session by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err)
+ }
+
+ return &ConsentSessionRows{rows: r}, nil
+}
+
+// SaveOAuth2Session saves a OAuth2Session to the database.
+func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) {
+ var query string
+
+ switch sessionType {
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlInsertOAuth2AuthorizeCodeSession
+ case OAuth2SessionTypeAccessToken:
+ query = p.sqlInsertOAuth2AccessTokenSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlInsertOAuth2RefreshTokenSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlInsertOAuth2PKCERequestSession
+ case OAuth2SessionTypeOpenIDConnect:
+ query = p.sqlInsertOAuth2OpenIDConnectSession
+ default:
+ return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType)
+ }
+
+ if session.Session, err = p.encrypt(session.Session); err != nil {
+ return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
+ }
+
+ _, err = p.db.ExecContext(ctx, query,
+ session.ChallengeID, session.RequestID, session.ClientID, session.Signature,
+ session.Subject, session.RequestedAt, session.RequestedScopes, session.GrantedScopes,
+ session.RequestedAudience, session.GrantedAudience,
+ session.Active, session.Revoked, session.Form, session.Session)
+
+ if err != nil {
+ return fmt.Errorf("error inserting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
+ }
+
+ return nil
+}
+
+// RevokeOAuth2Session marks a OAuth2Session as revoked in the database.
+func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
+ var query string
+
+ switch sessionType {
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlRevokeOAuth2AuthorizeCodeSession
+ case OAuth2SessionTypeAccessToken:
+ query = p.sqlRevokeOAuth2AccessTokenSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlRevokeOAuth2RefreshTokenSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlRevokeOAuth2PKCERequestSession
+ case OAuth2SessionTypeOpenIDConnect:
+ query = p.sqlRevokeOAuth2OpenIDConnectSession
+ default:
+ return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
+ }
+
+ if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
+ return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err)
+ }
+
+ return nil
+}
+
+// RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database.
+func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
+ var query string
+
+ switch sessionType {
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
+ case OAuth2SessionTypeAccessToken:
+ query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
+ case OAuth2SessionTypeOpenIDConnect:
+ query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID
+ default:
+ return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
+ }
+
+ if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
+ return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
+ }
+
+ return nil
+}
+
+// DeactivateOAuth2Session marks a OAuth2Session as inactive in the database.
+func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
+ var query string
+
+ switch sessionType {
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlDeactivateOAuth2AuthorizeCodeSession
+ case OAuth2SessionTypeAccessToken:
+ query = p.sqlDeactivateOAuth2AccessTokenSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlDeactivateOAuth2RefreshTokenSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlDeactivateOAuth2PKCERequestSession
+ case OAuth2SessionTypeOpenIDConnect:
+ query = p.sqlDeactivateOAuth2OpenIDConnectSession
+ default:
+ return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
+ }
+
+ if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
+ return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err)
+ }
+
+ return nil
+}
+
+// DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database.
+func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
+ var query string
+
+ switch sessionType {
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlDeactivateOAuth2AuthorizeCodeSession
+ case OAuth2SessionTypeAccessToken:
+ query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
+ case OAuth2SessionTypeOpenIDConnect:
+ query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID
+ default:
+ return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
+ }
+
+ if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
+ return fmt.Errorf("error deactivating oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
+ }
+
+ return nil
+}
+
+// LoadOAuth2Session saves a OAuth2Session from the database.
+func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) {
+ var query string
+
+ switch sessionType {
+ case OAuth2SessionTypeAuthorizeCode:
+ query = p.sqlSelectOAuth2AuthorizeCodeSession
+ case OAuth2SessionTypeAccessToken:
+ query = p.sqlSelectOAuth2AccessTokenSession
+ case OAuth2SessionTypeRefreshToken:
+ query = p.sqlSelectOAuth2RefreshTokenSession
+ case OAuth2SessionTypePKCEChallenge:
+ query = p.sqlSelectOAuth2PKCERequestSession
+ case OAuth2SessionTypeOpenIDConnect:
+ query = p.sqlSelectOAuth2OpenIDConnectSession
+ default:
+ return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType)
+ }
+
+ session = &model.OAuth2Session{}
+
+ if err = p.db.GetContext(ctx, session, query, signature); err != nil {
+ return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err)
+ }
+
+ if session.Session, err = p.decrypt(session.Session); err != nil {
+ return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject, session.RequestID, err)
+ }
+
+ return session, nil
+}
+
+// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
+func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
+ if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
+ return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
+ }
+
+ return nil
+}
+
+// LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database.
+func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) {
+ blacklistedJTI = &model.OAuth2BlacklistedJTI{}
+
+ if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil {
+ return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
+ }
+
+ return blacklistedJTI, nil
+}
+
// SavePreferred2FAMethod save the preferred method for 2FA to the database.
func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil {
diff --git a/internal/storage/sql_provider_backend_postgres.go b/internal/storage/sql_provider_backend_postgres.go
index 36c959e7..e1fdf630 100644
--- a/internal/storage/sql_provider_backend_postgres.go
+++ b/internal/storage/sql_provider_backend_postgres.go
@@ -26,19 +26,27 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
// Specific alterations to this provider.
// PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead.
- provider.sqlUpsertWebauthnDevice = fmt.Sprintf(queryFmtPostgresUpsertWebauthnDevice, tableWebauthnDevices)
- provider.sqlUpsertDuoDevice = fmt.Sprintf(queryFmtPostgresUpsertDuoDevice, tableDuoDevices)
- provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations)
- provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences)
- provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtPostgresUpsertEncryptionValue, tableEncryption)
+ provider.sqlUpsertWebauthnDevice = fmt.Sprintf(queryFmtUpsertWebauthnDevicePostgreSQL, tableWebauthnDevices)
+ provider.sqlUpsertDuoDevice = fmt.Sprintf(queryFmtUpsertDuoDevicePostgreSQL, tableDuoDevices)
+ provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtUpsertTOTPConfigurationPostgreSQL, tableTOTPConfigurations)
+ provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtUpsertPreferred2FAMethodPostgreSQL, tableUserPreferences)
+ provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtUpsertEncryptionValuePostgreSQL, tableEncryption)
+ provider.sqlUpsertOAuth2BlacklistedJTI = fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTIPostgreSQL, tableOAuth2BlacklistedJTI)
// PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders.
provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable)
+
provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod)
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
+
+ provider.sqlInsertUserOpaqueIdentifier = provider.db.Rebind(provider.sqlInsertUserOpaqueIdentifier)
+ provider.sqlSelectUserOpaqueIdentifier = provider.db.Rebind(provider.sqlSelectUserOpaqueIdentifier)
+ provider.sqlSelectUserOpaqueIdentifierBySignature = provider.db.Rebind(provider.sqlSelectUserOpaqueIdentifierBySignature)
+
provider.sqlSelectIdentityVerification = provider.db.Rebind(provider.sqlSelectIdentityVerification)
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
+
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)
@@ -46,21 +54,69 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
+
provider.sqlSelectWebauthnDevices = provider.db.Rebind(provider.sqlSelectWebauthnDevices)
provider.sqlSelectWebauthnDevicesByUsername = provider.db.Rebind(provider.sqlSelectWebauthnDevicesByUsername)
provider.sqlUpdateWebauthnDevicePublicKey = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKey)
provider.sqlUpdateWebauthnDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKeyByUsername)
provider.sqlUpdateWebauthnDeviceRecordSignIn = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignIn)
provider.sqlUpdateWebauthnDeviceRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignInByUsername)
+
provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice)
provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice)
+
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername)
+
provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration)
provider.sqlSelectMigrations = provider.db.Rebind(provider.sqlSelectMigrations)
provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration)
+
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
+ provider.sqlInsertOAuth2ConsentSession = provider.db.Rebind(provider.sqlInsertOAuth2ConsentSession)
+ provider.sqlUpdateOAuth2ConsentSessionResponse = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionResponse)
+ provider.sqlUpdateOAuth2ConsentSessionGranted = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionGranted)
+ provider.sqlSelectOAuth2ConsentSessionByChallengeID = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionByChallengeID)
+ provider.sqlSelectOAuth2ConsentSessionsPreConfigured = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionsPreConfigured)
+
+ provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession)
+ provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession)
+ provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID)
+ provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession)
+ provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID)
+ provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession)
+
+ provider.sqlInsertOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2AccessTokenSession)
+ provider.sqlRevokeOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSession)
+ provider.sqlRevokeOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSessionByRequestID)
+ provider.sqlDeactivateOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSession)
+ provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID)
+ provider.sqlSelectOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2AccessTokenSession)
+
+ provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession)
+ provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession)
+ provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID)
+ provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession)
+ provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID)
+ provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession)
+
+ provider.sqlInsertOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlInsertOAuth2PKCERequestSession)
+ provider.sqlRevokeOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSession)
+ provider.sqlRevokeOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSessionByRequestID)
+ provider.sqlDeactivateOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSession)
+ provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID)
+ provider.sqlSelectOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlSelectOAuth2PKCERequestSession)
+
+ provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession)
+ provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession)
+ provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID)
+ provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession)
+ provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID)
+ provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession)
+
+ provider.sqlSelectOAuth2BlacklistedJTI = provider.db.Rebind(provider.sqlSelectOAuth2BlacklistedJTI)
+
provider.schema = config.Storage.PostgreSQL.Schema
return provider
diff --git a/internal/storage/sql_provider_queries.go b/internal/storage/sql_provider_queries.go
index 43009413..57ecbc66 100644
--- a/internal/storage/sql_provider_queries.go
+++ b/internal/storage/sql_provider_queries.go
@@ -48,7 +48,7 @@ const (
REPLACE INTO %s (username, second_factor_method)
VALUES (?, ?);`
- queryFmtPostgresUpsertPreferred2FAMethod = `
+ queryFmtUpsertPreferred2FAMethodPostgreSQL = `
INSERT INTO %s (username, second_factor_method)
VALUES ($1, $2)
ON CONFLICT (username)
@@ -99,7 +99,7 @@ const (
REPLACE INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
- queryFmtPostgresUpsertTOTPConfiguration = `
+ queryFmtUpsertTOTPConfigurationPostgreSQL = `
INSERT INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (username)
@@ -160,7 +160,7 @@ const (
REPLACE INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
- queryFmtPostgresUpsertWebauthnDevice = `
+ queryFmtUpsertWebauthnDevicePostgreSQL = `
INSERT INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (username, description)
@@ -172,7 +172,7 @@ const (
REPLACE INTO %s (username, device, method)
VALUES (?, ?, ?);`
- queryFmtPostgresUpsertDuoDevice = `
+ queryFmtUpsertDuoDevicePostgreSQL = `
INSERT INTO %s (username, device, method)
VALUES ($1, $2, $3)
ON CONFLICT (username)
@@ -214,9 +214,103 @@ const (
REPLACE INTO %s (name, value)
VALUES (?, ?);`
- queryFmtPostgresUpsertEncryptionValue = `
+ queryFmtUpsertEncryptionValuePostgreSQL = `
INSERT INTO %s (name, value)
VALUES ($1, $2)
ON CONFLICT (name)
DO UPDATE SET value = $2;`
)
+
+const (
+ queryFmtSelectOAuth2ConsentSessionByChallengeID = `
+ SELECT id, challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at,
+ form_data, requested_scopes, granted_scopes, requested_audience, granted_audience
+ FROM %s
+ WHERE challenge_id = ?;`
+
+ queryFmtSelectOAuth2ConsentSessionsPreConfigured = `
+ SELECT id, challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at,
+ form_data, requested_scopes, granted_scopes, requested_audience, granted_audience
+ FROM %s
+ WHERE client_id = ? AND subject = ? AND
+ authorized = TRUE AND granted = TRUE AND expires_at IS NOT NULL AND expires_at >= CURRENT_TIMESTAMP;`
+
+ queryFmtInsertOAuth2ConsentSession = `
+ INSERT INTO %s (challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at,
+ form_data, requested_scopes, granted_scopes, requested_audience, granted_audience)
+ VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
+
+ queryFmtUpdateOAuth2ConsentSessionResponse = `
+ UPDATE %s
+ SET authorized = ?, responded_at = CURRENT_TIMESTAMP, expires_at = ?, granted_scopes = ?, granted_audience = ?
+ WHERE id = ? AND responded_at IS NULL;`
+
+ queryFmtUpdateOAuth2ConsentSessionGranted = `
+ UPDATE %s
+ SET granted = TRUE
+ WHERE id = ? AND responded_at IS NOT NULL;`
+
+ queryFmtSelectOAuth2Session = `
+ SELECT id, challenge_id, request_id, client_id, signature, subject, requested_at,
+ requested_scopes, granted_scopes, requested_audience, granted_audience,
+ active, revoked, form_data, session_data
+ FROM %s
+ WHERE signature = ? AND revoked = FALSE;`
+
+ queryFmtInsertOAuth2Session = `
+ INSERT INTO %s (challenge_id, request_id, client_id, signature, subject, requested_at,
+ requested_scopes, granted_scopes, requested_audience, granted_audience,
+ active, revoked, form_data, session_data)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
+
+ queryFmtRevokeOAuth2Session = `
+ UPDATE %s
+ SET revoked = TRUE
+ WHERE signature = ?;`
+
+ queryFmtRevokeOAuth2SessionByRequestID = `
+ UPDATE %s
+ SET revoked = TRUE
+ WHERE request_id = ?;`
+
+ queryFmtDeactivateOAuth2Session = `
+ UPDATE %s
+ SET active = FALSE
+ WHERE signature = ?;`
+
+ queryFmtDeactivateOAuth2SessionByRequestID = `
+ UPDATE %s
+ SET active = FALSE
+ WHERE request_id = ?;"`
+
+ queryFmtSelectOAuth2BlacklistedJTI = `
+ SELECT id, signature, expires_at
+ FROM %s
+ WHERE signature = ?;`
+
+ queryFmtUpsertOAuth2BlacklistedJTI = `
+ REPLACE INTO %s (signature, expires_at)
+ VALUES(?, ?);`
+
+ queryFmtUpsertOAuth2BlacklistedJTIPostgreSQL = `
+ INSERT INTO %s (signature, expires_at)
+ VALUES ($1, $2)
+ ON CONFLICT (signature)
+ DO UPDATE SET expires_at = $2;`
+)
+
+const (
+ queryFmtInsertUserOpaqueIdentifier = `
+ INSERT INTO %s (service, sector_id, username, identifier)
+ VALUES(?, ?, ?, ?);`
+
+ queryFmtSelectUserOpaqueIdentifier = `
+ SELECT id, sector_id, username, identifier
+ FROM %s
+ WHERE identifier = ?;`
+
+ queryFmtSelectUserOpaqueIdentifierBySignature = `
+ SELECT id, service, sector_id, username, identifier
+ FROM %s
+ WHERE service = ? AND sector_id = ? AND username = ?;`
+)
diff --git a/internal/storage/sql_rows.go b/internal/storage/sql_rows.go
new file mode 100644
index 00000000..148dbf38
--- /dev/null
+++ b/internal/storage/sql_rows.go
@@ -0,0 +1,47 @@
+package storage
+
+import (
+ "database/sql"
+
+ "github.com/jmoiron/sqlx"
+
+ "github.com/authelia/authelia/v4/internal/model"
+)
+
+// ConsentSessionRows holds and assists with retrieving multiple model.OAuth2ConsentSession rows.
+type ConsentSessionRows struct {
+ rows *sqlx.Rows
+}
+
+// Next is the row iterator.
+func (r *ConsentSessionRows) Next() bool {
+ if r.rows == nil {
+ return false
+ }
+
+ return r.rows.Next()
+}
+
+// Close the rows.
+func (r *ConsentSessionRows) Close() (err error) {
+ if r.rows == nil {
+ return nil
+ }
+
+ return r.rows.Close()
+}
+
+// Get returns the *model.OAuth2ConsentSession or scan error.
+func (r *ConsentSessionRows) Get() (consent *model.OAuth2ConsentSession, err error) {
+ if r.rows == nil {
+ return nil, sql.ErrNoRows
+ }
+
+ consent = &model.OAuth2ConsentSession{}
+
+ if err = r.rows.StructScan(consent); err != nil {
+ return nil, err
+ }
+
+ return consent, nil
+}
diff --git a/internal/suites/OIDC/configuration.yml b/internal/suites/OIDC/configuration.yml
index c63892dc..a68a82a6 100644
--- a/internal/suites/OIDC/configuration.yml
+++ b/internal/suites/OIDC/configuration.yml
@@ -58,6 +58,7 @@ notifier:
identity_providers:
oidc:
+ enable_client_debug_messages: true
hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm
issuer_private_key: |
-----BEGIN RSA PRIVATE KEY-----
diff --git a/internal/suites/OIDCTraefik/configuration.yml b/internal/suites/OIDCTraefik/configuration.yml
index cc949e56..11fb3e73 100644
--- a/internal/suites/OIDCTraefik/configuration.yml
+++ b/internal/suites/OIDCTraefik/configuration.yml
@@ -60,6 +60,7 @@ notifier:
identity_providers:
oidc:
+ enable_client_debug_messages: true
hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm
issuer_private_key: |
-----BEGIN RSA PRIVATE KEY-----
diff --git a/internal/suites/action_2fa_methods.go b/internal/suites/action_2fa_methods.go
index 89a1f241..b012c110 100644
--- a/internal/suites/action_2fa_methods.go
+++ b/internal/suites/action_2fa_methods.go
@@ -9,26 +9,26 @@ import (
)
func (rs *RodSession) doChangeMethod(t *testing.T, page *rod.Page, method string) {
- err := rs.WaitElementLocatedByCSSSelector(t, page, "methods-button").Click("left")
+ err := rs.WaitElementLocatedByID(t, page, "methods-button").Click("left")
require.NoError(t, err)
- rs.WaitElementLocatedByCSSSelector(t, page, "methods-dialog")
- err = rs.WaitElementLocatedByCSSSelector(t, page, fmt.Sprintf("%s-option", method)).Click("left")
+ rs.WaitElementLocatedByID(t, page, "methods-dialog")
+ err = rs.WaitElementLocatedByID(t, page, fmt.Sprintf("%s-option", method)).Click("left")
require.NoError(t, err)
}
func (rs *RodSession) doChangeDevice(t *testing.T, page *rod.Page, deviceID string) {
- err := rs.WaitElementLocatedByCSSSelector(t, page, "selection-link").Click("left")
+ err := rs.WaitElementLocatedByID(t, page, "selection-link").Click("left")
require.NoError(t, err)
rs.doSelectDevice(t, page, deviceID)
}
func (rs *RodSession) doSelectDevice(t *testing.T, page *rod.Page, deviceID string) {
- rs.WaitElementLocatedByCSSSelector(t, page, "device-selection")
- err := rs.WaitElementLocatedByCSSSelector(t, page, fmt.Sprintf("device-%s", deviceID)).Click("left")
+ rs.WaitElementLocatedByID(t, page, "device-selection")
+ err := rs.WaitElementLocatedByID(t, page, fmt.Sprintf("device-%s", deviceID)).Click("left")
require.NoError(t, err)
}
func (rs *RodSession) doClickButton(t *testing.T, page *rod.Page, buttonID string) {
- err := rs.WaitElementLocatedByCSSSelector(t, page, buttonID).Click("left")
+ err := rs.WaitElementLocatedByID(t, page, buttonID).Click("left")
require.NoError(t, err)
}
diff --git a/internal/suites/action_login.go b/internal/suites/action_login.go
index fd4af24b..73e7b807 100644
--- a/internal/suites/action_login.go
+++ b/internal/suites/action_login.go
@@ -9,21 +9,21 @@ import (
)
func (rs *RodSession) doFillLoginPageAndClick(t *testing.T, page *rod.Page, username, password string, keepMeLoggedIn bool) {
- usernameElement := rs.WaitElementLocatedByCSSSelector(t, page, "username-textfield")
+ usernameElement := rs.WaitElementLocatedByID(t, page, "username-textfield")
err := usernameElement.Input(username)
require.NoError(t, err)
- passwordElement := rs.WaitElementLocatedByCSSSelector(t, page, "password-textfield")
+ passwordElement := rs.WaitElementLocatedByID(t, page, "password-textfield")
err = passwordElement.Input(password)
require.NoError(t, err)
if keepMeLoggedIn {
- keepMeLoggedInElement := rs.WaitElementLocatedByCSSSelector(t, page, "remember-checkbox")
+ keepMeLoggedInElement := rs.WaitElementLocatedByID(t, page, "remember-checkbox")
err = keepMeLoggedInElement.Click("left")
require.NoError(t, err)
}
- buttonElement := rs.WaitElementLocatedByCSSSelector(t, page, "sign-in-button")
+ buttonElement := rs.WaitElementLocatedByID(t, page, "sign-in-button")
err = buttonElement.Click("left")
require.NoError(t, err)
}
diff --git a/internal/suites/action_reset_password.go b/internal/suites/action_reset_password.go
index 2dcf0a25..5a3ca5d5 100644
--- a/internal/suites/action_reset_password.go
+++ b/internal/suites/action_reset_password.go
@@ -9,13 +9,13 @@ import (
)
func (rs *RodSession) doInitiatePasswordReset(t *testing.T, page *rod.Page, username string) {
- err := rs.WaitElementLocatedByCSSSelector(t, page, "reset-password-button").Click("left")
+ err := rs.WaitElementLocatedByID(t, page, "reset-password-button").Click("left")
require.NoError(t, err)
// Fill in username.
- err = rs.WaitElementLocatedByCSSSelector(t, page, "username-textfield").Input(username)
+ err = rs.WaitElementLocatedByID(t, page, "username-textfield").Input(username)
require.NoError(t, err)
// And click on the reset button.
- err = rs.WaitElementLocatedByCSSSelector(t, page, "reset-button").Click("left")
+ err = rs.WaitElementLocatedByID(t, page, "reset-button").Click("left")
require.NoError(t, err)
}
@@ -25,15 +25,15 @@ func (rs *RodSession) doCompletePasswordReset(t *testing.T, page *rod.Page, newP
time.Sleep(1 * time.Second)
- err := rs.WaitElementLocatedByCSSSelector(t, page, "password1-textfield").Input(newPassword1)
+ err := rs.WaitElementLocatedByID(t, page, "password1-textfield").Input(newPassword1)
require.NoError(t, err)
time.Sleep(1 * time.Second)
- err = rs.WaitElementLocatedByCSSSelector(t, page, "password2-textfield").Input(newPassword2)
+ err = rs.WaitElementLocatedByID(t, page, "password2-textfield").Input(newPassword2)
require.NoError(t, err)
- err = rs.WaitElementLocatedByCSSSelector(t, page, "reset-button").Click("left")
+ err = rs.WaitElementLocatedByID(t, page, "reset-button").Click("left")
require.NoError(t, err)
}
diff --git a/internal/suites/action_totp.go b/internal/suites/action_totp.go
index 07d00736..bf3d16bc 100644
--- a/internal/suites/action_totp.go
+++ b/internal/suites/action_totp.go
@@ -12,7 +12,7 @@ import (
)
func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string {
- err := rs.WaitElementLocatedByCSSSelector(t, page, "register-link").Click("left")
+ err := rs.WaitElementLocatedByID(t, page, "register-link").Click("left")
require.NoError(t, err)
rs.verifyMailNotificationDisplayed(t, page)
link := doGetLinkFromLastMail(t)
@@ -28,7 +28,7 @@ func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string {
}
func (rs *RodSession) doEnterOTP(t *testing.T, page *rod.Page, code string) {
- inputs := rs.WaitElementsLocatedByCSSSelector(t, page, "otp-input input")
+ inputs := rs.WaitElementsLocatedByID(t, page, "otp-input input")
for i := 0; i < len(code); i++ {
_ = inputs[i].Input(string(code[i]))
diff --git a/internal/suites/example/compose/oidc-client/docker-compose.yml b/internal/suites/example/compose/oidc-client/docker-compose.yml
index 40b211a5..0a150c36 100644
--- a/internal/suites/example/compose/oidc-client/docker-compose.yml
+++ b/internal/suites/example/compose/oidc-client/docker-compose.yml
@@ -2,7 +2,7 @@
version: '3'
services:
oidc-client:
- image: ghcr.io/authelia/oidc-tester-app:master-89622a8
+ image: ghcr.io/authelia/oidc-tester-app:master-01ff268
command: /entrypoint.sh
depends_on:
- authelia-backend
diff --git a/internal/suites/example/compose/oidc-client/entrypoint.sh b/internal/suites/example/compose/oidc-client/entrypoint.sh
index 6cf8792e..0fe5729f 100755
--- a/internal/suites/example/compose/oidc-client/entrypoint.sh
+++ b/internal/suites/example/compose/oidc-client/entrypoint.sh
@@ -2,6 +2,6 @@
while true;
do
- oidc-tester-app --issuer https://login.example.com:8080 --id oidc-tester-app --secret foobar --scopes openid,profile,email --public-url https://oidc.example.com:8080
+ oidc-tester-app --issuer=https://login.example.com:8080 --id=oidc-tester-app --secret=foobar --scopes=openid,profile,email,groups --public-url=https://oidc.example.com:8080
sleep 5
done
\ No newline at end of file
diff --git a/internal/suites/scenario_available_methods_test.go b/internal/suites/scenario_available_methods_test.go
index 976fb70d..7135f7af 100644
--- a/internal/suites/scenario_available_methods_test.go
+++ b/internal/suites/scenario_available_methods_test.go
@@ -58,11 +58,11 @@ func (s *AvailableMethodsScenario) TestShouldCheckAvailableMethods() {
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
- methodsButton := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "methods-button")
+ methodsButton := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "methods-button")
err := methodsButton.Click("left")
s.Assert().NoError(err)
- methodsDialog := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "methods-dialog")
+ methodsDialog := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "methods-dialog")
options, err := methodsDialog.Elements(".method-option")
s.Assert().NoError(err)
s.Assert().Len(options, len(s.methods))
diff --git a/internal/suites/scenario_oidc_test.go b/internal/suites/scenario_oidc_test.go
index 6adc04a6..7add41b4 100644
--- a/internal/suites/scenario_oidc_test.go
+++ b/internal/suites/scenario_oidc_test.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
+ "regexp"
"testing"
"time"
@@ -89,11 +90,50 @@ func (s *OIDCScenario) TestShouldAuthorizeAccessToOIDCApp() {
assert.NoError(s.T(), err)
s.verifyIsConsentPage(s.T(), s.Context(ctx))
- err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "accept-button").Click("left")
+ err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "accept-button").Click("left")
assert.NoError(s.T(), err)
// Verify that the app is showing the info related to the user stored in the JWT token.
- s.waitBodyContains(s.T(), s.Context(ctx), "Logged in as john!")
+
+ rUUID := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
+ rInteger := regexp.MustCompile(`^\d+$`)
+ rBoolean := regexp.MustCompile(`^(true|false)$`)
+ rBase64 := regexp.MustCompile(`^[-_A-Za-z0-9+\\/]+([=]{0,3})$`)
+
+ testCases := []struct {
+ desc, elementID, elementText string
+ pattern *regexp.Regexp
+ }{
+ {"welcome", "welcome", "Logged in as john!", nil},
+ {"at_hash", "claim-at_hash", "", rBase64},
+ {"jti", "claim-jti", "", rUUID},
+ {"iat", "claim-iat", "", rInteger},
+ {"nbf", "claim-nbf", "", rInteger},
+ {"rat", "claim-rat", "", rInteger},
+ {"expires", "claim-exp", "", rInteger},
+ {"amr", "claim-amr", "pwd, otp, mfa", nil},
+ {"acr", "claim-acr", "", nil},
+ {"issuer", "claim-iss", "https://login.example.com:8080", nil},
+ {"name", "claim-name", "John Doe", nil},
+ {"preferred_username", "claim-preferred_username", "john", nil},
+ {"groups", "claim-groups", "admins, dev", nil},
+ {"email", "claim-email", "john.doe@authelia.com", nil},
+ {"email_verified", "claim-email_verified", "", rBoolean},
+ }
+
+ var text string
+
+ for _, tc := range testCases {
+ s.T().Run(fmt.Sprintf("check_claims/%s", tc.desc), func(t *testing.T) {
+ text, err = s.WaitElementLocatedByID(t, s.Context(ctx), tc.elementID).Text()
+ assert.NoError(t, err)
+ if tc.pattern == nil {
+ assert.Equal(t, tc.elementText, text)
+ } else {
+ assert.Regexp(t, tc.pattern, text)
+ }
+ })
+ }
}
func (s *OIDCScenario) TestShouldDenyConsent() {
@@ -117,10 +157,17 @@ func (s *OIDCScenario) TestShouldDenyConsent() {
s.verifyIsConsentPage(s.T(), s.Context(ctx))
- err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "deny-button").Click("left")
+ err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "deny-button").Click("left")
assert.NoError(s.T(), err)
- s.verifyIsOIDC(s.T(), s.Context(ctx), "oauth2:", "https://oidc.example.com:8080/oauth2/callback?error=access_denied&error_description=User%20has%20rejected%20the%20scopes")
+ s.verifyIsOIDC(s.T(), s.Context(ctx), "access_denied", "https://oidc.example.com:8080/error?error=access_denied&error_description=The+resource+owner+or+authorization+server+denied+the+request.+Make+sure+that+the+request+you+are+making+is+valid.+Maybe+the+credential+or+request+parameters+you+are+using+are+limited+in+scope+or+otherwise+restricted.&state=random-string-here")
+
+ errorDescription := "The resource owner or authorization server denied the request. Make sure that the request " +
+ "you are making is valid. Maybe the credential or request parameters you are using are limited in scope or " +
+ "otherwise restricted."
+
+ s.verifyIsOIDCErrorPage(s.T(), s.Context(ctx), "access_denied", errorDescription, "",
+ "random-string-here")
}
func TestRunOIDCScenario(t *testing.T) {
diff --git a/internal/suites/scenario_regulation_test.go b/internal/suites/scenario_regulation_test.go
index 4d7e20b0..0eba4661 100644
--- a/internal/suites/scenario_regulation_test.go
+++ b/internal/suites/scenario_regulation_test.go
@@ -60,16 +60,16 @@ func (s *RegulationScenario) TestShouldBanUserAfterTooManyAttempt() {
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.")
for i := 0; i < 3; i++ {
- err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("bad-password")
+ err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("bad-password")
require.NoError(s.T(), err)
- err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left")
+ err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left")
require.NoError(s.T(), err)
}
// Enter the correct password and test the regulation lock out.
- err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("password")
+ err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("password")
require.NoError(s.T(), err)
- err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left")
+ err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left")
require.NoError(s.T(), err)
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.")
@@ -77,9 +77,9 @@ func (s *RegulationScenario) TestShouldBanUserAfterTooManyAttempt() {
time.Sleep(10 * time.Second)
// Enter the correct password and test a successful login.
- err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("password")
+ err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("password")
require.NoError(s.T(), err)
- err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left")
+ err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left")
require.NoError(s.T(), err)
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
}
diff --git a/internal/suites/scenario_user_preferences_test.go b/internal/suites/scenario_user_preferences_test.go
index 6b35e805..f8e76b23 100644
--- a/internal/suites/scenario_user_preferences_test.go
+++ b/internal/suites/scenario_user_preferences_test.go
@@ -60,7 +60,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
// Then switch to push notification method.
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Switch context to clean up state in portal.
s.doVisit(s.T(), s.Context(ctx), HomeBaseURL)
@@ -70,7 +70,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
s.doVisit(s.T(), s.Context(ctx), GetLoginBaseURL())
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
// And check the latest method is still used.
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Meaning the authentication is successful.
s.verifyIsHome(s.T(), s.Context(ctx))
@@ -78,7 +78,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
s.doLogout(s.T(), s.Context(ctx))
s.doLoginOneFactor(s.T(), s.Context(ctx), "harry", "password", false, "")
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "one-time-password-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "one-time-password-method")
s.doLogout(s.T(), s.Context(ctx))
s.verifyIsFirstFactorPage(s.T(), s.Context(ctx))
@@ -86,7 +86,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
// Then log back as previous user and verify the push notification is still the default method.
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
s.verifyIsHome(s.T(), s.Context(ctx))
s.doLogout(s.T(), s.Context(ctx))
@@ -94,7 +94,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
// Eventually restore the default method.
s.doChangeMethod(s.T(), s.Context(ctx), "one-time-password")
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "one-time-password-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "one-time-password-method")
}
func TestUserPreferencesScenario(t *testing.T) {
diff --git a/internal/suites/suite_duo_push_test.go b/internal/suites/suite_duo_push_test.go
index 79ca2dc6..4573c5b0 100644
--- a/internal/suites/suite_duo_push_test.go
+++ b/internal/suites/suite_duo_push_test.go
@@ -110,7 +110,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAskUserToRegister() {
s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "state-not-registered")
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "No compatible device found")
enrollPage := s.Page.MustWaitOpen()
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "register-link").MustClick()
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "register-link").MustClick()
s.Page = enrollPage()
assert.Contains(s.T(), s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "description").MustText(), "This enrollment code has expired. Contact your administrator to get a new enrollment code.")
@@ -142,7 +142,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAutoSelectDevice() {
s.doLogout(s.T(), s.Context(ctx))
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
// And check the latest method and device is still used.
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Meaning the authentication is successful.
s.verifyIsHome(s.T(), s.Context(ctx))
}
@@ -176,7 +176,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() {
// Switch Method where Device Selection should open automatically.
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
// Check for available Device 1.
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "device-12345ABCDEFGHIJ67890")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "device-12345ABCDEFGHIJ67890")
// Test Back button.
s.doClickButton(s.T(), s.Context(ctx), "device-selection-back")
// then select Device 2 for further use and be redirected.
@@ -187,7 +187,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() {
s.doLogout(s.T(), s.Context(ctx))
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
// And check the latest method and device is still used.
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Meaning the authentication is successful.
s.verifyIsHome(s.T(), s.Context(ctx))
}
@@ -238,7 +238,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectNewDeviceAfterSavedDeviceMethodI
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
- s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "device-selection")
+ s.WaitElementLocatedByID(s.T(), s.Context(ctx), "device-selection")
s.doSelectDevice(s.T(), s.Context(ctx), "12345ABCDEFGHIJ67890")
s.verifyIsHome(s.T(), s.Context(ctx))
}
@@ -303,7 +303,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailSelectionBecauseOfSelectionDenied(
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
- err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "selection-link").Click("left")
+ err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "selection-link").Click("left")
require.NoError(s.T(), err)
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Device selection was denied by Duo policy")
}
diff --git a/internal/suites/verify_is_authenticated_page.go b/internal/suites/verify_is_authenticated_page.go
index 5cb24d4b..1e9bf7dd 100644
--- a/internal/suites/verify_is_authenticated_page.go
+++ b/internal/suites/verify_is_authenticated_page.go
@@ -7,5 +7,5 @@ import (
)
func (rs *RodSession) verifyIsAuthenticatedPage(t *testing.T, page *rod.Page) {
- rs.WaitElementLocatedByCSSSelector(t, page, "authenticated-stage")
+ rs.WaitElementLocatedByID(t, page, "authenticated-stage")
}
diff --git a/internal/suites/verify_is_consent_page.go b/internal/suites/verify_is_consent_page.go
index de143a83..0862b611 100644
--- a/internal/suites/verify_is_consent_page.go
+++ b/internal/suites/verify_is_consent_page.go
@@ -7,5 +7,5 @@ import (
)
func (rs *RodSession) verifyIsConsentPage(t *testing.T, page *rod.Page) {
- rs.WaitElementLocatedByCSSSelector(t, page, "consent-stage")
+ rs.WaitElementLocatedByID(t, page, "consent-stage")
}
diff --git a/internal/suites/verify_is_first_factor_page.go b/internal/suites/verify_is_first_factor_page.go
index 208194e3..40a50895 100644
--- a/internal/suites/verify_is_first_factor_page.go
+++ b/internal/suites/verify_is_first_factor_page.go
@@ -7,5 +7,5 @@ import (
)
func (rs *RodSession) verifyIsFirstFactorPage(t *testing.T, page *rod.Page) {
- rs.WaitElementLocatedByCSSSelector(t, page, "first-factor-stage")
+ rs.WaitElementLocatedByID(t, page, "first-factor-stage")
}
diff --git a/internal/suites/verify_is_oidc.go b/internal/suites/verify_is_oidc.go
index 9e8e2219..a6579bea 100644
--- a/internal/suites/verify_is_oidc.go
+++ b/internal/suites/verify_is_oidc.go
@@ -4,9 +4,33 @@ import (
"testing"
"github.com/go-rod/rod"
+ "github.com/stretchr/testify/assert"
)
func (rs *RodSession) verifyIsOIDC(t *testing.T, page *rod.Page, pattern, url string) {
page.MustElementR("body", pattern)
rs.verifyURLIs(t, page, url)
}
+
+func (rs *RodSession) verifyIsOIDCErrorPage(t *testing.T, page *rod.Page, errorCode, errorDescription, errorURI, state string) {
+ testCases := []struct {
+ ElementID, ElementText string
+ }{
+ {"error", errorCode},
+ {"error_description", errorDescription},
+ {"error_uri", errorURI},
+ {"state", state},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.ElementID, func(t *testing.T) {
+ if tc.ElementText == "" {
+ t.Skip("Test Skipped as the element is not expected.")
+ }
+
+ text, err := rs.WaitElementLocatedByID(t, page, tc.ElementID).Text()
+ assert.NoError(t, err)
+ assert.Equal(t, tc.ElementText, text)
+ })
+ }
+}
diff --git a/internal/suites/verify_is_second_factor_page.go b/internal/suites/verify_is_second_factor_page.go
index a37ffc27..8813afc6 100644
--- a/internal/suites/verify_is_second_factor_page.go
+++ b/internal/suites/verify_is_second_factor_page.go
@@ -7,5 +7,5 @@ import (
)
func (rs *RodSession) verifyIsSecondFactorPage(t *testing.T, page *rod.Page) {
- rs.WaitElementLocatedByCSSSelector(t, page, "second-factor-stage")
+ rs.WaitElementLocatedByID(t, page, "second-factor-stage")
}
diff --git a/internal/suites/verify_secret_authorized.go b/internal/suites/verify_secret_authorized.go
index 2c6c94f4..18f6bbe8 100644
--- a/internal/suites/verify_secret_authorized.go
+++ b/internal/suites/verify_secret_authorized.go
@@ -7,5 +7,5 @@ import (
)
func (rs *RodSession) verifySecretAuthorized(t *testing.T, page *rod.Page) {
- rs.WaitElementLocatedByCSSSelector(t, page, "secret")
+ rs.WaitElementLocatedByID(t, page, "secret")
}
diff --git a/internal/suites/webdriver.go b/internal/suites/webdriver.go
index e93b2353..28b7750b 100644
--- a/internal/suites/webdriver.go
+++ b/internal/suites/webdriver.go
@@ -82,8 +82,8 @@ func (rs *RodSession) WaitElementLocatedByClassName(t *testing.T, page *rod.Page
return e
}
-// WaitElementLocatedByCSSSelector wait an element is located by class name.
-func (rs *RodSession) WaitElementLocatedByCSSSelector(t *testing.T, page *rod.Page, cssSelector string) *rod.Element {
+// WaitElementLocatedByID waits for an element located by an id.
+func (rs *RodSession) WaitElementLocatedByID(t *testing.T, page *rod.Page, cssSelector string) *rod.Element {
e, err := page.Element("#" + cssSelector)
require.NoError(t, err)
require.NotNil(t, e)
@@ -91,8 +91,8 @@ func (rs *RodSession) WaitElementLocatedByCSSSelector(t *testing.T, page *rod.Pa
return e
}
-// WaitElementsLocatedByCSSSelector wait an element is located by CSS selector.
-func (rs *RodSession) WaitElementsLocatedByCSSSelector(t *testing.T, page *rod.Page, cssSelector string) rod.Elements {
+// WaitElementsLocatedByID waits for an elements located by an id.
+func (rs *RodSession) WaitElementsLocatedByID(t *testing.T, page *rod.Page, cssSelector string) rod.Elements {
e, err := page.Elements("#" + cssSelector)
require.NoError(t, err)
require.NotNil(t, e)
diff --git a/internal/utils/strings.go b/internal/utils/strings.go
index 231621f2..7fb6b8ee 100644
--- a/internal/utils/strings.go
+++ b/internal/utils/strings.go
@@ -241,6 +241,39 @@ func StringHTMLEscape(input string) (output string) {
return htmlEscaper.Replace(input)
}
+// StringJoinDelimitedEscaped joins a string with a specified rune delimiter after escaping any instance of that string
+// in the string slice. Used with StringSplitDelimitedEscaped.
+func StringJoinDelimitedEscaped(value []string, delimiter rune) string {
+ escaped := make([]string, len(value))
+ for k, v := range value {
+ escaped[k] = strings.ReplaceAll(v, string(delimiter), "\\"+string(delimiter))
+ }
+
+ return strings.Join(escaped, string(delimiter))
+}
+
+// StringSplitDelimitedEscaped splits a string with a specified rune delimiter after unescaping any instance of that
+// string in the string slice that has been escaped. Used with StringJoinDelimitedEscaped.
+func StringSplitDelimitedEscaped(value string, delimiter rune) (out []string) {
+ var escape bool
+
+ split := strings.FieldsFunc(value, func(r rune) bool {
+ if r == '\\' {
+ escape = !escape
+ } else if escape && r != delimiter {
+ escape = false
+ }
+
+ return !escape && r == delimiter
+ })
+
+ for k, v := range split {
+ split[k] = strings.ReplaceAll(v, "\\"+string(delimiter), string(delimiter))
+ }
+
+ return split
+}
+
// JoinAndCanonicalizeHeaders join header strings by a given sep.
func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) {
for i, header := range headers {
diff --git a/internal/utils/strings_test.go b/internal/utils/strings_test.go
index 2c9ad981..6b1b5e7c 100644
--- a/internal/utils/strings_test.go
+++ b/internal/utils/strings_test.go
@@ -8,6 +8,51 @@ import (
"github.com/stretchr/testify/require"
)
+func TestStringSplitDelimitedEscaped(t *testing.T) {
+ testCases := []struct {
+ desc, have string
+ delimiter rune
+ want []string
+ }{
+ {desc: "ShouldSplitNormalString", have: "abc,123,456", delimiter: ',', want: []string{"abc", "123", "456"}},
+ {desc: "ShouldSplitEscapedString", have: "a\\,bc,123,456", delimiter: ',', want: []string{"a,bc", "123", "456"}},
+ {desc: "ShouldSplitEscapedStringPipe", have: "a\\|bc|123|456", delimiter: '|', want: []string{"a|bc", "123", "456"}},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ actual := StringSplitDelimitedEscaped(tc.have, tc.delimiter)
+
+ assert.Equal(t, tc.want, actual)
+ })
+ }
+}
+
+func TestStringJoinDelimitedEscaped(t *testing.T) {
+ testCases := []struct {
+ desc, want string
+ delimiter rune
+ have []string
+ }{
+ {desc: "ShouldJoinNormalStringSlice", have: []string{"abc", "123", "456"}, delimiter: ',', want: "abc,123,456"},
+ {desc: "ShouldJoinEscapeNeededStringSlice", have: []string{"abc", "1,23", "456"}, delimiter: ',', want: "abc,1\\,23,456"},
+ {desc: "ShouldJoinEscapeNeededStringSlicePipe", have: []string{"abc", "1|23", "456"}, delimiter: '|', want: "abc|1\\|23|456"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.desc, func(t *testing.T) {
+ actual := StringJoinDelimitedEscaped(tc.have, tc.delimiter)
+
+ assert.Equal(t, tc.want, actual)
+
+ // Ensure splitting again also works fine.
+ split := StringSplitDelimitedEscaped(actual, tc.delimiter)
+
+ assert.Equal(t, tc.have, split)
+ })
+ }
+}
+
func TestShouldNotGenerateSameRandomString(t *testing.T) {
randomStringOne := RandomString(10, AlphaNumericCharacters, false)
randomStringTwo := RandomString(10, AlphaNumericCharacters, false)