mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
fix(storage): postgres schema hardcoded for tables query (#2667)
This removes the hardcoded schema value from the PostgreSQL existing tables query, making it compatible with the new schema config option.
This commit is contained in:
parent
ec1cc3d64e
commit
95a5e326a5
|
@ -54,3 +54,11 @@ type StorageConfiguration struct {
|
|||
var DefaultSQLStorageConfiguration = SQLStorageConfiguration{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// DefaultPostgreSQLStorageConfiguration represents the default PostgreSQL configuration.
|
||||
var DefaultPostgreSQLStorageConfiguration = PostgreSQLStorageConfiguration{
|
||||
Schema: "public",
|
||||
SSL: PostgreSQLSSLStorageConfiguration{
|
||||
Mode: "disable",
|
||||
},
|
||||
}
|
||||
|
|
|
@ -52,13 +52,17 @@ func validateSQLConfiguration(configuration *schema.SQLStorageConfiguration, val
|
|||
func validatePostgreSQLConfiguration(configuration *schema.PostgreSQLStorageConfiguration, validator *schema.StructValidator) {
|
||||
validateSQLConfiguration(&configuration.SQLStorageConfiguration, validator, "postgres")
|
||||
|
||||
if configuration.Schema == "" {
|
||||
configuration.Schema = schema.DefaultPostgreSQLStorageConfiguration.Schema
|
||||
}
|
||||
|
||||
// Deprecated. TODO: Remove in v4.36.0.
|
||||
if configuration.SSLMode != "" && configuration.SSL.Mode == "" {
|
||||
configuration.SSL.Mode = configuration.SSLMode
|
||||
}
|
||||
|
||||
if configuration.SSL.Mode == "" {
|
||||
configuration.SSL.Mode = testModeDisabled
|
||||
configuration.SSL.Mode = schema.DefaultPostgreSQLStorageConfiguration.SSL.Mode
|
||||
} else if !utils.IsStringInSlice(configuration.SSL.Mode, storagePostgreSQLValidSSLModes) {
|
||||
validator.Push(fmt.Errorf(errFmtStoragePostgreSQLInvalidSSLMode, configuration.SSL.Mode, strings.Join(storagePostgreSQLValidSSLModes, "', '")))
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ func (suite *StorageSuite) TestShouldValidatePostgreSQLHostUsernamePasswordAndDa
|
|||
suite.Assert().Len(suite.validator.Errors(), 0)
|
||||
}
|
||||
|
||||
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault() {
|
||||
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeAndSchemaDefaults() {
|
||||
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
|
||||
SQLStorageConfiguration: schema.SQLStorageConfiguration{
|
||||
Host: "db1",
|
||||
|
@ -120,6 +120,30 @@ func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault()
|
|||
suite.Assert().Len(suite.validator.Errors(), 0)
|
||||
|
||||
suite.Assert().Equal("disable", suite.configuration.PostgreSQL.SSL.Mode)
|
||||
suite.Assert().Equal("public", suite.configuration.PostgreSQL.Schema)
|
||||
}
|
||||
|
||||
func (suite *StorageSuite) TestShouldValidatePostgresDefaultsDontOverrideConfiguration() {
|
||||
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
|
||||
SQLStorageConfiguration: schema.SQLStorageConfiguration{
|
||||
Host: "db1",
|
||||
Username: "myuser",
|
||||
Password: "pass",
|
||||
Database: "database",
|
||||
},
|
||||
Schema: "authelia",
|
||||
SSL: schema.PostgreSQLSSLStorageConfiguration{
|
||||
Mode: "require",
|
||||
},
|
||||
}
|
||||
|
||||
ValidateStorage(suite.configuration, suite.validator)
|
||||
|
||||
suite.Assert().Len(suite.validator.Warnings(), 0)
|
||||
suite.Assert().Len(suite.validator.Errors(), 0)
|
||||
|
||||
suite.Assert().Equal("require", suite.configuration.PostgreSQL.SSL.Mode)
|
||||
suite.Assert().Equal("authelia", suite.configuration.PostgreSQL.Schema)
|
||||
}
|
||||
|
||||
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeMustBeValid() {
|
||||
|
|
|
@ -79,6 +79,7 @@ type SQLProvider struct {
|
|||
key [32]byte
|
||||
name string
|
||||
driverName string
|
||||
schema string
|
||||
config *schema.Configuration
|
||||
errOpen error
|
||||
|
||||
|
|
|
@ -57,6 +57,8 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
|
|||
provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration)
|
||||
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
|
||||
|
||||
provider.schema = config.Storage.PostgreSQL.Schema
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
|
@ -66,20 +68,14 @@ func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dat
|
|||
fmt.Sprintf("user='%s'", config.Username),
|
||||
fmt.Sprintf("password='%s'", config.Password),
|
||||
fmt.Sprintf("dbname=%s", config.Database),
|
||||
fmt.Sprintf("search_path=%s", config.Schema),
|
||||
fmt.Sprintf("sslmode=%s", config.SSL.Mode),
|
||||
}
|
||||
|
||||
if config.Port > 0 {
|
||||
args = append(args, fmt.Sprintf("port=%d", config.Port))
|
||||
}
|
||||
|
||||
if config.Schema != "" {
|
||||
args = append(args, fmt.Sprintf("search_path=%s", config.Schema))
|
||||
}
|
||||
|
||||
if config.SSL.Mode != "" {
|
||||
args = append(args, fmt.Sprintf("sslmode=%s", config.SSL.Mode))
|
||||
}
|
||||
|
||||
if config.SSL.RootCertificate != "" {
|
||||
args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate))
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ const (
|
|||
queryPostgreSelectExistingTables = `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_type = 'BASE TABLE' AND table_schema = 'public';`
|
||||
WHERE table_type = 'BASE TABLE' AND table_schema = $1;`
|
||||
|
||||
querySQLiteSelectExistingTables = `
|
||||
SELECT name
|
||||
|
|
|
@ -7,13 +7,23 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// SchemaTables returns a list of tables.
|
||||
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
|
||||
var rows *sqlx.Rows
|
||||
|
||||
switch p.schema {
|
||||
case "":
|
||||
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
|
||||
default:
|
||||
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables, p.schema)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return tables, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user