package storage import ( "database/sql" "encoding/base64" "fmt" "time" "github.com/clems4ever/authelia/models" ) // SQLProvider is a storage provider persisting data in a SQL database. type SQLProvider struct { db *sql.DB sqlGetPreferencesByUsername string sqlUpsertSecondFactorPreference string sqlTestIdentityVerificationTokenExistence string sqlInsertIdentityVerificationToken string sqlDeleteIdentityVerificationToken string sqlGetTOTPSecretByUsername string sqlUpsertTOTPSecret string sqlGetU2FDeviceHandleByUsername string sqlUpsertU2FDeviceHandle string sqlInsertAuthenticationLog string sqlGetLatestAuthenticationLogs string } func (p *SQLProvider) initialize(db *sql.DB) error { p.db = db _, err := db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, second_factor_method VARCHAR(10))", preferencesTableName)) if err != nil { return err } _, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (token VARCHAR(512))", identityVerificationTokensTableName)) if err != nil { return err } _, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, secret VARCHAR(64))", totpSecretsTableName)) if err != nil { return err } // keyHandle and publicKey are stored in base64 format _, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100) PRIMARY KEY, keyHandle TEXT, publicKey TEXT)", u2fDeviceHandlesTableName)) if err != nil { return err } _, err = db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (username VARCHAR(100), successful BOOL, time INTEGER)", authenticationLogsTableName)) if err != nil { return err } _, err = db.Exec(fmt.Sprintf("CREATE INDEX IF NOT EXISTS time ON %s (time);", authenticationLogsTableName)) if err != nil { return err } _, err = db.Exec(fmt.Sprintf("CREATE INDEX IF NOT EXISTS username ON %s (username);", authenticationLogsTableName)) if err != nil { return err } return nil } // LoadPrefered2FAMethod load the prefered method for 2FA from sqlite db. func (p *SQLProvider) LoadPrefered2FAMethod(username string) (string, error) { rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username) defer rows.Close() if err != nil { return "", err } if rows.Next() { var method string err = rows.Scan(&method) if err != nil { return "", err } return method, nil } return "", nil } // SavePrefered2FAMethod save the prefered method for 2FA in sqlite db. func (p *SQLProvider) SavePrefered2FAMethod(username string, method string) error { _, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method) return err } // FindIdentityVerificationToken look for an identity verification token in DB. func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) { var found bool err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found) if err != nil { return false, err } return found, nil } // SaveIdentityVerificationToken save an identity verification token in DB. func (p *SQLProvider) SaveIdentityVerificationToken(token string) error { _, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token) return err } // RemoveIdentityVerificationToken remove an identity verification token from the DB. func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error { _, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token) return err } // SaveTOTPSecret save a TOTP secret of a given user. func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error { _, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret) return err } // LoadTOTPSecret load a TOTP secret given a username. func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) { var secret string if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil { if err == sql.ErrNoRows { return "", nil } return "", err } return secret, nil } // SaveU2FDeviceHandle save a registered U2F device registration blob. func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error { _, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle, username, base64.StdEncoding.EncodeToString(keyHandle), base64.StdEncoding.EncodeToString(publicKey)) return err } // LoadU2FDeviceHandle load a U2F device registration blob for a given username. func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) { var keyHandleBase64, publicKeyBase64 string if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil { if err == sql.ErrNoRows { return nil, nil, ErrNoU2FDeviceHandle } return nil, nil, err } keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64) if err != nil { return nil, nil, err } publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64) if err != nil { return nil, nil, err } return keyHandle, publicKey, nil } // AppendAuthenticationLog append a mark to the authentication log. func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error { _, err := p.db.Exec(p.sqlInsertAuthenticationLog, attempt.Username, attempt.Successful, attempt.Time.Unix()) return err } // LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log. func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) { rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username) if err != nil { return nil, err } attempts := make([]models.AuthenticationAttempt, 0, 10) for rows.Next() { attempt := models.AuthenticationAttempt{ Username: username, } var t int64 err = rows.Scan(&attempt.Successful, &t) attempt.Time = time.Unix(t, 0) if err != nil { return nil, err } attempts = append(attempts, attempt) } return attempts, nil }