2020-07-16 12:56:08 +07:00
|
|
|
package storage
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql/driver"
|
|
|
|
"encoding/base64"
|
|
|
|
"fmt"
|
|
|
|
"sort"
|
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
|
|
|
"github.com/authelia/authelia/internal/authentication"
|
|
|
|
"github.com/authelia/authelia/internal/models"
|
|
|
|
)
|
|
|
|
|
|
|
|
const currentSchemaMockSchemaVersion = "1"
|
|
|
|
|
|
|
|
func TestSQLInitializeDatabase(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
rows := sqlmock.NewRows([]string{"name"})
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(rows)
|
|
|
|
|
|
|
|
mock.ExpectBegin()
|
|
|
|
|
|
|
|
keys := make([]string, 0, len(sqlUpgradeCreateTableStatements[1]))
|
|
|
|
for k := range sqlUpgradeCreateTableStatements[1] {
|
|
|
|
keys = append(keys, k)
|
|
|
|
}
|
|
|
|
|
|
|
|
sort.Strings(keys)
|
|
|
|
|
|
|
|
for _, table := range keys {
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("CREATE TABLE %s .*", table)).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
|
|
|
}
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
|
|
|
|
WithArgs("schema", "version", "1").
|
|
|
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
|
|
|
|
|
|
|
mock.ExpectCommit()
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestSQLUpgradeDatabase(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
|
|
|
AddRow(userPreferencesTableName).
|
|
|
|
AddRow(identityVerificationTokensTableName).
|
|
|
|
AddRow(totpSecretsTableName).
|
|
|
|
AddRow(u2fDeviceHandlesTableName).
|
|
|
|
AddRow(authenticationLogsTableName))
|
|
|
|
|
|
|
|
mock.ExpectBegin()
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("CREATE TABLE %s .*", configTableName)).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 0))
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
|
|
|
|
WithArgs("schema", "version", "1").
|
|
|
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
|
|
|
|
|
|
|
mock.ExpectCommit()
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestSQLProviderMethodsAuthenticationLogs(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
|
|
|
AddRow(userPreferencesTableName).
|
|
|
|
AddRow(identityVerificationTokensTableName).
|
|
|
|
AddRow(totpSecretsTableName).
|
|
|
|
AddRow(u2fDeviceHandlesTableName).
|
|
|
|
AddRow(authenticationLogsTableName).
|
|
|
|
AddRow(configTableName))
|
|
|
|
|
|
|
|
args := []driver.Value{"schema", "version"}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
|
|
|
AddRow("1"))
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
attempts := []models.AuthenticationAttempt{
|
|
|
|
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880001, 0)},
|
|
|
|
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880002, 0)},
|
|
|
|
{Username: unitTestUser, Successful: false, Time: time.Unix(1577880003, 0)},
|
|
|
|
}
|
|
|
|
|
|
|
|
rows := sqlmock.NewRows([]string{"successful", "time"})
|
|
|
|
|
|
|
|
for id, attempt := range attempts {
|
|
|
|
args = []driver.Value{attempt.Username, attempt.Successful, attempt.Time.Unix()}
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("INSERT INTO %s \\(username, successful, time\\) VALUES \\(\\?, \\?, \\?\\)", authenticationLogsTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnResult(sqlmock.NewResult(int64(id), 1))
|
|
|
|
|
|
|
|
err := provider.AppendAuthenticationLog(attempt)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
rows.AddRow(attempt.Successful, attempt.Time.Unix())
|
|
|
|
}
|
|
|
|
|
|
|
|
args = []driver.Value{1577880000, unitTestUser}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(rows)
|
|
|
|
|
|
|
|
after := time.Unix(1577880000, 0)
|
|
|
|
results, err := provider.LoadLatestAuthenticationLogs(unitTestUser, after)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
require.Len(t, results, 3)
|
|
|
|
assert.Equal(t, unitTestUser, results[0].Username)
|
|
|
|
assert.Equal(t, true, results[0].Successful)
|
|
|
|
assert.Equal(t, time.Unix(1577880001, 0), results[0].Time)
|
|
|
|
assert.Equal(t, unitTestUser, results[1].Username)
|
|
|
|
assert.Equal(t, true, results[1].Successful)
|
|
|
|
assert.Equal(t, time.Unix(1577880002, 0), results[1].Time)
|
|
|
|
assert.Equal(t, unitTestUser, results[2].Username)
|
|
|
|
assert.Equal(t, false, results[2].Successful)
|
|
|
|
assert.Equal(t, time.Unix(1577880003, 0), results[2].Time)
|
|
|
|
|
|
|
|
// Test Blank Rows.
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"successful", "time"}))
|
|
|
|
|
|
|
|
results, err = provider.LoadLatestAuthenticationLogs(unitTestUser, after)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 0)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestSQLProviderMethodsPreferred(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
|
|
|
AddRow(userPreferencesTableName).
|
|
|
|
AddRow(identityVerificationTokensTableName).
|
|
|
|
AddRow(totpSecretsTableName).
|
|
|
|
AddRow(u2fDeviceHandlesTableName).
|
|
|
|
AddRow(authenticationLogsTableName).
|
|
|
|
AddRow(configTableName))
|
|
|
|
|
|
|
|
args := []driver.Value{"schema", "version"}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
|
|
|
AddRow(currentSchemaMockSchemaVersion))
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("REPLACE INTO %s \\(username, second_factor_method\\) VALUES \\(\\?, \\?\\)", userPreferencesTableName)).
|
|
|
|
WithArgs(unitTestUser, authentication.TOTP).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
|
|
|
|
err = provider.SavePreferred2FAMethod(unitTestUser, authentication.TOTP)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
|
|
|
|
WithArgs(unitTestUser).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}).AddRow(authentication.TOTP))
|
|
|
|
|
|
|
|
method, err := provider.LoadPreferred2FAMethod(unitTestUser)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, authentication.TOTP, method)
|
|
|
|
|
|
|
|
// Test Blank Rows.
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
|
|
|
|
WithArgs(unitTestUser).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}))
|
|
|
|
|
|
|
|
method, err = provider.LoadPreferred2FAMethod(unitTestUser)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, "", method)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestSQLProviderMethodsTOTP(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
|
|
|
AddRow(userPreferencesTableName).
|
|
|
|
AddRow(identityVerificationTokensTableName).
|
|
|
|
AddRow(totpSecretsTableName).
|
|
|
|
AddRow(u2fDeviceHandlesTableName).
|
|
|
|
AddRow(authenticationLogsTableName).
|
|
|
|
AddRow(configTableName))
|
|
|
|
|
|
|
|
args := []driver.Value{"schema", "version"}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
|
|
|
AddRow(currentSchemaMockSchemaVersion))
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
pretendSecret := "abc123"
|
|
|
|
args = []driver.Value{unitTestUser, pretendSecret}
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("REPLACE INTO %s \\(username, secret\\) VALUES \\(\\?, \\?\\)", totpSecretsTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
|
|
|
|
err = provider.SaveTOTPSecret(unitTestUser, pretendSecret)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
args = []driver.Value{unitTestUser}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"secret"}).AddRow(pretendSecret))
|
|
|
|
|
|
|
|
secret, err := provider.LoadTOTPSecret(unitTestUser)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, pretendSecret, secret)
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("DELETE FROM %s WHERE username=\\?", totpSecretsTableName)).
|
|
|
|
WithArgs(unitTestUser).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
|
|
|
|
err = provider.DeleteTOTPSecret(unitTestUser)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"secret"}))
|
|
|
|
|
2020-09-18 19:05:43 +07:00
|
|
|
// Test Blank Rows
|
2020-07-16 12:56:08 +07:00
|
|
|
secret, err = provider.LoadTOTPSecret(unitTestUser)
|
|
|
|
assert.EqualError(t, err, "No TOTP secret registered")
|
|
|
|
assert.Equal(t, "", secret)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestSQLProviderMethodsU2F(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
|
|
|
AddRow(userPreferencesTableName).
|
|
|
|
AddRow(identityVerificationTokensTableName).
|
|
|
|
AddRow(totpSecretsTableName).
|
|
|
|
AddRow(u2fDeviceHandlesTableName).
|
|
|
|
AddRow(authenticationLogsTableName).
|
|
|
|
AddRow(configTableName))
|
|
|
|
|
|
|
|
args := []driver.Value{"schema", "version"}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
|
|
|
AddRow(currentSchemaMockSchemaVersion))
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
pretendKeyHandle := []byte("abc")
|
|
|
|
pretendPublicKey := []byte("123")
|
|
|
|
pretendKeyHandleB64 := base64.StdEncoding.EncodeToString(pretendKeyHandle)
|
|
|
|
pretendPublicKeyB64 := base64.StdEncoding.EncodeToString(pretendPublicKey)
|
|
|
|
|
|
|
|
args = []driver.Value{unitTestUser, pretendKeyHandleB64, pretendPublicKeyB64}
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("REPLACE INTO %s \\(username, keyHandle, publicKey\\) VALUES \\(\\?, \\?, \\?\\)", u2fDeviceHandlesTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
|
|
|
|
err = provider.SaveU2FDeviceHandle(unitTestUser, pretendKeyHandle, pretendPublicKey)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
args = []driver.Value{unitTestUser}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}).
|
|
|
|
AddRow(pretendKeyHandleB64, pretendPublicKeyB64))
|
|
|
|
|
|
|
|
keyHandle, publicKey, err := provider.LoadU2FDeviceHandle(unitTestUser)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, pretendKeyHandle, keyHandle)
|
|
|
|
assert.Equal(t, pretendPublicKey, publicKey)
|
|
|
|
|
|
|
|
// Test Blank Rows.
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}))
|
|
|
|
|
|
|
|
keyHandle, publicKey, err = provider.LoadU2FDeviceHandle(unitTestUser)
|
|
|
|
assert.EqualError(t, err, "No U2F device handle found")
|
|
|
|
assert.Equal(t, []byte(nil), keyHandle)
|
|
|
|
assert.Equal(t, []byte(nil), publicKey)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestSQLProviderMethodsIdentityVerificationTokens(t *testing.T) {
|
|
|
|
provider, mock := NewSQLMockProvider()
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
"SELECT name FROM sqlite_master WHERE type='table'").
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
|
|
|
AddRow(userPreferencesTableName).
|
|
|
|
AddRow(identityVerificationTokensTableName).
|
|
|
|
AddRow(totpSecretsTableName).
|
|
|
|
AddRow(u2fDeviceHandlesTableName).
|
|
|
|
AddRow(authenticationLogsTableName).
|
|
|
|
AddRow(configTableName))
|
|
|
|
|
|
|
|
args := []driver.Value{"schema", "version"}
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
|
|
|
WithArgs(args...).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
|
|
|
AddRow(currentSchemaMockSchemaVersion))
|
|
|
|
|
|
|
|
err := provider.initialize(provider.db)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
fakeIdentityVerificationToken := "abc"
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("INSERT INTO %s \\(token\\) VALUES \\(\\?\\)", identityVerificationTokensTableName)).
|
|
|
|
WithArgs(fakeIdentityVerificationToken).
|
|
|
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
|
|
|
|
|
|
|
err = provider.SaveIdentityVerificationToken(fakeIdentityVerificationToken)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
|
|
|
|
WithArgs(fakeIdentityVerificationToken).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
|
|
|
|
AddRow(true))
|
|
|
|
|
|
|
|
valid, err := provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.True(t, valid)
|
|
|
|
|
|
|
|
mock.ExpectExec(
|
|
|
|
fmt.Sprintf("DELETE FROM %s WHERE token=\\?", identityVerificationTokensTableName)).
|
|
|
|
WithArgs(fakeIdentityVerificationToken).
|
|
|
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
|
|
|
|
|
|
|
err = provider.RemoveIdentityVerificationToken(fakeIdentityVerificationToken)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
|
|
|
|
mock.ExpectQuery(
|
|
|
|
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
|
|
|
|
WithArgs(fakeIdentityVerificationToken).
|
|
|
|
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
|
|
|
|
AddRow(false))
|
|
|
|
|
|
|
|
valid, err = provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
assert.False(t, valid)
|
|
|
|
}
|