feat(oidc): persistent storage (#2965)

This moves the OpenID Connect storage from memory into the SQL storage, making it persistent and allowing it to be used with clustered deployments like the rest of Authelia.
This commit is contained in:
James Elliott 2022-04-07 15:33:53 +10:00 committed by GitHub
parent 06fd7105ea
commit 0a970aef8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 2946 additions and 591 deletions

View File

@ -232,7 +232,7 @@ Allows PKCE `plain` challenges when set to `true`.
Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows Some OpenID Connect Endpoints need to allow cross-origin resource sharing, however some are optional. This section allows
you to configure the optional parts. We reply with CORS headers when the request includes the Origin header. you to configure the optional parts. We reply with CORS headers when the request includes the Origin header.
##### endpoints #### endpoints
<div markdown="1"> <div markdown="1">
type: list(string) type: list(string)
{: .label .label-config .label-purple } {: .label .label-config .label-purple }
@ -523,7 +523,7 @@ individual user. Please use the claim `preferred_username` instead._
This scope includes the groups the authentication backend reports the user is a member of in the token. This scope includes the groups the authentication backend reports the user is a member of in the token.
| Claim | JWT Type | Authelia Attribute | Description | | Claim | JWT Type | Authelia Attribute | Description |
|:------:|:-------------:|:------------------:|:----------------------:| |:------:|:-------------:|:------------------:|:------------------------------------------------------------------------------------------------------------------:|
| groups | array[string] | groups | List of user's groups discovered via [authentication](https://www.authelia.com/docs/configuration/authentication/) | | groups | array[string] | groups | List of user's groups discovered via [authentication](https://www.authelia.com/docs/configuration/authentication/) |
### email ### email
@ -586,9 +586,9 @@ an example of the Authelia root URL which is also the OpenID Connect issuer.
These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP. These endpoints can be utilized to discover other endpoints and metadata about the Authelia OP.
| Endpoint | Path | | Endpoint | Path |
|:-------------:|:---------------------------------------------------------------:| |:-----------------------------------------:|:---------------------------------------------------------------:|
| Discovery | https://auth.example.com/.well-known/openid-configuration | | [OpenID Connect Discovery] | https://auth.example.com/.well-known/openid-configuration |
| Metadata | https://auth.example.com/.well-known/oauth-authorization-server | | [OAuth 2.0 Authorization Server Metadata] | https://auth.example.com/.well-known/oauth-authorization-server |
### Discoverable Endpoints ### Discoverable Endpoints
@ -596,19 +596,22 @@ These endpoints can be utilized to discover other endpoints and metadata about t
These endpoints implement OpenID Connect elements. These endpoints implement OpenID Connect elements.
| Endpoint | Path | Discovery Attribute | | Endpoint | Path | Discovery Attribute |
|:---------------:|:-----------------------------------------------:|:----------------------:| |:-------------------:|:-----------------------------------------------:|:----------------------:|
| JWKS | https://auth.example.com/jwks.json | jwks_uri | | [JSON Web Key Sets] | https://auth.example.com/jwks.json | jwks_uri |
| [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint | | [Authorization] | https://auth.example.com/api/oidc/authorization | authorization_endpoint |
| [Token] | https://auth.example.com/api/oidc/token | token_endpoint | | [Token] | https://auth.example.com/api/oidc/token | token_endpoint |
| [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint | | [Userinfo] | https://auth.example.com/api/oidc/userinfo | userinfo_endpoint |
| [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint | | [Introspection] | https://auth.example.com/api/oidc/introspection | introspection_endpoint |
| [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint | | [Revocation] | https://auth.example.com/api/oidc/revocation | revocation_endpoint |
[JSON Web Key Sets]: https://datatracker.ietf.org/doc/html/rfc7517#section-5
[OpenID Connect]: https://openid.net/connect/ [OpenID Connect]: https://openid.net/connect/
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration [OpenID Connect Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
[OAuth 2.0 Authorization Server Metadata]: https://datatracker.ietf.org/doc/html/rfc8414
[Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint [Authorization]: https://openid.net/specs/openid-connect-core-1_0.html#AuthorizationEndpoint
[Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint [Token]: https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
[Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo [Userinfo]: https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
[Introspection]: https://datatracker.ietf.org/doc/html/rfc7662 [Introspection]: https://datatracker.ietf.org/doc/html/rfc7662
[Revocation]: https://datatracker.ietf.org/doc/html/rfc7009 [Revocation]: https://datatracker.ietf.org/doc/html/rfc7009
[RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176 [RFC8176]: https://datatracker.ietf.org/doc/html/rfc8176
[token lifespan]: https://docs.apigee.com/api-platform/antipatterns/oauth-long-expiration

View File

@ -24,3 +24,4 @@ this instance if you wanted to downgrade to pre1 you would need to use an Authel
| 1 | 4.33.0 | Initial migration managed version | | 1 | 4.33.0 | Initial migration managed version |
| 2 | 4.34.0 | Webauthn - added webauthn_devices table, altered totp_config to include device created/used dates | | 2 | 4.34.0 | Webauthn - added webauthn_devices table, altered totp_config to include device created/used dates |
| 3 | 4.34.2 | Webauthn - fix V2 migration kid column length and provide migration path for anyone on V2 | | 3 | 4.34.2 | Webauthn - fix V2 migration kid column length and provide migration path for anyone on V2 |
| 4 | 4.35.0 | Added OpenID Connect storage tables and opaque user identifier tables |

View File

@ -74,7 +74,7 @@ for which stage will have each feature, and may evolve over time:
<td>Proof Key for Code Exchange (PKCE) for Authorization Code Flow</td> <td>Proof Key for Code Exchange (PKCE) for Authorization Code Flow</td>
</tr> </tr>
<tr> <tr>
<td rowspan="3" class="tbl-header tbl-beta-stage">beta4 <sup>1</sup></td> <td rowspan="8" class="tbl-header tbl-beta-stage">beta5 (4.35.0)</td>
<td>Token Storage</td> <td>Token Storage</td>
</tr> </tr>
<tr> <tr>
@ -83,6 +83,21 @@ for which stage will have each feature, and may evolve over time:
<tr> <tr>
<td class="tbl-beta-stage">Subject Storage</td> <td class="tbl-beta-stage">Subject Storage</td>
</tr> </tr>
<tr>
<td class="tbl-beta-stage"><a href="https://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes" target="_blank">Pairwise Subject Identifier Type</a></td>
</tr>
<tr>
<td class="tbl-beta-stage">Per-Client Consent Pre-Configuration</td>
</tr>
<tr>
<td class="tbl-beta-stage">Cross-Origin Resource Sharing Configuration</td>
</tr>
<tr>
<td class="tbl-beta-stage">Authentication Methods References Claim</td>
</tr>
<tr>
<td class="tbl-beta-stage">UUID v4 <code>sub</code> claim</td>
</tr>
<tr> <tr>
<td rowspan="2" class="tbl-header tbl-beta-stage">beta5 <sup>1</sup></td> <td rowspan="2" class="tbl-header tbl-beta-stage">beta5 <sup>1</sup></td>
<td class="tbl-beta-stage">Prompt Handling</td> <td class="tbl-beta-stage">Prompt Handling</td>
@ -91,7 +106,7 @@ for which stage will have each feature, and may evolve over time:
<td class="tbl-beta-stage">Display Handling</td> <td class="tbl-beta-stage">Display Handling</td>
</tr> </tr>
<tr> <tr>
<td rowspan="5" class="tbl-header tbl-beta-stage">beta6 <sup>1</sup></td> <td rowspan="4" class="tbl-header tbl-beta-stage">beta6 <sup>1</sup></td>
<td><a href="https://openid.net/specs/openid-connect-backchannel-1_0.html" target="_blank" rel="noopener noreferrer">Back-Channel Logout</a></td> <td><a href="https://openid.net/specs/openid-connect-backchannel-1_0.html" target="_blank" rel="noopener noreferrer">Back-Channel Logout</a></td>
</tr> </tr>
<tr> <tr>
@ -103,9 +118,6 @@ for which stage will have each feature, and may evolve over time:
<tr> <tr>
<td class="tbl-beta-stage">Client Secrets Hashed in Configuration</td> <td class="tbl-beta-stage">Client Secrets Hashed in Configuration</td>
</tr> </tr>
<tr>
<td class="tbl-beta-stage">UUID or Random String for <code>sub</code> claim</td>
</tr>
<tr> <tr>
<td class="tbl-header tbl-beta-stage">GA <sup>1</sup></td> <td class="tbl-header tbl-beta-stage">GA <sup>1</sup></td>
<td class="tbl-beta-stage">General Availability after previous stages are vetted for bug fixes</td> <td class="tbl-beta-stage">General Availability after previous stages are vetted for bug fixes</td>

View File

@ -64,7 +64,7 @@ func getProviders() (providers middlewares.Providers, warnings []error, errors [
sessionProvider := session.NewProvider(config.Session, autheliaCertPool) sessionProvider := session.NewProvider(config.Session, autheliaCertPool)
regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock) regulator := regulation.NewRegulator(config.Regulation, storageProvider, clock)
oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC) oidcProvider, err := oidc.NewOpenIDConnectProvider(config.IdentityProviders.OIDC, storageProvider)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
} }

View File

@ -12,7 +12,6 @@ type IdentityProvidersConfiguration struct {
// OpenIDConnectConfiguration configuration for OpenID Connect. // OpenIDConnectConfiguration configuration for OpenID Connect.
type OpenIDConnectConfiguration struct { type OpenIDConnectConfiguration struct {
// This secret must be 32 bytes long.
HMACSecret string `koanf:"hmac_secret"` HMACSecret string `koanf:"hmac_secret"`
IssuerPrivateKey string `koanf:"issuer_private_key"` IssuerPrivateKey string `koanf:"issuer_private_key"`
@ -49,9 +48,10 @@ type OpenIDConnectClientConfiguration struct {
Policy string `koanf:"authorization_policy"` Policy string `koanf:"authorization_policy"`
RedirectURIs []string `koanf:"redirect_uris"`
Audience []string `koanf:"audience"` Audience []string `koanf:"audience"`
Scopes []string `koanf:"scopes"` Scopes []string `koanf:"scopes"`
RedirectURIs []string `koanf:"redirect_uris"`
GrantTypes []string `koanf:"grant_types"` GrantTypes []string `koanf:"grant_types"`
ResponseTypes []string `koanf:"response_types"` ResponseTypes []string `koanf:"response_types"`
ResponseModes []string `koanf:"response_modes"` ResponseModes []string `koanf:"response_modes"`

View File

@ -73,7 +73,7 @@ func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re
userSession := ctx.GetSession() userSession := ctx.GetSession()
newSession := session.NewDefaultUserSession() newSession := session.NewDefaultUserSession()
newSession.OIDCWorkflowSession = userSession.OIDCWorkflowSession newSession.ConsentChallengeID = userSession.ConsentChallengeID
// Reset all values from previous session except OIDC workflow before regenerating the cookie. // Reset all values from previous session except OIDC workflow before regenerating the cookie.
if err = ctx.SaveSession(newSession); err != nil { if err = ctx.SaveSession(newSession); err != nil {
@ -135,7 +135,7 @@ func FirstFactorPost(delayFunc middlewares.TimingAttackDelayFunc) middlewares.Re
successful = true successful = true
if userSession.OIDCWorkflowSession != nil { if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx) handleOIDCWorkflowResponse(ctx)
} else { } else {
Handle1FAResponse(ctx, bodyJSON.TargetURL, bodyJSON.RequestMethod, userSession.Username, userSession.Groups) Handle1FAResponse(ctx, bodyJSON.TargetURL, bodyJSON.RequestMethod, userSession.Username, userSession.Groups)

View File

@ -2,18 +2,15 @@ package handlers
import ( import (
"errors" "errors"
"fmt"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/google/uuid"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/session"
) )
// OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint. // OpenIDConnectAuthorizationGET handles GET requests to the OpenID Connect 1.0 Authorization endpoint.
@ -23,7 +20,7 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
var ( var (
requester fosite.AuthorizeRequester requester fosite.AuthorizeRequester
responder fosite.AuthorizeResponder responder fosite.AuthorizeResponder
client *oidc.InternalClient client *oidc.Client
authTime time.Time authTime time.Time
issuer string issuer string
err error err error
@ -43,7 +40,7 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' is being processed", requester.GetID(), clientID) ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' is being processed", requester.GetID(), clientID)
if client, err = ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID); err != nil { if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(clientID); err != nil {
if errors.Is(err, fosite.ErrNotFound) { if errors.Is(err, fosite.ErrNotFound) {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: client was not found", requester.GetID(), clientID) ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: client was not found", requester.GetID(), clientID)
} else { } else {
@ -65,31 +62,27 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
userSession := ctx.GetSession() userSession := ctx.GetSession()
requestedScopes := requester.GetRequestedScopes() var subject uuid.UUID
requestedAudience := requester.GetRequestedAudience()
isAuthInsufficient := !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) if subject, err = ctx.Providers.OpenIDConnect.Store.GetSubject(ctx, client.GetSectorIdentifier(), userSession.Username); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred retrieving subject for user '%s': %+v", requester.GetID(), client.GetID(), userSession.Username, err)
if isAuthInsufficient || (isConsentMissing(userSession.OIDCWorkflowSession, requestedScopes, requestedAudience)) { ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not retrieve the subject."))
oidcAuthorizeHandleAuthorizationOrConsentInsufficient(ctx, userSession, client, isAuthInsufficient, rw, r, requester, issuer)
return return
} }
extraClaims := oidcGrantRequests(requester, requestedScopes, requestedAudience, &userSession) var (
consent *model.OAuth2ConsentSession
workflowCreated := time.Unix(userSession.OIDCWorkflowSession.CreatedTimestamp, 0) handled bool
)
userSession.OIDCWorkflowSession = nil
if err = ctx.SaveSession(userSession); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session: %+v", requester.GetID(), client.GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session."))
if consent, handled = handleOIDCAuthorizationConsent(ctx, issuer, client, userSession, subject, rw, r, requester); handled {
return return
} }
extraClaims := oidcGrantRequests(requester, consent, &userSession)
if authTime, err = userSession.AuthenticatedTime(client.Policy); err != nil { if authTime, err = userSession.AuthenticatedTime(client.Policy); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred checking authentication time: %+v", requester.GetID(), client.GetID(), err) ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred checking authentication time: %+v", requester.GetID(), client.GetID(), err)
@ -100,9 +93,8 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID) ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' was successfully processed, proceeding to build Authorization Response", requester.GetID(), clientID)
subject := userSession.Username
oidcSession := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(), oidcSession := oidc.NewSessionWithAuthorizeRequest(issuer, ctx.Providers.OpenIDConnect.KeyManager.GetActiveKeyID(),
subject, userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, workflowCreated, requester) userSession.Username, userSession.AuthenticationMethodRefs.MarshalRFC8176(), extraClaims, authTime, consent, requester)
ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v", ctx.Logger.Tracef("Authorization Request with id '%s' on client with id '%s' creating session for Authorization Response for subject '%s' with username '%s' with claims: %+v",
requester.GetID(), oidcSession.ClientID, oidcSession.Subject, oidcSession.Username, oidcSession.Claims) requester.GetID(), oidcSession.ClientID, oidcSession.Subject, oidcSession.Username, oidcSession.Claims)
@ -119,39 +111,13 @@ func OpenIDConnectAuthorizationGET(ctx *middlewares.AutheliaCtx, rw http.Respons
return return
} }
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeResponse(rw, requester, responder) if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSessionGranted(ctx, consent.ID); err != nil {
} ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving consent session: %+v", requester.GetID(), client.GetID(), err)
func oidcAuthorizeHandleAuthorizationOrConsentInsufficient(
ctx *middlewares.AutheliaCtx, userSession session.UserSession, client *oidc.InternalClient, isAuthInsufficient bool,
rw http.ResponseWriter, r *http.Request,
requester fosite.AuthorizeRequester, issuer string) {
redirectURL := fmt.Sprintf("%s%s", issuer, string(ctx.Request.RequestURI()))
ctx.Logger.Debugf("Authorization Request with id '%s' on client with id '%s' requires user '%s' provides consent for scopes '%s'",
requester.GetID(), client.GetID(), userSession.Username, strings.Join(requester.GetRequestedScopes(), "', '"))
userSession.OIDCWorkflowSession = &model.OIDCWorkflowSession{
ClientID: client.GetID(),
RequestedScopes: requester.GetRequestedScopes(),
RequestedAudience: requester.GetRequestedAudience(),
AuthURI: redirectURL,
TargetURI: requester.GetRedirectURI().String(),
Require2FA: client.Policy == authorization.TwoFactor,
CreatedTimestamp: time.Now().Unix(),
}
if err := ctx.SaveSession(userSession); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session for consent: %+v", requester.GetID(), client.GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session.")) ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session."))
return return
} }
if isAuthInsufficient { ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeResponse(rw, requester, responder)
http.Redirect(rw, r, issuer, http.StatusFound)
} else {
http.Redirect(rw, r, fmt.Sprintf("%s/consent", issuer), http.StatusFound)
}
} }

View File

@ -0,0 +1,168 @@
package handlers
import (
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/ory/fosite"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils"
)
func handleOIDCAuthorizationConsent(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client,
userSession session.UserSession, subject uuid.UUID,
rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) {
if userSession.ConsentChallengeID != nil {
return handleOIDCAuthorizationConsentWithChallengeID(ctx, rootURI, client, userSession, rw, r, requester)
}
return handleOIDCAuthorizationConsentOrGenerate(ctx, rootURI, client, userSession, subject, rw, r, requester)
}
func handleOIDCAuthorizationConsentWithChallengeID(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client,
userSession session.UserSession,
rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) {
var (
err error
)
if consent, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred during consent session lookup: %+v", requester.GetID(), requester.GetClient().GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Failed to lookup consent session."))
userSession.ConsentChallengeID = nil
if err = ctx.SaveSession(userSession); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred unlinking consent session challenge id: %+v", requester.GetID(), requester.GetClient().GetID(), err)
}
return nil, true
}
if consent.Responded() {
userSession.ConsentChallengeID = nil
if err = ctx.SaveSession(userSession); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving session: %+v", requester.GetID(), client.GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the session."))
return nil, true
}
if consent.Granted {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: this consent session with challenge id '%s' was already granted", requester.GetID(), client.GetID(), consent.ChallengeID.String())
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Authorization already granted."))
return nil, true
}
ctx.Logger.Debugf("Authorization Request with id '%s' loaded consent session with id '%d' and challenge id '%s' for client id '%s' and subject '%s' and scopes '%s'", requester.GetID(), consent.ID, consent.ChallengeID.String(), client.GetID(), consent.Subject.String(), strings.Join(requester.GetRequestedScopes(), " "))
if consent.IsDenied() {
ctx.Logger.Warnf("Authorization Request with id '%s' and challenge id '%s' for client id '%s' and subject '%s' and scopes '%s' was not denied by the user durng the consent session", requester.GetID(), consent.ChallengeID.String(), client.GetID(), consent.Subject.String(), strings.Join(requester.GetRequestedScopes(), " "))
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrAccessDenied)
return nil, true
}
return consent, false
}
handleOIDCAuthorizationConsentRedirect(rootURI, client, userSession, rw, r)
return consent, true
}
func handleOIDCAuthorizationConsentOrGenerate(ctx *middlewares.AutheliaCtx, rootURI string, client *oidc.Client,
userSession session.UserSession, subject uuid.UUID,
rw http.ResponseWriter, r *http.Request, requester fosite.AuthorizeRequester) (consent *model.OAuth2ConsentSession, handled bool) {
var (
rows *storage.ConsentSessionRows
scopes, audience []string
err error
)
if rows, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionsPreConfigured(ctx, client.GetID(), subject); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' had error looking up pre-configured consent sessions: %+v", requester.GetID(), requester.GetClient().GetID(), err)
}
defer rows.Close()
for rows.Next() {
if consent, err = rows.Get(); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' had error looking up pre-configured consent sessions: %+v", requester.GetID(), requester.GetClient().GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not lookup pre-configured consent sessions."))
return nil, true
}
scopes, audience = getExpectedScopesAndAudience(requester)
if consent.HasExactGrants(scopes, audience) && consent.CanGrant() {
break
}
}
if consent != nil && consent.HasExactGrants(scopes, audience) && consent.CanGrant() {
return consent, false
}
if consent, err = model.NewOAuth2ConsentSession(subject, requester); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred generating consent: %+v", requester.GetID(), requester.GetClient().GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not generate the consent session."))
return nil, true
}
if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSession(ctx, *consent); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving consent session: %+v", requester.GetID(), client.GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the consent session."))
return nil, true
}
userSession.ConsentChallengeID = &consent.ChallengeID
if err = ctx.SaveSession(userSession); err != nil {
ctx.Logger.Errorf("Authorization Request with id '%s' on client with id '%s' could not be processed: error occurred saving user session for consent: %+v", requester.GetID(), client.GetID(), err)
ctx.Providers.OpenIDConnect.Fosite.WriteAuthorizeError(rw, requester, fosite.ErrServerError.WithHint("Could not save the user session."))
return nil, true
}
handleOIDCAuthorizationConsentRedirect(rootURI, client, userSession, rw, r)
return consent, true
}
func handleOIDCAuthorizationConsentRedirect(destination string, client *oidc.Client, userSession session.UserSession, rw http.ResponseWriter, r *http.Request) {
if client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
destination = fmt.Sprintf("%s/consent", destination)
}
http.Redirect(rw, r, destination, http.StatusFound)
}
func getExpectedScopesAndAudience(requester fosite.Requester) (scopes, audience []string) {
audience = requester.GetRequestedAudience()
if !utils.IsStringInSlice(requester.GetClient().GetID(), audience) {
audience = append(audience, requester.GetClient().GetID())
}
return requester.GetRequestedScopes(), audience
}

View File

@ -5,116 +5,135 @@ import (
"fmt" "fmt"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/utils"
) )
// OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect. // OpenIDConnectConsentGET handles requests to provide consent for OpenID Connect.
func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) { func OpenIDConnectConsentGET(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession() userSession, consent, client, handled := oidcConsentGetSessionsAndClient(ctx)
if handled {
if userSession.OIDCWorkflowSession == nil {
ctx.Logger.Debugf("Cannot consent for user %s when OIDC workflow has not been initiated", userSession.Username)
ctx.ReplyForbidden()
return
}
clientID := userSession.OIDCWorkflowSession.ClientID
client, err := ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID)
if err != nil {
ctx.Logger.Debugf("Unable to find related client configuration with name '%s': %v", clientID, err)
ctx.ReplyForbidden()
return return
} }
if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
ctx.Logger.Debugf("Insufficient permissions to give consent during GET current level: %d, require 2FA: %t", userSession.AuthenticationLevel, userSession.OIDCWorkflowSession.Require2FA) ctx.Logger.Errorf("Unable to perform consent without sufficient authentication for user '%s' and client id '%s'", userSession.Username, consent.ClientID)
ctx.ReplyForbidden() ctx.ReplyForbidden()
return return
} }
if err := ctx.SetJSONBody(client.GetConsentResponseBody(userSession.OIDCWorkflowSession)); err != nil { if err := ctx.SetJSONBody(client.GetConsentResponseBody(consent)); err != nil {
ctx.Error(fmt.Errorf("unable to set JSON body: %v", err), "Operation failed") ctx.Error(fmt.Errorf("unable to set JSON body: %v", err), "Operation failed")
} }
} }
// OpenIDConnectConsentPOST handles consent responses for OpenID Connect. // OpenIDConnectConsentPOST handles consent responses for OpenID Connect.
func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) { func OpenIDConnectConsentPOST(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession() var (
body oidc.ConsentPostRequestBody
err error
)
if userSession.OIDCWorkflowSession == nil { if err = json.Unmarshal(ctx.Request.Body(), &body); err != nil {
ctx.Logger.Debugf("Cannot consent for user %s when OIDC workflow has not been initiated", userSession.Username) ctx.Logger.Errorf("Failed to parse JSON body in consent POST: %+v", err)
ctx.ReplyForbidden() ctx.SetJSONError(messageOperationFailed)
return return
} }
client, err := ctx.Providers.OpenIDConnect.Store.GetInternalClient(userSession.OIDCWorkflowSession.ClientID) userSession, consent, client, handled := oidcConsentGetSessionsAndClient(ctx)
if handled {
if err != nil {
ctx.Logger.Debugf("Unable to find related client configuration with name '%s': %v", userSession.OIDCWorkflowSession.ClientID, err)
ctx.ReplyForbidden()
return return
} }
if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) { if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
ctx.Logger.Debugf("Insufficient permissions to give consent during POST current level: %d, require 2FA: %t", userSession.AuthenticationLevel, userSession.OIDCWorkflowSession.Require2FA) ctx.Logger.Debugf("Insufficient permissions to give consent during POST current level: %d, require 2FA: %d", userSession.AuthenticationLevel, client.Policy)
ctx.ReplyForbidden() ctx.ReplyForbidden()
return return
} }
var body ConsentPostRequestBody if consent.ClientID != body.ClientID {
err = json.Unmarshal(ctx.Request.Body(), &body) ctx.Logger.Errorf("User '%s' consented to scopes of another client (%s) than expected (%s). Beware this can be a sign of attack",
userSession.Username, body.ClientID, consent.ClientID)
ctx.SetJSONError(messageOperationFailed)
if err != nil {
ctx.Error(fmt.Errorf("unable to unmarshal body: %v", err), "Operation failed")
return return
} }
if body.AcceptOrReject != accept && body.AcceptOrReject != reject { var (
ctx.Logger.Infof("User %s tried to reply to consent with an unexpected verb", userSession.Username) externalRootURL string
authorized = true
)
switch body.AcceptOrReject {
case accept:
if externalRootURL, err = ctx.ExternalRootURL(); err != nil {
ctx.Logger.Errorf("Could not determine the external URL during consent session processing with challenge id '%s' for user '%s': %v", consent.ChallengeID.String(), userSession.Username, err)
ctx.SetJSONError(messageOperationFailed)
return
}
consent.GrantedScopes = consent.RequestedScopes
consent.GrantedAudience = consent.RequestedAudience
if !utils.IsStringInSlice(consent.ClientID, consent.GrantedAudience) {
consent.GrantedAudience = append(consent.GrantedAudience, consent.ClientID)
}
case reject:
authorized = false
default:
ctx.Logger.Warnf("User '%s' tried to reply to consent with an unexpected verb", userSession.Username)
ctx.ReplyBadRequest() ctx.ReplyBadRequest()
return return
} }
if userSession.OIDCWorkflowSession.ClientID != body.ClientID { if err = ctx.Providers.StorageProvider.SaveOAuth2ConsentSessionResponse(ctx, *consent, authorized); err != nil {
ctx.Logger.Infof("User %s consented to scopes of another client (%s) than expected (%s). Beware this can be a sign of attack", ctx.Logger.Errorf("Failed to save the consent session response to the database: %+v", err)
userSession.Username, body.ClientID, userSession.OIDCWorkflowSession.ClientID) ctx.SetJSONError(messageOperationFailed)
ctx.ReplyBadRequest()
return return
} }
var redirectionURL string response := oidc.ConsentPostResponseBody{RedirectURI: fmt.Sprintf("%s%s?%s", externalRootURL, oidc.AuthorizationPath, consent.Form)}
if body.AcceptOrReject == accept { if err = ctx.SetJSONBody(response); err != nil {
redirectionURL = userSession.OIDCWorkflowSession.AuthURI
userSession.OIDCWorkflowSession.GrantedScopes = userSession.OIDCWorkflowSession.RequestedScopes
userSession.OIDCWorkflowSession.GrantedAudience = userSession.OIDCWorkflowSession.RequestedAudience
if err := ctx.SaveSession(userSession); err != nil {
ctx.Error(fmt.Errorf("unable to write session: %v", err), "Operation failed")
return
}
} else if body.AcceptOrReject == reject {
redirectionURL = fmt.Sprintf("%s?error=access_denied&error_description=%s",
userSession.OIDCWorkflowSession.TargetURI, "User has rejected the scopes")
userSession.OIDCWorkflowSession = nil
if err := ctx.SaveSession(userSession); err != nil {
ctx.Error(fmt.Errorf("unable to write session: %v", err), "Operation failed")
return
}
}
response := ConsentPostResponseBody{RedirectURI: redirectionURL}
if err := ctx.SetJSONBody(response); err != nil {
ctx.Error(fmt.Errorf("unable to set JSON body in response"), "Operation failed") ctx.Error(fmt.Errorf("unable to set JSON body in response"), "Operation failed")
} }
} }
func oidcConsentGetSessionsAndClient(ctx *middlewares.AutheliaCtx) (userSession session.UserSession, consent *model.OAuth2ConsentSession, client *oidc.Client, handled bool) {
var (
err error
)
userSession = ctx.GetSession()
if userSession.ConsentChallengeID == nil {
ctx.Logger.Errorf("Cannot consent for user '%s' when OIDC consent session has not been initiated", userSession.Username)
ctx.ReplyForbidden()
return userSession, nil, nil, true
}
if consent, err = ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID); err != nil {
ctx.Logger.Errorf("Unable to load consent session with challenge id '%s': %v", userSession.ConsentChallengeID.String(), err)
ctx.ReplyForbidden()
return userSession, nil, nil, true
}
if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(consent.ClientID); err != nil {
ctx.Logger.Errorf("Unable to find related client configuration with name '%s': %v", consent.ClientID, err)
ctx.ReplyForbidden()
return userSession, nil, nil, true
}
return userSession, consent, client, false
}

View File

@ -11,6 +11,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
) )
@ -21,7 +22,7 @@ func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter,
var ( var (
tokenType fosite.TokenType tokenType fosite.TokenType
requester fosite.AccessRequester requester fosite.AccessRequester
client *oidc.InternalClient client *oidc.Client
err error err error
) )
@ -54,13 +55,13 @@ func OpenIDConnectUserinfo(ctx *middlewares.AutheliaCtx, rw http.ResponseWriter,
return return
} }
if client, err = ctx.Providers.OpenIDConnect.Store.GetInternalClient(clientID); err != nil { if client, err = ctx.Providers.OpenIDConnect.Store.GetFullClient(clientID); err != nil {
ctx.Providers.OpenIDConnect.WriteError(rw, req, errors.WithStack(fosite.ErrServerError.WithHint("Unable to assert type of client"))) ctx.Providers.OpenIDConnect.WriteError(rw, req, errors.WithStack(fosite.ErrServerError.WithHint("Unable to assert type of client")))
return return
} }
claims := requester.GetSession().(*oidc.OpenIDSession).IDTokenClaims().ToMap() claims := requester.GetSession().(*model.OpenIDSession).IDTokenClaims().ToMap()
delete(claims, "jti") delete(claims, "jti")
delete(claims, "sid") delete(claims, "sid")
delete(claims, "at_hash") delete(claims, "at_hash")

View File

@ -266,7 +266,7 @@ func HandleAllow(ctx *middlewares.AutheliaCtx, targetURL string) {
return return
} }
if userSession.OIDCWorkflowSession != nil { if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx) handleOIDCWorkflowResponse(ctx)
} else { } else {
Handle2FAResponse(ctx, targetURL) Handle2FAResponse(ctx, targetURL)

View File

@ -78,7 +78,7 @@ func SecondFactorTOTPPost(ctx *middlewares.AutheliaCtx) {
return return
} }
if userSession.OIDCWorkflowSession != nil { if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx) handleOIDCWorkflowResponse(ctx)
} else { } else {
Handle2FAResponse(ctx, requestBody.TargetURL) Handle2FAResponse(ctx, requestBody.TargetURL)

View File

@ -197,7 +197,7 @@ func SecondFactorWebauthnAssertionPOST(ctx *middlewares.AutheliaCtx) {
return return
} }
if userSession.OIDCWorkflowSession != nil { if userSession.ConsentChallengeID != nil {
handleOIDCWorkflowResponse(ctx) handleOIDCWorkflowResponse(ctx)
} else { } else {
Handle2FAResponse(ctx, requestBody.TargetURL) Handle2FAResponse(ctx, requestBody.TargetURL)

View File

@ -6,24 +6,12 @@ import (
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
"github.com/authelia/authelia/v4/internal/utils"
) )
// isConsentMissing compares the requestedScopes and requestedAudience to the workflows func oidcGrantRequests(ar fosite.AuthorizeRequester, consent *model.OAuth2ConsentSession, userSession *session.UserSession) (extraClaims map[string]interface{}) {
// GrantedScopes and GrantedAudience and returns true if they do not match or the workflow is nil.
func isConsentMissing(workflow *model.OIDCWorkflowSession, requestedScopes, requestedAudience []string) (isMissing bool) {
if workflow == nil {
return true
}
return len(requestedScopes) > 0 && utils.IsStringSlicesDifferent(requestedScopes, workflow.GrantedScopes) ||
len(requestedAudience) > 0 && utils.IsStringSlicesDifferentFold(requestedAudience, workflow.GrantedAudience)
}
func oidcGrantRequests(ar fosite.AuthorizeRequester, scopes, audiences []string, userSession *session.UserSession) (extraClaims map[string]interface{}) {
extraClaims = map[string]interface{}{} extraClaims = map[string]interface{}{}
for _, scope := range scopes { for _, scope := range consent.GrantedScopes {
if ar != nil { if ar != nil {
ar.GrantScope(scope) ar.GrantScope(scope)
} }
@ -47,13 +35,9 @@ func oidcGrantRequests(ar fosite.AuthorizeRequester, scopes, audiences []string,
} }
if ar != nil { if ar != nil {
for _, audience := range audiences { for _, audience := range consent.GrantedAudience {
ar.GrantAudience(audience) ar.GrantAudience(audience)
} }
if !utils.IsStringInSlice(ar.GetClient().GetID(), ar.GetGrantedAudience()) {
ar.GrantAudience(ar.GetClient().GetID())
}
} }
return extraClaims return extraClaims

View File

@ -11,32 +11,12 @@ import (
"github.com/authelia/authelia/v4/internal/session" "github.com/authelia/authelia/v4/internal/session"
) )
func TestShouldDetectIfConsentIsMissing(t *testing.T) { func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) {
var workflow *model.OIDCWorkflowSession consent := &model.OAuth2ConsentSession{
GrantedScopes: []string{oidc.ScopeProfile},
requestedScopes := []string{"openid", "profile"}
requestedAudience := []string{"https://authelia.com"}
assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
workflow = &model.OIDCWorkflowSession{
GrantedScopes: []string{"openid", "profile"},
GrantedAudience: []string{"https://authelia.com"},
} }
assert.False(t, isConsentMissing(workflow, requestedScopes, requestedAudience)) extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
requestedScopes = []string{"openid", "profile", "group"}
assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
requestedScopes = []string{"openid", "profile"}
requestedAudience = []string{"https://not.authelia.com"}
assert.True(t, isConsentMissing(workflow, requestedScopes, requestedAudience))
}
func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) {
extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeProfile}, []string{}, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 2) assert.Len(t, extraClaims, 2)
@ -48,7 +28,11 @@ func TestShouldGrantAppropriateClaimsForScopeProfile(t *testing.T) {
} }
func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) { func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) {
extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeGroups}, []string{}, &oidcUserSessionJohn) consent := &model.OAuth2ConsentSession{
GrantedScopes: []string{oidc.ScopeGroups},
}
extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 1) assert.Len(t, extraClaims, 1)
@ -57,7 +41,7 @@ func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) {
assert.Contains(t, extraClaims[oidc.ClaimGroups], "admin") assert.Contains(t, extraClaims[oidc.ClaimGroups], "admin")
assert.Contains(t, extraClaims[oidc.ClaimGroups], "dev") assert.Contains(t, extraClaims[oidc.ClaimGroups], "dev")
extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeGroups}, []string{}, &oidcUserSessionFred) extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred)
assert.Len(t, extraClaims, 1) assert.Len(t, extraClaims, 1)
@ -67,7 +51,11 @@ func TestShouldGrantAppropriateClaimsForScopeGroups(t *testing.T) {
} }
func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) { func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) {
extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeEmail}, []string{}, &oidcUserSessionJohn) consent := &model.OAuth2ConsentSession{
GrantedScopes: []string{oidc.ScopeEmail},
}
extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 3) assert.Len(t, extraClaims, 3)
@ -81,7 +69,7 @@ func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) {
require.Contains(t, extraClaims, oidc.ClaimEmailVerified) require.Contains(t, extraClaims, oidc.ClaimEmailVerified)
assert.Equal(t, true, extraClaims[oidc.ClaimEmailVerified]) assert.Equal(t, true, extraClaims[oidc.ClaimEmailVerified])
extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeEmail}, []string{}, &oidcUserSessionFred) extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred)
assert.Len(t, extraClaims, 2) assert.Len(t, extraClaims, 2)
@ -93,7 +81,11 @@ func TestShouldGrantAppropriateClaimsForScopeEmail(t *testing.T) {
} }
func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) { func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) {
extraClaims := oidcGrantRequests(nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile}, []string{}, &oidcUserSessionJohn) consent := &model.OAuth2ConsentSession{
GrantedScopes: []string{oidc.ScopeOpenID, oidc.ScopeProfile},
}
extraClaims := oidcGrantRequests(nil, consent, &oidcUserSessionJohn)
assert.Len(t, extraClaims, 2) assert.Len(t, extraClaims, 2)
@ -103,7 +95,7 @@ func TestShouldGrantAppropriateClaimsForScopeOpenIDAndProfile(t *testing.T) {
require.Contains(t, extraClaims, oidc.ClaimDisplayName) require.Contains(t, extraClaims, oidc.ClaimDisplayName)
assert.Equal(t, "John Smith", extraClaims[oidc.ClaimDisplayName]) assert.Equal(t, "John Smith", extraClaims[oidc.ClaimDisplayName])
extraClaims = oidcGrantRequests(nil, []string{oidc.ScopeOpenID, oidc.ScopeProfile}, []string{}, &oidcUserSessionFred) extraClaims = oidcGrantRequests(nil, consent, &oidcUserSessionFred)
assert.Len(t, extraClaims, 2) assert.Len(t, extraClaims, 2)

View File

@ -7,9 +7,9 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/middlewares" "github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/oidc"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
@ -17,14 +17,15 @@ import (
func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) { func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession() userSession := ctx.GetSession()
if userSession.OIDCWorkflowSession.Require2FA && userSession.AuthenticationLevel != authentication.TwoFactor { if userSession.ConsentChallengeID == nil {
ctx.Logger.Warnf("OpenID Connect client '%s' requires 2FA, cannot be redirected yet", userSession.OIDCWorkflowSession.ClientID) ctx.Logger.Errorf("Unable to handle OIDC workflow response because the user session doesn't contain a consent challenge id")
ctx.ReplyOK()
respondUnauthorized(ctx, messageOperationFailed)
return return
} }
uri, err := ctx.ExternalRootURL() externalRootURL, err := ctx.ExternalRootURL()
if err != nil { if err != nil {
ctx.Logger.Errorf("Unable to determine external Base URL: %v", err) ctx.Logger.Errorf("Unable to determine external Base URL: %v", err)
@ -33,18 +34,37 @@ func handleOIDCWorkflowResponse(ctx *middlewares.AutheliaCtx) {
return return
} }
if isConsentMissing( consent, err := ctx.Providers.StorageProvider.LoadOAuth2ConsentSessionByChallengeID(ctx, *userSession.ConsentChallengeID)
userSession.OIDCWorkflowSession,
userSession.OIDCWorkflowSession.RequestedScopes,
userSession.OIDCWorkflowSession.RequestedAudience) {
err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s/consent", uri)})
if err != nil { if err != nil {
ctx.Logger.Errorf("Unable to load consent session from database: %v", err)
respondUnauthorized(ctx, messageOperationFailed)
return
}
client, err := ctx.Providers.OpenIDConnect.Store.GetFullClient(consent.ClientID)
if err != nil {
ctx.Logger.Errorf("Unable to find client for the consent session: %v", err)
respondUnauthorized(ctx, messageOperationFailed)
return
}
if !client.IsAuthenticationLevelSufficient(userSession.AuthenticationLevel) {
ctx.Logger.Warnf("OpenID Connect client '%s' requires 2FA, cannot be redirected yet", client.ID)
ctx.ReplyOK()
return
}
if userSession.ConsentChallengeID != nil {
if err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s/consent", externalRootURL)}); err != nil {
ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err) ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err)
} }
} else { } else {
err = ctx.SetJSONBody(redirectResponse{Redirect: userSession.OIDCWorkflowSession.AuthURI}) if err = ctx.SetJSONBody(redirectResponse{Redirect: fmt.Sprintf("%s%s?%s", externalRootURL, oidc.AuthorizationPath, consent.Form)}); err != nil {
if err != nil {
ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err) ctx.Logger.Errorf("Unable to set default redirection URL in body: %s", err)
} }
} }

View File

@ -1,12 +0,0 @@
package handlers
// ConsentPostRequestBody schema of the request body of the consent POST endpoint.
type ConsentPostRequestBody struct {
ClientID string `json:"client_id"`
AcceptOrReject string `json:"accept_or_reject"`
}
// ConsentPostResponseBody schema of the response body of the consent POST endpoint.
type ConsentPostResponseBody struct {
RedirectURI string `json:"redirect_uri"`
}

View File

@ -10,8 +10,10 @@ import (
time "time" time "time"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
uuid "github.com/google/uuid"
model "github.com/authelia/authelia/v4/internal/model" model "github.com/authelia/authelia/v4/internal/model"
storage "github.com/authelia/authelia/v4/internal/storage"
) )
// MockStorage is a mock of Provider interface. // MockStorage is a mock of Provider interface.
@ -51,6 +53,21 @@ func (mr *MockStorageMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorage)(nil).AppendAuthenticationLog), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockStorage)(nil).AppendAuthenticationLog), arg0, arg1)
} }
// BeginTX mocks base method.
func (m *MockStorage) BeginTX(arg0 context.Context) (context.Context, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTX", arg0)
ret0, _ := ret[0].(context.Context)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTX indicates an expected call of BeginTX.
func (mr *MockStorageMockRecorder) BeginTX(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTX", reflect.TypeOf((*MockStorage)(nil).BeginTX), arg0)
}
// Close mocks base method. // Close mocks base method.
func (m *MockStorage) Close() error { func (m *MockStorage) Close() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -65,6 +82,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close))
} }
// Commit mocks base method.
func (m *MockStorage) Commit(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockStorageMockRecorder) Commit(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockStorage)(nil).Commit), arg0)
}
// ConsumeIdentityVerification mocks base method. // ConsumeIdentityVerification mocks base method.
func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 model.NullIP) error { func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 model.NullIP) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -79,6 +110,34 @@ func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
} }
// DeactivateOAuth2Session mocks base method.
func (m *MockStorage) DeactivateOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeactivateOAuth2Session", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// DeactivateOAuth2Session indicates an expected call of DeactivateOAuth2Session.
func (mr *MockStorageMockRecorder) DeactivateOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeactivateOAuth2Session", reflect.TypeOf((*MockStorage)(nil).DeactivateOAuth2Session), arg0, arg1, arg2)
}
// DeactivateOAuth2SessionByRequestID mocks base method.
func (m *MockStorage) DeactivateOAuth2SessionByRequestID(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeactivateOAuth2SessionByRequestID", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// DeactivateOAuth2SessionByRequestID indicates an expected call of DeactivateOAuth2SessionByRequestID.
func (mr *MockStorageMockRecorder) DeactivateOAuth2SessionByRequestID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeactivateOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).DeactivateOAuth2SessionByRequestID), arg0, arg1, arg2)
}
// DeletePreferredDuoDevice mocks base method. // DeletePreferredDuoDevice mocks base method.
func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error { func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -137,6 +196,66 @@ func (mr *MockStorageMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockStorage)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4)
} }
// LoadOAuth2BlacklistedJTI mocks base method.
func (m *MockStorage) LoadOAuth2BlacklistedJTI(arg0 context.Context, arg1 string) (*model.OAuth2BlacklistedJTI, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOAuth2BlacklistedJTI", arg0, arg1)
ret0, _ := ret[0].(*model.OAuth2BlacklistedJTI)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOAuth2BlacklistedJTI indicates an expected call of LoadOAuth2BlacklistedJTI.
func (mr *MockStorageMockRecorder) LoadOAuth2BlacklistedJTI(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2BlacklistedJTI", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2BlacklistedJTI), arg0, arg1)
}
// LoadOAuth2ConsentSessionByChallengeID mocks base method.
func (m *MockStorage) LoadOAuth2ConsentSessionByChallengeID(arg0 context.Context, arg1 uuid.UUID) (*model.OAuth2ConsentSession, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOAuth2ConsentSessionByChallengeID", arg0, arg1)
ret0, _ := ret[0].(*model.OAuth2ConsentSession)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOAuth2ConsentSessionByChallengeID indicates an expected call of LoadOAuth2ConsentSessionByChallengeID.
func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionByChallengeID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionByChallengeID", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionByChallengeID), arg0, arg1)
}
// LoadOAuth2ConsentSessionsPreConfigured mocks base method.
func (m *MockStorage) LoadOAuth2ConsentSessionsPreConfigured(arg0 context.Context, arg1 string, arg2 uuid.UUID) (*storage.ConsentSessionRows, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOAuth2ConsentSessionsPreConfigured", arg0, arg1, arg2)
ret0, _ := ret[0].(*storage.ConsentSessionRows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOAuth2ConsentSessionsPreConfigured indicates an expected call of LoadOAuth2ConsentSessionsPreConfigured.
func (mr *MockStorageMockRecorder) LoadOAuth2ConsentSessionsPreConfigured(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2ConsentSessionsPreConfigured", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2ConsentSessionsPreConfigured), arg0, arg1, arg2)
}
// LoadOAuth2Session mocks base method.
func (m *MockStorage) LoadOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) (*model.OAuth2Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadOAuth2Session", arg0, arg1, arg2)
ret0, _ := ret[0].(*model.OAuth2Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadOAuth2Session indicates an expected call of LoadOAuth2Session.
func (mr *MockStorageMockRecorder) LoadOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadOAuth2Session", reflect.TypeOf((*MockStorage)(nil).LoadOAuth2Session), arg0, arg1, arg2)
}
// LoadPreferred2FAMethod mocks base method. // LoadPreferred2FAMethod mocks base method.
func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) { func (m *MockStorage) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -212,6 +331,36 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1)
} }
// LoadUserOpaqueIdentifier mocks base method.
func (m *MockStorage) LoadUserOpaqueIdentifier(arg0 context.Context, arg1 uuid.UUID) (*model.UserOpaqueIdentifier, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadUserOpaqueIdentifier", arg0, arg1)
ret0, _ := ret[0].(*model.UserOpaqueIdentifier)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadUserOpaqueIdentifier indicates an expected call of LoadUserOpaqueIdentifier.
func (mr *MockStorageMockRecorder) LoadUserOpaqueIdentifier(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserOpaqueIdentifier", reflect.TypeOf((*MockStorage)(nil).LoadUserOpaqueIdentifier), arg0, arg1)
}
// LoadUserOpaqueIdentifierBySignature mocks base method.
func (m *MockStorage) LoadUserOpaqueIdentifierBySignature(arg0 context.Context, arg1, arg2, arg3 string) (*model.UserOpaqueIdentifier, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadUserOpaqueIdentifierBySignature", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*model.UserOpaqueIdentifier)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadUserOpaqueIdentifierBySignature indicates an expected call of LoadUserOpaqueIdentifierBySignature.
func (mr *MockStorageMockRecorder) LoadUserOpaqueIdentifierBySignature(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserOpaqueIdentifierBySignature", reflect.TypeOf((*MockStorage)(nil).LoadUserOpaqueIdentifierBySignature), arg0, arg1, arg2, arg3)
}
// LoadWebauthnDevices mocks base method. // LoadWebauthnDevices mocks base method.
func (m *MockStorage) LoadWebauthnDevices(arg0 context.Context, arg1, arg2 int) ([]model.WebauthnDevice, error) { func (m *MockStorage) LoadWebauthnDevices(arg0 context.Context, arg1, arg2 int) ([]model.WebauthnDevice, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -242,6 +391,48 @@ func (mr *MockStorageMockRecorder) LoadWebauthnDevicesByUsername(arg0, arg1 inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWebauthnDevicesByUsername", reflect.TypeOf((*MockStorage)(nil).LoadWebauthnDevicesByUsername), arg0, arg1)
} }
// RevokeOAuth2Session mocks base method.
func (m *MockStorage) RevokeOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeOAuth2Session", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeOAuth2Session indicates an expected call of RevokeOAuth2Session.
func (mr *MockStorageMockRecorder) RevokeOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2Session", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2Session), arg0, arg1, arg2)
}
// RevokeOAuth2SessionByRequestID mocks base method.
func (m *MockStorage) RevokeOAuth2SessionByRequestID(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeOAuth2SessionByRequestID", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeOAuth2SessionByRequestID indicates an expected call of RevokeOAuth2SessionByRequestID.
func (mr *MockStorageMockRecorder) RevokeOAuth2SessionByRequestID(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeOAuth2SessionByRequestID", reflect.TypeOf((*MockStorage)(nil).RevokeOAuth2SessionByRequestID), arg0, arg1, arg2)
}
// Rollback mocks base method.
func (m *MockStorage) Rollback(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockStorageMockRecorder) Rollback(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockStorage)(nil).Rollback), arg0)
}
// SaveIdentityVerification mocks base method. // SaveIdentityVerification mocks base method.
func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 model.IdentityVerification) error { func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 model.IdentityVerification) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -256,6 +447,76 @@ func (mr *MockStorageMockRecorder) SaveIdentityVerification(arg0, arg1 interface
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).SaveIdentityVerification), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).SaveIdentityVerification), arg0, arg1)
} }
// SaveOAuth2BlacklistedJTI mocks base method.
func (m *MockStorage) SaveOAuth2BlacklistedJTI(arg0 context.Context, arg1 model.OAuth2BlacklistedJTI) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOAuth2BlacklistedJTI", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveOAuth2BlacklistedJTI indicates an expected call of SaveOAuth2BlacklistedJTI.
func (mr *MockStorageMockRecorder) SaveOAuth2BlacklistedJTI(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2BlacklistedJTI", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2BlacklistedJTI), arg0, arg1)
}
// SaveOAuth2ConsentSession mocks base method.
func (m *MockStorage) SaveOAuth2ConsentSession(arg0 context.Context, arg1 model.OAuth2ConsentSession) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOAuth2ConsentSession", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveOAuth2ConsentSession indicates an expected call of SaveOAuth2ConsentSession.
func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSession(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSession", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSession), arg0, arg1)
}
// SaveOAuth2ConsentSessionGranted mocks base method.
func (m *MockStorage) SaveOAuth2ConsentSessionGranted(arg0 context.Context, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOAuth2ConsentSessionGranted", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveOAuth2ConsentSessionGranted indicates an expected call of SaveOAuth2ConsentSessionGranted.
func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionGranted(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionGranted", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionGranted), arg0, arg1)
}
// SaveOAuth2ConsentSessionResponse mocks base method.
func (m *MockStorage) SaveOAuth2ConsentSessionResponse(arg0 context.Context, arg1 model.OAuth2ConsentSession, arg2 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOAuth2ConsentSessionResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SaveOAuth2ConsentSessionResponse indicates an expected call of SaveOAuth2ConsentSessionResponse.
func (mr *MockStorageMockRecorder) SaveOAuth2ConsentSessionResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2ConsentSessionResponse", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2ConsentSessionResponse), arg0, arg1, arg2)
}
// SaveOAuth2Session mocks base method.
func (m *MockStorage) SaveOAuth2Session(arg0 context.Context, arg1 storage.OAuth2SessionType, arg2 model.OAuth2Session) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveOAuth2Session", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SaveOAuth2Session indicates an expected call of SaveOAuth2Session.
func (mr *MockStorageMockRecorder) SaveOAuth2Session(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveOAuth2Session", reflect.TypeOf((*MockStorage)(nil).SaveOAuth2Session), arg0, arg1, arg2)
}
// SavePreferred2FAMethod mocks base method. // SavePreferred2FAMethod mocks base method.
func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error { func (m *MockStorage) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -298,6 +559,20 @@ func (mr *MockStorageMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).SaveTOTPConfiguration), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockStorage)(nil).SaveTOTPConfiguration), arg0, arg1)
} }
// SaveUserOpaqueIdentifier mocks base method.
func (m *MockStorage) SaveUserOpaqueIdentifier(arg0 context.Context, arg1 model.UserOpaqueIdentifier) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveUserOpaqueIdentifier", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveUserOpaqueIdentifier indicates an expected call of SaveUserOpaqueIdentifier.
func (mr *MockStorageMockRecorder) SaveUserOpaqueIdentifier(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveUserOpaqueIdentifier", reflect.TypeOf((*MockStorage)(nil).SaveUserOpaqueIdentifier), arg0, arg1)
}
// SaveWebauthnDevice mocks base method. // SaveWebauthnDevice mocks base method.
func (m *MockStorage) SaveWebauthnDevice(arg0 context.Context, arg1 model.WebauthnDevice) error { func (m *MockStorage) SaveWebauthnDevice(arg0 context.Context, arg1 model.WebauthnDevice) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -1,14 +1,228 @@
package model package model
// OIDCWorkflowSession represent an OIDC workflow session. import (
type OIDCWorkflowSession struct { "context"
ClientID string "crypto/sha256"
RequestedScopes []string "encoding/json"
GrantedScopes []string "errors"
RequestedAudience []string "fmt"
GrantedAudience []string "net/url"
TargetURI string "time"
AuthURI string
Require2FA bool "github.com/google/uuid"
CreatedTimestamp int64 "github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewOAuth2ConsentSession creates a new OAuth2ConsentSession.
func NewOAuth2ConsentSession(subject uuid.UUID, r fosite.Requester) (consent *OAuth2ConsentSession, err error) {
consent = &OAuth2ConsentSession{
ClientID: r.GetClient().GetID(),
Subject: subject,
Form: r.GetRequestForm().Encode(),
RequestedAt: r.GetRequestedAt(),
RequestedScopes: StringSlicePipeDelimited(r.GetRequestedScopes()),
RequestedAudience: StringSlicePipeDelimited(r.GetRequestedAudience()),
GrantedScopes: StringSlicePipeDelimited(r.GetGrantedScopes()),
GrantedAudience: StringSlicePipeDelimited(r.GetGrantedAudience()),
}
if consent.ChallengeID, err = uuid.NewRandom(); err != nil {
return nil, err
}
return consent, nil
}
// NewOAuth2SessionFromRequest creates a new OAuth2Session from a signature and fosite.Requester.
func NewOAuth2SessionFromRequest(signature string, r fosite.Requester) (session *OAuth2Session, err error) {
var (
subject string
openidSession *OpenIDSession
sessData []byte
)
openidSession = r.GetSession().(*OpenIDSession)
if openidSession == nil {
return nil, errors.New("unexpected session type")
}
subject = openidSession.GetSubject()
if sessData, err = json.Marshal(openidSession); err != nil {
return nil, err
}
return &OAuth2Session{
ChallengeID: openidSession.ChallengeID,
RequestID: r.GetID(),
ClientID: r.GetClient().GetID(),
Signature: signature,
RequestedAt: r.GetRequestedAt(),
Subject: subject,
RequestedScopes: StringSlicePipeDelimited(r.GetRequestedScopes()),
GrantedScopes: StringSlicePipeDelimited(r.GetGrantedScopes()),
RequestedAudience: StringSlicePipeDelimited(r.GetRequestedAudience()),
GrantedAudience: StringSlicePipeDelimited(r.GetGrantedAudience()),
Active: true,
Revoked: false,
Form: r.GetRequestForm().Encode(),
Session: sessData,
}, nil
}
// NewOAuth2BlacklistedJTI creates a new OAuth2BlacklistedJTI.
func NewOAuth2BlacklistedJTI(jti string, exp time.Time) (jtiBlacklist OAuth2BlacklistedJTI) {
return OAuth2BlacklistedJTI{
Signature: fmt.Sprintf("%x", sha256.Sum256([]byte(jti))),
ExpiresAt: exp,
}
}
// OAuth2ConsentSession stores information about an OAuth2.0 Consent.
type OAuth2ConsentSession struct {
ID int `db:"id"`
ChallengeID uuid.UUID `db:"challenge_id"`
ClientID string `db:"client_id"`
Subject uuid.UUID `db:"subject"`
Authorized bool `db:"authorized"`
Granted bool `db:"granted"`
RequestedAt time.Time `db:"requested_at"`
RespondedAt *time.Time `db:"responded_at"`
ExpiresAt *time.Time `db:"expires_at"`
Form string `db:"form_data"`
RequestedScopes StringSlicePipeDelimited `db:"requested_scopes"`
GrantedScopes StringSlicePipeDelimited `db:"granted_scopes"`
RequestedAudience StringSlicePipeDelimited `db:"requested_audience"`
GrantedAudience StringSlicePipeDelimited `db:"granted_audience"`
}
// HasExactGrants returns true if the granted audience and scopes of this consent matches exactly with another
// audience and set of scopes.
func (s OAuth2ConsentSession) HasExactGrants(scopes, audience []string) (has bool) {
return s.HasExactGrantedScopes(scopes) && s.HasExactGrantedAudience(audience)
}
// HasExactGrantedAudience returns true if the granted audience of this consent matches exactly with another audience.
func (s OAuth2ConsentSession) HasExactGrantedAudience(audience []string) (has bool) {
return !utils.IsStringSlicesDifferent(s.GrantedAudience, audience)
}
// HasExactGrantedScopes returns true if the granted scopes of this consent matches exactly with another set of scopes.
func (s OAuth2ConsentSession) HasExactGrantedScopes(scopes []string) (has bool) {
return !utils.IsStringSlicesDifferent(s.GrantedScopes, scopes)
}
// IsAuthorized returns true if the user has responded to the consent session and it was authorized.
func (s OAuth2ConsentSession) IsAuthorized() bool {
return s.Responded() && s.Authorized
}
// CanGrant returns true if the user has responded to the consent session, it was authorized, and it either hast not
// previously been granted or the ability to grant has not expired.
func (s OAuth2ConsentSession) CanGrant() bool {
if !s.Responded() {
return false
}
if s.Granted && (s.ExpiresAt == nil || s.ExpiresAt.Before(time.Now())) {
return false
}
return true
}
// IsDenied returns true if the user has responded to the consent session and it was not authorized.
func (s OAuth2ConsentSession) IsDenied() bool {
return s.Responded() && !s.Authorized
}
// Responded returns true if the user has responded to the consent session.
func (s OAuth2ConsentSession) Responded() bool {
return s.RespondedAt != nil
}
// GetForm returns the form.
func (s OAuth2ConsentSession) GetForm() (form url.Values, err error) {
return url.ParseQuery(s.Form)
}
// OAuth2BlacklistedJTI represents a blacklisted JTI used with OAuth2.0.
type OAuth2BlacklistedJTI struct {
ID int `db:"id"`
Signature string `db:"signature"`
ExpiresAt time.Time `db:"expires_at"`
}
// OpenIDSession holds OIDC Session information.
type OpenIDSession struct {
*openid.DefaultSession `json:"id_token"`
ChallengeID uuid.UUID `db:"challenge_id"`
ClientID string
Extra map[string]interface{} `json:"extra"`
}
// OAuth2Session represents a OAuth2.0 session.
type OAuth2Session struct {
ID int `db:"id"`
ChallengeID uuid.UUID `db:"challenge_id"`
RequestID string `db:"request_id"`
ClientID string `db:"client_id"`
Signature string `db:"signature"`
RequestedAt time.Time `db:"requested_at"`
Subject string `db:"subject"`
RequestedScopes StringSlicePipeDelimited `db:"requested_scopes"`
GrantedScopes StringSlicePipeDelimited `db:"granted_scopes"`
RequestedAudience StringSlicePipeDelimited `db:"requested_audience"`
GrantedAudience StringSlicePipeDelimited `db:"granted_audience"`
Active bool `db:"active"`
Revoked bool `db:"revoked"`
Form string `db:"form_data"`
Session []byte `db:"session_data"`
}
// SetSubject implements an interface required for RFC7523.
func (s *OAuth2Session) SetSubject(subject string) {
s.Subject = subject
}
// ToRequest converts an OAuth2Session into a fosite.Request given a fosite.Session and fosite.Storage.
func (s OAuth2Session) ToRequest(ctx context.Context, session fosite.Session, store fosite.Storage) (request *fosite.Request, err error) {
sessionData := s.Session
if session != nil {
if err = json.Unmarshal(sessionData, session); err != nil {
return nil, err
}
}
client, err := store.GetClient(ctx, s.ClientID)
if err != nil {
return nil, err
}
values, err := url.ParseQuery(s.Form)
if err != nil {
return nil, err
}
return &fosite.Request{
ID: s.RequestID,
RequestedAt: s.RequestedAt,
Client: client,
RequestedScope: fosite.Arguments(s.RequestedScopes),
GrantedScope: fosite.Arguments(s.GrantedScopes),
RequestedAudience: fosite.Arguments(s.RequestedAudience),
GrantedAudience: fosite.Arguments(s.GrantedAudience),
Form: values,
Session: session,
}, nil
} }

View File

@ -1,10 +1,13 @@
package model package model
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
"github.com/authelia/authelia/v4/internal/utils"
) )
// NewIP easily constructs a new IP. // NewIP easily constructs a new IP.
@ -150,3 +153,26 @@ func (b *Base64) Scan(src interface{}) (err error) {
type StartupCheck interface { type StartupCheck interface {
StartupCheck() (err error) StartupCheck() (err error)
} }
// StringSlicePipeDelimited is a string slice that is stored in the database delimited by pipes.
type StringSlicePipeDelimited []string
// Scan is the StringSlicePipeDelimited implementation of the sql.Scanner.
func (s *StringSlicePipeDelimited) Scan(value interface{}) (err error) {
var nullStr sql.NullString
if err = nullStr.Scan(value); err != nil {
return err
}
if nullStr.Valid {
*s = utils.StringSplitDelimitedEscaped(nullStr.String, '|')
}
return nil
}
// Value is the StringSlicePipeDelimited implementation of the databases/sql driver.Valuer.
func (s StringSlicePipeDelimited) Value() (driver.Value, error) {
return utils.StringJoinDelimitedEscaped(s, '|'), nil
}

View File

@ -1,11 +1,21 @@
package model package model
import ( import (
"fmt"
"testing" "testing"
"github.com/ory/fosite"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test(t *testing.T) {
args := fosite.Arguments{"abc", "123"}
x := StringSlicePipeDelimited(args)
fmt.Println(x)
}
func TestDatabaseModelTypeIP(t *testing.T) { func TestDatabaseModelTypeIP(t *testing.T) {
ip := IP{} ip := IP{}

View File

@ -0,0 +1,33 @@
package model
import (
"fmt"
"github.com/google/uuid"
)
// NewUserOpaqueIdentifier either creates a new UserOpaqueIdentifier or returns an error.
func NewUserOpaqueIdentifier(service, sectorID, username string) (id *UserOpaqueIdentifier, err error) {
var opaqueID uuid.UUID
if opaqueID, err = uuid.NewRandom(); err != nil {
return nil, fmt.Errorf("unable to generate uuid: %w", err)
}
return &UserOpaqueIdentifier{
Service: service,
SectorID: sectorID,
Username: username,
Identifier: opaqueID,
}, nil
}
// UserOpaqueIdentifier represents an opaque identifier for a user. Commonly used with OAuth 2.0 and OpenID Connect.
type UserOpaqueIdentifier struct {
ID int `db:"id"`
Service string `db:"service"`
SectorID string `db:"sector_id"`
Username string `db:"username"`
Identifier uuid.UUID `db:"identifier"`
}

View File

@ -9,9 +9,9 @@ import (
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
) )
// NewClient creates a new InternalClient. // NewClient creates a new Client.
func NewClient(config schema.OpenIDConnectClientConfiguration) (client *InternalClient) { func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Client) {
client = &InternalClient{ client = &Client{
ID: config.ID, ID: config.ID,
Description: config.Description, Description: config.Description,
Secret: []byte(config.Secret), Secret: []byte(config.Secret),
@ -37,42 +37,47 @@ func NewClient(config schema.OpenIDConnectClientConfiguration) (client *Internal
} }
// IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient. // IsAuthenticationLevelSufficient returns if the provided authentication.Level is sufficient for the client of the AutheliaClient.
func (c InternalClient) IsAuthenticationLevelSufficient(level authentication.Level) bool { func (c Client) IsAuthenticationLevelSufficient(level authentication.Level) bool {
return authorization.IsAuthLevelSufficient(level, c.Policy) return authorization.IsAuthLevelSufficient(level, c.Policy)
} }
// GetID returns the ID. // GetID returns the ID.
func (c InternalClient) GetID() string { func (c Client) GetID() string {
return c.ID return c.ID
} }
// GetConsentResponseBody returns the proper consent response body for this model.OIDCWorkflowSession. // GetSectorIdentifier returns the SectorIdentifier for this client.
func (c InternalClient) GetConsentResponseBody(session *model.OIDCWorkflowSession) ConsentGetResponseBody { func (c Client) GetSectorIdentifier() string {
return c.SectorIdentifier
}
// GetConsentResponseBody returns the proper consent response body for this session.OIDCWorkflowSession.
func (c Client) GetConsentResponseBody(consent *model.OAuth2ConsentSession) ConsentGetResponseBody {
body := ConsentGetResponseBody{ body := ConsentGetResponseBody{
ClientID: c.ID, ClientID: c.ID,
ClientDescription: c.Description, ClientDescription: c.Description,
} }
if session != nil { if consent != nil {
body.Scopes = session.RequestedScopes body.Scopes = consent.RequestedScopes
body.Audience = session.RequestedAudience body.Audience = consent.RequestedAudience
} }
return body return body
} }
// GetHashedSecret returns the Secret. // GetHashedSecret returns the Secret.
func (c InternalClient) GetHashedSecret() []byte { func (c Client) GetHashedSecret() []byte {
return c.Secret return c.Secret
} }
// GetRedirectURIs returns the RedirectURIs. // GetRedirectURIs returns the RedirectURIs.
func (c InternalClient) GetRedirectURIs() []string { func (c Client) GetRedirectURIs() []string {
return c.RedirectURIs return c.RedirectURIs
} }
// GetGrantTypes returns the GrantTypes. // GetGrantTypes returns the GrantTypes.
func (c InternalClient) GetGrantTypes() fosite.Arguments { func (c Client) GetGrantTypes() fosite.Arguments {
if len(c.GrantTypes) == 0 { if len(c.GrantTypes) == 0 {
return fosite.Arguments{"authorization_code"} return fosite.Arguments{"authorization_code"}
} }
@ -81,7 +86,7 @@ func (c InternalClient) GetGrantTypes() fosite.Arguments {
} }
// GetResponseTypes returns the ResponseTypes. // GetResponseTypes returns the ResponseTypes.
func (c InternalClient) GetResponseTypes() fosite.Arguments { func (c Client) GetResponseTypes() fosite.Arguments {
if len(c.ResponseTypes) == 0 { if len(c.ResponseTypes) == 0 {
return fosite.Arguments{"code"} return fosite.Arguments{"code"}
} }
@ -90,23 +95,23 @@ func (c InternalClient) GetResponseTypes() fosite.Arguments {
} }
// GetScopes returns the Scopes. // GetScopes returns the Scopes.
func (c InternalClient) GetScopes() fosite.Arguments { func (c Client) GetScopes() fosite.Arguments {
return c.Scopes return c.Scopes
} }
// IsPublic returns the value of the Public property. // IsPublic returns the value of the Public property.
func (c InternalClient) IsPublic() bool { func (c Client) IsPublic() bool {
return c.Public return c.Public
} }
// GetAudience returns the Audience. // GetAudience returns the Audience.
func (c InternalClient) GetAudience() fosite.Arguments { func (c Client) GetAudience() fosite.Arguments {
return c.Audience return c.Audience
} }
// GetResponseModes returns the valid response modes for this client. // GetResponseModes returns the valid response modes for this client.
// //
// Implements the fosite.ResponseModeClient. // Implements the fosite.ResponseModeClient.
func (c InternalClient) GetResponseModes() []fosite.ResponseModeType { func (c Client) GetResponseModes() []fosite.ResponseModeType {
return c.ResponseModes return c.ResponseModes
} }

View File

@ -44,7 +44,7 @@ func TestNewClient(t *testing.T) {
} }
func TestIsAuthenticationLevelSufficient(t *testing.T) { func TestIsAuthenticationLevelSufficient(t *testing.T) {
c := InternalClient{} c := Client{}
c.Policy = authorization.Bypass c.Policy = authorization.Bypass
assert.True(t, c.IsAuthenticationLevelSufficient(authentication.NotAuthenticated)) assert.True(t, c.IsAuthenticationLevelSufficient(authentication.NotAuthenticated))
@ -68,7 +68,7 @@ func TestIsAuthenticationLevelSufficient(t *testing.T) {
} }
func TestInternalClient_GetConsentResponseBody(t *testing.T) { func TestInternalClient_GetConsentResponseBody(t *testing.T) {
c := InternalClient{} c := Client{}
consentRequestBody := c.GetConsentResponseBody(nil) consentRequestBody := c.GetConsentResponseBody(nil)
assert.Equal(t, "", consentRequestBody.ClientID) assert.Equal(t, "", consentRequestBody.ClientID)
@ -79,7 +79,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) {
c.ID = "myclient" c.ID = "myclient"
c.Description = "My Client" c.Description = "My Client"
workflow := &model.OIDCWorkflowSession{ consent := &model.OAuth2ConsentSession{
RequestedAudience: []string{"https://example.com"}, RequestedAudience: []string{"https://example.com"},
RequestedScopes: []string{"openid", "groups"}, RequestedScopes: []string{"openid", "groups"},
} }
@ -87,7 +87,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) {
expectedScopes := []string{"openid", "groups"} expectedScopes := []string{"openid", "groups"}
expectedAudiences := []string{"https://example.com"} expectedAudiences := []string{"https://example.com"}
consentRequestBody = c.GetConsentResponseBody(workflow) consentRequestBody = c.GetConsentResponseBody(consent)
assert.Equal(t, "myclient", consentRequestBody.ClientID) assert.Equal(t, "myclient", consentRequestBody.ClientID)
assert.Equal(t, "My Client", consentRequestBody.ClientDescription) assert.Equal(t, "My Client", consentRequestBody.ClientDescription)
assert.Equal(t, expectedScopes, consentRequestBody.Scopes) assert.Equal(t, expectedScopes, consentRequestBody.Scopes)
@ -95,7 +95,7 @@ func TestInternalClient_GetConsentResponseBody(t *testing.T) {
} }
func TestInternalClient_GetAudience(t *testing.T) { func TestInternalClient_GetAudience(t *testing.T) {
c := InternalClient{} c := Client{}
audience := c.GetAudience() audience := c.GetAudience()
assert.Len(t, audience, 0) assert.Len(t, audience, 0)
@ -108,7 +108,7 @@ func TestInternalClient_GetAudience(t *testing.T) {
} }
func TestInternalClient_GetScopes(t *testing.T) { func TestInternalClient_GetScopes(t *testing.T) {
c := InternalClient{} c := Client{}
scopes := c.GetScopes() scopes := c.GetScopes()
assert.Len(t, scopes, 0) assert.Len(t, scopes, 0)
@ -121,7 +121,7 @@ func TestInternalClient_GetScopes(t *testing.T) {
} }
func TestInternalClient_GetGrantTypes(t *testing.T) { func TestInternalClient_GetGrantTypes(t *testing.T) {
c := InternalClient{} c := Client{}
grantTypes := c.GetGrantTypes() grantTypes := c.GetGrantTypes()
require.Len(t, grantTypes, 1) require.Len(t, grantTypes, 1)
@ -135,7 +135,7 @@ func TestInternalClient_GetGrantTypes(t *testing.T) {
} }
func TestInternalClient_GetHashedSecret(t *testing.T) { func TestInternalClient_GetHashedSecret(t *testing.T) {
c := InternalClient{} c := Client{}
hashedSecret := c.GetHashedSecret() hashedSecret := c.GetHashedSecret()
assert.Equal(t, []byte(nil), hashedSecret) assert.Equal(t, []byte(nil), hashedSecret)
@ -147,7 +147,7 @@ func TestInternalClient_GetHashedSecret(t *testing.T) {
} }
func TestInternalClient_GetID(t *testing.T) { func TestInternalClient_GetID(t *testing.T) {
c := InternalClient{} c := Client{}
id := c.GetID() id := c.GetID()
assert.Equal(t, "", id) assert.Equal(t, "", id)
@ -159,7 +159,7 @@ func TestInternalClient_GetID(t *testing.T) {
} }
func TestInternalClient_GetRedirectURIs(t *testing.T) { func TestInternalClient_GetRedirectURIs(t *testing.T) {
c := InternalClient{} c := Client{}
redirectURIs := c.GetRedirectURIs() redirectURIs := c.GetRedirectURIs()
require.Len(t, redirectURIs, 0) require.Len(t, redirectURIs, 0)
@ -172,7 +172,7 @@ func TestInternalClient_GetRedirectURIs(t *testing.T) {
} }
func TestInternalClient_GetResponseModes(t *testing.T) { func TestInternalClient_GetResponseModes(t *testing.T) {
c := InternalClient{} c := Client{}
responseModes := c.GetResponseModes() responseModes := c.GetResponseModes()
require.Len(t, responseModes, 0) require.Len(t, responseModes, 0)
@ -191,7 +191,7 @@ func TestInternalClient_GetResponseModes(t *testing.T) {
} }
func TestInternalClient_GetResponseTypes(t *testing.T) { func TestInternalClient_GetResponseTypes(t *testing.T) {
c := InternalClient{} c := Client{}
responseTypes := c.GetResponseTypes() responseTypes := c.GetResponseTypes()
require.Len(t, responseTypes, 1) require.Len(t, responseTypes, 1)
@ -206,7 +206,7 @@ func TestInternalClient_GetResponseTypes(t *testing.T) {
} }
func TestInternalClient_IsPublic(t *testing.T) { func TestInternalClient_IsPublic(t *testing.T) {
c := InternalClient{} c := Client{}
assert.False(t, c.IsPublic()) assert.False(t, c.IsPublic())

View File

@ -0,0 +1,79 @@
package oidc
// NewOpenIDConnectWellKnownConfiguration generates a new OpenIDConnectWellKnownConfiguration.
func NewOpenIDConnectWellKnownConfiguration(enablePKCEPlainChallenge, pairwise bool) (config OpenIDConnectWellKnownConfiguration) {
config = OpenIDConnectWellKnownConfiguration{
CommonDiscoveryOptions: CommonDiscoveryOptions{
SubjectTypesSupported: []string{
"public",
},
ResponseTypesSupported: []string{
"code",
"token",
"id_token",
"code token",
"code id_token",
"token id_token",
"code token id_token",
"none",
},
ResponseModesSupported: []string{
"form_post",
"query",
"fragment",
},
ScopesSupported: []string{
ScopeOfflineAccess,
ScopeOpenID,
ScopeProfile,
ScopeGroups,
ScopeEmail,
},
ClaimsSupported: []string{
"aud",
"exp",
"iat",
"iss",
"jti",
"rat",
"sub",
"auth_time",
"nonce",
ClaimEmail,
ClaimEmailVerified,
ClaimEmailAlts,
ClaimGroups,
ClaimPreferredUsername,
ClaimDisplayName,
},
},
OAuth2DiscoveryOptions: OAuth2DiscoveryOptions{
CodeChallengeMethodsSupported: []string{
"S256",
},
},
OpenIDConnectDiscoveryOptions: OpenIDConnectDiscoveryOptions{
IDTokenSigningAlgValuesSupported: []string{
"RS256",
},
UserinfoSigningAlgValuesSupported: []string{
"none",
"RS256",
},
RequestObjectSigningAlgValuesSupported: []string{
"none",
"RS256",
},
},
}
if pairwise {
config.SubjectTypesSupported = append(config.SubjectTypesSupported, "pairwise")
}
if enablePKCEPlainChallenge {
config.CodeChallengeMethodsSupported = append(config.CodeChallengeMethodsSupported, "plain")
}
return config
}

View File

@ -0,0 +1,65 @@
package oidc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewOpenIDConnectWellKnownConfiguration(t *testing.T) {
testCases := []struct {
desc string
pkcePlainChallenge, pairwise bool
expectCodeChallengeMethodsSupported, expectSubjectTypesSupported []string
}{
{
desc: "ShouldHaveChallengeMethodsS256ANDSubjectTypesSupportedPublic",
pkcePlainChallenge: false,
pairwise: false,
expectCodeChallengeMethodsSupported: []string{"S256"},
expectSubjectTypesSupported: []string{"public"},
},
{
desc: "ShouldHaveChallengeMethodsS256PlainANDSubjectTypesSupportedPublic",
pkcePlainChallenge: true,
pairwise: false,
expectCodeChallengeMethodsSupported: []string{"S256", "plain"},
expectSubjectTypesSupported: []string{"public"},
},
{
desc: "ShouldHaveChallengeMethodsS256ANDSubjectTypesSupportedPublicPairwise",
pkcePlainChallenge: false,
pairwise: true,
expectCodeChallengeMethodsSupported: []string{"S256"},
expectSubjectTypesSupported: []string{"public", "pairwise"},
},
{
desc: "ShouldHaveChallengeMethodsS256PlainANDSubjectTypesSupportedPublicPairwise",
pkcePlainChallenge: true,
pairwise: true,
expectCodeChallengeMethodsSupported: []string{"S256", "plain"},
expectSubjectTypesSupported: []string{"public", "pairwise"},
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
actual := NewOpenIDConnectWellKnownConfiguration(tc.pkcePlainChallenge, tc.pairwise)
for _, codeChallengeMethod := range tc.expectCodeChallengeMethodsSupported {
assert.Contains(t, actual.CodeChallengeMethodsSupported, codeChallengeMethod)
}
for _, subjectType := range tc.expectSubjectTypesSupported {
assert.Contains(t, actual.SubjectTypesSupported, subjectType)
}
for _, codeChallengeMethod := range actual.CodeChallengeMethodsSupported {
assert.Contains(t, tc.expectCodeChallengeMethodsSupported, codeChallengeMethod)
}
for _, subjectType := range actual.SubjectTypesSupported {
assert.Contains(t, tc.expectSubjectTypesSupported, subjectType)
}
})
}
}

View File

@ -6,7 +6,7 @@ import (
) )
// Compare compares the hash with the data and returns an error if they don't match. // Compare compares the hash with the data and returns an error if they don't match.
func (h AutheliaHasher) Compare(_ context.Context, hash, data []byte) (err error) { func (h PlainTextHasher) Compare(_ context.Context, hash, data []byte) (err error) {
if subtle.ConstantTimeCompare(hash, data) == 0 { if subtle.ConstantTimeCompare(hash, data) == 0 {
return errPasswordsDoNotMatch return errPasswordsDoNotMatch
} }
@ -15,6 +15,6 @@ func (h AutheliaHasher) Compare(_ context.Context, hash, data []byte) (err error
} }
// Hash creates a new hash from data. // Hash creates a new hash from data.
func (h AutheliaHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) { func (h PlainTextHasher) Hash(_ context.Context, data []byte) (hash []byte, err error) {
return data, nil return data, nil
} }

View File

@ -8,7 +8,7 @@ import (
) )
func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) { func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
hasher := AutheliaHasher{} hasher := PlainTextHasher{}
a := []byte("abc") a := []byte("abc")
b := []byte("abc") b := []byte("abc")
@ -21,7 +21,7 @@ func TestShouldNotRaiseErrorOnEqualPasswordsPlainText(t *testing.T) {
} }
func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) { func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
hasher := AutheliaHasher{} hasher := PlainTextHasher{}
a := []byte("abc") a := []byte("abc")
b := []byte("abcd") b := []byte("abcd")
@ -34,7 +34,7 @@ func TestShouldRaiseErrorOnNonEqualPasswordsPlainText(t *testing.T) {
} }
func TestShouldHashPassword(t *testing.T) { func TestShouldHashPassword(t *testing.T) {
hasher := AutheliaHasher{} hasher := PlainTextHasher{}
data := []byte("abc") data := []byte("abc")

View File

@ -8,34 +8,35 @@ import (
"github.com/ory/herodot" "github.com/ory/herodot"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils" "github.com/authelia/authelia/v4/internal/utils"
) )
// NewOpenIDConnectProvider new-ups a OpenIDConnectProvider. // NewOpenIDConnectProvider new-ups a OpenIDConnectProvider.
func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration) (provider OpenIDConnectProvider, err error) { func NewOpenIDConnectProvider(config *schema.OpenIDConnectConfiguration, storageProvider storage.Provider) (provider OpenIDConnectProvider, err error) {
provider = OpenIDConnectProvider{ provider = OpenIDConnectProvider{
Fosite: nil, Fosite: nil,
} }
if configuration == nil { if config == nil {
return provider, nil return provider, nil
} }
provider.Store = NewOpenIDConnectStore(configuration) provider.Store = NewOpenIDConnectStore(config, storageProvider)
composeConfiguration := &compose.Config{ composeConfiguration := &compose.Config{
AccessTokenLifespan: configuration.AccessTokenLifespan, AccessTokenLifespan: config.AccessTokenLifespan,
AuthorizeCodeLifespan: configuration.AuthorizeCodeLifespan, AuthorizeCodeLifespan: config.AuthorizeCodeLifespan,
IDTokenLifespan: configuration.IDTokenLifespan, IDTokenLifespan: config.IDTokenLifespan,
RefreshTokenLifespan: configuration.RefreshTokenLifespan, RefreshTokenLifespan: config.RefreshTokenLifespan,
SendDebugMessagesToClients: configuration.EnableClientDebugMessages, SendDebugMessagesToClients: config.EnableClientDebugMessages,
MinParameterEntropy: configuration.MinimumParameterEntropy, MinParameterEntropy: config.MinimumParameterEntropy,
EnforcePKCE: configuration.EnforcePKCE == "always", EnforcePKCE: config.EnforcePKCE == "always",
EnforcePKCEForPublicClients: configuration.EnforcePKCE != "never", EnforcePKCEForPublicClients: config.EnforcePKCE != "never",
EnablePKCEPlainChallengeMethod: configuration.EnablePKCEPlainChallenge, EnablePKCEPlainChallengeMethod: config.EnablePKCEPlainChallenge,
} }
keyManager, err := NewKeyManagerWithConfiguration(configuration) keyManager, err := NewKeyManagerWithConfiguration(config)
if err != nil { if err != nil {
return provider, err return provider, err
} }
@ -50,7 +51,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
strategy := &compose.CommonStrategy{ strategy := &compose.CommonStrategy{
CoreStrategy: compose.NewOAuth2HMACStrategy( CoreStrategy: compose.NewOAuth2HMACStrategy(
composeConfiguration, composeConfiguration,
[]byte(utils.HashSHA256FromString(configuration.HMACSecret)), []byte(utils.HashSHA256FromString(config.HMACSecret)),
nil, nil,
), ),
OpenIDConnectTokenStrategy: compose.NewOpenIDConnectStrategy( OpenIDConnectTokenStrategy: compose.NewOpenIDConnectStrategy(
@ -64,7 +65,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
composeConfiguration, composeConfiguration,
provider.Store, provider.Store,
strategy, strategy,
AutheliaHasher{}, PlainTextHasher{},
/* /*
These are the OAuth2 and OpenIDConnect factories. Order is important (the OAuth2 factories at the top must These are the OAuth2 and OpenIDConnect factories. Order is important (the OAuth2 factories at the top must
@ -75,7 +76,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
compose.OAuth2AuthorizeImplicitFactory, compose.OAuth2AuthorizeImplicitFactory,
compose.OAuth2ClientCredentialsGrantFactory, compose.OAuth2ClientCredentialsGrantFactory,
compose.OAuth2RefreshTokenGrantFactory, compose.OAuth2RefreshTokenGrantFactory,
compose.OAuth2ResourceOwnerPasswordCredentialsFactory, // compose.OAuth2ResourceOwnerPasswordCredentialsFactory,
// compose.RFC7523AssertionGrantFactory,. // compose.RFC7523AssertionGrantFactory,.
compose.OpenIDConnectExplicitFactory, compose.OpenIDConnectExplicitFactory,
@ -89,80 +90,24 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
compose.OAuth2PKCEFactory, compose.OAuth2PKCEFactory,
) )
provider.discovery = OpenIDConnectWellKnownConfiguration{ provider.discovery = NewOpenIDConnectWellKnownConfiguration(config.EnablePKCEPlainChallenge, provider.Pairwise())
CommonDiscoveryOptions: CommonDiscoveryOptions{
SubjectTypesSupported: []string{
"public",
},
ResponseTypesSupported: []string{
"code",
"token",
"id_token",
"code token",
"code id_token",
"token id_token",
"code token id_token",
"none",
},
ResponseModesSupported: []string{
"form_post",
"query",
"fragment",
},
ScopesSupported: []string{
ScopeOfflineAccess,
ScopeOpenID,
ScopeProfile,
ScopeGroups,
ScopeEmail,
},
ClaimsSupported: []string{
"aud",
"exp",
"iat",
"iss",
"jti",
"rat",
"sub",
"auth_time",
"nonce",
ClaimEmail,
ClaimEmailVerified,
ClaimEmailAlts,
ClaimGroups,
ClaimPreferredUsername,
ClaimDisplayName,
},
},
OAuth2DiscoveryOptions: OAuth2DiscoveryOptions{
CodeChallengeMethodsSupported: []string{
"S256",
},
},
OpenIDConnectDiscoveryOptions: OpenIDConnectDiscoveryOptions{
IDTokenSigningAlgValuesSupported: []string{
"RS256",
},
UserinfoSigningAlgValuesSupported: []string{
"none",
"RS256",
},
RequestObjectSigningAlgValuesSupported: []string{
"none",
"RS256",
},
},
}
if configuration.EnablePKCEPlainChallenge {
provider.discovery.CodeChallengeMethodsSupported = append(provider.discovery.CodeChallengeMethodsSupported, "plain")
}
provider.herodot = herodot.NewJSONWriter(nil) provider.herodot = herodot.NewJSONWriter(nil)
return provider, nil return provider, nil
} }
// Pairwise returns true if this provider is configured with clients that require pairwise.
func (p OpenIDConnectProvider) Pairwise() bool {
for _, c := range p.Store.clients {
if c.SectorIdentifier != "" {
return true
}
}
return false
}
// Write writes data with herodot.JSONWriter. // Write writes data with herodot.JSONWriter.
func (p OpenIDConnectProvider) Write(w http.ResponseWriter, r *http.Request, e interface{}, opts ...herodot.EncoderOptions) { func (p OpenIDConnectProvider) Write(w http.ResponseWriter, r *http.Request, e interface{}, opts ...herodot.EncoderOptions) {
p.herodot.Write(w, r, e, opts...) p.herodot.Write(w, r, e, opts...)

View File

@ -12,7 +12,7 @@ import (
var exampleIssuerPrivateKey = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAvcMVMB2vEbqI6PlSNJ4HmUyMxBDJ5iY7FS+zDDAHOZBg9S3S\nKcAn1CZcnyL0VvJ7wcdhR6oTnOwR94eKvzUyJZ+GL2hTMm27dubEYsNdhoCl6N3X\nyEEohNfoxiiCYraVauX8X3M9jFzbEz9+pacaDbHB2syaJ1qFmMNR+HSu2jPzOo7M\nlqKIOgUzA0741MaYNt47AEVg4XU5ORLdolbAkItmYg1QbyFndg9H5IvwKkYaXTGE\nlgDBcPUC0yVjAC15Mguquq+jZeQay+6PSbHTD8PQMOkLjyChI2xEhVNbdCXe676R\ncMW2R/gjrcK23zmtmTWRfdC1iZLSlHO+bJj9vQIDAQABAoIBAEZvkP/JJOCJwqPn\nV3IcbmmilmV4bdi1vByDFgyiDyx4wOSA24+PubjvfFW9XcCgRPuKjDtTj/AhWBHv\nB7stfa2lZuNV7/u562mZArA+IAr62Zp0LdIxDV8x3T8gbjVB3HhPYbv0RJZDKTYd\nzV6jhfIrVu9mHpoY6ZnodhapCPYIyk/d49KBIHZuAc25CUjMXgTeaVtf0c996036\nUxW6ef33wAOJAvW0RCvbXAJfmBeEq2qQlkjTIlpYx71fhZWexHifi8Ouv3Zonc+1\n/P2Adq5uzYVBT92f9RKHg9QxxNzVrLjSMaxyvUtWQCAQfW0tFIRdqBGsHYsQrFtI\nF4yzv8ECgYEA7ntpyN9HD9Z9lYQzPCR73sFCLM+ID99aVij0wHuxK97bkSyyvkLd\n7MyTaym3lg1UEqWNWBCLvFULZx7F0Ah6qCzD4ymm3Bj/ADpWWPgljBI0AFml+HHs\nhcATmXUrj5QbLyhiP2gmJjajp1o/rgATx6ED66seSynD6JOH8wUhhZUCgYEAy7OA\n06PF8GfseNsTqlDjNF0K7lOqd21S0prdwrsJLiVzUlfMM25MLE0XLDUutCnRheeh\nIlcuDoBsVTxz6rkvFGD74N+pgXlN4CicsBq5ofK060PbqCQhSII3fmHobrZ9Cr75\nHmBjAxHx998SKaAAGbBbcYGUAp521i1pH5CEPYkCgYEAkUd1Zf0+2RMdZhwm6hh/\nrW+l1I6IoMK70YkZsLipccRNld7Y9LbfYwYtODcts6di9AkOVfueZJiaXbONZfIE\nZrb+jkAteh9wGL9xIrnohbABJcV3Kiaco84jInUSmGDtPokncOENfHIEuEpuSJ2b\nbx1TuhmAVuGWivR0+ULC7RECgYEAgS0cDRpWc9Xzh9Cl7+PLsXEvdWNpPsL9OsEq\n0Ep7z9+/+f/jZtoTRCS/BTHUpDvAuwHglT5j3p5iFMt5VuiIiovWLwynGYwrbnNS\nqfrIrYKUaH1n1oDS+oBZYLQGCe9/7EifAjxtjYzbvSyg//SPG7tSwfBCREbpZXj2\nqSWkNsECgYA/mCDzCTlrrWPuiepo6kTmN+4TnFA+hJI6NccDVQ+jvbqEdoJ4SW4L\nzqfZSZRFJMNpSgIqkQNRPJqMP0jQ5KRtJrjMWBnYxktwKz9fDg2R2MxdFgMF2LH2\nHEMMhFHlv8NDjVOXh1KwRoltNGVWYsSrD9wKU9GhRCEfmNCGrvBcEg==\n-----END RSA PRIVATE KEY-----" var exampleIssuerPrivateKey = "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAvcMVMB2vEbqI6PlSNJ4HmUyMxBDJ5iY7FS+zDDAHOZBg9S3S\nKcAn1CZcnyL0VvJ7wcdhR6oTnOwR94eKvzUyJZ+GL2hTMm27dubEYsNdhoCl6N3X\nyEEohNfoxiiCYraVauX8X3M9jFzbEz9+pacaDbHB2syaJ1qFmMNR+HSu2jPzOo7M\nlqKIOgUzA0741MaYNt47AEVg4XU5ORLdolbAkItmYg1QbyFndg9H5IvwKkYaXTGE\nlgDBcPUC0yVjAC15Mguquq+jZeQay+6PSbHTD8PQMOkLjyChI2xEhVNbdCXe676R\ncMW2R/gjrcK23zmtmTWRfdC1iZLSlHO+bJj9vQIDAQABAoIBAEZvkP/JJOCJwqPn\nV3IcbmmilmV4bdi1vByDFgyiDyx4wOSA24+PubjvfFW9XcCgRPuKjDtTj/AhWBHv\nB7stfa2lZuNV7/u562mZArA+IAr62Zp0LdIxDV8x3T8gbjVB3HhPYbv0RJZDKTYd\nzV6jhfIrVu9mHpoY6ZnodhapCPYIyk/d49KBIHZuAc25CUjMXgTeaVtf0c996036\nUxW6ef33wAOJAvW0RCvbXAJfmBeEq2qQlkjTIlpYx71fhZWexHifi8Ouv3Zonc+1\n/P2Adq5uzYVBT92f9RKHg9QxxNzVrLjSMaxyvUtWQCAQfW0tFIRdqBGsHYsQrFtI\nF4yzv8ECgYEA7ntpyN9HD9Z9lYQzPCR73sFCLM+ID99aVij0wHuxK97bkSyyvkLd\n7MyTaym3lg1UEqWNWBCLvFULZx7F0Ah6qCzD4ymm3Bj/ADpWWPgljBI0AFml+HHs\nhcATmXUrj5QbLyhiP2gmJjajp1o/rgATx6ED66seSynD6JOH8wUhhZUCgYEAy7OA\n06PF8GfseNsTqlDjNF0K7lOqd21S0prdwrsJLiVzUlfMM25MLE0XLDUutCnRheeh\nIlcuDoBsVTxz6rkvFGD74N+pgXlN4CicsBq5ofK060PbqCQhSII3fmHobrZ9Cr75\nHmBjAxHx998SKaAAGbBbcYGUAp521i1pH5CEPYkCgYEAkUd1Zf0+2RMdZhwm6hh/\nrW+l1I6IoMK70YkZsLipccRNld7Y9LbfYwYtODcts6di9AkOVfueZJiaXbONZfIE\nZrb+jkAteh9wGL9xIrnohbABJcV3Kiaco84jInUSmGDtPokncOENfHIEuEpuSJ2b\nbx1TuhmAVuGWivR0+ULC7RECgYEAgS0cDRpWc9Xzh9Cl7+PLsXEvdWNpPsL9OsEq\n0Ep7z9+/+f/jZtoTRCS/BTHUpDvAuwHglT5j3p5iFMt5VuiIiovWLwynGYwrbnNS\nqfrIrYKUaH1n1oDS+oBZYLQGCe9/7EifAjxtjYzbvSyg//SPG7tSwfBCREbpZXj2\nqSWkNsECgYA/mCDzCTlrrWPuiepo6kTmN+4TnFA+hJI6NccDVQ+jvbqEdoJ4SW4L\nzqfZSZRFJMNpSgIqkQNRPJqMP0jQ5KRtJrjMWBnYxktwKz9fDg2R2MxdFgMF2LH2\nHEMMhFHlv8NDjVOXh1KwRoltNGVWYsSrD9wKU9GhRCEfmNCGrvBcEg==\n-----END RSA PRIVATE KEY-----"
func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing.T) { func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing.T) {
provider, err := NewOpenIDConnectProvider(nil) provider, err := NewOpenIDConnectProvider(nil, nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, provider.Fosite) assert.Nil(t, provider.Fosite)
@ -22,7 +22,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_NotConfigured(t *testing
func TestOpenIDConnectProvider_NewOpenIDConnectProvider_BadIssuerKey(t *testing.T) { func TestOpenIDConnectProvider_NewOpenIDConnectProvider_BadIssuerKey(t *testing.T) {
_, err := NewOpenIDConnectProvider(&schema.OpenIDConnectConfiguration{ _, err := NewOpenIDConnectProvider(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: "BAD KEY", IssuerPrivateKey: "BAD KEY",
}) }, nil)
assert.Error(t, err, "abc") assert.Error(t, err, "abc")
} }
@ -60,7 +60,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GoodConfiguration(t *tes
}, },
}, },
}, },
}) }, nil)
assert.NotNil(t, provider) assert.NotNil(t, provider)
assert.NoError(t, err) assert.NoError(t, err)
@ -80,10 +80,12 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
}, },
}, },
}, },
}) }, nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, provider.Pairwise())
disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com") disco := provider.GetOpenIDConnectWellKnownConfiguration("https://example.com")
assert.Equal(t, "https://example.com", disco.Issuer) assert.Equal(t, "https://example.com", disco.Issuer)
@ -95,8 +97,8 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
assert.Equal(t, "https://example.com/api/oidc/revocation", disco.RevocationEndpoint) assert.Equal(t, "https://example.com/api/oidc/revocation", disco.RevocationEndpoint)
assert.Equal(t, "", disco.RegistrationEndpoint) assert.Equal(t, "", disco.RegistrationEndpoint)
require.Len(t, disco.CodeChallengeMethodsSupported, 1) assert.Len(t, disco.CodeChallengeMethodsSupported, 1)
assert.Equal(t, "S256", disco.CodeChallengeMethodsSupported[0]) assert.Contains(t, disco.CodeChallengeMethodsSupported, "S256")
assert.Len(t, disco.ScopesSupported, 5) assert.Len(t, disco.ScopesSupported, 5)
assert.Contains(t, disco.ScopesSupported, ScopeOpenID) assert.Contains(t, disco.ScopesSupported, ScopeOpenID)
@ -166,7 +168,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOAuth2WellKnownConfig
}, },
}, },
}, },
}) }, nil)
assert.NoError(t, err) assert.NoError(t, err)
@ -241,7 +243,7 @@ func TestOpenIDConnectProvider_NewOpenIDConnectProvider_GetOpenIDConnectWellKnow
}, },
}, },
}, },
}) }, nil)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -2,38 +2,32 @@ package oidc
import ( import (
"context" "context"
"crypto/sha256"
"database/sql"
"errors"
"fmt"
"time" "time"
"github.com/google/uuid"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/storage"
"gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/configuration/schema" "github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/storage"
) )
// NewOpenIDConnectStore returns a new OpenIDConnectStore using the provided schema.OpenIDConnectConfiguration. // NewOpenIDConnectStore returns a OpenIDConnectStore when provided with a schema.OpenIDConnectConfiguration and storage.Provider.
func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore) { func NewOpenIDConnectStore(config *schema.OpenIDConnectConfiguration, provider storage.Provider) (store *OpenIDConnectStore) {
logger := logging.Logger() logger := logging.Logger()
store = &OpenIDConnectStore{ store = &OpenIDConnectStore{
memory: &storage.MemoryStore{ provider: provider,
IDSessions: map[string]fosite.Requester{}, clients: map[string]*Client{},
Users: map[string]storage.MemoryUserRelation{},
AuthorizeCodes: map[string]storage.StoreAuthorizeCode{},
AccessTokens: map[string]fosite.Requester{},
RefreshTokens: map[string]storage.StoreRefreshToken{},
PKCES: map[string]fosite.Requester{},
BlacklistedJTIs: map[string]time.Time{},
AccessTokenRequestIDs: map[string]string{},
RefreshTokenRequestIDs: map[string]string{},
},
} }
store.clients = make(map[string]*InternalClient) for _, client := range config.Clients {
for _, client := range configuration.Clients {
policy := authorization.PolicyToLevel(client.Policy) policy := authorization.PolicyToLevel(client.Policy)
logger.Debugf("Registering client %s with policy %s (%v)", client.ID, client.Policy, policy) logger.Debugf("Registering client %s with policy %s (%v)", client.ID, client.Policy, policy)
@ -43,9 +37,37 @@ func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (st
return store return store
} }
// GenerateOpaqueUserID either retrieves or creates an opaque user id from a sectorID and username.
func (s OpenIDConnectStore) GenerateOpaqueUserID(ctx context.Context, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
if opaqueID, err = s.provider.LoadUserOpaqueIdentifierBySignature(ctx, "openid", sectorID, username); err != nil {
return nil, err
} else if opaqueID == nil {
if opaqueID, err = model.NewUserOpaqueIdentifier("openid", sectorID, username); err != nil {
return nil, err
}
if err = s.provider.SaveUserOpaqueIdentifier(ctx, *opaqueID); err != nil {
return nil, err
}
}
return opaqueID, nil
}
// GetSubject returns a subject UUID for a username. If it exists, it returns the existing one, otherwise it creates and saves it.
func (s OpenIDConnectStore) GetSubject(ctx context.Context, sectorID, username string) (subject uuid.UUID, err error) {
var opaqueID *model.UserOpaqueIdentifier
if opaqueID, err = s.GenerateOpaqueUserID(ctx, sectorID, username); err != nil {
return uuid.UUID{}, err
}
return opaqueID.Identifier, nil
}
// GetClientPolicy retrieves the policy from the client with the matching provided id. // GetClientPolicy retrieves the policy from the client with the matching provided id.
func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) { func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Level) {
client, err := s.GetInternalClient(id) client, err := s.GetFullClient(id)
if err != nil { if err != nil {
return authorization.TwoFactor return authorization.TwoFactor
} }
@ -53,8 +75,8 @@ func (s OpenIDConnectStore) GetClientPolicy(id string) (level authorization.Leve
return client.Policy return client.Policy
} }
// GetInternalClient returns a fosite.Client asserted as an InternalClient matching the provided id. // GetFullClient returns a fosite.Client asserted as an Client matching the provided id.
func (s OpenIDConnectStore) GetInternalClient(id string) (client *InternalClient, err error) { func (s OpenIDConnectStore) GetFullClient(id string) (client *Client, err error) {
client, ok := s.clients[id] client, ok := s.clients[id]
if !ok { if !ok {
return nil, fosite.ErrNotFound return nil, fosite.ErrNotFound
@ -65,142 +87,248 @@ func (s OpenIDConnectStore) GetInternalClient(id string) (client *InternalClient
// IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map. // IsValidClientID returns true if the provided id exists in the OpenIDConnectProvider.Clients map.
func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) { func (s OpenIDConnectStore) IsValidClientID(id string) (valid bool) {
_, err := s.GetInternalClient(id) _, err := s.GetFullClient(id)
return err == nil return err == nil
} }
// CreateOpenIDConnectSession decorates fosite's storage.MemoryStore CreateOpenIDConnectSession method. // BeginTX starts a transaction.
func (s *OpenIDConnectStore) CreateOpenIDConnectSession(ctx context.Context, authorizeCode string, requester fosite.Requester) error { // This implements a portion of fosite storage.Transactional interface.
return s.memory.CreateOpenIDConnectSession(ctx, authorizeCode, requester) func (s *OpenIDConnectStore) BeginTX(ctx context.Context) (c context.Context, err error) {
return s.provider.BeginTX(ctx)
} }
// GetOpenIDConnectSession decorates fosite's storage.MemoryStore GetOpenIDConnectSession method. // Commit completes a transaction.
func (s *OpenIDConnectStore) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, requester fosite.Requester) (fosite.Requester, error) { // This implements a portion of fosite storage.Transactional interface.
return s.memory.GetOpenIDConnectSession(ctx, authorizeCode, requester) func (s *OpenIDConnectStore) Commit(ctx context.Context) (err error) {
return s.provider.Commit(ctx)
} }
// DeleteOpenIDConnectSession decorates fosite's storage.MemoryStore DeleteOpenIDConnectSession method. // Rollback rolls a transaction back.
func (s *OpenIDConnectStore) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) error { // This implements a portion of fosite storage.Transactional interface.
return s.memory.DeleteOpenIDConnectSession(ctx, authorizeCode) func (s *OpenIDConnectStore) Rollback(ctx context.Context) (err error) {
return s.provider.Rollback(ctx)
} }
// GetClient decorates fosite's storage.MemoryStore GetClient method. // GetClient loads the client by its ID or returns an error if the client does not exist or another error occurred.
func (s *OpenIDConnectStore) GetClient(_ context.Context, id string) (fosite.Client, error) { // This implements a portion of fosite.ClientManager.
return s.GetInternalClient(id) func (s *OpenIDConnectStore) GetClient(_ context.Context, id string) (client fosite.Client, err error) {
return s.GetFullClient(id)
} }
// ClientAssertionJWTValid decorates fosite's storage.MemoryStore ClientAssertionJWTValid method. // ClientAssertionJWTValid returns an error if the JTI is known or the DB check failed and nil if the JTI is not known.
func (s *OpenIDConnectStore) ClientAssertionJWTValid(ctx context.Context, jti string) error { // This implements a portion of fosite.ClientManager.
return s.memory.ClientAssertionJWTValid(ctx, jti) func (s *OpenIDConnectStore) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
signature := fmt.Sprintf("%x", sha256.Sum256([]byte(jti)))
blacklistedJTI, err := s.provider.LoadOAuth2BlacklistedJTI(ctx, signature)
switch {
case errors.Is(sql.ErrNoRows, err):
return nil
case err != nil:
return err
case blacklistedJTI.ExpiresAt.After(time.Now()):
return fosite.ErrJTIKnown
default:
return nil
}
} }
// SetClientAssertionJWT decorates fosite's storage.MemoryStore SetClientAssertionJWT method. // SetClientAssertionJWT marks a JTI as known for the given expiry time. Before inserting the new JTI, it will clean
func (s *OpenIDConnectStore) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) error { // up any existing JTIs that have expired as those tokens can not be replayed due to the expiry.
return s.memory.SetClientAssertionJWT(ctx, jti, exp) // This implements a portion of fosite.ClientManager.
func (s *OpenIDConnectStore) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) (err error) {
blacklistedJTI := model.NewOAuth2BlacklistedJTI(jti, exp)
return s.provider.SaveOAuth2BlacklistedJTI(ctx, blacklistedJTI)
} }
// CreateAuthorizeCodeSession decorates fosite's storage.MemoryStore CreateAuthorizeCodeSession method. // CreateAuthorizeCodeSession stores the authorization request for a given authorization code.
func (s *OpenIDConnectStore) CreateAuthorizeCodeSession(ctx context.Context, code string, req fosite.Requester) error { // This implements a portion of oauth2.AuthorizeCodeStorage.
return s.memory.CreateAuthorizeCodeSession(ctx, code, req) func (s *OpenIDConnectStore) CreateAuthorizeCodeSession(ctx context.Context, code string, request fosite.Requester) (err error) {
return s.saveSession(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, request)
} }
// GetAuthorizeCodeSession decorates fosite's storage.MemoryStore GetAuthorizeCodeSession method. // InvalidateAuthorizeCodeSession is called when an authorize code is being used. The state of the authorization
func (s *OpenIDConnectStore) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (fosite.Requester, error) { // code should be set to invalid and consecutive requests to GetAuthorizeCodeSession should return the
return s.memory.GetAuthorizeCodeSession(ctx, code, session) // ErrInvalidatedAuthorizeCode error.
// This implements a portion of oauth2.AuthorizeCodeStorage.
func (s *OpenIDConnectStore) InvalidateAuthorizeCodeSession(ctx context.Context, code string) (err error) {
return s.provider.DeactivateOAuth2Session(ctx, storage.OAuth2SessionTypeAuthorizeCode, code)
} }
// InvalidateAuthorizeCodeSession decorates fosite's storage.MemoryStore InvalidateAuthorizeCodeSession method. // GetAuthorizeCodeSession hydrates the session based on the given code and returns the authorization request.
func (s *OpenIDConnectStore) InvalidateAuthorizeCodeSession(ctx context.Context, code string) error { // If the authorization code has been invalidated with `InvalidateAuthorizeCodeSession`, this
return s.memory.InvalidateAuthorizeCodeSession(ctx, code) // method should return the ErrInvalidatedAuthorizeCode error.
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedAuthorizeCode error!
// This implements a portion of oauth2.AuthorizeCodeStorage.
func (s *OpenIDConnectStore) GetAuthorizeCodeSession(ctx context.Context, code string, session fosite.Session) (request fosite.Requester, err error) {
// TODO: Implement the fosite.ErrInvalidatedAuthorizeCode error above. This requires splitting the invalidated sessions and deleted sessions.
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAuthorizeCode, code, session)
} }
// CreatePKCERequestSession decorates fosite's storage.MemoryStore CreatePKCERequestSession method. // CreateAccessTokenSession stores the authorization request for a given access token.
func (s *OpenIDConnectStore) CreatePKCERequestSession(ctx context.Context, code string, req fosite.Requester) error { // This implements a portion of oauth2.AccessTokenStorage.
return s.memory.CreatePKCERequestSession(ctx, code, req) func (s *OpenIDConnectStore) CreateAccessTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) {
return s.saveSession(ctx, storage.OAuth2SessionTypeAccessToken, signature, request)
} }
// GetPKCERequestSession decorates fosite's storage.MemoryStore GetPKCERequestSession method. // DeleteAccessTokenSession marks an access token session as deleted.
func (s *OpenIDConnectStore) GetPKCERequestSession(ctx context.Context, code string, session fosite.Session) (fosite.Requester, error) { // This implements a portion of oauth2.AccessTokenStorage.
return s.memory.GetPKCERequestSession(ctx, code, session) func (s *OpenIDConnectStore) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) {
return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature)
} }
// DeletePKCERequestSession decorates fosite's storage.MemoryStore DeletePKCERequestSession method. // RevokeAccessToken revokes an access token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1
func (s *OpenIDConnectStore) DeletePKCERequestSession(ctx context.Context, code string) error { // If the token passed to the request is an access token, the server MAY revoke the respective refresh token as well.
return s.memory.DeletePKCERequestSession(ctx, code) // This implements a portion of oauth2.TokenRevocationStorage.
func (s *OpenIDConnectStore) RevokeAccessToken(ctx context.Context, requestID string) (err error) {
return s.revokeSessionByRequestID(ctx, storage.OAuth2SessionTypeAccessToken, requestID)
} }
// CreateAccessTokenSession decorates fosite's storage.MemoryStore CreateAccessTokenSession method. // GetAccessTokenSession gets the authorization request for a given access token.
func (s *OpenIDConnectStore) CreateAccessTokenSession(ctx context.Context, signature string, req fosite.Requester) error { // This implements a portion of oauth2.AccessTokenStorage.
return s.memory.CreateAccessTokenSession(ctx, signature, req) func (s *OpenIDConnectStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature, session)
} }
// GetAccessTokenSession decorates fosite's storage.MemoryStore GetAccessTokenSession method. // CreateRefreshTokenSession stores the authorization request for a given refresh token.
func (s *OpenIDConnectStore) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { // This implements a portion of oauth2.RefreshTokenStorage.
return s.memory.GetAccessTokenSession(ctx, signature, session) func (s *OpenIDConnectStore) CreateRefreshTokenSession(ctx context.Context, signature string, request fosite.Requester) (err error) {
return s.saveSession(ctx, storage.OAuth2SessionTypeRefreshToken, signature, request)
} }
// DeleteAccessTokenSession decorates fosite's storage.MemoryStore DeleteAccessTokenSession method. // DeleteRefreshTokenSession marks the authorization request for a given refresh token as deleted.
func (s *OpenIDConnectStore) DeleteAccessTokenSession(ctx context.Context, signature string) error { // This implements a portion of oauth2.RefreshTokenStorage.
return s.memory.DeleteAccessTokenSession(ctx, signature) func (s *OpenIDConnectStore) DeleteRefreshTokenSession(ctx context.Context, signature string) (err error) {
return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature)
} }
// CreateRefreshTokenSession decorates fosite's storage.MemoryStore CreateRefreshTokenSession method. // RevokeRefreshToken revokes a refresh token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1
func (s *OpenIDConnectStore) CreateRefreshTokenSession(ctx context.Context, signature string, req fosite.Requester) error { // If the particular token is a refresh token and the authorization server supports the revocation of access tokens,
return s.memory.CreateRefreshTokenSession(ctx, signature, req) // then the authorization server SHOULD also invalidate all access tokens based on the same authorization grant (see Implementation Note).
// This implements a portion of oauth2.TokenRevocationStorage.
func (s *OpenIDConnectStore) RevokeRefreshToken(ctx context.Context, requestID string) (err error) {
return s.provider.DeactivateOAuth2SessionByRequestID(ctx, storage.OAuth2SessionTypeRefreshToken, requestID)
} }
// GetRefreshTokenSession decorates fosite's storage.MemoryStore GetRefreshTokenSession method. // RevokeRefreshTokenMaybeGracePeriod revokes an access token as specified in: https://tools.ietf.org/html/rfc7009#section-2.1
func (s *OpenIDConnectStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { // If the token passed to the request is an access token, the server MAY revoke the respective refresh token as well.
return s.memory.GetRefreshTokenSession(ctx, signature, session) // This implements a portion of oauth2.TokenRevocationStorage.
func (s *OpenIDConnectStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) (err error) {
return s.RevokeRefreshToken(ctx, requestID)
} }
// DeleteRefreshTokenSession decorates fosite's storage.MemoryStore DeleteRefreshTokenSession method. // GetRefreshTokenSession gets the authorization request for a given refresh token.
func (s *OpenIDConnectStore) DeleteRefreshTokenSession(ctx context.Context, signature string) error { // This implements a portion of oauth2.RefreshTokenStorage.
return s.memory.DeleteRefreshTokenSession(ctx, signature) func (s *OpenIDConnectStore) GetRefreshTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeRefreshToken, signature, session)
} }
// Authenticate decorates fosite's storage.MemoryStore Authenticate method. // CreatePKCERequestSession stores the authorization request for a given PKCE request.
func (s *OpenIDConnectStore) Authenticate(ctx context.Context, name string, secret string) error { // This implements a portion of pkce.PKCERequestStorage.
return s.memory.Authenticate(ctx, name, secret) func (s *OpenIDConnectStore) CreatePKCERequestSession(ctx context.Context, signature string, request fosite.Requester) (err error) {
return s.saveSession(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, request)
} }
// RevokeRefreshToken decorates fosite's storage.MemoryStore RevokeRefreshToken method. // DeletePKCERequestSession marks the authorization request for a given PKCE request as deleted.
func (s *OpenIDConnectStore) RevokeRefreshToken(ctx context.Context, requestID string) error { // This implements a portion of pkce.PKCERequestStorage.
return s.memory.RevokeRefreshToken(ctx, requestID) func (s *OpenIDConnectStore) DeletePKCERequestSession(ctx context.Context, signature string) (err error) {
return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, signature)
} }
// RevokeRefreshTokenMaybeGracePeriod decorates fosite's storage.MemoryStore RevokeRefreshTokenMaybeGracePeriod method. // GetPKCERequestSession gets the authorization request for a given PKCE request.
func (s OpenIDConnectStore) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestID string, signature string) error { // This implements a portion of pkce.PKCERequestStorage.
return s.memory.RevokeRefreshTokenMaybeGracePeriod(ctx, requestID, signature) func (s *OpenIDConnectStore) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (requester fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypePKCEChallenge, signature, session)
} }
// RevokeAccessToken decorates fosite's storage.MemoryStore RevokeAccessToken method. // CreateOpenIDConnectSession creates an open id connect session for a given authorize code.
func (s *OpenIDConnectStore) RevokeAccessToken(ctx context.Context, requestID string) error { // This is relevant for explicit open id connect flow.
return s.memory.RevokeAccessToken(ctx, requestID) // This implements a portion of openid.OpenIDConnectRequestStorage.
func (s *OpenIDConnectStore) CreateOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (err error) {
return s.saveSession(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request)
} }
// GetPublicKey decorates fosite's storage.MemoryStore GetPublicKey method. // DeleteOpenIDConnectSession just implements the method required by fosite even though it's unused.
func (s *OpenIDConnectStore) GetPublicKey(ctx context.Context, issuer string, subject string, keyID string) (*jose.JSONWebKey, error) { // This implements a portion of openid.OpenIDConnectRequestStorage.
return s.memory.GetPublicKey(ctx, issuer, subject, keyID) func (s *OpenIDConnectStore) DeleteOpenIDConnectSession(ctx context.Context, authorizeCode string) (err error) {
return s.revokeSessionBySignature(ctx, storage.OAuth2SessionTypeAccessToken, authorizeCode)
} }
// GetPublicKeys decorates fosite's storage.MemoryStore GetPublicKeys method. // GetOpenIDConnectSession returns error:
func (s *OpenIDConnectStore) GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) { // - nil if a session was found,
return s.memory.GetPublicKeys(ctx, issuer, subject) // - ErrNoSessionFound if no session was found
// - or an arbitrary error if an error occurred.
// This implements a portion of openid.OpenIDConnectRequestStorage.
func (s *OpenIDConnectStore) GetOpenIDConnectSession(ctx context.Context, authorizeCode string, request fosite.Requester) (r fosite.Requester, err error) {
return s.loadSessionBySignature(ctx, storage.OAuth2SessionTypeOpenIDConnect, authorizeCode, request.GetSession())
} }
// GetPublicKeyScopes decorates fosite's storage.MemoryStore GetPublicKeyScopes method. // IsJWTUsed implements an interface required for RFC7523.
func (s *OpenIDConnectStore) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyID string) ([]string, error) { func (s *OpenIDConnectStore) IsJWTUsed(ctx context.Context, jti string) (used bool, err error) {
return s.memory.GetPublicKeyScopes(ctx, issuer, subject, keyID) if err = s.ClientAssertionJWTValid(ctx, jti); err != nil {
return true, err
}
return false, nil
} }
// IsJWTUsed decorates fosite's storage.MemoryStore IsJWTUsed method. // MarkJWTUsedForTime implements an interface required for rfc7523.RFC7523KeyStorage.
func (s *OpenIDConnectStore) IsJWTUsed(ctx context.Context, jti string) (bool, error) { func (s *OpenIDConnectStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) {
return s.memory.IsJWTUsed(ctx, jti) return s.SetClientAssertionJWT(ctx, jti, exp)
} }
// MarkJWTUsedForTime decorates fosite's storage.MemoryStore MarkJWTUsedForTime method. func (s *OpenIDConnectStore) loadSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, session fosite.Session) (r fosite.Requester, err error) {
func (s *OpenIDConnectStore) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error { var (
return s.memory.MarkJWTUsedForTime(ctx, jti, exp) sessionModel *model.OAuth2Session
)
sessionModel, err = s.provider.LoadOAuth2Session(ctx, sessionType, signature)
if err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, fosite.ErrNotFound
default:
return nil, err
}
}
if r, err = sessionModel.ToRequest(ctx, session, s); err != nil {
return nil, err
}
if !sessionModel.Active && sessionType == storage.OAuth2SessionTypeAuthorizeCode {
return r, fosite.ErrInvalidatedAuthorizeCode
}
return r, nil
}
func (s *OpenIDConnectStore) saveSession(ctx context.Context, sessionType storage.OAuth2SessionType, signature string, r fosite.Requester) (err error) {
var session *model.OAuth2Session
if session, err = model.NewOAuth2SessionFromRequest(signature, r); err != nil {
return err
}
return s.provider.SaveOAuth2Session(ctx, sessionType, *session)
}
func (s *OpenIDConnectStore) revokeSessionBySignature(ctx context.Context, sessionType storage.OAuth2SessionType, signature string) (err error) {
return s.provider.RevokeOAuth2Session(ctx, sessionType, signature)
}
func (s *OpenIDConnectStore) revokeSessionByRequestID(ctx context.Context, sessionType storage.OAuth2SessionType, requestID string) (err error) {
if err = s.provider.RevokeOAuth2SessionByRequestID(ctx, sessionType, requestID); err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return fosite.ErrNotFound
default:
return err
}
}
return nil
} }

View File

@ -1,15 +1,6 @@
package oidc package oidc
import ( /*
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) { func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{ s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
@ -80,7 +71,7 @@ func TestOpenIDConnectStore_GetInternalClient_ValidClient(t *testing.T) {
Clients: []schema.OpenIDConnectClientConfiguration{c1}, Clients: []schema.OpenIDConnectClientConfiguration{c1},
}) })
client, err := s.GetInternalClient(c1.ID) client, err := s.GetFullClient(c1.ID)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, client) require.NotNil(t, client)
assert.Equal(t, client.ID, c1.ID) assert.Equal(t, client.ID, c1.ID)
@ -107,7 +98,7 @@ func TestOpenIDConnectStore_GetInternalClient_InvalidClient(t *testing.T) {
Clients: []schema.OpenIDConnectClientConfiguration{c1}, Clients: []schema.OpenIDConnectClientConfiguration{c1},
}) })
client, err := s.GetInternalClient("another-client") client, err := s.GetFullClient("another-client")
assert.Nil(t, client) assert.Nil(t, client)
assert.EqualError(t, err, "not_found") assert.EqualError(t, err, "not_found")
} }
@ -131,4 +122,5 @@ func TestOpenIDConnectStore_IsValidClientID(t *testing.T) {
assert.True(t, validClient) assert.True(t, validClient)
assert.False(t, invalidClient) assert.False(t, invalidClient)
} }.
*/

View File

@ -6,17 +6,19 @@ import (
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/storage"
"github.com/ory/fosite/token/jwt" "github.com/ory/fosite/token/jwt"
"github.com/ory/herodot" "github.com/ory/herodot"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"github.com/authelia/authelia/v4/internal/authorization" "github.com/authelia/authelia/v4/internal/authorization"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils"
) )
// NewSession creates a new OpenIDSession struct. // NewSession creates a new empty OpenIDSession struct.
func NewSession() (session *OpenIDSession) { func NewSession() (session *model.OpenIDSession) {
return &OpenIDSession{ return &model.OpenIDSession{
DefaultSession: &openid.DefaultSession{ DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{
Extra: map[string]interface{}{}, Extra: map[string]interface{}{},
@ -30,19 +32,19 @@ func NewSession() (session *OpenIDSession) {
} }
// NewSessionWithAuthorizeRequest uses details from an AuthorizeRequester to generate an OpenIDSession. // NewSessionWithAuthorizeRequest uses details from an AuthorizeRequester to generate an OpenIDSession.
func NewSessionWithAuthorizeRequest(issuer, kid, subject, username string, amr []string, extra map[string]interface{}, func NewSessionWithAuthorizeRequest(issuer, kid, username string, amr []string, extra map[string]interface{},
authTime, requestedAt time.Time, requester fosite.AuthorizeRequester) (session *OpenIDSession) { authTime time.Time, consent *model.OAuth2ConsentSession, requester fosite.AuthorizeRequester) (session *model.OpenIDSession) {
if extra == nil { if extra == nil {
extra = make(map[string]interface{}) extra = make(map[string]interface{})
} }
return &OpenIDSession{ session = &model.OpenIDSession{
DefaultSession: &openid.DefaultSession{ DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{
Subject: subject, Subject: consent.Subject.String(),
Issuer: issuer, Issuer: issuer,
AuthTime: authTime, AuthTime: authTime,
RequestedAt: requestedAt, RequestedAt: consent.RequestedAt,
IssuedAt: time.Now(), IssuedAt: time.Now(),
Nonce: requester.GetRequestForm().Get("nonce"), Nonce: requester.GetRequestForm().Get("nonce"),
Audience: requester.GetGrantedAudience(), Audience: requester.GetGrantedAudience(),
@ -55,12 +57,20 @@ func NewSessionWithAuthorizeRequest(issuer, kid, subject, username string, amr [
"kid": kid, "kid": kid,
}, },
}, },
Subject: subject, Subject: consent.Subject.String(),
Username: username, Username: username,
}, },
Extra: map[string]interface{}{}, Extra: map[string]interface{}{},
ClientID: requester.GetClient().GetID(), ClientID: requester.GetClient().GetID(),
ChallengeID: consent.ChallengeID,
} }
// Ensure required audience value of the client_id exists.
if !utils.IsStringInSlice(requester.GetClient().GetID(), session.Claims.Audience) {
session.Claims.Audience = append(session.Claims.Audience, requester.GetClient().GetID())
}
return session
} }
// OpenIDConnectProvider for OpenID Connect. // OpenIDConnectProvider for OpenID Connect.
@ -74,33 +84,34 @@ type OpenIDConnectProvider struct {
discovery OpenIDConnectWellKnownConfiguration discovery OpenIDConnectWellKnownConfiguration
} }
// OpenIDConnectStore is Authelia's internal representation of the fosite.Storage interface. // OpenIDConnectStore is Authelia's internal representation of the fosite.Storage interface. It maps the following
// // interfaces to the storage.Provider interface:
// Currently it is mostly just implementing a decorator pattern other then GetInternalClient. // fosite.Storage, fosite.ClientManager, storage.Transactional, oauth2.AuthorizeCodeStorage, oauth2.AccessTokenStorage,
// The long term plan is to have these methods interact with the Authelia storage and // oauth2.RefreshTokenStorage, oauth2.TokenRevocationStorage, pkce.PKCERequestStorage,
// session providers where applicable. // openid.OpenIDConnectRequestStorage, and partially implements rfc7523.RFC7523KeyStorage.
type OpenIDConnectStore struct { type OpenIDConnectStore struct {
clients map[string]*InternalClient provider storage.Provider
memory *storage.MemoryStore clients map[string]*Client
} }
// InternalClient represents the client internally. // Client represents the client internally.
type InternalClient struct { type Client struct {
ID string `json:"id"` ID string
Description string `json:"-"` SectorIdentifier string
Secret []byte `json:"client_secret,omitempty"` Description string
Public bool `json:"public"` Secret []byte
Public bool
Policy authorization.Level `json:"-"` Policy authorization.Level
Audience []string `json:"audience"` Audience []string
Scopes []string `json:"scopes"` Scopes []string
RedirectURIs []string `json:"redirect_uris"` RedirectURIs []string
GrantTypes []string `json:"grant_types"` GrantTypes []string
ResponseTypes []string `json:"response_types"` ResponseTypes []string
ResponseModes []fosite.ResponseModeType `json:"response_modes"` ResponseModes []fosite.ResponseModeType
UserinfoSigningAlgorithm string `json:"userinfo_signed_response_alg,omitempty"` UserinfoSigningAlgorithm string
} }
// KeyManager keeps track of all of the active/inactive rsa keys and provides them to services requiring them. // KeyManager keeps track of all of the active/inactive rsa keys and provides them to services requiring them.
@ -112,8 +123,8 @@ type KeyManager struct {
strategy *RS256JWTStrategy strategy *RS256JWTStrategy
} }
// AutheliaHasher implements the fosite.Hasher interface without an actual hashing algo. // PlainTextHasher implements the fosite.Hasher interface without an actual hashing algo.
type AutheliaHasher struct{} type PlainTextHasher struct{}
// ConsentGetResponseBody schema of the response body of the consent GET endpoint. // ConsentGetResponseBody schema of the response body of the consent GET endpoint.
type ConsentGetResponseBody struct { type ConsentGetResponseBody struct {
@ -123,12 +134,15 @@ type ConsentGetResponseBody struct {
Audience []string `json:"audience"` Audience []string `json:"audience"`
} }
// OpenIDSession holds OIDC Session information. // ConsentPostRequestBody schema of the request body of the consent POST endpoint.
type OpenIDSession struct { type ConsentPostRequestBody struct {
*openid.DefaultSession `json:"idToken"` ClientID string `json:"client_id"`
AcceptOrReject string `json:"accept_or_reject"`
}
Extra map[string]interface{} `json:"extra"` // ConsentPostResponseBody schema of the response body of the consent POST endpoint.
ClientID string type ConsentPostResponseBody struct {
RedirectURI string `json:"redirect_uri"`
} }
/* /*

View File

@ -9,6 +9,8 @@ import (
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/authelia/authelia/v4/internal/model"
) )
func TestNewSession(t *testing.T) { func TestNewSession(t *testing.T) {
@ -38,7 +40,7 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) {
Request: fosite.Request{ Request: fosite.Request{
ID: requestID.String(), ID: requestID.String(),
Form: formValues, Form: formValues,
Client: &InternalClient{ID: "example"}, Client: &Client{ID: "example"},
}, },
} }
@ -51,7 +53,13 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) {
issuer := "https://example.com" issuer := "https://example.com"
amr := []string{AMRPasswordBasedAuthentication} amr := []string{AMRPasswordBasedAuthentication}
session := NewSessionWithAuthorizeRequest(issuer, "primary", subject.String(), "john", amr, extra, authAt, requested, request) consent := &model.OAuth2ConsentSession{
ChallengeID: uuid.New(),
RequestedAt: requested,
Subject: subject,
}
session := NewSessionWithAuthorizeRequest(issuer, "primary", "john", amr, extra, authAt, consent, request)
require.NotNil(t, session) require.NotNil(t, session)
require.NotNil(t, session.Extra) require.NotNil(t, session.Extra)
@ -78,7 +86,12 @@ func TestNewSessionWithAuthorizeRequest(t *testing.T) {
require.Contains(t, session.Claims.Extra, "preferred_username") require.Contains(t, session.Claims.Extra, "preferred_username")
session = NewSessionWithAuthorizeRequest(issuer, "primary", subject.String(), "john", nil, nil, authAt, requested, request) consent = &model.OAuth2ConsentSession{
ChallengeID: uuid.New(),
RequestedAt: requested,
}
session = NewSessionWithAuthorizeRequest(issuer, "primary", "john", nil, nil, authAt, consent, request)
require.NotNil(t, session) require.NotNil(t, session)
require.NotNil(t, session.Claims) require.NotNil(t, session.Claims)

View File

@ -7,11 +7,11 @@ import (
"github.com/fasthttp/session/v2" "github.com/fasthttp/session/v2"
"github.com/fasthttp/session/v2/providers/redis" "github.com/fasthttp/session/v2/providers/redis"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
"github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication" "github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/logging" "github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/oidc" "github.com/authelia/authelia/v4/internal/oidc"
) )
@ -43,8 +43,8 @@ type UserSession struct {
// Webauthn holds the session registration data for this session. // Webauthn holds the session registration data for this session.
Webauthn *webauthn.SessionData Webauthn *webauthn.SessionData
// Represent an OIDC workflow session initiated by the client if not null. // ConsentChallengeID is the OpenID Connect Consent Session challenge ID.
OIDCWorkflowSession *model.OIDCWorkflowSession ConsentChallengeID *uuid.UUID
// This boolean is set to true after identity verification and checked // This boolean is set to true after identity verification and checked
// while doing the query actually updating the password. // while doing the query actually updating the password.

View File

@ -5,18 +5,40 @@ import (
) )
const ( const (
tableUserPreferences = "user_preferences" tableAuthenticationLogs = "authentication_logs"
tableDuoDevices = "duo_devices"
tableIdentityVerification = "identity_verification" tableIdentityVerification = "identity_verification"
tableTOTPConfigurations = "totp_configurations" tableTOTPConfigurations = "totp_configurations"
tableUserOpaqueIdentifier = "user_opaque_identifier"
tableUserPreferences = "user_preferences"
tableWebauthnDevices = "webauthn_devices" tableWebauthnDevices = "webauthn_devices"
tableDuoDevices = "duo_devices"
tableAuthenticationLogs = "authentication_logs" tableOAuth2ConsentSession = "oauth2_consent_session"
tableOAuth2AuthorizeCodeSession = "oauth2_authorization_code_session"
tableOAuth2AccessTokenSession = "oauth2_access_token_session" //nolint:gosec // This is not a hardcoded credential.
tableOAuth2RefreshTokenSession = "oauth2_refresh_token_session" //nolint:gosec // This is not a hardcoded credential.
tableOAuth2PKCERequestSession = "oauth2_pkce_request_session"
tableOAuth2OpenIDConnectSession = "oauth2_openid_connect_session"
tableOAuth2BlacklistedJTI = "oauth2_blacklisted_jti"
tableMigrations = "migrations" tableMigrations = "migrations"
tableEncryption = "encryption" tableEncryption = "encryption"
tablePrefixBackup = "_bkp_" tablePrefixBackup = "_bkp_"
) )
// OAuth2SessionType represents the potential OAuth 2.0 session types.
type OAuth2SessionType string
// Representation of specific OAuth 2.0 session types.
const (
OAuth2SessionTypeAuthorizeCode OAuth2SessionType = "authorization code"
OAuth2SessionTypeAccessToken OAuth2SessionType = "access token"
OAuth2SessionTypeRefreshToken OAuth2SessionType = "refresh token"
OAuth2SessionTypePKCEChallenge OAuth2SessionType = "pkce challenge"
OAuth2SessionTypeOpenIDConnect OAuth2SessionType = "openid connect"
)
const ( const (
encryptionNameCheck = "check" encryptionNameCheck = "check"
) )
@ -56,7 +78,7 @@ const (
const ( const (
// This is the latest schema version for the purpose of tests. // This is the latest schema version for the purpose of tests.
testLatestVersion = 3 testLatestVersion = 4
) )
const ( const (
@ -64,6 +86,12 @@ const (
SchemaLatest = 2147483647 SchemaLatest = 2147483647
) )
type ctxKey int
const (
ctxKeyTransaction ctxKey = iota
)
var ( var (
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`) reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`)
) )

View File

@ -0,0 +1,188 @@
CREATE TABLE IF NOT EXISTS user_opaque_identifier (
id INTEGER AUTO_INCREMENT,
service VARCHAR(20) NOT NULL,
sector_id VARCHAR(255) NOT NULL,
username VARCHAR(100) NOT NULL,
identifier CHAR(36) NOT NULL,
PRIMARY KEY (id)
);
CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username);
CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier);
CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti (
id INTEGER AUTO_INCREMENT,
signature VARCHAR(64) NOT NULL,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
);
CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature);
CREATE TABLE IF NOT EXISTS oauth2_consent_session (
id INTEGER AUTO_INCREMENT,
challenge_id CHAR(36) NOT NULL,
client_id VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
authorized BOOLEAN NOT NULL DEFAULT FALSE,
granted BOOLEAN NOT NULL DEFAULT FALSE,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
responded_at TIMESTAMP NULL DEFAULT NULL,
expires_at TIMESTAMP NULL DEFAULT NULL,
form_data TEXT NOT NULL,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL,
granted_audience TEXT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_consent_subject_fkey
FOREIGN KEY (subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id);
CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session (
id INTEGER AUTO_INCREMENT,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL,
granted_audience TEXT NULL,
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey
FOREIGN KEY (challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_authorization_code_session_subject_fkey
FOREIGN KEY (subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id);
CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id);
CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_access_token_session (
id INTEGER AUTO_INCREMENT,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL,
granted_audience TEXT NULL,
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_access_token_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_access_token_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id);
CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id);
CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session (
id INTEGER AUTO_INCREMENT,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL,
granted_audience TEXT NULL,
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_refresh_token_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id);
CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id);
CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session (
id INTEGER AUTO_INCREMENT,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL,
granted_audience TEXT NULL,
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_pkce_request_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id);
CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id);
CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session (
id INTEGER AUTO_INCREMENT,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL,
granted_audience TEXT NULL,
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_openid_connect_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id);
CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id);
CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject);

View File

@ -0,0 +1,188 @@
CREATE TABLE IF NOT EXISTS user_opaque_identifier (
id SERIAL,
service VARCHAR(20) NOT NULL,
sector_id VARCHAR(255) NOT NULL,
username VARCHAR(100) NOT NULL,
identifier CHAR(36) NOT NULL,
PRIMARY KEY (id)
);
CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username);
CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier);
CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti (
id SERIAL,
signature VARCHAR(64) NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
);
CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature);
CREATE TABLE IF NOT EXISTS oauth2_consent_session (
id SERIAL,
challenge_id CHAR(36) NOT NULL,
client_id VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
authorized BOOLEAN NOT NULL DEFAULT FALSE,
granted BOOLEAN NOT NULL DEFAULT FALSE,
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
responded_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
expires_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
form_data TEXT NOT NULL,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
PRIMARY KEY (id),
CONSTRAINT oauth2_consent_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id);
CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session (
id SERIAL,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36),
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BYTEA NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_authorization_code_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id);
CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id);
CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_access_token_session (
id SERIAL,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BYTEA NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_access_token_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_access_token_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id);
CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id);
CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session (
id SERIAL,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BYTEA NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_refresh_token_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id);
CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id);
CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session (
id SERIAL,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BYTEA NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_pkce_request_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id);
CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id);
CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session (
id SERIAL,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BYTEA NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_openid_connect_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id);
CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id);
CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject);

View File

@ -0,0 +1,188 @@
CREATE TABLE IF NOT EXISTS user_opaque_identifier (
id INTEGER,
service VARCHAR(20) NOT NULL,
sector_id VARCHAR(255) NOT NULL,
username VARCHAR(100) NOT NULL,
identifier CHAR(36) NOT NULL,
PRIMARY KEY (id)
);
CREATE UNIQUE INDEX user_opaque_identifier_service_sector_id_username_key ON user_opaque_identifier (service, sector_id, username);
CREATE UNIQUE INDEX user_opaque_identifier_identifier_key ON user_opaque_identifier (identifier);
CREATE TABLE IF NOT EXISTS oauth2_blacklisted_jti (
id INTEGER,
signature VARCHAR(64) NOT NULL,
expires_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id)
);
CREATE UNIQUE INDEX oauth2_blacklisted_jti_signature_key ON oauth2_blacklisted_jti (signature);
CREATE TABLE IF NOT EXISTS oauth2_consent_session (
id INTEGER,
challenge_id CHAR(36) NOT NULL,
client_id VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
authorized BOOLEAN NOT NULL DEFAULT FALSE,
granted BOOLEAN NOT NULL DEFAULT FALSE,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
responded_at TIMESTAMP NULL DEFAULT NULL,
expires_at TIMESTAMP NULL DEFAULT NULL,
form_data TEXT NOT NULL,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
PRIMARY KEY (id),
CONSTRAINT oauth2_consent_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE UNIQUE INDEX oauth2_consent_session_challenge_id_key ON oauth2_consent_session (challenge_id);
CREATE TABLE IF NOT EXISTS oauth2_authorization_code_session (
id INTEGER,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_authorization_code_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_authorization_code_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_authorization_code_session_request_id_idx ON oauth2_authorization_code_session (request_id);
CREATE INDEX oauth2_authorization_code_session_client_id_idx ON oauth2_authorization_code_session (client_id);
CREATE INDEX oauth2_authorization_code_session_client_id_subject_idx ON oauth2_authorization_code_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_access_token_session (
id INTEGER,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_access_token_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_access_token_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_access_token_session_request_id_idx ON oauth2_access_token_session (request_id);
CREATE INDEX oauth2_access_token_session_client_id_idx ON oauth2_access_token_session (client_id);
CREATE INDEX oauth2_access_token_session_client_id_subject_idx ON oauth2_access_token_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_refresh_token_session (
id INTEGER,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_refresh_token_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_refresh_token_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_refresh_token_session_request_id_idx ON oauth2_refresh_token_session (request_id);
CREATE INDEX oauth2_refresh_token_session_client_id_idx ON oauth2_refresh_token_session (client_id);
CREATE INDEX oauth2_refresh_token_session_client_id_subject_idx ON oauth2_refresh_token_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_pkce_request_session (
id INTEGER,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_pkce_request_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_pkce_request_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_pkce_request_session_request_id_idx ON oauth2_pkce_request_session (request_id);
CREATE INDEX oauth2_pkce_request_session_client_id_idx ON oauth2_pkce_request_session (client_id);
CREATE INDEX oauth2_pkce_request_session_client_id_subject_idx ON oauth2_pkce_request_session (client_id, subject);
CREATE TABLE IF NOT EXISTS oauth2_openid_connect_session (
id INTEGER,
challenge_id CHAR(36) NOT NULL,
request_id VARCHAR(40) NOT NULL,
client_id VARCHAR(255) NOT NULL,
signature VARCHAR(255) NOT NULL,
subject CHAR(36) NOT NULL,
requested_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
requested_scopes TEXT NOT NULL,
granted_scopes TEXT NOT NULL,
requested_audience TEXT NULL DEFAULT '',
granted_audience TEXT NULL DEFAULT '',
active BOOLEAN NOT NULL DEFAULT FALSE,
revoked BOOLEAN NOT NULL DEFAULT FALSE,
form_data TEXT NOT NULL,
session_data BLOB NOT NULL,
PRIMARY KEY (id),
CONSTRAINT oauth2_openid_connect_session_challenge_id_fkey
FOREIGN KEY(challenge_id)
REFERENCES oauth2_consent_session(challenge_id) ON UPDATE CASCADE ON DELETE CASCADE,
CONSTRAINT oauth2_openid_connect_session_subject_fkey
FOREIGN KEY(subject)
REFERENCES user_opaque_identifier(identifier) ON UPDATE RESTRICT ON DELETE RESTRICT
);
CREATE INDEX oauth2_openid_connect_session_request_id_idx ON oauth2_openid_connect_session (request_id);
CREATE INDEX oauth2_openid_connect_session_client_id_idx ON oauth2_openid_connect_session (client_id);
CREATE INDEX oauth2_openid_connect_session_client_id_subject_idx ON oauth2_openid_connect_session (client_id, subject);

View File

@ -0,0 +1,8 @@
DROP TABLE IF EXISTS oauth2_blacklisted_jti;
DROP TABLE IF EXISTS oauth2_authorization_code_session;
DROP TABLE IF EXISTS oauth2_access_token_session;
DROP TABLE IF EXISTS oauth2_refresh_token_session;
DROP TABLE IF EXISTS oauth2_pkce_request_session;
DROP TABLE IF EXISTS oauth2_openid_connect_session;
DROP TABLE IF EXISTS oauth2_consent_session;
DROP TABLE IF EXISTS user_opaque_identifier;

View File

@ -4,6 +4,9 @@ import (
"context" "context"
"time" "time"
"github.com/google/uuid"
"github.com/ory/fosite/storage"
"github.com/authelia/authelia/v4/internal/model" "github.com/authelia/authelia/v4/internal/model"
) )
@ -13,10 +16,16 @@ type Provider interface {
RegulatorProvider RegulatorProvider
storage.Transactional
SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error)
LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error)
LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error)
SaveUserOpaqueIdentifier(ctx context.Context, subject model.UserOpaqueIdentifier) (err error)
LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (subject *model.UserOpaqueIdentifier, err error)
LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (subject *model.UserOpaqueIdentifier, err error)
SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error)
ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error)
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
@ -36,6 +45,22 @@ type Provider interface {
DeletePreferredDuoDevice(ctx context.Context, username string) (err error) DeletePreferredDuoDevice(ctx context.Context, username string) (err error)
LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error)
SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error)
SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, rejection bool) (err error)
SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error)
LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error)
LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error)
SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error)
RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error)
RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error)
DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error)
DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error)
LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error)
SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error)
LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error)
SchemaTables(ctx context.Context) (tables []string, err error) SchemaTables(ctx context.Context) (tables []string, err error)
SchemaVersion(ctx context.Context) (version int, err error) SchemaVersion(ctx context.Context) (version int, err error)
SchemaLatestVersion() (version int, err error) SchemaLatestVersion() (version int, err error)

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -63,6 +64,54 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences), sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebauthnDevices, tableDuoDevices, tableUserPreferences), sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebauthnDevices, tableDuoDevices, tableUserPreferences),
sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier),
sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier),
sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier),
sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession),
sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession),
sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession),
sqlRevokeOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession),
sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlRevokeOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlRevokeOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession),
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
sqlInsertOAuth2ConsentSession: fmt.Sprintf(queryFmtInsertOAuth2ConsentSession, tableOAuth2ConsentSession),
sqlUpdateOAuth2ConsentSessionResponse: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionResponse, tableOAuth2ConsentSession),
sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession),
sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession),
sqlSelectOAuth2ConsentSessionsPreConfigured: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionsPreConfigured, tableOAuth2ConsentSession),
sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations), sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations), sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations), sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
@ -128,6 +177,11 @@ type SQLProvider struct {
sqlSelectPreferred2FAMethod string sqlSelectPreferred2FAMethod string
sqlSelectUserInfo string sqlSelectUserInfo string
// Table: user_opaque_identifier.
sqlInsertUserOpaqueIdentifier string
sqlSelectUserOpaqueIdentifier string
sqlSelectUserOpaqueIdentifierBySignature string
// Table: migrations. // Table: migrations.
sqlInsertMigration string sqlInsertMigration string
sqlSelectMigrations string sqlSelectMigrations string
@ -137,6 +191,56 @@ type SQLProvider struct {
sqlUpsertEncryptionValue string sqlUpsertEncryptionValue string
sqlSelectEncryptionValue string sqlSelectEncryptionValue string
// Table: oauth2_authorization_code_session.
sqlInsertOAuth2AuthorizeCodeSession string
sqlSelectOAuth2AuthorizeCodeSession string
sqlRevokeOAuth2AuthorizeCodeSession string
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID string
sqlDeactivateOAuth2AuthorizeCodeSession string
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID string
// Table: oauth2_access_token_session.
sqlInsertOAuth2AccessTokenSession string
sqlSelectOAuth2AccessTokenSession string
sqlRevokeOAuth2AccessTokenSession string
sqlRevokeOAuth2AccessTokenSessionByRequestID string
sqlDeactivateOAuth2AccessTokenSession string
sqlDeactivateOAuth2AccessTokenSessionByRequestID string
// Table: oauth2_refresh_token_session.
sqlInsertOAuth2RefreshTokenSession string
sqlSelectOAuth2RefreshTokenSession string
sqlRevokeOAuth2RefreshTokenSession string
sqlRevokeOAuth2RefreshTokenSessionByRequestID string
sqlDeactivateOAuth2RefreshTokenSession string
sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
// Table: oauth2_pkce_request_session.
sqlInsertOAuth2PKCERequestSession string
sqlSelectOAuth2PKCERequestSession string
sqlRevokeOAuth2PKCERequestSession string
sqlRevokeOAuth2PKCERequestSessionByRequestID string
sqlDeactivateOAuth2PKCERequestSession string
sqlDeactivateOAuth2PKCERequestSessionByRequestID string
// Table: oauth2_openid_connect_session.
sqlInsertOAuth2OpenIDConnectSession string
sqlSelectOAuth2OpenIDConnectSession string
sqlRevokeOAuth2OpenIDConnectSession string
sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
sqlDeactivateOAuth2OpenIDConnectSession string
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
// Table: oauth2_consent_session.
sqlInsertOAuth2ConsentSession string
sqlUpdateOAuth2ConsentSessionResponse string
sqlUpdateOAuth2ConsentSessionGranted string
sqlSelectOAuth2ConsentSessionByChallengeID string
sqlSelectOAuth2ConsentSessionsPreConfigured string
sqlUpsertOAuth2BlacklistedJTI string
sqlSelectOAuth2BlacklistedJTI string
// Utility. // Utility.
sqlSelectExistingTables string sqlSelectExistingTables string
sqlFmtRenameTable string sqlFmtRenameTable string
@ -187,6 +291,329 @@ func (p *SQLProvider) StartupCheck() (err error) {
} }
} }
// BeginTX begins a transaction.
func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) {
var tx *sql.Tx
if tx, err = p.db.Begin(); err != nil {
return nil, err
}
return context.WithValue(ctx, ctxKeyTransaction, tx), nil
}
// Commit performs a database commit.
func (p *SQLProvider) Commit(ctx context.Context) (err error) {
tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
if !ok {
return errors.New("could not retrieve tx")
}
return tx.Commit()
}
// Rollback performs a database rollback.
func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
if !ok {
return errors.New("could not retrieve tx")
}
return tx.Rollback()
}
// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil {
return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err)
}
return nil
}
// LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) {
opaqueID = &model.UserOpaqueIdentifier{}
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, nil
default:
return nil, err
}
}
return opaqueID, nil
}
// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
opaqueID = &model.UserOpaqueIdentifier{}
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, nil
default:
return nil, err
}
}
return opaqueID, nil
}
// SaveOAuth2ConsentSession inserts an OAuth2.0 consent.
func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession,
consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted,
consent.RequestedAt, consent.RespondedAt, consent.ExpiresAt, consent.Form,
consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience); err != nil {
return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.String(), err)
}
return nil
}
// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent with the consent response.
func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.ExpiresAt, consent.GrantedScopes, consent.GrantedAudience, consent.ID)
if err != nil {
return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject, err)
}
return nil
}
// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint.
func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil {
return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err)
}
return nil
}
// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID.
func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) {
consent = &model.OAuth2ConsentSession{}
if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil {
return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err)
}
return consent, nil
}
// LoadOAuth2ConsentSessionsPreConfigured returns an OAuth2.0 consents that are pre-configured given the consent signature.
func (p *SQLProvider) LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error) {
var r *sqlx.Rows
if r, err = p.db.QueryxContext(ctx, p.sqlSelectOAuth2ConsentSessionsPreConfigured, clientID, subject); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return &ConsentSessionRows{}, nil
}
return &ConsentSessionRows{}, fmt.Errorf("error selecting oauth2 consent session by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err)
}
return &ConsentSessionRows{rows: r}, nil
}
// SaveOAuth2Session saves a OAuth2Session to the database.
func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) {
var query string
switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlInsertOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlInsertOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlInsertOAuth2RefreshTokenSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlInsertOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlInsertOAuth2OpenIDConnectSession
default:
return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType)
}
if session.Session, err = p.encrypt(session.Session); err != nil {
return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
}
_, err = p.db.ExecContext(ctx, query,
session.ChallengeID, session.RequestID, session.ClientID, session.Signature,
session.Subject, session.RequestedAt, session.RequestedScopes, session.GrantedScopes,
session.RequestedAudience, session.GrantedAudience,
session.Active, session.Revoked, session.Form, session.Session)
if err != nil {
return fmt.Errorf("error inserting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
}
return nil
}
// RevokeOAuth2Session marks a OAuth2Session as revoked in the database.
func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
var query string
switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlRevokeOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlRevokeOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlRevokeOAuth2RefreshTokenSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlRevokeOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlRevokeOAuth2OpenIDConnectSession
default:
return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err)
}
return nil
}
// RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database.
func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
var query string
switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
case OAuth2SessionTypeAccessToken:
query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID
case OAuth2SessionTypeRefreshToken:
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
case OAuth2SessionTypePKCEChallenge:
query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID
default:
return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
}
return nil
}
// DeactivateOAuth2Session marks a OAuth2Session as inactive in the database.
func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
var query string
switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlDeactivateOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlDeactivateOAuth2RefreshTokenSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlDeactivateOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlDeactivateOAuth2OpenIDConnectSession
default:
return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err)
}
return nil
}
// DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database.
func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
var query string
switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID
case OAuth2SessionTypeRefreshToken:
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
case OAuth2SessionTypePKCEChallenge:
query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID
default:
return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
}
if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
return fmt.Errorf("error deactivating oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
}
return nil
}
// LoadOAuth2Session saves a OAuth2Session from the database.
func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) {
var query string
switch sessionType {
case OAuth2SessionTypeAuthorizeCode:
query = p.sqlSelectOAuth2AuthorizeCodeSession
case OAuth2SessionTypeAccessToken:
query = p.sqlSelectOAuth2AccessTokenSession
case OAuth2SessionTypeRefreshToken:
query = p.sqlSelectOAuth2RefreshTokenSession
case OAuth2SessionTypePKCEChallenge:
query = p.sqlSelectOAuth2PKCERequestSession
case OAuth2SessionTypeOpenIDConnect:
query = p.sqlSelectOAuth2OpenIDConnectSession
default:
return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType)
}
session = &model.OAuth2Session{}
if err = p.db.GetContext(ctx, session, query, signature); err != nil {
return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err)
}
if session.Session, err = p.decrypt(session.Session); err != nil {
return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject, session.RequestID, err)
}
return session, nil
}
// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
}
return nil
}
// LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database.
func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) {
blacklistedJTI = &model.OAuth2BlacklistedJTI{}
if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil {
return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
}
return blacklistedJTI, nil
}
// SavePreferred2FAMethod save the preferred method for 2FA to the database. // SavePreferred2FAMethod save the preferred method for 2FA to the database.
func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) { func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil { if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil {

View File

@ -26,19 +26,27 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
// Specific alterations to this provider. // Specific alterations to this provider.
// PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead. // PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead.
provider.sqlUpsertWebauthnDevice = fmt.Sprintf(queryFmtPostgresUpsertWebauthnDevice, tableWebauthnDevices) provider.sqlUpsertWebauthnDevice = fmt.Sprintf(queryFmtUpsertWebauthnDevicePostgreSQL, tableWebauthnDevices)
provider.sqlUpsertDuoDevice = fmt.Sprintf(queryFmtPostgresUpsertDuoDevice, tableDuoDevices) provider.sqlUpsertDuoDevice = fmt.Sprintf(queryFmtUpsertDuoDevicePostgreSQL, tableDuoDevices)
provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations) provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtUpsertTOTPConfigurationPostgreSQL, tableTOTPConfigurations)
provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences) provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtUpsertPreferred2FAMethodPostgreSQL, tableUserPreferences)
provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtPostgresUpsertEncryptionValue, tableEncryption) provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtUpsertEncryptionValuePostgreSQL, tableEncryption)
provider.sqlUpsertOAuth2BlacklistedJTI = fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTIPostgreSQL, tableOAuth2BlacklistedJTI)
// PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders. // PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders.
provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable) provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable)
provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod) provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod)
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo) provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
provider.sqlInsertUserOpaqueIdentifier = provider.db.Rebind(provider.sqlInsertUserOpaqueIdentifier)
provider.sqlSelectUserOpaqueIdentifier = provider.db.Rebind(provider.sqlSelectUserOpaqueIdentifier)
provider.sqlSelectUserOpaqueIdentifierBySignature = provider.db.Rebind(provider.sqlSelectUserOpaqueIdentifierBySignature)
provider.sqlSelectIdentityVerification = provider.db.Rebind(provider.sqlSelectIdentityVerification) provider.sqlSelectIdentityVerification = provider.db.Rebind(provider.sqlSelectIdentityVerification)
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification) provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification) provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig) provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn) provider.sqlUpdateTOTPConfigRecordSignIn = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignIn)
provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername) provider.sqlUpdateTOTPConfigRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigRecordSignInByUsername)
@ -46,21 +54,69 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs) provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret) provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername) provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
provider.sqlSelectWebauthnDevices = provider.db.Rebind(provider.sqlSelectWebauthnDevices) provider.sqlSelectWebauthnDevices = provider.db.Rebind(provider.sqlSelectWebauthnDevices)
provider.sqlSelectWebauthnDevicesByUsername = provider.db.Rebind(provider.sqlSelectWebauthnDevicesByUsername) provider.sqlSelectWebauthnDevicesByUsername = provider.db.Rebind(provider.sqlSelectWebauthnDevicesByUsername)
provider.sqlUpdateWebauthnDevicePublicKey = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKey) provider.sqlUpdateWebauthnDevicePublicKey = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKey)
provider.sqlUpdateWebauthnDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKeyByUsername) provider.sqlUpdateWebauthnDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDevicePublicKeyByUsername)
provider.sqlUpdateWebauthnDeviceRecordSignIn = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignIn) provider.sqlUpdateWebauthnDeviceRecordSignIn = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignIn)
provider.sqlUpdateWebauthnDeviceRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignInByUsername) provider.sqlUpdateWebauthnDeviceRecordSignInByUsername = provider.db.Rebind(provider.sqlUpdateWebauthnDeviceRecordSignInByUsername)
provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice) provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice)
provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice) provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice)
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt) provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername) provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername)
provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration) provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration)
provider.sqlSelectMigrations = provider.db.Rebind(provider.sqlSelectMigrations) provider.sqlSelectMigrations = provider.db.Rebind(provider.sqlSelectMigrations)
provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration) provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration)
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue) provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
provider.sqlInsertOAuth2ConsentSession = provider.db.Rebind(provider.sqlInsertOAuth2ConsentSession)
provider.sqlUpdateOAuth2ConsentSessionResponse = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionResponse)
provider.sqlUpdateOAuth2ConsentSessionGranted = provider.db.Rebind(provider.sqlUpdateOAuth2ConsentSessionGranted)
provider.sqlSelectOAuth2ConsentSessionByChallengeID = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionByChallengeID)
provider.sqlSelectOAuth2ConsentSessionsPreConfigured = provider.db.Rebind(provider.sqlSelectOAuth2ConsentSessionsPreConfigured)
provider.sqlInsertOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlInsertOAuth2AuthorizeCodeSession)
provider.sqlRevokeOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSession)
provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID)
provider.sqlDeactivateOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSession)
provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID)
provider.sqlSelectOAuth2AuthorizeCodeSession = provider.db.Rebind(provider.sqlSelectOAuth2AuthorizeCodeSession)
provider.sqlInsertOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2AccessTokenSession)
provider.sqlRevokeOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSession)
provider.sqlRevokeOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2AccessTokenSessionByRequestID)
provider.sqlDeactivateOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSession)
provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2AccessTokenSessionByRequestID)
provider.sqlSelectOAuth2AccessTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2AccessTokenSession)
provider.sqlInsertOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlInsertOAuth2RefreshTokenSession)
provider.sqlRevokeOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSession)
provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2RefreshTokenSessionByRequestID)
provider.sqlDeactivateOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSession)
provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2RefreshTokenSessionByRequestID)
provider.sqlSelectOAuth2RefreshTokenSession = provider.db.Rebind(provider.sqlSelectOAuth2RefreshTokenSession)
provider.sqlInsertOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlInsertOAuth2PKCERequestSession)
provider.sqlRevokeOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSession)
provider.sqlRevokeOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2PKCERequestSessionByRequestID)
provider.sqlDeactivateOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSession)
provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2PKCERequestSessionByRequestID)
provider.sqlSelectOAuth2PKCERequestSession = provider.db.Rebind(provider.sqlSelectOAuth2PKCERequestSession)
provider.sqlInsertOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlInsertOAuth2OpenIDConnectSession)
provider.sqlRevokeOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSession)
provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlRevokeOAuth2OpenIDConnectSessionByRequestID)
provider.sqlDeactivateOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSession)
provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID = provider.db.Rebind(provider.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID)
provider.sqlSelectOAuth2OpenIDConnectSession = provider.db.Rebind(provider.sqlSelectOAuth2OpenIDConnectSession)
provider.sqlSelectOAuth2BlacklistedJTI = provider.db.Rebind(provider.sqlSelectOAuth2BlacklistedJTI)
provider.schema = config.Storage.PostgreSQL.Schema provider.schema = config.Storage.PostgreSQL.Schema
return provider return provider

View File

@ -48,7 +48,7 @@ const (
REPLACE INTO %s (username, second_factor_method) REPLACE INTO %s (username, second_factor_method)
VALUES (?, ?);` VALUES (?, ?);`
queryFmtPostgresUpsertPreferred2FAMethod = ` queryFmtUpsertPreferred2FAMethodPostgreSQL = `
INSERT INTO %s (username, second_factor_method) INSERT INTO %s (username, second_factor_method)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (username) ON CONFLICT (username)
@ -99,7 +99,7 @@ const (
REPLACE INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret) REPLACE INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);` VALUES (?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtPostgresUpsertTOTPConfiguration = ` queryFmtUpsertTOTPConfigurationPostgreSQL = `
INSERT INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret) INSERT INTO %s (created_at, last_used_at, username, issuer, algorithm, digits, period, secret)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (username) ON CONFLICT (username)
@ -160,7 +160,7 @@ const (
REPLACE INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning) REPLACE INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtPostgresUpsertWebauthnDevice = ` queryFmtUpsertWebauthnDevicePostgreSQL = `
INSERT INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning) INSERT INTO %s (created_at, last_used_at, rpid, username, description, kid, public_key, attestation_type, transport, aaguid, sign_count, clone_warning)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (username, description) ON CONFLICT (username, description)
@ -172,7 +172,7 @@ const (
REPLACE INTO %s (username, device, method) REPLACE INTO %s (username, device, method)
VALUES (?, ?, ?);` VALUES (?, ?, ?);`
queryFmtPostgresUpsertDuoDevice = ` queryFmtUpsertDuoDevicePostgreSQL = `
INSERT INTO %s (username, device, method) INSERT INTO %s (username, device, method)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (username) ON CONFLICT (username)
@ -214,9 +214,103 @@ const (
REPLACE INTO %s (name, value) REPLACE INTO %s (name, value)
VALUES (?, ?);` VALUES (?, ?);`
queryFmtPostgresUpsertEncryptionValue = ` queryFmtUpsertEncryptionValuePostgreSQL = `
INSERT INTO %s (name, value) INSERT INTO %s (name, value)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (name) ON CONFLICT (name)
DO UPDATE SET value = $2;` DO UPDATE SET value = $2;`
) )
const (
queryFmtSelectOAuth2ConsentSessionByChallengeID = `
SELECT id, challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at,
form_data, requested_scopes, granted_scopes, requested_audience, granted_audience
FROM %s
WHERE challenge_id = ?;`
queryFmtSelectOAuth2ConsentSessionsPreConfigured = `
SELECT id, challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at,
form_data, requested_scopes, granted_scopes, requested_audience, granted_audience
FROM %s
WHERE client_id = ? AND subject = ? AND
authorized = TRUE AND granted = TRUE AND expires_at IS NOT NULL AND expires_at >= CURRENT_TIMESTAMP;`
queryFmtInsertOAuth2ConsentSession = `
INSERT INTO %s (challenge_id, client_id, subject, authorized, granted, requested_at, responded_at, expires_at,
form_data, requested_scopes, granted_scopes, requested_audience, granted_audience)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtUpdateOAuth2ConsentSessionResponse = `
UPDATE %s
SET authorized = ?, responded_at = CURRENT_TIMESTAMP, expires_at = ?, granted_scopes = ?, granted_audience = ?
WHERE id = ? AND responded_at IS NULL;`
queryFmtUpdateOAuth2ConsentSessionGranted = `
UPDATE %s
SET granted = TRUE
WHERE id = ? AND responded_at IS NOT NULL;`
queryFmtSelectOAuth2Session = `
SELECT id, challenge_id, request_id, client_id, signature, subject, requested_at,
requested_scopes, granted_scopes, requested_audience, granted_audience,
active, revoked, form_data, session_data
FROM %s
WHERE signature = ? AND revoked = FALSE;`
queryFmtInsertOAuth2Session = `
INSERT INTO %s (challenge_id, request_id, client_id, signature, subject, requested_at,
requested_scopes, granted_scopes, requested_audience, granted_audience,
active, revoked, form_data, session_data)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);`
queryFmtRevokeOAuth2Session = `
UPDATE %s
SET revoked = TRUE
WHERE signature = ?;`
queryFmtRevokeOAuth2SessionByRequestID = `
UPDATE %s
SET revoked = TRUE
WHERE request_id = ?;`
queryFmtDeactivateOAuth2Session = `
UPDATE %s
SET active = FALSE
WHERE signature = ?;`
queryFmtDeactivateOAuth2SessionByRequestID = `
UPDATE %s
SET active = FALSE
WHERE request_id = ?;"`
queryFmtSelectOAuth2BlacklistedJTI = `
SELECT id, signature, expires_at
FROM %s
WHERE signature = ?;`
queryFmtUpsertOAuth2BlacklistedJTI = `
REPLACE INTO %s (signature, expires_at)
VALUES(?, ?);`
queryFmtUpsertOAuth2BlacklistedJTIPostgreSQL = `
INSERT INTO %s (signature, expires_at)
VALUES ($1, $2)
ON CONFLICT (signature)
DO UPDATE SET expires_at = $2;`
)
const (
queryFmtInsertUserOpaqueIdentifier = `
INSERT INTO %s (service, sector_id, username, identifier)
VALUES(?, ?, ?, ?);`
queryFmtSelectUserOpaqueIdentifier = `
SELECT id, sector_id, username, identifier
FROM %s
WHERE identifier = ?;`
queryFmtSelectUserOpaqueIdentifierBySignature = `
SELECT id, service, sector_id, username, identifier
FROM %s
WHERE service = ? AND sector_id = ? AND username = ?;`
)

View File

@ -0,0 +1,47 @@
package storage
import (
"database/sql"
"github.com/jmoiron/sqlx"
"github.com/authelia/authelia/v4/internal/model"
)
// ConsentSessionRows holds and assists with retrieving multiple model.OAuth2ConsentSession rows.
type ConsentSessionRows struct {
rows *sqlx.Rows
}
// Next is the row iterator.
func (r *ConsentSessionRows) Next() bool {
if r.rows == nil {
return false
}
return r.rows.Next()
}
// Close the rows.
func (r *ConsentSessionRows) Close() (err error) {
if r.rows == nil {
return nil
}
return r.rows.Close()
}
// Get returns the *model.OAuth2ConsentSession or scan error.
func (r *ConsentSessionRows) Get() (consent *model.OAuth2ConsentSession, err error) {
if r.rows == nil {
return nil, sql.ErrNoRows
}
consent = &model.OAuth2ConsentSession{}
if err = r.rows.StructScan(consent); err != nil {
return nil, err
}
return consent, nil
}

View File

@ -58,6 +58,7 @@ notifier:
identity_providers: identity_providers:
oidc: oidc:
enable_client_debug_messages: true
hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm
issuer_private_key: | issuer_private_key: |
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----

View File

@ -60,6 +60,7 @@ notifier:
identity_providers: identity_providers:
oidc: oidc:
enable_client_debug_messages: true
hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm hmac_secret: IVPWBkAdJHje3uz7LtFTDU2pFUfh39Xm
issuer_private_key: | issuer_private_key: |
-----BEGIN RSA PRIVATE KEY----- -----BEGIN RSA PRIVATE KEY-----

View File

@ -9,26 +9,26 @@ import (
) )
func (rs *RodSession) doChangeMethod(t *testing.T, page *rod.Page, method string) { func (rs *RodSession) doChangeMethod(t *testing.T, page *rod.Page, method string) {
err := rs.WaitElementLocatedByCSSSelector(t, page, "methods-button").Click("left") err := rs.WaitElementLocatedByID(t, page, "methods-button").Click("left")
require.NoError(t, err) require.NoError(t, err)
rs.WaitElementLocatedByCSSSelector(t, page, "methods-dialog") rs.WaitElementLocatedByID(t, page, "methods-dialog")
err = rs.WaitElementLocatedByCSSSelector(t, page, fmt.Sprintf("%s-option", method)).Click("left") err = rs.WaitElementLocatedByID(t, page, fmt.Sprintf("%s-option", method)).Click("left")
require.NoError(t, err) require.NoError(t, err)
} }
func (rs *RodSession) doChangeDevice(t *testing.T, page *rod.Page, deviceID string) { func (rs *RodSession) doChangeDevice(t *testing.T, page *rod.Page, deviceID string) {
err := rs.WaitElementLocatedByCSSSelector(t, page, "selection-link").Click("left") err := rs.WaitElementLocatedByID(t, page, "selection-link").Click("left")
require.NoError(t, err) require.NoError(t, err)
rs.doSelectDevice(t, page, deviceID) rs.doSelectDevice(t, page, deviceID)
} }
func (rs *RodSession) doSelectDevice(t *testing.T, page *rod.Page, deviceID string) { func (rs *RodSession) doSelectDevice(t *testing.T, page *rod.Page, deviceID string) {
rs.WaitElementLocatedByCSSSelector(t, page, "device-selection") rs.WaitElementLocatedByID(t, page, "device-selection")
err := rs.WaitElementLocatedByCSSSelector(t, page, fmt.Sprintf("device-%s", deviceID)).Click("left") err := rs.WaitElementLocatedByID(t, page, fmt.Sprintf("device-%s", deviceID)).Click("left")
require.NoError(t, err) require.NoError(t, err)
} }
func (rs *RodSession) doClickButton(t *testing.T, page *rod.Page, buttonID string) { func (rs *RodSession) doClickButton(t *testing.T, page *rod.Page, buttonID string) {
err := rs.WaitElementLocatedByCSSSelector(t, page, buttonID).Click("left") err := rs.WaitElementLocatedByID(t, page, buttonID).Click("left")
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -9,21 +9,21 @@ import (
) )
func (rs *RodSession) doFillLoginPageAndClick(t *testing.T, page *rod.Page, username, password string, keepMeLoggedIn bool) { func (rs *RodSession) doFillLoginPageAndClick(t *testing.T, page *rod.Page, username, password string, keepMeLoggedIn bool) {
usernameElement := rs.WaitElementLocatedByCSSSelector(t, page, "username-textfield") usernameElement := rs.WaitElementLocatedByID(t, page, "username-textfield")
err := usernameElement.Input(username) err := usernameElement.Input(username)
require.NoError(t, err) require.NoError(t, err)
passwordElement := rs.WaitElementLocatedByCSSSelector(t, page, "password-textfield") passwordElement := rs.WaitElementLocatedByID(t, page, "password-textfield")
err = passwordElement.Input(password) err = passwordElement.Input(password)
require.NoError(t, err) require.NoError(t, err)
if keepMeLoggedIn { if keepMeLoggedIn {
keepMeLoggedInElement := rs.WaitElementLocatedByCSSSelector(t, page, "remember-checkbox") keepMeLoggedInElement := rs.WaitElementLocatedByID(t, page, "remember-checkbox")
err = keepMeLoggedInElement.Click("left") err = keepMeLoggedInElement.Click("left")
require.NoError(t, err) require.NoError(t, err)
} }
buttonElement := rs.WaitElementLocatedByCSSSelector(t, page, "sign-in-button") buttonElement := rs.WaitElementLocatedByID(t, page, "sign-in-button")
err = buttonElement.Click("left") err = buttonElement.Click("left")
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -9,13 +9,13 @@ import (
) )
func (rs *RodSession) doInitiatePasswordReset(t *testing.T, page *rod.Page, username string) { func (rs *RodSession) doInitiatePasswordReset(t *testing.T, page *rod.Page, username string) {
err := rs.WaitElementLocatedByCSSSelector(t, page, "reset-password-button").Click("left") err := rs.WaitElementLocatedByID(t, page, "reset-password-button").Click("left")
require.NoError(t, err) require.NoError(t, err)
// Fill in username. // Fill in username.
err = rs.WaitElementLocatedByCSSSelector(t, page, "username-textfield").Input(username) err = rs.WaitElementLocatedByID(t, page, "username-textfield").Input(username)
require.NoError(t, err) require.NoError(t, err)
// And click on the reset button. // And click on the reset button.
err = rs.WaitElementLocatedByCSSSelector(t, page, "reset-button").Click("left") err = rs.WaitElementLocatedByID(t, page, "reset-button").Click("left")
require.NoError(t, err) require.NoError(t, err)
} }
@ -25,15 +25,15 @@ func (rs *RodSession) doCompletePasswordReset(t *testing.T, page *rod.Page, newP
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
err := rs.WaitElementLocatedByCSSSelector(t, page, "password1-textfield").Input(newPassword1) err := rs.WaitElementLocatedByID(t, page, "password1-textfield").Input(newPassword1)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
err = rs.WaitElementLocatedByCSSSelector(t, page, "password2-textfield").Input(newPassword2) err = rs.WaitElementLocatedByID(t, page, "password2-textfield").Input(newPassword2)
require.NoError(t, err) require.NoError(t, err)
err = rs.WaitElementLocatedByCSSSelector(t, page, "reset-button").Click("left") err = rs.WaitElementLocatedByID(t, page, "reset-button").Click("left")
require.NoError(t, err) require.NoError(t, err)
} }

View File

@ -12,7 +12,7 @@ import (
) )
func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string { func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string {
err := rs.WaitElementLocatedByCSSSelector(t, page, "register-link").Click("left") err := rs.WaitElementLocatedByID(t, page, "register-link").Click("left")
require.NoError(t, err) require.NoError(t, err)
rs.verifyMailNotificationDisplayed(t, page) rs.verifyMailNotificationDisplayed(t, page)
link := doGetLinkFromLastMail(t) link := doGetLinkFromLastMail(t)
@ -28,7 +28,7 @@ func (rs *RodSession) doRegisterTOTP(t *testing.T, page *rod.Page) string {
} }
func (rs *RodSession) doEnterOTP(t *testing.T, page *rod.Page, code string) { func (rs *RodSession) doEnterOTP(t *testing.T, page *rod.Page, code string) {
inputs := rs.WaitElementsLocatedByCSSSelector(t, page, "otp-input input") inputs := rs.WaitElementsLocatedByID(t, page, "otp-input input")
for i := 0; i < len(code); i++ { for i := 0; i < len(code); i++ {
_ = inputs[i].Input(string(code[i])) _ = inputs[i].Input(string(code[i]))

View File

@ -2,7 +2,7 @@
version: '3' version: '3'
services: services:
oidc-client: oidc-client:
image: ghcr.io/authelia/oidc-tester-app:master-89622a8 image: ghcr.io/authelia/oidc-tester-app:master-01ff268
command: /entrypoint.sh command: /entrypoint.sh
depends_on: depends_on:
- authelia-backend - authelia-backend

View File

@ -2,6 +2,6 @@
while true; while true;
do do
oidc-tester-app --issuer https://login.example.com:8080 --id oidc-tester-app --secret foobar --scopes openid,profile,email --public-url https://oidc.example.com:8080 oidc-tester-app --issuer=https://login.example.com:8080 --id=oidc-tester-app --secret=foobar --scopes=openid,profile,email,groups --public-url=https://oidc.example.com:8080
sleep 5 sleep 5
done done

View File

@ -58,11 +58,11 @@ func (s *AvailableMethodsScenario) TestShouldCheckAvailableMethods() {
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
methodsButton := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "methods-button") methodsButton := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "methods-button")
err := methodsButton.Click("left") err := methodsButton.Click("left")
s.Assert().NoError(err) s.Assert().NoError(err)
methodsDialog := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "methods-dialog") methodsDialog := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "methods-dialog")
options, err := methodsDialog.Elements(".method-option") options, err := methodsDialog.Elements(".method-option")
s.Assert().NoError(err) s.Assert().NoError(err)
s.Assert().Len(options, len(s.methods)) s.Assert().Len(options, len(s.methods))

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"regexp"
"testing" "testing"
"time" "time"
@ -89,11 +90,50 @@ func (s *OIDCScenario) TestShouldAuthorizeAccessToOIDCApp() {
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
s.verifyIsConsentPage(s.T(), s.Context(ctx)) s.verifyIsConsentPage(s.T(), s.Context(ctx))
err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "accept-button").Click("left") err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "accept-button").Click("left")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
// Verify that the app is showing the info related to the user stored in the JWT token. // Verify that the app is showing the info related to the user stored in the JWT token.
s.waitBodyContains(s.T(), s.Context(ctx), "Logged in as john!")
rUUID := regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
rInteger := regexp.MustCompile(`^\d+$`)
rBoolean := regexp.MustCompile(`^(true|false)$`)
rBase64 := regexp.MustCompile(`^[-_A-Za-z0-9+\\/]+([=]{0,3})$`)
testCases := []struct {
desc, elementID, elementText string
pattern *regexp.Regexp
}{
{"welcome", "welcome", "Logged in as john!", nil},
{"at_hash", "claim-at_hash", "", rBase64},
{"jti", "claim-jti", "", rUUID},
{"iat", "claim-iat", "", rInteger},
{"nbf", "claim-nbf", "", rInteger},
{"rat", "claim-rat", "", rInteger},
{"expires", "claim-exp", "", rInteger},
{"amr", "claim-amr", "pwd, otp, mfa", nil},
{"acr", "claim-acr", "", nil},
{"issuer", "claim-iss", "https://login.example.com:8080", nil},
{"name", "claim-name", "John Doe", nil},
{"preferred_username", "claim-preferred_username", "john", nil},
{"groups", "claim-groups", "admins, dev", nil},
{"email", "claim-email", "john.doe@authelia.com", nil},
{"email_verified", "claim-email_verified", "", rBoolean},
}
var text string
for _, tc := range testCases {
s.T().Run(fmt.Sprintf("check_claims/%s", tc.desc), func(t *testing.T) {
text, err = s.WaitElementLocatedByID(t, s.Context(ctx), tc.elementID).Text()
assert.NoError(t, err)
if tc.pattern == nil {
assert.Equal(t, tc.elementText, text)
} else {
assert.Regexp(t, tc.pattern, text)
}
})
}
} }
func (s *OIDCScenario) TestShouldDenyConsent() { func (s *OIDCScenario) TestShouldDenyConsent() {
@ -117,10 +157,17 @@ func (s *OIDCScenario) TestShouldDenyConsent() {
s.verifyIsConsentPage(s.T(), s.Context(ctx)) s.verifyIsConsentPage(s.T(), s.Context(ctx))
err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "deny-button").Click("left") err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "deny-button").Click("left")
assert.NoError(s.T(), err) assert.NoError(s.T(), err)
s.verifyIsOIDC(s.T(), s.Context(ctx), "oauth2:", "https://oidc.example.com:8080/oauth2/callback?error=access_denied&error_description=User%20has%20rejected%20the%20scopes") s.verifyIsOIDC(s.T(), s.Context(ctx), "access_denied", "https://oidc.example.com:8080/error?error=access_denied&error_description=The+resource+owner+or+authorization+server+denied+the+request.+Make+sure+that+the+request+you+are+making+is+valid.+Maybe+the+credential+or+request+parameters+you+are+using+are+limited+in+scope+or+otherwise+restricted.&state=random-string-here")
errorDescription := "The resource owner or authorization server denied the request. Make sure that the request " +
"you are making is valid. Maybe the credential or request parameters you are using are limited in scope or " +
"otherwise restricted."
s.verifyIsOIDCErrorPage(s.T(), s.Context(ctx), "access_denied", errorDescription, "",
"random-string-here")
} }
func TestRunOIDCScenario(t *testing.T) { func TestRunOIDCScenario(t *testing.T) {

View File

@ -60,16 +60,16 @@ func (s *RegulationScenario) TestShouldBanUserAfterTooManyAttempt() {
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.") s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.")
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("bad-password") err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("bad-password")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left") err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left")
require.NoError(s.T(), err) require.NoError(s.T(), err)
} }
// Enter the correct password and test the regulation lock out. // Enter the correct password and test the regulation lock out.
err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("password") err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("password")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left") err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left")
require.NoError(s.T(), err) require.NoError(s.T(), err)
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.") s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Incorrect username or password.")
@ -77,9 +77,9 @@ func (s *RegulationScenario) TestShouldBanUserAfterTooManyAttempt() {
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
// Enter the correct password and test a successful login. // Enter the correct password and test a successful login.
err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "password-textfield").Input("password") err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "password-textfield").Input("password")
require.NoError(s.T(), err) require.NoError(s.T(), err)
err = s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "sign-in-button").Click("left") err = s.WaitElementLocatedByID(s.T(), s.Context(ctx), "sign-in-button").Click("left")
require.NoError(s.T(), err) require.NoError(s.T(), err)
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
} }

View File

@ -60,7 +60,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
// Then switch to push notification method. // Then switch to push notification method.
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Switch context to clean up state in portal. // Switch context to clean up state in portal.
s.doVisit(s.T(), s.Context(ctx), HomeBaseURL) s.doVisit(s.T(), s.Context(ctx), HomeBaseURL)
@ -70,7 +70,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
s.doVisit(s.T(), s.Context(ctx), GetLoginBaseURL()) s.doVisit(s.T(), s.Context(ctx), GetLoginBaseURL())
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
// And check the latest method is still used. // And check the latest method is still used.
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Meaning the authentication is successful. // Meaning the authentication is successful.
s.verifyIsHome(s.T(), s.Context(ctx)) s.verifyIsHome(s.T(), s.Context(ctx))
@ -78,7 +78,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
s.doLogout(s.T(), s.Context(ctx)) s.doLogout(s.T(), s.Context(ctx))
s.doLoginOneFactor(s.T(), s.Context(ctx), "harry", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "harry", "password", false, "")
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "one-time-password-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "one-time-password-method")
s.doLogout(s.T(), s.Context(ctx)) s.doLogout(s.T(), s.Context(ctx))
s.verifyIsFirstFactorPage(s.T(), s.Context(ctx)) s.verifyIsFirstFactorPage(s.T(), s.Context(ctx))
@ -86,7 +86,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
// Then log back as previous user and verify the push notification is still the default method. // Then log back as previous user and verify the push notification is still the default method.
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
s.verifyIsSecondFactorPage(s.T(), s.Context(ctx)) s.verifyIsSecondFactorPage(s.T(), s.Context(ctx))
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
s.verifyIsHome(s.T(), s.Context(ctx)) s.verifyIsHome(s.T(), s.Context(ctx))
s.doLogout(s.T(), s.Context(ctx)) s.doLogout(s.T(), s.Context(ctx))
@ -94,7 +94,7 @@ func (s *UserPreferencesScenario) TestShouldRememberLastUsed2FAMethod() {
// Eventually restore the default method. // Eventually restore the default method.
s.doChangeMethod(s.T(), s.Context(ctx), "one-time-password") s.doChangeMethod(s.T(), s.Context(ctx), "one-time-password")
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "one-time-password-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "one-time-password-method")
} }
func TestUserPreferencesScenario(t *testing.T) { func TestUserPreferencesScenario(t *testing.T) {

View File

@ -110,7 +110,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAskUserToRegister() {
s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "state-not-registered") s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "state-not-registered")
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "No compatible device found") s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "No compatible device found")
enrollPage := s.Page.MustWaitOpen() enrollPage := s.Page.MustWaitOpen()
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "register-link").MustClick() s.WaitElementLocatedByID(s.T(), s.Context(ctx), "register-link").MustClick()
s.Page = enrollPage() s.Page = enrollPage()
assert.Contains(s.T(), s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "description").MustText(), "This enrollment code has expired. Contact your administrator to get a new enrollment code.") assert.Contains(s.T(), s.WaitElementLocatedByClassName(s.T(), s.Context(ctx), "description").MustText(), "This enrollment code has expired. Contact your administrator to get a new enrollment code.")
@ -142,7 +142,7 @@ func (s *DuoPushWebDriverSuite) TestShouldAutoSelectDevice() {
s.doLogout(s.T(), s.Context(ctx)) s.doLogout(s.T(), s.Context(ctx))
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
// And check the latest method and device is still used. // And check the latest method and device is still used.
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Meaning the authentication is successful. // Meaning the authentication is successful.
s.verifyIsHome(s.T(), s.Context(ctx)) s.verifyIsHome(s.T(), s.Context(ctx))
} }
@ -176,7 +176,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() {
// Switch Method where Device Selection should open automatically. // Switch Method where Device Selection should open automatically.
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
// Check for available Device 1. // Check for available Device 1.
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "device-12345ABCDEFGHIJ67890") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "device-12345ABCDEFGHIJ67890")
// Test Back button. // Test Back button.
s.doClickButton(s.T(), s.Context(ctx), "device-selection-back") s.doClickButton(s.T(), s.Context(ctx), "device-selection-back")
// then select Device 2 for further use and be redirected. // then select Device 2 for further use and be redirected.
@ -187,7 +187,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectDevice() {
s.doLogout(s.T(), s.Context(ctx)) s.doLogout(s.T(), s.Context(ctx))
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
// And check the latest method and device is still used. // And check the latest method and device is still used.
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "push-notification-method") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "push-notification-method")
// Meaning the authentication is successful. // Meaning the authentication is successful.
s.verifyIsHome(s.T(), s.Context(ctx)) s.verifyIsHome(s.T(), s.Context(ctx))
} }
@ -238,7 +238,7 @@ func (s *DuoPushWebDriverSuite) TestShouldSelectNewDeviceAfterSavedDeviceMethodI
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "device-selection") s.WaitElementLocatedByID(s.T(), s.Context(ctx), "device-selection")
s.doSelectDevice(s.T(), s.Context(ctx), "12345ABCDEFGHIJ67890") s.doSelectDevice(s.T(), s.Context(ctx), "12345ABCDEFGHIJ67890")
s.verifyIsHome(s.T(), s.Context(ctx)) s.verifyIsHome(s.T(), s.Context(ctx))
} }
@ -303,7 +303,7 @@ func (s *DuoPushWebDriverSuite) TestShouldFailSelectionBecauseOfSelectionDenied(
s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "") s.doLoginOneFactor(s.T(), s.Context(ctx), "john", "password", false, "")
s.doChangeMethod(s.T(), s.Context(ctx), "push-notification") s.doChangeMethod(s.T(), s.Context(ctx), "push-notification")
err := s.WaitElementLocatedByCSSSelector(s.T(), s.Context(ctx), "selection-link").Click("left") err := s.WaitElementLocatedByID(s.T(), s.Context(ctx), "selection-link").Click("left")
require.NoError(s.T(), err) require.NoError(s.T(), err)
s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Device selection was denied by Duo policy") s.verifyNotificationDisplayed(s.T(), s.Context(ctx), "Device selection was denied by Duo policy")
} }

View File

@ -7,5 +7,5 @@ import (
) )
func (rs *RodSession) verifyIsAuthenticatedPage(t *testing.T, page *rod.Page) { func (rs *RodSession) verifyIsAuthenticatedPage(t *testing.T, page *rod.Page) {
rs.WaitElementLocatedByCSSSelector(t, page, "authenticated-stage") rs.WaitElementLocatedByID(t, page, "authenticated-stage")
} }

View File

@ -7,5 +7,5 @@ import (
) )
func (rs *RodSession) verifyIsConsentPage(t *testing.T, page *rod.Page) { func (rs *RodSession) verifyIsConsentPage(t *testing.T, page *rod.Page) {
rs.WaitElementLocatedByCSSSelector(t, page, "consent-stage") rs.WaitElementLocatedByID(t, page, "consent-stage")
} }

View File

@ -7,5 +7,5 @@ import (
) )
func (rs *RodSession) verifyIsFirstFactorPage(t *testing.T, page *rod.Page) { func (rs *RodSession) verifyIsFirstFactorPage(t *testing.T, page *rod.Page) {
rs.WaitElementLocatedByCSSSelector(t, page, "first-factor-stage") rs.WaitElementLocatedByID(t, page, "first-factor-stage")
} }

View File

@ -4,9 +4,33 @@ import (
"testing" "testing"
"github.com/go-rod/rod" "github.com/go-rod/rod"
"github.com/stretchr/testify/assert"
) )
func (rs *RodSession) verifyIsOIDC(t *testing.T, page *rod.Page, pattern, url string) { func (rs *RodSession) verifyIsOIDC(t *testing.T, page *rod.Page, pattern, url string) {
page.MustElementR("body", pattern) page.MustElementR("body", pattern)
rs.verifyURLIs(t, page, url) rs.verifyURLIs(t, page, url)
} }
func (rs *RodSession) verifyIsOIDCErrorPage(t *testing.T, page *rod.Page, errorCode, errorDescription, errorURI, state string) {
testCases := []struct {
ElementID, ElementText string
}{
{"error", errorCode},
{"error_description", errorDescription},
{"error_uri", errorURI},
{"state", state},
}
for _, tc := range testCases {
t.Run(tc.ElementID, func(t *testing.T) {
if tc.ElementText == "" {
t.Skip("Test Skipped as the element is not expected.")
}
text, err := rs.WaitElementLocatedByID(t, page, tc.ElementID).Text()
assert.NoError(t, err)
assert.Equal(t, tc.ElementText, text)
})
}
}

View File

@ -7,5 +7,5 @@ import (
) )
func (rs *RodSession) verifyIsSecondFactorPage(t *testing.T, page *rod.Page) { func (rs *RodSession) verifyIsSecondFactorPage(t *testing.T, page *rod.Page) {
rs.WaitElementLocatedByCSSSelector(t, page, "second-factor-stage") rs.WaitElementLocatedByID(t, page, "second-factor-stage")
} }

View File

@ -7,5 +7,5 @@ import (
) )
func (rs *RodSession) verifySecretAuthorized(t *testing.T, page *rod.Page) { func (rs *RodSession) verifySecretAuthorized(t *testing.T, page *rod.Page) {
rs.WaitElementLocatedByCSSSelector(t, page, "secret") rs.WaitElementLocatedByID(t, page, "secret")
} }

View File

@ -82,8 +82,8 @@ func (rs *RodSession) WaitElementLocatedByClassName(t *testing.T, page *rod.Page
return e return e
} }
// WaitElementLocatedByCSSSelector wait an element is located by class name. // WaitElementLocatedByID waits for an element located by an id.
func (rs *RodSession) WaitElementLocatedByCSSSelector(t *testing.T, page *rod.Page, cssSelector string) *rod.Element { func (rs *RodSession) WaitElementLocatedByID(t *testing.T, page *rod.Page, cssSelector string) *rod.Element {
e, err := page.Element("#" + cssSelector) e, err := page.Element("#" + cssSelector)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, e) require.NotNil(t, e)
@ -91,8 +91,8 @@ func (rs *RodSession) WaitElementLocatedByCSSSelector(t *testing.T, page *rod.Pa
return e return e
} }
// WaitElementsLocatedByCSSSelector wait an element is located by CSS selector. // WaitElementsLocatedByID waits for an elements located by an id.
func (rs *RodSession) WaitElementsLocatedByCSSSelector(t *testing.T, page *rod.Page, cssSelector string) rod.Elements { func (rs *RodSession) WaitElementsLocatedByID(t *testing.T, page *rod.Page, cssSelector string) rod.Elements {
e, err := page.Elements("#" + cssSelector) e, err := page.Elements("#" + cssSelector)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, e) require.NotNil(t, e)

View File

@ -241,6 +241,39 @@ func StringHTMLEscape(input string) (output string) {
return htmlEscaper.Replace(input) return htmlEscaper.Replace(input)
} }
// StringJoinDelimitedEscaped joins a string with a specified rune delimiter after escaping any instance of that string
// in the string slice. Used with StringSplitDelimitedEscaped.
func StringJoinDelimitedEscaped(value []string, delimiter rune) string {
escaped := make([]string, len(value))
for k, v := range value {
escaped[k] = strings.ReplaceAll(v, string(delimiter), "\\"+string(delimiter))
}
return strings.Join(escaped, string(delimiter))
}
// StringSplitDelimitedEscaped splits a string with a specified rune delimiter after unescaping any instance of that
// string in the string slice that has been escaped. Used with StringJoinDelimitedEscaped.
func StringSplitDelimitedEscaped(value string, delimiter rune) (out []string) {
var escape bool
split := strings.FieldsFunc(value, func(r rune) bool {
if r == '\\' {
escape = !escape
} else if escape && r != delimiter {
escape = false
}
return !escape && r == delimiter
})
for k, v := range split {
split[k] = strings.ReplaceAll(v, "\\"+string(delimiter), string(delimiter))
}
return split
}
// JoinAndCanonicalizeHeaders join header strings by a given sep. // JoinAndCanonicalizeHeaders join header strings by a given sep.
func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) { func JoinAndCanonicalizeHeaders(sep []byte, headers ...string) (joined []byte) {
for i, header := range headers { for i, header := range headers {

View File

@ -8,6 +8,51 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestStringSplitDelimitedEscaped(t *testing.T) {
testCases := []struct {
desc, have string
delimiter rune
want []string
}{
{desc: "ShouldSplitNormalString", have: "abc,123,456", delimiter: ',', want: []string{"abc", "123", "456"}},
{desc: "ShouldSplitEscapedString", have: "a\\,bc,123,456", delimiter: ',', want: []string{"a,bc", "123", "456"}},
{desc: "ShouldSplitEscapedStringPipe", have: "a\\|bc|123|456", delimiter: '|', want: []string{"a|bc", "123", "456"}},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
actual := StringSplitDelimitedEscaped(tc.have, tc.delimiter)
assert.Equal(t, tc.want, actual)
})
}
}
func TestStringJoinDelimitedEscaped(t *testing.T) {
testCases := []struct {
desc, want string
delimiter rune
have []string
}{
{desc: "ShouldJoinNormalStringSlice", have: []string{"abc", "123", "456"}, delimiter: ',', want: "abc,123,456"},
{desc: "ShouldJoinEscapeNeededStringSlice", have: []string{"abc", "1,23", "456"}, delimiter: ',', want: "abc,1\\,23,456"},
{desc: "ShouldJoinEscapeNeededStringSlicePipe", have: []string{"abc", "1|23", "456"}, delimiter: '|', want: "abc|1\\|23|456"},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
actual := StringJoinDelimitedEscaped(tc.have, tc.delimiter)
assert.Equal(t, tc.want, actual)
// Ensure splitting again also works fine.
split := StringSplitDelimitedEscaped(actual, tc.delimiter)
assert.Equal(t, tc.have, split)
})
}
}
func TestShouldNotGenerateSameRandomString(t *testing.T) { func TestShouldNotGenerateSameRandomString(t *testing.T) {
randomStringOne := RandomString(10, AlphaNumericCharacters, false) randomStringOne := RandomString(10, AlphaNumericCharacters, false)
randomStringTwo := RandomString(10, AlphaNumericCharacters, false) randomStringTwo := RandomString(10, AlphaNumericCharacters, false)