mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
fix(storage): postgres schema hardcoded for tables query (#2667)
This removes the hardcoded schema value from the PostgreSQL existing tables query, making it compatible with the new schema config option.
This commit is contained in:
parent
ec1cc3d64e
commit
95a5e326a5
|
@ -54,3 +54,11 @@ type StorageConfiguration struct {
|
||||||
var DefaultSQLStorageConfiguration = SQLStorageConfiguration{
|
var DefaultSQLStorageConfiguration = SQLStorageConfiguration{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultPostgreSQLStorageConfiguration represents the default PostgreSQL configuration.
|
||||||
|
var DefaultPostgreSQLStorageConfiguration = PostgreSQLStorageConfiguration{
|
||||||
|
Schema: "public",
|
||||||
|
SSL: PostgreSQLSSLStorageConfiguration{
|
||||||
|
Mode: "disable",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
|
@ -52,13 +52,17 @@ func validateSQLConfiguration(configuration *schema.SQLStorageConfiguration, val
|
||||||
func validatePostgreSQLConfiguration(configuration *schema.PostgreSQLStorageConfiguration, validator *schema.StructValidator) {
|
func validatePostgreSQLConfiguration(configuration *schema.PostgreSQLStorageConfiguration, validator *schema.StructValidator) {
|
||||||
validateSQLConfiguration(&configuration.SQLStorageConfiguration, validator, "postgres")
|
validateSQLConfiguration(&configuration.SQLStorageConfiguration, validator, "postgres")
|
||||||
|
|
||||||
|
if configuration.Schema == "" {
|
||||||
|
configuration.Schema = schema.DefaultPostgreSQLStorageConfiguration.Schema
|
||||||
|
}
|
||||||
|
|
||||||
// Deprecated. TODO: Remove in v4.36.0.
|
// Deprecated. TODO: Remove in v4.36.0.
|
||||||
if configuration.SSLMode != "" && configuration.SSL.Mode == "" {
|
if configuration.SSLMode != "" && configuration.SSL.Mode == "" {
|
||||||
configuration.SSL.Mode = configuration.SSLMode
|
configuration.SSL.Mode = configuration.SSLMode
|
||||||
}
|
}
|
||||||
|
|
||||||
if configuration.SSL.Mode == "" {
|
if configuration.SSL.Mode == "" {
|
||||||
configuration.SSL.Mode = testModeDisabled
|
configuration.SSL.Mode = schema.DefaultPostgreSQLStorageConfiguration.SSL.Mode
|
||||||
} else if !utils.IsStringInSlice(configuration.SSL.Mode, storagePostgreSQLValidSSLModes) {
|
} else if !utils.IsStringInSlice(configuration.SSL.Mode, storagePostgreSQLValidSSLModes) {
|
||||||
validator.Push(fmt.Errorf(errFmtStoragePostgreSQLInvalidSSLMode, configuration.SSL.Mode, strings.Join(storagePostgreSQLValidSSLModes, "', '")))
|
validator.Push(fmt.Errorf(errFmtStoragePostgreSQLInvalidSSLMode, configuration.SSL.Mode, strings.Join(storagePostgreSQLValidSSLModes, "', '")))
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,7 +104,7 @@ func (suite *StorageSuite) TestShouldValidatePostgreSQLHostUsernamePasswordAndDa
|
||||||
suite.Assert().Len(suite.validator.Errors(), 0)
|
suite.Assert().Len(suite.validator.Errors(), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault() {
|
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeAndSchemaDefaults() {
|
||||||
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
|
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
|
||||||
SQLStorageConfiguration: schema.SQLStorageConfiguration{
|
SQLStorageConfiguration: schema.SQLStorageConfiguration{
|
||||||
Host: "db1",
|
Host: "db1",
|
||||||
|
@ -120,6 +120,30 @@ func (suite *StorageSuite) TestShouldValidatePostgresSSLModeIsDisableByDefault()
|
||||||
suite.Assert().Len(suite.validator.Errors(), 0)
|
suite.Assert().Len(suite.validator.Errors(), 0)
|
||||||
|
|
||||||
suite.Assert().Equal("disable", suite.configuration.PostgreSQL.SSL.Mode)
|
suite.Assert().Equal("disable", suite.configuration.PostgreSQL.SSL.Mode)
|
||||||
|
suite.Assert().Equal("public", suite.configuration.PostgreSQL.Schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *StorageSuite) TestShouldValidatePostgresDefaultsDontOverrideConfiguration() {
|
||||||
|
suite.configuration.PostgreSQL = &schema.PostgreSQLStorageConfiguration{
|
||||||
|
SQLStorageConfiguration: schema.SQLStorageConfiguration{
|
||||||
|
Host: "db1",
|
||||||
|
Username: "myuser",
|
||||||
|
Password: "pass",
|
||||||
|
Database: "database",
|
||||||
|
},
|
||||||
|
Schema: "authelia",
|
||||||
|
SSL: schema.PostgreSQLSSLStorageConfiguration{
|
||||||
|
Mode: "require",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ValidateStorage(suite.configuration, suite.validator)
|
||||||
|
|
||||||
|
suite.Assert().Len(suite.validator.Warnings(), 0)
|
||||||
|
suite.Assert().Len(suite.validator.Errors(), 0)
|
||||||
|
|
||||||
|
suite.Assert().Equal("require", suite.configuration.PostgreSQL.SSL.Mode)
|
||||||
|
suite.Assert().Equal("authelia", suite.configuration.PostgreSQL.Schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeMustBeValid() {
|
func (suite *StorageSuite) TestShouldValidatePostgresSSLModeMustBeValid() {
|
||||||
|
|
|
@ -79,6 +79,7 @@ type SQLProvider struct {
|
||||||
key [32]byte
|
key [32]byte
|
||||||
name string
|
name string
|
||||||
driverName string
|
driverName string
|
||||||
|
schema string
|
||||||
config *schema.Configuration
|
config *schema.Configuration
|
||||||
errOpen error
|
errOpen error
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,8 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
|
||||||
provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration)
|
provider.sqlSelectLatestMigration = provider.db.Rebind(provider.sqlSelectLatestMigration)
|
||||||
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
|
provider.sqlSelectEncryptionValue = provider.db.Rebind(provider.sqlSelectEncryptionValue)
|
||||||
|
|
||||||
|
provider.schema = config.Storage.PostgreSQL.Schema
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,20 +68,14 @@ func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dat
|
||||||
fmt.Sprintf("user='%s'", config.Username),
|
fmt.Sprintf("user='%s'", config.Username),
|
||||||
fmt.Sprintf("password='%s'", config.Password),
|
fmt.Sprintf("password='%s'", config.Password),
|
||||||
fmt.Sprintf("dbname=%s", config.Database),
|
fmt.Sprintf("dbname=%s", config.Database),
|
||||||
|
fmt.Sprintf("search_path=%s", config.Schema),
|
||||||
|
fmt.Sprintf("sslmode=%s", config.SSL.Mode),
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Port > 0 {
|
if config.Port > 0 {
|
||||||
args = append(args, fmt.Sprintf("port=%d", config.Port))
|
args = append(args, fmt.Sprintf("port=%d", config.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.Schema != "" {
|
|
||||||
args = append(args, fmt.Sprintf("search_path=%s", config.Schema))
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.SSL.Mode != "" {
|
|
||||||
args = append(args, fmt.Sprintf("sslmode=%s", config.SSL.Mode))
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.SSL.RootCertificate != "" {
|
if config.SSL.RootCertificate != "" {
|
||||||
args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate))
|
args = append(args, fmt.Sprintf("sslrootcert=%s", config.SSL.RootCertificate))
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ const (
|
||||||
queryPostgreSelectExistingTables = `
|
queryPostgreSelectExistingTables = `
|
||||||
SELECT table_name
|
SELECT table_name
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
WHERE table_type = 'BASE TABLE' AND table_schema = 'public';`
|
WHERE table_type = 'BASE TABLE' AND table_schema = $1;`
|
||||||
|
|
||||||
querySQLiteSelectExistingTables = `
|
querySQLiteSelectExistingTables = `
|
||||||
SELECT name
|
SELECT name
|
||||||
|
|
|
@ -7,13 +7,23 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
|
||||||
"github.com/authelia/authelia/v4/internal/models"
|
"github.com/authelia/authelia/v4/internal/models"
|
||||||
"github.com/authelia/authelia/v4/internal/utils"
|
"github.com/authelia/authelia/v4/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SchemaTables returns a list of tables.
|
// SchemaTables returns a list of tables.
|
||||||
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
|
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
|
||||||
rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
|
var rows *sqlx.Rows
|
||||||
|
|
||||||
|
switch p.schema {
|
||||||
|
case "":
|
||||||
|
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
|
||||||
|
default:
|
||||||
|
rows, err = p.db.QueryxContext(ctx, p.sqlSelectExistingTables, p.schema)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tables, err
|
return tables, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user