authelia/internal/storage/sql_provider_encryption.go
2022-03-06 16:47:40 +11:00

300 lines
7.0 KiB
Go

package storage
import (
"context"
"crypto/sha256"
"fmt"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/authelia/authelia/v4/internal/model"
"github.com/authelia/authelia/v4/internal/utils"
)
// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
// by this command to encrypt the values again and update them using a transaction.
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) {
tx, err := p.db.Beginx()
if err != nil {
return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
}
key := sha256.Sum256([]byte(encryptionKey))
if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
return err
}
if err = p.schemaEncryptionChangeKeyWebauthn(ctx, tx, key); err != nil {
return err
}
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
return tx.Commit()
}
func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
var configs []model.TOTPConfiguration
for page := 0; true; page++ {
if configs, err = p.LoadTOTPConfigurations(ctx, 10, page); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
for _, config := range configs {
if config.Secret, err = utils.Encrypt(config.Secret, &key); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
}
if len(configs) != 10 {
break
}
}
return nil
}
func (p *SQLProvider) schemaEncryptionChangeKeyWebauthn(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
var devices []model.WebauthnDevice
for page := 0; true; page++ {
if devices, err = p.LoadWebauthnDevices(ctx, 10, page); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
for _, device := range devices {
if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
if err = p.updateWebauthnDevicePublicKey(ctx, device); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
}
return fmt.Errorf("rollback due to error: %w", err)
}
}
if len(devices) != 10 {
break
}
}
return nil
}
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) {
version, err := p.SchemaVersion(ctx)
if err != nil {
return err
}
if version < 1 {
return ErrSchemaEncryptionVersionUnsupported
}
var errs []error
if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
errs = append(errs, ErrSchemaEncryptionInvalidKey)
}
if verbose {
if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
errs = append(errs, err)
}
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
for i, e := range errs {
if i == 0 {
err = e
continue
}
err = fmt.Errorf("%w, %v", err, e)
}
return err
}
return nil
}
func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
var (
config model.TOTPConfiguration
row int
invalid int
total int
)
pageSize := 10
var rows *sqlx.Rows
for page := 0; true; page++ {
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
_ = rows.Close()
return fmt.Errorf("error selecting TOTP configurations: %w", err)
}
row = 0
for rows.Next() {
total++
row++
if err = rows.StructScan(&config); err != nil {
_ = rows.Close()
return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
}
if _, err = p.decrypt(config.Secret); err != nil {
invalid++
}
}
_ = rows.Close()
if row < pageSize {
break
}
}
if invalid != 0 {
return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
}
return nil
}
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
var (
device model.U2FDevice
row int
invalid int
total int
)
pageSize := 10
var rows *sqlx.Rows
for page := 0; true; page++ {
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
_ = rows.Close()
return fmt.Errorf("error selecting U2F devices: %w", err)
}
row = 0
for rows.Next() {
total++
row++
if err = rows.StructScan(&device); err != nil {
_ = rows.Close()
return fmt.Errorf("error scanning U2F device to struct: %w", err)
}
if _, err = p.decrypt(device.PublicKey); err != nil {
invalid++
}
}
_ = rows.Close()
if row < pageSize {
break
}
}
if invalid != 0 {
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
}
return nil
}
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
return utils.Encrypt(clearText, &p.key)
}
func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
return utils.Decrypt(cipherText, &p.key)
}
func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
var encryptedValue []byte
err = p.db.GetContext(ctx, &encryptedValue, p.sqlSelectEncryptionValue, name)
if err != nil {
return nil, err
}
return p.decrypt(encryptedValue)
}
func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]byte, e sqlx.ExecerContext) (err error) {
valueClearText, err := uuid.NewUUID()
if err != nil {
return err
}
value, err := utils.Encrypt([]byte(valueClearText.String()), key)
if err != nil {
return err
}
if e != nil {
_, err = e.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
} else {
_, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
}
return err
}