mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
3695aa8140
This is a massive overhaul to the SQL Storage for Authelia. It facilitates a whole heap of utility commands to help manage the database, primary keys, ensures all database requests use a context for cancellations, and paves the way for a few other PR's which improve the database. Fixes #1337
205 lines
4.6 KiB
Go
205 lines
4.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"embed"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
//go:embed migrations/*
|
|
var migrationsFS embed.FS
|
|
|
|
func latestMigrationVersion(providerName string) (version int, err error) {
|
|
entries, err := migrationsFS.ReadDir("migrations")
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
m, err := scanMigration(entry.Name())
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
if m.Provider != providerName {
|
|
continue
|
|
}
|
|
|
|
if !m.Up {
|
|
continue
|
|
}
|
|
|
|
if m.Version > version {
|
|
version = m.Version
|
|
}
|
|
}
|
|
|
|
return version, nil
|
|
}
|
|
|
|
func loadMigration(providerName string, version int, up bool) (migration *SchemaMigration, err error) {
|
|
entries, err := migrationsFS.ReadDir("migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
m, err := scanMigration(entry.Name())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
migration = &m
|
|
|
|
if up != migration.Up {
|
|
continue
|
|
}
|
|
|
|
if migration.Provider != providerAll && migration.Provider != providerName {
|
|
continue
|
|
}
|
|
|
|
if version != migration.Version {
|
|
continue
|
|
}
|
|
|
|
return migration, nil
|
|
}
|
|
|
|
return nil, errors.New("migration not found")
|
|
}
|
|
|
|
// loadMigrations scans the migrations fs and loads the appropriate migrations for a given providerName, prior and
|
|
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
|
|
// this indicates the database zero state.
|
|
func loadMigrations(providerName string, prior, target int) (migrations []SchemaMigration, err error) {
|
|
if prior == target && (prior != -1 || target != -1) {
|
|
return nil, errors.New("cannot migrate to the same version as prior")
|
|
}
|
|
|
|
entries, err := migrationsFS.ReadDir("migrations")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
up := prior < target
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
migration, err := scanMigration(entry.Name())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if skipMigration(providerName, up, target, prior, &migration) {
|
|
continue
|
|
}
|
|
|
|
migrations = append(migrations, migration)
|
|
}
|
|
|
|
if up {
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
} else {
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version > migrations[j].Version
|
|
})
|
|
}
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
func skipMigration(providerName string, up bool, target, prior int, migration *SchemaMigration) (skip bool) {
|
|
if migration.Provider != providerAll && migration.Provider != providerName {
|
|
// Skip if migration.Provider is not a match.
|
|
return true
|
|
}
|
|
|
|
if up {
|
|
if !migration.Up {
|
|
// Skip if we wanted an Up migration but it isn't an Up migration.
|
|
return true
|
|
}
|
|
|
|
if target != -1 && (migration.Version > target || migration.Version <= prior) {
|
|
// Skip if the migration version is greater than the target or less than or equal to the previous version.
|
|
return true
|
|
}
|
|
} else {
|
|
if migration.Up {
|
|
// Skip if we didn't want an Up migration but it is an Up migration.
|
|
return true
|
|
}
|
|
|
|
if migration.Version == 1 && target == -1 {
|
|
// Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data
|
|
// preventing a successful migration.
|
|
return true
|
|
}
|
|
|
|
if migration.Version <= target || migration.Version > prior {
|
|
// Skip the migration if we want to go down and the migration version is less than or equal to the target
|
|
// or greater than the previous version.
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func scanMigration(m string) (migration SchemaMigration, err error) {
|
|
result := reMigration.FindStringSubmatch(m)
|
|
|
|
if result == nil || len(result) != 5 {
|
|
return SchemaMigration{}, errors.New("invalid migration: could not parse the format")
|
|
}
|
|
|
|
migration = SchemaMigration{
|
|
Name: strings.ReplaceAll(result[2], "_", " "),
|
|
Provider: result[3],
|
|
}
|
|
|
|
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
|
|
if err != nil {
|
|
return SchemaMigration{}, err
|
|
}
|
|
|
|
migration.Query = string(data)
|
|
|
|
switch result[4] {
|
|
case "up":
|
|
migration.Up = true
|
|
case "down":
|
|
migration.Up = false
|
|
default:
|
|
return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4])
|
|
}
|
|
|
|
migration.Version, _ = strconv.Atoi(result[1])
|
|
|
|
switch migration.Provider {
|
|
case providerAll, providerSQLite, providerMySQL, providerPostgres:
|
|
break
|
|
default:
|
|
return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3])
|
|
}
|
|
|
|
return migration, nil
|
|
}
|