mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
abf1c86ab9
Fix and issue that would prevent a correct ID Token from being generated for users who start off anonymous. This also avoids generating one in the first place for anonymous users.
957 lines
42 KiB
Go
957 lines
42 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
|
"github.com/authelia/authelia/v4/internal/logging"
|
|
"github.com/authelia/authelia/v4/internal/model"
|
|
)
|
|
|
|
// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's.
|
|
func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceName string) (provider SQLProvider) {
|
|
db, err := sqlx.Open(driverName, dataSourceName)
|
|
|
|
provider = SQLProvider{
|
|
db: db,
|
|
key: sha256.Sum256([]byte(config.Storage.EncryptionKey)),
|
|
name: name,
|
|
driverName: driverName,
|
|
config: config,
|
|
errOpen: err,
|
|
log: logging.Logger(),
|
|
|
|
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
|
|
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
|
|
|
|
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
|
|
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
|
|
sqlSelectIdentityVerification: fmt.Sprintf(queryFmtSelectIdentityVerification, tableIdentityVerification),
|
|
|
|
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
|
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
|
|
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
|
|
sqlSelectTOTPConfigs: fmt.Sprintf(queryFmtSelectTOTPConfigurations, tableTOTPConfigurations),
|
|
|
|
sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations),
|
|
sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations),
|
|
sqlUpdateTOTPConfigRecordSignIn: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignIn, tableTOTPConfigurations),
|
|
sqlUpdateTOTPConfigRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigRecordSignInByUsername, tableTOTPConfigurations),
|
|
|
|
sqlUpsertWebauthnDevice: fmt.Sprintf(queryFmtUpsertWebauthnDevice, tableWebauthnDevices),
|
|
sqlSelectWebauthnDevices: fmt.Sprintf(queryFmtSelectWebauthnDevices, tableWebauthnDevices),
|
|
sqlSelectWebauthnDevicesByUsername: fmt.Sprintf(queryFmtSelectWebauthnDevicesByUsername, tableWebauthnDevices),
|
|
|
|
sqlUpdateWebauthnDevicePublicKey: fmt.Sprintf(queryFmtUpdateWebauthnDevicePublicKey, tableWebauthnDevices),
|
|
sqlUpdateWebauthnDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateWebauthnDevicePublicKeyByUsername, tableWebauthnDevices),
|
|
sqlUpdateWebauthnDeviceRecordSignIn: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignIn, tableWebauthnDevices),
|
|
sqlUpdateWebauthnDeviceRecordSignInByUsername: fmt.Sprintf(queryFmtUpdateWebauthnDeviceRecordSignInByUsername, tableWebauthnDevices),
|
|
|
|
sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices),
|
|
sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices),
|
|
sqlSelectDuoDevice: fmt.Sprintf(queryFmtSelectDuoDevice, tableDuoDevices),
|
|
|
|
sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences),
|
|
sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
|
|
sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableWebauthnDevices, tableDuoDevices, tableUserPreferences),
|
|
|
|
sqlInsertUserOpaqueIdentifier: fmt.Sprintf(queryFmtInsertUserOpaqueIdentifier, tableUserOpaqueIdentifier),
|
|
sqlSelectUserOpaqueIdentifier: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifier, tableUserOpaqueIdentifier),
|
|
sqlSelectUserOpaqueIdentifiers: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifiers, tableUserOpaqueIdentifier),
|
|
sqlSelectUserOpaqueIdentifierBySignature: fmt.Sprintf(queryFmtSelectUserOpaqueIdentifierBySignature, tableUserOpaqueIdentifier),
|
|
|
|
sqlInsertOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
|
sqlSelectOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
|
sqlRevokeOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
|
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
|
|
sqlDeactivateOAuth2AuthorizeCodeSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AuthorizeCodeSession),
|
|
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AuthorizeCodeSession),
|
|
|
|
sqlInsertOAuth2AccessTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2AccessTokenSession),
|
|
sqlSelectOAuth2AccessTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2AccessTokenSession),
|
|
sqlRevokeOAuth2AccessTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2AccessTokenSession),
|
|
sqlRevokeOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
|
|
sqlDeactivateOAuth2AccessTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2AccessTokenSession),
|
|
sqlDeactivateOAuth2AccessTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2AccessTokenSession),
|
|
|
|
sqlInsertOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2RefreshTokenSession),
|
|
sqlSelectOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2RefreshTokenSession),
|
|
sqlRevokeOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2RefreshTokenSession),
|
|
sqlRevokeOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
|
|
sqlDeactivateOAuth2RefreshTokenSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2RefreshTokenSession),
|
|
sqlDeactivateOAuth2RefreshTokenSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2RefreshTokenSession),
|
|
|
|
sqlInsertOAuth2PKCERequestSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2PKCERequestSession),
|
|
sqlSelectOAuth2PKCERequestSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2PKCERequestSession),
|
|
sqlRevokeOAuth2PKCERequestSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2PKCERequestSession),
|
|
sqlRevokeOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
|
|
sqlDeactivateOAuth2PKCERequestSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2PKCERequestSession),
|
|
sqlDeactivateOAuth2PKCERequestSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2PKCERequestSession),
|
|
|
|
sqlInsertOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtInsertOAuth2Session, tableOAuth2OpenIDConnectSession),
|
|
sqlSelectOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtSelectOAuth2Session, tableOAuth2OpenIDConnectSession),
|
|
sqlRevokeOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtRevokeOAuth2Session, tableOAuth2OpenIDConnectSession),
|
|
sqlRevokeOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtRevokeOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
|
|
sqlDeactivateOAuth2OpenIDConnectSession: fmt.Sprintf(queryFmtDeactivateOAuth2Session, tableOAuth2OpenIDConnectSession),
|
|
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID: fmt.Sprintf(queryFmtDeactivateOAuth2SessionByRequestID, tableOAuth2OpenIDConnectSession),
|
|
|
|
sqlInsertOAuth2ConsentSession: fmt.Sprintf(queryFmtInsertOAuth2ConsentSession, tableOAuth2ConsentSession),
|
|
sqlUpdateOAuth2ConsentSessionSubject: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionSubject, tableOAuth2ConsentSession),
|
|
sqlUpdateOAuth2ConsentSessionResponse: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionResponse, tableOAuth2ConsentSession),
|
|
sqlUpdateOAuth2ConsentSessionGranted: fmt.Sprintf(queryFmtUpdateOAuth2ConsentSessionGranted, tableOAuth2ConsentSession),
|
|
sqlSelectOAuth2ConsentSessionByChallengeID: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionByChallengeID, tableOAuth2ConsentSession),
|
|
sqlSelectOAuth2ConsentSessionsPreConfigured: fmt.Sprintf(queryFmtSelectOAuth2ConsentSessionsPreConfigured, tableOAuth2ConsentSession),
|
|
|
|
sqlUpsertOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtUpsertOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
|
|
sqlSelectOAuth2BlacklistedJTI: fmt.Sprintf(queryFmtSelectOAuth2BlacklistedJTI, tableOAuth2BlacklistedJTI),
|
|
|
|
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
|
|
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
|
|
sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
|
|
|
|
sqlUpsertEncryptionValue: fmt.Sprintf(queryFmtUpsertEncryptionValue, tableEncryption),
|
|
sqlSelectEncryptionValue: fmt.Sprintf(queryFmtSelectEncryptionValue, tableEncryption),
|
|
|
|
sqlFmtRenameTable: queryFmtRenameTable,
|
|
}
|
|
|
|
return provider
|
|
}
|
|
|
|
// SQLProvider is a storage provider persisting data in a SQL database.
|
|
type SQLProvider struct {
|
|
db *sqlx.DB
|
|
key [32]byte
|
|
name string
|
|
driverName string
|
|
schema string
|
|
config *schema.Configuration
|
|
errOpen error
|
|
|
|
log *logrus.Logger
|
|
|
|
// Table: authentication_logs.
|
|
sqlInsertAuthenticationAttempt string
|
|
sqlSelectAuthenticationAttemptsByUsername string
|
|
|
|
// Table: identity_verification.
|
|
sqlInsertIdentityVerification string
|
|
sqlConsumeIdentityVerification string
|
|
sqlSelectIdentityVerification string
|
|
|
|
// Table: totp_configurations.
|
|
sqlUpsertTOTPConfig string
|
|
sqlDeleteTOTPConfig string
|
|
sqlSelectTOTPConfig string
|
|
sqlSelectTOTPConfigs string
|
|
|
|
sqlUpdateTOTPConfigSecret string
|
|
sqlUpdateTOTPConfigSecretByUsername string
|
|
sqlUpdateTOTPConfigRecordSignIn string
|
|
sqlUpdateTOTPConfigRecordSignInByUsername string
|
|
|
|
// Table: webauthn_devices.
|
|
sqlUpsertWebauthnDevice string
|
|
sqlSelectWebauthnDevices string
|
|
sqlSelectWebauthnDevicesByUsername string
|
|
|
|
sqlUpdateWebauthnDevicePublicKey string
|
|
sqlUpdateWebauthnDevicePublicKeyByUsername string
|
|
sqlUpdateWebauthnDeviceRecordSignIn string
|
|
sqlUpdateWebauthnDeviceRecordSignInByUsername string
|
|
|
|
// Table: duo_devices.
|
|
sqlUpsertDuoDevice string
|
|
sqlDeleteDuoDevice string
|
|
sqlSelectDuoDevice string
|
|
|
|
// Table: user_preferences.
|
|
sqlUpsertPreferred2FAMethod string
|
|
sqlSelectPreferred2FAMethod string
|
|
sqlSelectUserInfo string
|
|
|
|
// Table: user_opaque_identifier.
|
|
sqlInsertUserOpaqueIdentifier string
|
|
sqlSelectUserOpaqueIdentifier string
|
|
sqlSelectUserOpaqueIdentifiers string
|
|
sqlSelectUserOpaqueIdentifierBySignature string
|
|
|
|
// Table: migrations.
|
|
sqlInsertMigration string
|
|
sqlSelectMigrations string
|
|
sqlSelectLatestMigration string
|
|
|
|
// Table: encryption.
|
|
sqlUpsertEncryptionValue string
|
|
sqlSelectEncryptionValue string
|
|
|
|
// Table: oauth2_authorization_code_session.
|
|
sqlInsertOAuth2AuthorizeCodeSession string
|
|
sqlSelectOAuth2AuthorizeCodeSession string
|
|
sqlRevokeOAuth2AuthorizeCodeSession string
|
|
sqlRevokeOAuth2AuthorizeCodeSessionByRequestID string
|
|
sqlDeactivateOAuth2AuthorizeCodeSession string
|
|
sqlDeactivateOAuth2AuthorizeCodeSessionByRequestID string
|
|
|
|
// Table: oauth2_access_token_session.
|
|
sqlInsertOAuth2AccessTokenSession string
|
|
sqlSelectOAuth2AccessTokenSession string
|
|
sqlRevokeOAuth2AccessTokenSession string
|
|
sqlRevokeOAuth2AccessTokenSessionByRequestID string
|
|
sqlDeactivateOAuth2AccessTokenSession string
|
|
sqlDeactivateOAuth2AccessTokenSessionByRequestID string
|
|
|
|
// Table: oauth2_refresh_token_session.
|
|
sqlInsertOAuth2RefreshTokenSession string
|
|
sqlSelectOAuth2RefreshTokenSession string
|
|
sqlRevokeOAuth2RefreshTokenSession string
|
|
sqlRevokeOAuth2RefreshTokenSessionByRequestID string
|
|
sqlDeactivateOAuth2RefreshTokenSession string
|
|
sqlDeactivateOAuth2RefreshTokenSessionByRequestID string
|
|
|
|
// Table: oauth2_pkce_request_session.
|
|
sqlInsertOAuth2PKCERequestSession string
|
|
sqlSelectOAuth2PKCERequestSession string
|
|
sqlRevokeOAuth2PKCERequestSession string
|
|
sqlRevokeOAuth2PKCERequestSessionByRequestID string
|
|
sqlDeactivateOAuth2PKCERequestSession string
|
|
sqlDeactivateOAuth2PKCERequestSessionByRequestID string
|
|
|
|
// Table: oauth2_openid_connect_session.
|
|
sqlInsertOAuth2OpenIDConnectSession string
|
|
sqlSelectOAuth2OpenIDConnectSession string
|
|
sqlRevokeOAuth2OpenIDConnectSession string
|
|
sqlRevokeOAuth2OpenIDConnectSessionByRequestID string
|
|
sqlDeactivateOAuth2OpenIDConnectSession string
|
|
sqlDeactivateOAuth2OpenIDConnectSessionByRequestID string
|
|
|
|
// Table: oauth2_consent_session.
|
|
sqlInsertOAuth2ConsentSession string
|
|
sqlUpdateOAuth2ConsentSessionSubject string
|
|
sqlUpdateOAuth2ConsentSessionResponse string
|
|
sqlUpdateOAuth2ConsentSessionGranted string
|
|
sqlSelectOAuth2ConsentSessionByChallengeID string
|
|
sqlSelectOAuth2ConsentSessionsPreConfigured string
|
|
|
|
sqlUpsertOAuth2BlacklistedJTI string
|
|
sqlSelectOAuth2BlacklistedJTI string
|
|
|
|
// Utility.
|
|
sqlSelectExistingTables string
|
|
sqlFmtRenameTable string
|
|
}
|
|
|
|
// Close the underlying database connection.
|
|
func (p *SQLProvider) Close() (err error) {
|
|
return p.db.Close()
|
|
}
|
|
|
|
// StartupCheck implements the provider startup check interface.
|
|
func (p *SQLProvider) StartupCheck() (err error) {
|
|
if p.errOpen != nil {
|
|
return fmt.Errorf("error opening database: %w", p.errOpen)
|
|
}
|
|
|
|
// TODO: Decide if this is needed, or if it should be configurable.
|
|
for i := 0; i < 19; i++ {
|
|
if err = p.db.Ping(); err == nil {
|
|
break
|
|
}
|
|
|
|
time.Sleep(time.Millisecond * 500)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("error pinging database: %w", err)
|
|
}
|
|
|
|
p.log.Infof("Storage schema is being checked for updates")
|
|
|
|
ctx := context.Background()
|
|
|
|
if err = p.SchemaEncryptionCheckKey(ctx, false); err != nil && !errors.Is(err, ErrSchemaEncryptionVersionUnsupported) {
|
|
return err
|
|
}
|
|
|
|
err = p.SchemaMigrate(ctx, true, SchemaLatest)
|
|
|
|
switch err {
|
|
case ErrSchemaAlreadyUpToDate:
|
|
p.log.Infof("Storage schema is already up to date")
|
|
return nil
|
|
case nil:
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("error during schema migrate: %w", err)
|
|
}
|
|
}
|
|
|
|
// BeginTX begins a transaction.
|
|
func (p *SQLProvider) BeginTX(ctx context.Context) (c context.Context, err error) {
|
|
var tx *sql.Tx
|
|
|
|
if tx, err = p.db.Begin(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return context.WithValue(ctx, ctxKeyTransaction, tx), nil
|
|
}
|
|
|
|
// Commit performs a database commit.
|
|
func (p *SQLProvider) Commit(ctx context.Context) (err error) {
|
|
tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
|
|
|
|
if !ok {
|
|
return errors.New("could not retrieve tx")
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Rollback performs a database rollback.
|
|
func (p *SQLProvider) Rollback(ctx context.Context) (err error) {
|
|
tx, ok := ctx.Value(ctxKeyTransaction).(*sql.Tx)
|
|
|
|
if !ok {
|
|
return errors.New("could not retrieve tx")
|
|
}
|
|
|
|
return tx.Rollback()
|
|
}
|
|
|
|
// SaveUserOpaqueIdentifier saves a new opaque user identifier to the database.
|
|
func (p *SQLProvider) SaveUserOpaqueIdentifier(ctx context.Context, opaqueID model.UserOpaqueIdentifier) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlInsertUserOpaqueIdentifier, opaqueID.Service, opaqueID.SectorID, opaqueID.Username, opaqueID.Identifier); err != nil {
|
|
return fmt.Errorf("error inserting user opaque id for user '%s' with opaque id '%s': %w", opaqueID.Username, opaqueID.Identifier.String(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadUserOpaqueIdentifier selects an opaque user identifier from the database.
|
|
func (p *SQLProvider) LoadUserOpaqueIdentifier(ctx context.Context, opaqueUUID uuid.UUID) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
|
opaqueID = &model.UserOpaqueIdentifier{}
|
|
|
|
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifier, opaqueUUID); err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return opaqueID, nil
|
|
}
|
|
|
|
// LoadUserOpaqueIdentifiers selects an opaque user identifiers from the database.
|
|
func (p *SQLProvider) LoadUserOpaqueIdentifiers(ctx context.Context) (opaqueIDs []model.UserOpaqueIdentifier, err error) {
|
|
var rows *sqlx.Rows
|
|
|
|
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectUserOpaqueIdentifiers); err != nil {
|
|
return nil, fmt.Errorf("error selecting user opaque identifiers: %w", err)
|
|
}
|
|
|
|
var opaqueID *model.UserOpaqueIdentifier
|
|
|
|
for rows.Next() {
|
|
opaqueID = &model.UserOpaqueIdentifier{}
|
|
|
|
if err = rows.StructScan(opaqueID); err != nil {
|
|
return nil, fmt.Errorf("error selecting user opaque identifiers: error scanning row: %w", err)
|
|
}
|
|
|
|
opaqueIDs = append(opaqueIDs, *opaqueID)
|
|
}
|
|
|
|
return opaqueIDs, nil
|
|
}
|
|
|
|
// LoadUserOpaqueIdentifierBySignature selects an opaque user identifier from the database given a service name, sector id, and username.
|
|
func (p *SQLProvider) LoadUserOpaqueIdentifierBySignature(ctx context.Context, service, sectorID, username string) (opaqueID *model.UserOpaqueIdentifier, err error) {
|
|
opaqueID = &model.UserOpaqueIdentifier{}
|
|
|
|
if err = p.db.GetContext(ctx, opaqueID, p.sqlSelectUserOpaqueIdentifierBySignature, service, sectorID, username); err != nil {
|
|
switch {
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return nil, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return opaqueID, nil
|
|
}
|
|
|
|
// SaveOAuth2ConsentSession inserts an OAuth2.0 consent session.
|
|
func (p *SQLProvider) SaveOAuth2ConsentSession(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlInsertOAuth2ConsentSession,
|
|
consent.ChallengeID, consent.ClientID, consent.Subject, consent.Authorized, consent.Granted,
|
|
consent.RequestedAt, consent.RespondedAt, consent.ExpiresAt, consent.Form,
|
|
consent.RequestedScopes, consent.GrantedScopes, consent.RequestedAudience, consent.GrantedAudience); err != nil {
|
|
return fmt.Errorf("error inserting oauth2 consent session with challenge id '%s' for subject '%s': %w", consent.ChallengeID.String(), consent.Subject.String(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveOAuth2ConsentSessionSubject updates an OAuth2.0 consent session with the subject.
|
|
func (p *SQLProvider) SaveOAuth2ConsentSessionSubject(ctx context.Context, consent model.OAuth2ConsentSession) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionSubject, consent.Subject, consent.ID); err != nil {
|
|
return fmt.Errorf("error updating oauth2 consent session subject with id '%d' and challenge id '%s' for subject '%s': %w", consent.ID, consent.ChallengeID, consent.Subject, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveOAuth2ConsentSessionResponse updates an OAuth2.0 consent session with the response.
|
|
func (p *SQLProvider) SaveOAuth2ConsentSessionResponse(ctx context.Context, consent model.OAuth2ConsentSession, authorized bool) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionResponse, authorized, consent.ExpiresAt, consent.GrantedScopes, consent.GrantedAudience, consent.ID); err != nil {
|
|
return fmt.Errorf("error updating oauth2 consent session (authorized '%t') with id '%d' and challenge id '%s' for subject '%s': %w", authorized, consent.ID, consent.ChallengeID, consent.Subject, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveOAuth2ConsentSessionGranted updates an OAuth2.0 consent recording that it has been granted by the authorization endpoint.
|
|
func (p *SQLProvider) SaveOAuth2ConsentSessionGranted(ctx context.Context, id int) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpdateOAuth2ConsentSessionGranted, id); err != nil {
|
|
return fmt.Errorf("error updating oauth2 consent session (granted) with id '%d': %w", id, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadOAuth2ConsentSessionByChallengeID returns an OAuth2.0 consent given the challenge ID.
|
|
func (p *SQLProvider) LoadOAuth2ConsentSessionByChallengeID(ctx context.Context, challengeID uuid.UUID) (consent *model.OAuth2ConsentSession, err error) {
|
|
consent = &model.OAuth2ConsentSession{}
|
|
|
|
if err = p.db.GetContext(ctx, consent, p.sqlSelectOAuth2ConsentSessionByChallengeID, challengeID); err != nil {
|
|
return nil, fmt.Errorf("error selecting oauth2 consent session with challenge id '%s': %w", challengeID.String(), err)
|
|
}
|
|
|
|
return consent, nil
|
|
}
|
|
|
|
// LoadOAuth2ConsentSessionsPreConfigured returns an OAuth2.0 consents that are pre-configured given the consent signature.
|
|
func (p *SQLProvider) LoadOAuth2ConsentSessionsPreConfigured(ctx context.Context, clientID string, subject uuid.UUID) (rows *ConsentSessionRows, err error) {
|
|
var r *sqlx.Rows
|
|
|
|
if r, err = p.db.QueryxContext(ctx, p.sqlSelectOAuth2ConsentSessionsPreConfigured, clientID, subject); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return &ConsentSessionRows{}, nil
|
|
}
|
|
|
|
return &ConsentSessionRows{}, fmt.Errorf("error selecting oauth2 consent session by signature with client id '%s' and subject '%s': %w", clientID, subject.String(), err)
|
|
}
|
|
|
|
return &ConsentSessionRows{rows: r}, nil
|
|
}
|
|
|
|
// SaveOAuth2Session saves a OAuth2Session to the database.
|
|
func (p *SQLProvider) SaveOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, session model.OAuth2Session) (err error) {
|
|
var query string
|
|
|
|
switch sessionType {
|
|
case OAuth2SessionTypeAuthorizeCode:
|
|
query = p.sqlInsertOAuth2AuthorizeCodeSession
|
|
case OAuth2SessionTypeAccessToken:
|
|
query = p.sqlInsertOAuth2AccessTokenSession
|
|
case OAuth2SessionTypeRefreshToken:
|
|
query = p.sqlInsertOAuth2RefreshTokenSession
|
|
case OAuth2SessionTypePKCEChallenge:
|
|
query = p.sqlInsertOAuth2PKCERequestSession
|
|
case OAuth2SessionTypeOpenIDConnect:
|
|
query = p.sqlInsertOAuth2OpenIDConnectSession
|
|
default:
|
|
return fmt.Errorf("error inserting oauth2 session for subject '%s' and request id '%s': unknown oauth2 session type '%s'", session.Subject, session.RequestID, sessionType)
|
|
}
|
|
|
|
if session.Session, err = p.encrypt(session.Session); err != nil {
|
|
return fmt.Errorf("error encrypting the oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
|
|
}
|
|
|
|
_, err = p.db.ExecContext(ctx, query,
|
|
session.ChallengeID, session.RequestID, session.ClientID, session.Signature,
|
|
session.Subject, session.RequestedAt, session.RequestedScopes, session.GrantedScopes,
|
|
session.RequestedAudience, session.GrantedAudience,
|
|
session.Active, session.Revoked, session.Form, session.Session)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("error inserting oauth2 %s session data for subject '%s' and request id '%s' and challenge id '%s': %w", sessionType, session.Subject, session.RequestID, session.ChallengeID.String(), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RevokeOAuth2Session marks a OAuth2Session as revoked in the database.
|
|
func (p *SQLProvider) RevokeOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
|
|
var query string
|
|
|
|
switch sessionType {
|
|
case OAuth2SessionTypeAuthorizeCode:
|
|
query = p.sqlRevokeOAuth2AuthorizeCodeSession
|
|
case OAuth2SessionTypeAccessToken:
|
|
query = p.sqlRevokeOAuth2AccessTokenSession
|
|
case OAuth2SessionTypeRefreshToken:
|
|
query = p.sqlRevokeOAuth2RefreshTokenSession
|
|
case OAuth2SessionTypePKCEChallenge:
|
|
query = p.sqlRevokeOAuth2PKCERequestSession
|
|
case OAuth2SessionTypeOpenIDConnect:
|
|
query = p.sqlRevokeOAuth2OpenIDConnectSession
|
|
default:
|
|
return fmt.Errorf("error revoking oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
|
|
return fmt.Errorf("error revoking oauth2 %s session with signature '%s': %w", sessionType, signature, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RevokeOAuth2SessionByRequestID marks a OAuth2Session as revoked in the database.
|
|
func (p *SQLProvider) RevokeOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
|
|
var query string
|
|
|
|
switch sessionType {
|
|
case OAuth2SessionTypeAuthorizeCode:
|
|
query = p.sqlRevokeOAuth2AuthorizeCodeSessionByRequestID
|
|
case OAuth2SessionTypeAccessToken:
|
|
query = p.sqlRevokeOAuth2AccessTokenSessionByRequestID
|
|
case OAuth2SessionTypeRefreshToken:
|
|
query = p.sqlRevokeOAuth2RefreshTokenSessionByRequestID
|
|
case OAuth2SessionTypePKCEChallenge:
|
|
query = p.sqlRevokeOAuth2PKCERequestSessionByRequestID
|
|
case OAuth2SessionTypeOpenIDConnect:
|
|
query = p.sqlRevokeOAuth2OpenIDConnectSessionByRequestID
|
|
default:
|
|
return fmt.Errorf("error revoking oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
|
|
return fmt.Errorf("error revoking oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeactivateOAuth2Session marks a OAuth2Session as inactive in the database.
|
|
func (p *SQLProvider) DeactivateOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (err error) {
|
|
var query string
|
|
|
|
switch sessionType {
|
|
case OAuth2SessionTypeAuthorizeCode:
|
|
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
|
|
case OAuth2SessionTypeAccessToken:
|
|
query = p.sqlDeactivateOAuth2AccessTokenSession
|
|
case OAuth2SessionTypeRefreshToken:
|
|
query = p.sqlDeactivateOAuth2RefreshTokenSession
|
|
case OAuth2SessionTypePKCEChallenge:
|
|
query = p.sqlDeactivateOAuth2PKCERequestSession
|
|
case OAuth2SessionTypeOpenIDConnect:
|
|
query = p.sqlDeactivateOAuth2OpenIDConnectSession
|
|
default:
|
|
return fmt.Errorf("error deactivating oauth2 session with signature '%s': unknown oauth2 session type '%s'", signature, sessionType)
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, query, signature); err != nil {
|
|
return fmt.Errorf("error deactivating oauth2 %s session with signature '%s': %w", sessionType, signature, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeactivateOAuth2SessionByRequestID marks a OAuth2Session as inactive in the database.
|
|
func (p *SQLProvider) DeactivateOAuth2SessionByRequestID(ctx context.Context, sessionType OAuth2SessionType, requestID string) (err error) {
|
|
var query string
|
|
|
|
switch sessionType {
|
|
case OAuth2SessionTypeAuthorizeCode:
|
|
query = p.sqlDeactivateOAuth2AuthorizeCodeSession
|
|
case OAuth2SessionTypeAccessToken:
|
|
query = p.sqlDeactivateOAuth2AccessTokenSessionByRequestID
|
|
case OAuth2SessionTypeRefreshToken:
|
|
query = p.sqlDeactivateOAuth2RefreshTokenSessionByRequestID
|
|
case OAuth2SessionTypePKCEChallenge:
|
|
query = p.sqlDeactivateOAuth2PKCERequestSessionByRequestID
|
|
case OAuth2SessionTypeOpenIDConnect:
|
|
query = p.sqlDeactivateOAuth2OpenIDConnectSessionByRequestID
|
|
default:
|
|
return fmt.Errorf("error deactivating oauth2 session with request id '%s': unknown oauth2 session type '%s'", requestID, sessionType)
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, query, requestID); err != nil {
|
|
return fmt.Errorf("error deactivating oauth2 %s session with request id '%s': %w", sessionType, requestID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadOAuth2Session saves a OAuth2Session from the database.
|
|
func (p *SQLProvider) LoadOAuth2Session(ctx context.Context, sessionType OAuth2SessionType, signature string) (session *model.OAuth2Session, err error) {
|
|
var query string
|
|
|
|
switch sessionType {
|
|
case OAuth2SessionTypeAuthorizeCode:
|
|
query = p.sqlSelectOAuth2AuthorizeCodeSession
|
|
case OAuth2SessionTypeAccessToken:
|
|
query = p.sqlSelectOAuth2AccessTokenSession
|
|
case OAuth2SessionTypeRefreshToken:
|
|
query = p.sqlSelectOAuth2RefreshTokenSession
|
|
case OAuth2SessionTypePKCEChallenge:
|
|
query = p.sqlSelectOAuth2PKCERequestSession
|
|
case OAuth2SessionTypeOpenIDConnect:
|
|
query = p.sqlSelectOAuth2OpenIDConnectSession
|
|
default:
|
|
return nil, fmt.Errorf("error selecting oauth2 session: unknown oauth2 session type '%s'", sessionType)
|
|
}
|
|
|
|
session = &model.OAuth2Session{}
|
|
|
|
if err = p.db.GetContext(ctx, session, query, signature); err != nil {
|
|
return nil, fmt.Errorf("error selecting oauth2 %s session with signature '%s': %w", sessionType, signature, err)
|
|
}
|
|
|
|
if session.Session, err = p.decrypt(session.Session); err != nil {
|
|
return nil, fmt.Errorf("error decrypting the oauth2 %s session data with signature '%s' for subject '%s' and request id '%s': %w", sessionType, signature, session.Subject, session.RequestID, err)
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
// SaveOAuth2BlacklistedJTI saves a OAuth2BlacklistedJTI to the database.
|
|
func (p *SQLProvider) SaveOAuth2BlacklistedJTI(ctx context.Context, blacklistedJTI model.OAuth2BlacklistedJTI) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertOAuth2BlacklistedJTI, blacklistedJTI.Signature, blacklistedJTI.ExpiresAt); err != nil {
|
|
return fmt.Errorf("error inserting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadOAuth2BlacklistedJTI loads a OAuth2BlacklistedJTI from the database.
|
|
func (p *SQLProvider) LoadOAuth2BlacklistedJTI(ctx context.Context, signature string) (blacklistedJTI *model.OAuth2BlacklistedJTI, err error) {
|
|
blacklistedJTI = &model.OAuth2BlacklistedJTI{}
|
|
|
|
if err = p.db.GetContext(ctx, blacklistedJTI, p.sqlSelectOAuth2BlacklistedJTI, signature); err != nil {
|
|
return nil, fmt.Errorf("error selecting oauth2 blacklisted JTI with signature '%s': %w", blacklistedJTI.Signature, err)
|
|
}
|
|
|
|
return blacklistedJTI, nil
|
|
}
|
|
|
|
// SavePreferred2FAMethod save the preferred method for 2FA to the database.
|
|
func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method); err != nil {
|
|
return fmt.Errorf("error upserting preferred two factor method for user '%s': %w", username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
|
|
func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
|
|
err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
|
|
|
|
switch {
|
|
case err == nil:
|
|
return method, nil
|
|
case errors.Is(err, sql.ErrNoRows):
|
|
return "", sql.ErrNoRows
|
|
default:
|
|
return "", fmt.Errorf("error selecting preferred two factor method for user '%s': %w", username, err)
|
|
}
|
|
}
|
|
|
|
// LoadUserInfo loads the model.UserInfo from the database.
|
|
func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info model.UserInfo, err error) {
|
|
err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username, username)
|
|
|
|
switch {
|
|
case err == nil, errors.Is(err, sql.ErrNoRows):
|
|
return info, nil
|
|
default:
|
|
return model.UserInfo{}, fmt.Errorf("error selecting user info for user '%s': %w", username, err)
|
|
}
|
|
}
|
|
|
|
// SaveIdentityVerification save an identity verification record to the database.
|
|
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification model.IdentityVerification) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
|
|
verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt,
|
|
verification.Username, verification.Action); err != nil {
|
|
return fmt.Errorf("error inserting identity verification for user '%s' with uuid '%s': %w", verification.Username, verification.JTI, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ConsumeIdentityVerification marks an identity verification record in the database as consumed.
|
|
func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip model.NullIP) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil {
|
|
return fmt.Errorf("error updating identity verification: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// FindIdentityVerification checks if an identity verification record is in the database and active.
|
|
func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
|
|
verification := model.IdentityVerification{}
|
|
if err = p.db.GetContext(ctx, &verification, p.sqlSelectIdentityVerification, jti); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
|
|
return false, fmt.Errorf("error selecting identity verification exists: %w", err)
|
|
}
|
|
|
|
switch {
|
|
case verification.Consumed != nil:
|
|
return false, fmt.Errorf("the token has already been consumed")
|
|
case verification.ExpiresAt.Before(time.Now()):
|
|
return false, fmt.Errorf("the token expired %s ago", time.Since(verification.ExpiresAt))
|
|
default:
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
// SaveTOTPConfiguration save a TOTP configuration of a given user in the database.
|
|
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config model.TOTPConfiguration) (err error) {
|
|
if config.Secret, err = p.encrypt(config.Secret); err != nil {
|
|
return fmt.Errorf("error encrypting the TOTP configuration secret for user '%s': %w", config.Username, err)
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
|
|
config.CreatedAt, config.LastUsedAt,
|
|
config.Username, config.Issuer,
|
|
config.Algorithm, config.Digits, config.Period, config.Secret); err != nil {
|
|
return fmt.Errorf("error upserting TOTP configuration for user '%s': %w", config.Username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateTOTPConfigurationSignIn updates a registered Webauthn devices sign in information.
|
|
func (p *SQLProvider) UpdateTOTPConfigurationSignIn(ctx context.Context, id int, lastUsedAt *time.Time) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigRecordSignIn, lastUsedAt, id); err != nil {
|
|
return fmt.Errorf("error updating TOTP configuration id %d: %w", id, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteTOTPConfiguration delete a TOTP configuration from the database given a username.
|
|
func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username); err != nil {
|
|
return fmt.Errorf("error deleting TOTP configuration for user '%s': %w", username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadTOTPConfiguration load a TOTP configuration given a username from the database.
|
|
func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *model.TOTPConfiguration, err error) {
|
|
config = &model.TOTPConfiguration{}
|
|
|
|
if err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNoTOTPConfiguration
|
|
}
|
|
|
|
return nil, fmt.Errorf("error selecting TOTP configuration for user '%s': %w", username, err)
|
|
}
|
|
|
|
if config.Secret, err = p.decrypt(config.Secret); err != nil {
|
|
return nil, fmt.Errorf("error decrypting the TOTP secret for user '%s': %w", username, err)
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// LoadTOTPConfigurations load a set of TOTP configurations.
|
|
func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []model.TOTPConfiguration, err error) {
|
|
configs = make([]model.TOTPConfiguration, 0, limit)
|
|
|
|
if err = p.db.SelectContext(ctx, &configs, p.sqlSelectTOTPConfigs, limit, limit*page); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("error selecting TOTP configurations: %w", err)
|
|
}
|
|
|
|
for i, c := range configs {
|
|
if configs[i].Secret, err = p.decrypt(c.Secret); err != nil {
|
|
return nil, fmt.Errorf("error decrypting TOTP configuration for user '%s': %w", c.Username, err)
|
|
}
|
|
}
|
|
|
|
return configs, nil
|
|
}
|
|
|
|
func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config model.TOTPConfiguration) (err error) {
|
|
switch config.ID {
|
|
case 0:
|
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
|
|
default:
|
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecret, config.Secret, config.ID)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("error updating TOTP configuration secret for user '%s': %w", config.Username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveWebauthnDevice saves a registered Webauthn device.
|
|
func (p *SQLProvider) SaveWebauthnDevice(ctx context.Context, device model.WebauthnDevice) (err error) {
|
|
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
|
|
return fmt.Errorf("error encrypting the Webauthn device public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
|
|
}
|
|
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertWebauthnDevice,
|
|
device.CreatedAt, device.LastUsedAt,
|
|
device.RPID, device.Username, device.Description,
|
|
device.KID, device.PublicKey,
|
|
device.AttestationType, device.Transport, device.AAGUID, device.SignCount, device.CloneWarning,
|
|
); err != nil {
|
|
return fmt.Errorf("error upserting Webauthn device for user '%s' kid '%x': %w", device.Username, device.KID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateWebauthnDeviceSignIn updates a registered Webauthn devices sign in information.
|
|
func (p *SQLProvider) UpdateWebauthnDeviceSignIn(ctx context.Context, id int, rpid string, lastUsedAt *time.Time, signCount uint32, cloneWarning bool) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDeviceRecordSignIn, rpid, lastUsedAt, signCount, cloneWarning, id); err != nil {
|
|
return fmt.Errorf("error updating Webauthn signin metadata for id '%x': %w", id, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadWebauthnDevices loads Webauthn device registrations.
|
|
func (p *SQLProvider) LoadWebauthnDevices(ctx context.Context, limit, page int) (devices []model.WebauthnDevice, err error) {
|
|
devices = make([]model.WebauthnDevice, 0, limit)
|
|
|
|
if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevices, limit, limit*page); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("error selecting Webauthn devices: %w", err)
|
|
}
|
|
|
|
for i, device := range devices {
|
|
if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil {
|
|
return nil, fmt.Errorf("error decrypting Webauthn public key for user '%s': %w", device.Username, err)
|
|
}
|
|
}
|
|
|
|
return devices, nil
|
|
}
|
|
|
|
// LoadWebauthnDevicesByUsername loads all webauthn devices registration for a given username.
|
|
func (p *SQLProvider) LoadWebauthnDevicesByUsername(ctx context.Context, username string) (devices []model.WebauthnDevice, err error) {
|
|
if err = p.db.SelectContext(ctx, &devices, p.sqlSelectWebauthnDevicesByUsername, username); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNoWebauthnDevice
|
|
}
|
|
|
|
return nil, fmt.Errorf("error selecting Webauthn devices for user '%s': %w", username, err)
|
|
}
|
|
|
|
for i, device := range devices {
|
|
if devices[i].PublicKey, err = p.decrypt(device.PublicKey); err != nil {
|
|
return nil, fmt.Errorf("error decrypting Webauthn public key for user '%s': %w", username, err)
|
|
}
|
|
}
|
|
|
|
return devices, nil
|
|
}
|
|
|
|
func (p *SQLProvider) updateWebauthnDevicePublicKey(ctx context.Context, device model.WebauthnDevice) (err error) {
|
|
switch device.ID {
|
|
case 0:
|
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDevicePublicKeyByUsername, device.PublicKey, device.Username, device.KID)
|
|
default:
|
|
_, err = p.db.ExecContext(ctx, p.sqlUpdateWebauthnDevicePublicKey, device.PublicKey, device.ID)
|
|
}
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("error updating Webauthn public key for user '%s' kid '%x': %w", device.Username, device.KID, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SavePreferredDuoDevice saves a Duo device.
|
|
func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device model.DuoDevice) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method); err != nil {
|
|
return fmt.Errorf("error upserting preferred duo device for user '%s': %w", device.Username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeletePreferredDuoDevice deletes a Duo device of a given user.
|
|
func (p *SQLProvider) DeletePreferredDuoDevice(ctx context.Context, username string) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlDeleteDuoDevice, username); err != nil {
|
|
return fmt.Errorf("error deleting preferred duo device for user '%s': %w", username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadPreferredDuoDevice loads a Duo device of a given user.
|
|
func (p *SQLProvider) LoadPreferredDuoDevice(ctx context.Context, username string) (device *model.DuoDevice, err error) {
|
|
device = &model.DuoDevice{}
|
|
|
|
if err = p.db.QueryRowxContext(ctx, p.sqlSelectDuoDevice, username).StructScan(device); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, ErrNoDuoDevice
|
|
}
|
|
|
|
return nil, fmt.Errorf("error selecting preferred duo device for user '%s': %w", username, err)
|
|
}
|
|
|
|
return device, nil
|
|
}
|
|
|
|
// AppendAuthenticationLog append a mark to the authentication log.
|
|
func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt model.AuthenticationAttempt) (err error) {
|
|
if _, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt,
|
|
attempt.Time, attempt.Successful, attempt.Banned, attempt.Username,
|
|
attempt.Type, attempt.RemoteIP, attempt.RequestURI, attempt.RequestMethod); err != nil {
|
|
return fmt.Errorf("error inserting authentication attempt for user '%s': %w", attempt.Username, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log.
|
|
func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []model.AuthenticationAttempt, err error) {
|
|
attempts = make([]model.AuthenticationAttempt, 0, limit)
|
|
|
|
if err = p.db.SelectContext(ctx, &attempts, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNoAuthenticationLogs
|
|
}
|
|
|
|
return nil, fmt.Errorf("error selecting authentication logs for user '%s': %w", username, err)
|
|
}
|
|
|
|
return attempts, nil
|
|
}
|