authelia/internal/storage/sql_provider_backend_postgres.go
James Elliott ad8e844af6
feat(totp): algorithm and digits config (#2634)
Allow users to configure the TOTP Algorithm and Digits. This should be used with caution as many TOTP applications do not support it. Some will also fail to notify the user that there is an issue. i.e. if the algorithm in the QR code is sha512, they continue to generate one time passwords with sha1. In addition this drastically refactors TOTP in general to be more user friendly by not forcing them to register a new device if the administrator changes the period (or algorithm).

Fixes #1226.
2021-12-01 23:11:29 +11:00

78 lines
3.5 KiB
Go

package storage
import (
"fmt"
"strings"
"time"
_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// PostgreSQLProvider is a PostgreSQL provider.
type PostgreSQLProvider struct {
SQLProvider
}
// NewPostgreSQLProvider a PostgreSQL provider.
func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLProvider) {
provider = &PostgreSQLProvider{
SQLProvider: NewSQLProvider(config, providerPostgres, "pgx", dataSourceNamePostgreSQL(*config.Storage.PostgreSQL)),
}
// All providers have differing SELECT existing table statements.
provider.sqlSelectExistingTables = queryPostgreSelectExistingTables
// Specific alterations to this provider.
// PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead.
provider.sqlUpsertU2FDevice = fmt.Sprintf(queryFmtPostgresUpsertU2FDevice, tableU2FDevices)
provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations)
provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences)
provider.sqlUpsertEncryptionValue = fmt.Sprintf(queryFmtPostgresUpsertEncryptionValue, tableEncryption)
// PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders.
provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable)
provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod)
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpsertTOTPConfig = provider.db.Rebind(provider.sqlUpsertTOTPConfig)
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername)
provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration)
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
return provider
}
func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) {
args := []string{
fmt.Sprintf("user='%s'", config.Username),
fmt.Sprintf("password='%s'", config.Password),
}
if config.Host != "" {
args = append(args, fmt.Sprintf("host=%s", config.Host))
}
if config.Port > 0 {
args = append(args, fmt.Sprintf("port=%d", config.Port))
}
if config.Database != "" {
args = append(args, fmt.Sprintf("dbname=%s", config.Database))
}
args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second)))
return strings.Join(args, " ")
}