package storage

import (
	"context"
	"crypto/sha256"
	"fmt"

	"github.com/google/uuid"
	"github.com/jmoiron/sqlx"

	"github.com/authelia/authelia/v4/internal/model"
	"github.com/authelia/authelia/v4/internal/utils"
)

// SchemaEncryptionChangeKey uses the currently configured key to decrypt values in the database and the key provided
// by this command to encrypt the values again and update them using a transaction.
func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionKey string) (err error) {
	tx, err := p.db.Beginx()
	if err != nil {
		return fmt.Errorf("error beginning transaction to change encryption key: %w", err)
	}

	key := sha256.Sum256([]byte(encryptionKey))

	if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
		return err
	}

	if err = p.schemaEncryptionChangeKeyWebauthn(ctx, tx, key); err != nil {
		return err
	}

	if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
		}

		return fmt.Errorf("rollback due to error: %w", err)
	}

	return tx.Commit()
}

func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
	var configs []model.TOTPConfiguration

	for page := 0; true; page++ {
		if configs, err = p.LoadTOTPConfigurations(ctx, 10, page); err != nil {
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
				return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
			}

			return fmt.Errorf("rollback due to error: %w", err)
		}

		for _, config := range configs {
			if config.Secret, err = utils.Encrypt(config.Secret, &key); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
				}

				return fmt.Errorf("rollback due to error: %w", err)
			}

			if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
				}

				return fmt.Errorf("rollback due to error: %w", err)
			}
		}

		if len(configs) != 10 {
			break
		}
	}

	return nil
}

func (p *SQLProvider) schemaEncryptionChangeKeyWebauthn(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
	var devices []model.WebauthnDevice

	for page := 0; true; page++ {
		if devices, err = p.LoadWebauthnDevices(ctx, 10, page); err != nil {
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
				return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
			}

			return fmt.Errorf("rollback due to error: %w", err)
		}

		for _, device := range devices {
			if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
				}

				return fmt.Errorf("rollback due to error: %w", err)
			}

			if err = p.updateWebauthnDevicePublicKey(ctx, device); err != nil {
				if rollbackErr := tx.Rollback(); rollbackErr != nil {
					return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
				}

				return fmt.Errorf("rollback due to error: %w", err)
			}
		}

		if len(devices) != 10 {
			break
		}
	}

	return nil
}

// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool) (err error) {
	version, err := p.SchemaVersion(ctx)
	if err != nil {
		return err
	}

	if version < 1 {
		return ErrSchemaEncryptionVersionUnsupported
	}

	var errs []error

	if _, err = p.getEncryptionValue(ctx, encryptionNameCheck); err != nil {
		errs = append(errs, ErrSchemaEncryptionInvalidKey)
	}

	if verbose {
		if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
			errs = append(errs, err)
		}

		if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
			errs = append(errs, err)
		}
	}

	if len(errs) != 0 {
		for i, e := range errs {
			if i == 0 {
				err = e

				continue
			}

			err = fmt.Errorf("%w, %v", err, e)
		}

		return err
	}

	return nil
}

func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
	var (
		config  model.TOTPConfiguration
		row     int
		invalid int
		total   int
	)

	pageSize := 10

	var rows *sqlx.Rows

	for page := 0; true; page++ {
		if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
			_ = rows.Close()

			return fmt.Errorf("error selecting TOTP configurations: %w", err)
		}

		row = 0

		for rows.Next() {
			total++
			row++

			if err = rows.StructScan(&config); err != nil {
				_ = rows.Close()
				return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
			}

			if _, err = p.decrypt(config.Secret); err != nil {
				invalid++
			}
		}

		_ = rows.Close()

		if row < pageSize {
			break
		}
	}

	if invalid != 0 {
		return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
	}

	return nil
}

func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
	var (
		device  model.U2FDevice
		row     int
		invalid int
		total   int
	)

	pageSize := 10

	var rows *sqlx.Rows

	for page := 0; true; page++ {
		if rows, err = p.db.QueryxContext(ctx, p.sqlSelectWebauthnDevices, pageSize, pageSize*page); err != nil {
			_ = rows.Close()

			return fmt.Errorf("error selecting U2F devices: %w", err)
		}

		row = 0

		for rows.Next() {
			total++
			row++

			if err = rows.StructScan(&device); err != nil {
				_ = rows.Close()
				return fmt.Errorf("error scanning U2F device to struct: %w", err)
			}

			if _, err = p.decrypt(device.PublicKey); err != nil {
				invalid++
			}
		}

		_ = rows.Close()

		if row < pageSize {
			break
		}
	}

	if invalid != 0 {
		return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
	}

	return nil
}

func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
	return utils.Encrypt(clearText, &p.key)
}

func (p SQLProvider) decrypt(cipherText []byte) (clearText []byte, err error) {
	return utils.Decrypt(cipherText, &p.key)
}

func (p *SQLProvider) getEncryptionValue(ctx context.Context, name string) (value []byte, err error) {
	var encryptedValue []byte

	err = p.db.GetContext(ctx, &encryptedValue, p.sqlSelectEncryptionValue, name)
	if err != nil {
		return nil, err
	}

	return p.decrypt(encryptedValue)
}

func (p *SQLProvider) setNewEncryptionCheckValue(ctx context.Context, key *[32]byte, e sqlx.ExecerContext) (err error) {
	valueClearText, err := uuid.NewRandom()
	if err != nil {
		return err
	}

	value, err := utils.Encrypt([]byte(valueClearText.String()), key)
	if err != nil {
		return err
	}

	if e != nil {
		_, err = e.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
	} else {
		_, err = p.db.ExecContext(ctx, p.sqlUpsertEncryptionValue, encryptionNameCheck, value)
	}

	return err
}