feat(storage): primary key for all tables and general qol refactoring (#2431)

This is a massive overhaul to the SQL Storage for Authelia. It facilitates a whole heap of utility commands to help manage the database, primary keys, ensures all database requests use a context for cancellations, and paves the way for a few other PR's which improve the database.

Fixes #1337
This commit is contained in:
James Elliott 2021-11-23 20:45:38 +11:00 committed by GitHub
parent 884dc99083
commit 3695aa8140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
90 changed files with 3602 additions and 1738 deletions

View File

@ -1,14 +1,13 @@
package main
import (
"os"
"github.com/authelia/authelia/v4/internal/commands"
"github.com/authelia/authelia/v4/internal/logging"
)
func main() {
logger := logging.Logger()
if err := commands.NewRootCmd().Execute(); err != nil {
logger.Fatal(err)
os.Exit(1)
}
}

View File

@ -264,4 +264,5 @@ In versions <= `4.24.0` not including the `username_attribute` placeholder will
and will result in session resets when the refresh interval has expired, default of 5 minutes.
[LDAP GeneralizedTime]: https://ldapwiki.com/wiki/GeneralizedTime
[username attribute]: #username_attribute
[TechNet wiki]: https://social.technet.microsoft.com/wiki/contents/articles/5392.active-directory-ldap-syntax-filters.aspx

View File

@ -0,0 +1,24 @@
---
layout: default
title: Migrations
parent: Storage Backends
grand_parent: Configuration
nav_order: 5
---
Storage migrations are important for keeping your database compatible with Authelia. Authelia will automatically upgrade
your schema on startup. However, if you wish to use an older version of Authelia you may be required to manually
downgrade your schema with a version of Authelia that supports your current schema.
## Schema Version to Authelia Version map
This table contains a list of schema versions and the corresponding release of Authelia that shipped with that version.
This means all Authelia versions between two schema versions use the first schema version.
For example for version pre1, it is used for all versions between it and the version 1 schema, so 4.0.0 to 4.32.2. In
this instance if you wanted to downgrade to pre1 you would need to use an Authelia binary with version 4.33.0 or higher.
|Schema Version|Authelia Version|Notes |
|:------------:|:--------------:|:----------------------------------------------------------:|
|pre1 |4.0.0 |Downgrading to this version requires you use the --pre1 flag|
|1 |4.33.0 | |

View File

@ -99,3 +99,36 @@ This section has the required status of the value and must be one of `yes`, `no`
depends on other configuration options. If it's situational the situational usage should be documented. This is
immediately followed by the styles `.label`, `.label-config`, and a traffic lights color label, i.e. if yes `.label-red`,
if no `.label-green`, or if situational `.label-yellow`.
### Storage
This section outlines some rules for storage contributions. Including but not limited to migrations, schema rules, etc.
#### Migrations
All migrations must have an up and down migration, preferably idempotent.
All migrations must be named in the following format:
```text
V<version>.<name>.<engine>.<direction>.sql
```
##### version
A 4 digit version number, should be in sequential order.
##### name
A name containing alphanumeric characters, underscores (treated as spaces), hyphens, and no spaces.
##### engine
The target engine for the migration, options are all, mysql, postgres, and sqlite.
#### Primary Key
All tables must have a primary key. This primary key must be an integer with auto increment enabled, or in the case of
PostgreSQL a serial type.
#### Table/Column Names
Table and Column names must be in snake case format. This means they must have only lowercase letters, and have words
seperated by underscores. The reasoning for this is that some database engines ignore case by default and this makes it
easy to be consistent with the casing.
#### Context
All database methods should include the context attribute so that database requests that are no longer needed are
terminated appropriately.

84
go.mod
View File

@ -1,15 +1,12 @@
module github.com/authelia/authelia/v4
go 1.16
go 1.17
require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4
github.com/Workiva/go-datastructures v1.0.53
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/deckarep/golang-set v1.7.1
github.com/duosecurity/duo_api_golang v0.0.0-20211027140842-72da735c6f15
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
github.com/fasthttp/router v1.4.4
github.com/fasthttp/session/v2 v2.4.4
github.com/go-ldap/ldap/v3 v3.4.1
@ -19,6 +16,7 @@ require (
github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0
github.com/jackc/pgx/v4 v4.14.0
github.com/jmoiron/sqlx v1.3.1
github.com/knadh/koanf v1.3.2
github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/mitchellh/mapstructure v1.4.2
@ -30,16 +28,88 @@ require (
github.com/simia-tech/crypt v0.5.0
github.com/sirupsen/logrus v1.8.1
github.com/spf13/cobra v1.2.1
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.7.0
github.com/tstranex/u2f v1.0.0
github.com/valyala/fasthttp v1.31.0
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b // indirect
golang.org/x/text v0.3.7
gopkg.in/square/go-jose.v2 v2.6.0
gopkg.in/yaml.v2 v2.4.0
)
require (
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect
github.com/andybalholm/brotli v1.0.2 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgraph-io/ristretto v0.1.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
github.com/fsnotify/fsnotify v1.4.9 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect
github.com/go-redis/redis/v8 v8.11.4 // indirect
github.com/gobuffalo/pop/v5 v5.3.3 // indirect
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.10.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.2.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/pgtype v1.9.0 // indirect
github.com/jandelgado/gcov2lcov v1.0.4 // indirect
github.com/klauspost/compress v1.13.4 // indirect
github.com/magiconair/properties v1.8.5 // indirect
github.com/mattn/goveralls v0.0.6 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/ory/go-acc v0.2.6 // indirect
github.com/ory/go-convenience v0.1.0 // indirect
github.com/ory/viper v1.7.5 // indirect
github.com/ory/x v0.0.288 // indirect
github.com/pborman/uuid v1.2.1 // indirect
github.com/pelletier/go-toml v1.9.3 // indirect
github.com/philhofer/fwd v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/savsgio/dictpool v0.0.0-20210921080634-84324d0689d7 // indirect
github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4 // indirect
github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 // indirect
github.com/spf13/afero v1.6.0 // indirect
github.com/spf13/cast v1.3.2-0.20200723214538-8d17101741c8 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/sqs/goreturns v0.0.0-20181028201513-538ac6014518 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/tinylib/msgp v1.1.6 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/ysmood/goob v0.3.0 // indirect
github.com/ysmood/gson v0.6.4 // indirect
github.com/ysmood/leakless v0.7.0 // indirect
go.opentelemetry.io/contrib v0.20.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.20.0 // indirect
go.opentelemetry.io/otel v0.20.0 // indirect
go.opentelemetry.io/otel/metric v0.20.0 // indirect
go.opentelemetry.io/otel/trace v0.20.0 // indirect
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
golang.org/x/mod v0.4.2 // indirect
golang.org/x/net v0.0.0-20210510120150-4163338589ed // indirect
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect
golang.org/x/tools v0.1.2 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c // indirect
google.golang.org/grpc v1.38.0 // indirect
google.golang.org/protobuf v1.26.0 // indirect
gopkg.in/ini.v1 v1.62.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
replace (
github.com/mattn/go-sqlite3 v2.0.3+incompatible => github.com/mattn/go-sqlite3 v1.14.8
github.com/mattn/go-sqlite3 v2.0.3+incompatible => github.com/mattn/go-sqlite3 v1.14.9
github.com/tidwall/gjson => github.com/tidwall/gjson v1.11.0
)

17
go.sum
View File

@ -45,8 +45,6 @@ github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzS
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/DataDog/datadog-go v4.0.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4 h1:vdT7QwBhJJEVNFMBNhRSFDRCB6O16T28VhvqRgqFyn8=
github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4/go.mod h1:SvXOG8ElV28oAiG9zv91SDe5+9PfIr7PPccpr8YyXNs=
@ -66,8 +64,6 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
github.com/Workiva/go-datastructures v1.0.53 h1:J6Y/52yX10Xc5JjXmGtWoSSxs3mZnGSaq37xZZh7Yig=
github.com/Workiva/go-datastructures v1.0.53/go.mod h1:1yZL+zfsztete+ePzZz/Zb1/t5BnDuE2Ya2MMGhzP6A=
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
github.com/ajg/form v0.0.0-20160822230020-523a5da1a92f/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
@ -91,8 +87,8 @@ github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:l
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg=
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ=
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU=
github.com/aws/aws-sdk-go v1.23.19/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
@ -825,6 +821,7 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jmoiron/sqlx v0.0.0-20180614180643-0dae4fefe7c0/go.mod h1:IiEW3SEiiErVyFdH8NTuWjSifiEQKUoyK3LNqr2kCHU=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/jmoiron/sqlx v1.3.1 h1:aLN7YINNZ7cYOPK3QC83dbM6KT0NMqVMw961TqrejlE=
github.com/jmoiron/sqlx v1.3.1/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ=
github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901/go.mod h1:Z86h9688Y0wesXCyonoVr47MasHilkuLMqGhRZ4Hpak=
github.com/joho/godotenv v1.2.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
@ -948,8 +945,9 @@ github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOq
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/mattn/goveralls v0.0.6 h1:cr8Y0VMo/MnEZBjxNN/vh6G90SZ7IMb6lms1dzMoO+Y=
github.com/mattn/goveralls v0.0.6/go.mod h1:h8b4ow6FxSPMQHF6o2ve3qsclnffZjYTNEKmLesRwqw=
@ -1304,14 +1302,12 @@ github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7V
github.com/tidwall/sjson v1.1.5 h1:wsUceI/XDyZk3J1FUvuuYlK62zJv2HO2Pzb8A5EWdUE=
github.com/tidwall/sjson v1.1.5/go.mod h1:VuJzsZnTowhSxWdOgsAnb886i4AjEyTkk7tNtsL7EYE=
github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tinylib/msgp v1.1.5/go.mod h1:eQsjooMTnV42mHu917E26IogZ2930nFyBQdofk10Udg=
github.com/tinylib/msgp v1.1.6 h1:i+SbKraHhnrf9M5MYmvQhFnbLhAXSDWF8WWsuyRdocw=
github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tstranex/u2f v1.0.0 h1:HhJkSzDDlVSVIVt7pDJwCHQj67k7A5EeBgPmeD+pVsQ=
github.com/tstranex/u2f v1.0.0/go.mod h1:eahSLaqAS0zsIEv80+vXT7WanXs7MQQDg3j3wGBSayo=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q=
github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g=
github.com/uber/jaeger-client-go v2.15.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
github.com/uber/jaeger-client-go v2.22.1+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
@ -1681,9 +1677,8 @@ golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b h1:S7hKs0Flbq0bbc9xgYt4stIEG1zNDFqyrPwAX2Wj/sE=
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -9,7 +9,6 @@ import (
"sync"
"github.com/asaskevich/govalidator"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"github.com/authelia/authelia/v4/internal/configuration/schema"
@ -208,6 +207,6 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
}
// StartupCheck implements the startup check provider interface.
func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) {
func (p *FileUserProvider) StartupCheck() (err error) {
return nil
}

View File

@ -21,7 +21,7 @@ type LDAPUserProvider struct {
configuration schema.LDAPAuthenticationBackendConfiguration
tlsConfig *tls.Config
dialOpts []ldap.DialOpt
logger *logrus.Logger
log *logrus.Logger
connectionFactory LDAPConnectionFactory
disableResetPassword bool
@ -72,7 +72,7 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura
configuration: configuration,
tlsConfig: tlsConfig,
dialOpts: dialOpts,
logger: logging.Logger(),
log: logging.Logger(),
connectionFactory: factory,
disableResetPassword: disableResetPassword,
}
@ -148,7 +148,7 @@ func (p *LDAPUserProvider) resolveUsersFilter(inputUsername string) (filter stri
filter = strings.ReplaceAll(filter, ldapPlaceholderInput, p.ldapEscape(inputUsername))
}
p.logger.Tracef("Computed user filter is %s", filter)
p.log.Tracef("Computed user filter is %s", filter)
return filter
}
@ -223,7 +223,7 @@ func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ld
}
}
p.logger.Tracef("Computed groups filter is %s", filter)
p.log.Tracef("Computed groups filter is %s", filter)
return filter, nil
}
@ -262,7 +262,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error
for _, res := range sr.Entries {
if len(res.Attributes) == 0 {
p.logger.Warningf("No groups retrieved from LDAP for user %s", inputUsername)
p.log.Warningf("No groups retrieved from LDAP for user %s", inputUsername)
break
}

View File

@ -4,13 +4,12 @@ import (
"strings"
"github.com/go-ldap/ldap/v3"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// StartupCheck implements the startup check provider interface.
func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
func (p *LDAPUserProvider) StartupCheck() (err error) {
conn, err := p.connect(p.configuration.User, p.configuration.Password)
if err != nil {
return err
@ -33,7 +32,7 @@ func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
// Iterate the attribute values to see what the server supports.
for _, attr := range sr.Entries[0].Attributes {
if attr.Name == ldapSupportedExtensionAttribute {
logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
p.log.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
for _, oid := range attr.Values {
if oid == ldapOIDPasswdModifyExtension {
@ -48,7 +47,7 @@ func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
if !p.supportExtensionPasswdModify && !p.disableResetPassword &&
p.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
logger.Warn("Your LDAP server implementation may not support a method for password hashing " +
p.log.Warn("Your LDAP server implementation may not support a method for password hashing " +
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
"attribute when users reset their password via Authelia.")
}
@ -61,7 +60,7 @@ func (p *LDAPUserProvider) parseDynamicUsersConfiguration() {
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{mail_attribute}", p.configuration.MailAttribute)
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{display_name_attribute}", p.configuration.DisplayNameAttribute)
p.logger.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter)
p.log.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter)
p.usersAttributes = []string{
p.configuration.DisplayNameAttribute,
@ -75,13 +74,13 @@ func (p *LDAPUserProvider) parseDynamicUsersConfiguration() {
p.usersBaseDN = p.configuration.BaseDN
}
p.logger.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN)
p.log.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN)
if strings.Contains(p.configuration.UsersFilter, ldapPlaceholderInput) {
p.usersFilterReplacementInput = true
}
p.logger.Tracef("Detected user filter replacements that need to be resolved per lookup are: %s=%v",
p.log.Tracef("Detected user filter replacements that need to be resolved per lookup are: %s=%v",
ldapPlaceholderInput, p.usersFilterReplacementInput)
}
@ -96,7 +95,7 @@ func (p *LDAPUserProvider) parseDynamicGroupsConfiguration() {
p.groupsBaseDN = p.configuration.BaseDN
}
p.logger.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN)
p.log.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN)
if strings.Contains(p.configuration.GroupsFilter, ldapPlaceholderInput) {
p.groupsFilterReplacementInput = true
@ -110,5 +109,5 @@ func (p *LDAPUserProvider) parseDynamicGroupsConfiguration() {
p.groupsFilterReplacementDN = true
}
p.logger.Tracef("Detected group filter replacements that need to be resolved per lookup are: input=%v, username=%v, dn=%v", p.groupsFilterReplacementInput, p.groupsFilterReplacementUsername, p.groupsFilterReplacementDN)
p.log.Tracef("Detected group filter replacements that need to be resolved per lookup are: input=%v, username=%v, dn=%v", p.groupsFilterReplacementInput, p.groupsFilterReplacementUsername, p.groupsFilterReplacementDN)
}

View File

@ -12,7 +12,6 @@ import (
"golang.org/x/text/encoding/unicode"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
)
@ -216,7 +215,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
assert.NoError(t, err)
assert.True(t, ldapClient.supportExtensionPasswdModify)
@ -273,7 +272,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
assert.NoError(t, err)
assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -306,7 +305,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()).
Return(mockConn, errors.New("could not connect"))
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
assert.EqualError(t, err, "could not connect")
assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -351,7 +350,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
assert.EqualError(t, err, "could not perform the search")
assert.False(t, ldapClient.supportExtensionPasswdModify)
@ -755,7 +754,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password")
@ -862,7 +861,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password")
@ -966,7 +965,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
err := ldapClient.StartupCheck(logging.Logger())
err := ldapClient.StartupCheck()
require.NoError(t, err)
err = ldapClient.UpdatePassword("john", "password")

View File

@ -1,14 +1,15 @@
package authentication
import (
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/models"
)
// UserProvider is the interface for checking user password and
// gathering user details.
type UserProvider interface {
models.StartupCheck
CheckUserPassword(username string, password string) (valid bool, err error)
GetDetails(username string) (details *UserDetails, err error)
UpdatePassword(username string, newPassword string) (err error)
StartupCheck(logger *logrus.Logger) (err error)
}

View File

@ -75,3 +75,8 @@ PowerShell:
PS> authelia completion powershell > authelia.ps1
# and source this file from your PowerShell profile.
`
const (
storageMigrateDirectionUp = "up"
storageMigrateDirectionDown = "down"
)

View File

@ -0,0 +1,28 @@
package commands
import (
"errors"
"github.com/authelia/authelia/v4/internal/storage"
)
func getStorageProvider() (provider storage.Provider, err error) {
switch {
case config.Storage.PostgreSQL != nil:
provider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL)
case config.Storage.MySQL != nil:
provider = storage.NewMySQLProvider(*config.Storage.MySQL)
case config.Storage.Local != nil:
provider = storage.NewSQLiteProvider(config.Storage.Local.Path)
default:
return nil, errors.New("no storage provider configured")
}
if (config.Storage.MySQL != nil && config.Storage.PostgreSQL != nil) ||
(config.Storage.MySQL != nil && config.Storage.Local != nil) ||
(config.Storage.PostgreSQL != nil && config.Storage.Local != nil) {
return nil, errors.New("multiple storage providers are configured but should only configure one")
}
return provider, err
}

View File

@ -13,6 +13,7 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/notification"
"github.com/authelia/authelia/v4/internal/ntp"
"github.com/authelia/authelia/v4/internal/oidc"
@ -46,6 +47,7 @@ func NewRootCmd() (cmd *cobra.Command) {
newCompletionCmd(),
NewHashPasswordCmd(),
NewRSACmd(),
NewStorageCmd(),
newValidateConfigCmd(),
)
@ -101,9 +103,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
storageProvider = storage.NewMySQLProvider(*config.Storage.MySQL)
case config.Storage.Local != nil:
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
default:
// TODO: Add storage provider startup check and remove this.
errors = append(errors, fmt.Errorf("unrecognized storage provider"))
}
var (
@ -162,6 +161,12 @@ func doStartupChecks(config *schema.Configuration, providers *middlewares.Provid
err error
)
if err = doStartupCheck(logger, "storage", providers.StorageProvider, false); err != nil {
logger.Errorf("Failure running the storage provider startup check: %+v", err)
failures = append(failures, "storage")
}
if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil {
logger.Errorf("Failure running the user provider startup check: %+v", err)
@ -187,7 +192,7 @@ func doStartupChecks(config *schema.Configuration, providers *middlewares.Provid
}
}
func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) {
func doStartupCheck(logger *logrus.Logger, name string, provider models.StartupCheck, disabled bool) (err error) {
if disabled {
logger.Debugf("%s provider: startup check skipped as it is disabled", name)
return nil
@ -197,7 +202,7 @@ func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.Pro
return fmt.Errorf("unrecognized provider or it is not configured properly")
}
if err = provider.StartupCheck(logger); err != nil {
if err = provider.StartupCheck(); err != nil {
return err
}

View File

@ -0,0 +1,126 @@
package commands
import (
"github.com/spf13/cobra"
)
// NewStorageCmd returns a new storage *cobra.Command.
func NewStorageCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: "storage",
Short: "Manage the Authelia storage",
Args: cobra.NoArgs,
PersistentPreRunE: storagePersistentPreRunE,
}
cmd.PersistentFlags().StringSliceP("config", "c", []string{"config.yml"}, "configuration file to load for the storage migration")
cmd.PersistentFlags().String("sqlite.path", "", "the SQLite database path")
cmd.PersistentFlags().String("mysql.host", "", "the MySQL hostname")
cmd.PersistentFlags().Int("mysql.port", 3306, "the MySQL port")
cmd.PersistentFlags().String("mysql.database", "authelia", "the MySQL database name")
cmd.PersistentFlags().String("mysql.username", "authelia", "the MySQL username")
cmd.PersistentFlags().String("mysql.password", "", "the MySQL password")
cmd.PersistentFlags().String("postgres.host", "", "the PostgreSQL hostname")
cmd.PersistentFlags().Int("postgres.port", 5432, "the PostgreSQL port")
cmd.PersistentFlags().String("postgres.database", "authelia", "the PostgreSQL database name")
cmd.PersistentFlags().String("postgres.username", "authelia", "the PostgreSQL username")
cmd.PersistentFlags().String("postgres.password", "", "the PostgreSQL password")
cmd.AddCommand(
newStorageMigrateCmd(),
newStorageSchemaInfoCmd(),
)
return cmd
}
func newStorageSchemaInfoCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: "schema-info",
Short: "Show the storage information",
RunE: storageSchemaInfoRunE,
}
return cmd
}
// NewMigrationCmd returns a new Migration Cmd.
func newStorageMigrateCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: "migrate",
Short: "Perform or list migrations",
Args: cobra.NoArgs,
}
cmd.AddCommand(
newStorageMigrateUpCmd(), newStorageMigrateDownCmd(),
newStorageMigrateListUpCmd(), newStorageMigrateListDownCmd(),
newStorageMigrateHistoryCmd(),
)
return cmd
}
func newStorageMigrateHistoryCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: "history",
Short: "Show migration history",
Args: cobra.NoArgs,
RunE: storageMigrateHistoryRunE,
}
return cmd
}
func newStorageMigrateListUpCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: "list-up",
Short: "List the up migrations available",
Args: cobra.NoArgs,
RunE: newStorageMigrateListRunE(true),
}
return cmd
}
func newStorageMigrateListDownCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: "list-down",
Short: "List the down migrations available",
Args: cobra.NoArgs,
RunE: newStorageMigrateListRunE(false),
}
return cmd
}
func newStorageMigrateUpCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: storageMigrateDirectionUp,
Short: "Perform a migration up",
Args: cobra.NoArgs,
RunE: newStorageMigrationRunE(true),
}
cmd.Flags().IntP("target", "t", 0, "sets the version to migrate to, by default this is the latest version")
return cmd
}
func newStorageMigrateDownCmd() (cmd *cobra.Command) {
cmd = &cobra.Command{
Use: storageMigrateDirectionDown,
Short: "Perform a migration down",
Args: cobra.NoArgs,
RunE: newStorageMigrationRunE(false),
}
cmd.Flags().IntP("target", "t", 0, "sets the version to migrate to")
cmd.Flags().Bool("pre1", false, "sets pre1 as the version to migrate to")
cmd.Flags().Bool("destroy-data", false, "confirms you want to destroy data with this migration")
return cmd
}

View File

@ -0,0 +1,291 @@
package commands
import (
"context"
"errors"
"fmt"
"os"
"strings"
"github.com/spf13/cobra"
"github.com/authelia/authelia/v4/internal/configuration"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator"
"github.com/authelia/authelia/v4/internal/storage"
)
func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) {
configs, err := cmd.Flags().GetStringSlice("config")
if err != nil {
return err
}
sources := make([]configuration.Source, 0, len(configs)+3)
if cmd.Flags().Changed("config") {
for _, configFile := range configs {
if _, err := os.Stat(configFile); os.IsNotExist(err) {
return fmt.Errorf("could not load the provided configuration file %s: %w", configFile, err)
}
sources = append(sources, configuration.NewYAMLFileSource(configFile))
}
} else {
if _, err := os.Stat(configs[0]); err == nil {
sources = append(sources, configuration.NewYAMLFileSource(configs[0]))
}
}
mapping := map[string]string{
"sqlite.path": "storage.local.path",
"mysql.host": "storage.mysql.host",
"mysql.port": "storage.mysql.port",
"mysql.database": "storage.mysql.database",
"mysql.username": "storage.mysql.username",
"mysql.password": "storage.mysql.password",
"postgres.host": "storage.postgres.host",
"postgres.port": "storage.postgres.port",
"postgres.database": "storage.postgres.database",
"postgres.username": "storage.postgres.username",
"postgres.password": "storage.postgres.password",
"postgres.schema": "storage.postgres.schema",
}
sources = append(sources, configuration.NewEnvironmentSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter))
sources = append(sources, configuration.NewSecretsSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter))
sources = append(sources, configuration.NewCommandLineSourceWithMapping(cmd.Flags(), mapping, true, false))
val := schema.NewStructValidator()
config = &schema.Configuration{}
_, err = configuration.LoadAdvanced(val, "storage", &config.Storage, sources...)
if err != nil {
return err
}
if val.HasErrors() {
var finalErr error
for i, err := range val.Errors() {
if i == 0 {
finalErr = err
continue
}
finalErr = fmt.Errorf("%w, %v", finalErr, err)
}
return finalErr
}
validator.ValidateStorage(config.Storage, val)
if val.HasErrors() {
var finalErr error
for i, err := range val.Errors() {
if i == 0 {
finalErr = err
continue
}
finalErr = fmt.Errorf("%w, %v", finalErr, err)
}
return finalErr
}
return nil
}
func storageMigrateHistoryRunE(_ *cobra.Command, _ []string) (err error) {
var (
provider storage.Provider
ctx = context.Background()
)
provider, err = getStorageProvider()
if err != nil {
return err
}
migrations, err := provider.SchemaMigrationHistory(ctx)
if err != nil {
return err
}
if len(migrations) == 0 {
return errors.New("no migration history found which may indicate a broken schema")
}
fmt.Printf("Migration History:\n\nID\tDate\t\t\t\tBefore\tAfter\tAuthelia Version\n")
for _, m := range migrations {
fmt.Printf("%d\t%s\t%d\t%d\t%s\n", m.ID, m.Applied.Format("2006-01-02 15:04:05 -0700"), m.Before, m.After, m.Version)
}
return nil
}
func newStorageMigrateListRunE(up bool) func(cmd *cobra.Command, args []string) (err error) {
return func(cmd *cobra.Command, args []string) (err error) {
var (
provider storage.Provider
ctx = context.Background()
migrations []storage.SchemaMigration
directionStr string
)
provider, err = getStorageProvider()
if err != nil {
return err
}
if up {
migrations, err = provider.SchemaMigrationsUp(ctx, 0)
directionStr = "Up"
} else {
migrations, err = provider.SchemaMigrationsDown(ctx, 0)
directionStr = "Down"
}
if err != nil {
if err.Error() == "cannot migrate to the same version as prior" {
fmt.Printf("No %s migrations found\n", directionStr)
return nil
}
return err
}
if len(migrations) == 0 {
fmt.Printf("Storage Schema Migration List (%s)\n\nNo Migrations Available\n", directionStr)
} else {
fmt.Printf("Storage Schema Migration List (%s)\n\nVersion\t\tDescription\n", directionStr)
for _, migration := range migrations {
fmt.Printf("%d\t\t%s\n", migration.Version, migration.Name)
}
}
return nil
}
}
func newStorageMigrationRunE(up bool) func(cmd *cobra.Command, args []string) (err error) {
return func(cmd *cobra.Command, args []string) (err error) {
var (
provider storage.Provider
ctx = context.Background()
)
provider, err = getStorageProvider()
if err != nil {
return err
}
target, err := cmd.Flags().GetInt("target")
if err != nil {
return err
}
switch {
case up:
switch cmd.Flags().Changed("target") {
case true:
return provider.SchemaMigrate(ctx, true, target)
default:
return provider.SchemaMigrate(ctx, true, storage.SchemaLatest)
}
default:
if !cmd.Flags().Changed("target") {
return errors.New("must set target")
}
if err = storageMigrateDownConfirmDestroy(cmd); err != nil {
return err
}
pre1, err := cmd.Flags().GetBool("pre1")
if err != nil {
return err
}
switch {
case pre1:
return provider.SchemaMigrate(ctx, false, -1)
default:
return provider.SchemaMigrate(ctx, false, target)
}
}
}
}
func storageMigrateDownConfirmDestroy(cmd *cobra.Command) (err error) {
destroy, err := cmd.Flags().GetBool("destroy-data")
if err != nil {
return err
}
if !destroy {
fmt.Printf("Schema Down Migrations may DESTROY data, type 'DESTROY' and press return to continue: ")
var text string
_, _ = fmt.Scanln(&text)
if text != "DESTROY" {
return errors.New("cancelling down migration due to user not accepting data destruction")
}
}
return nil
}
func storageSchemaInfoRunE(_ *cobra.Command, _ []string) (err error) {
var (
provider storage.Provider
ctx = context.Background()
upgradeStr string
tablesStr string
)
provider, err = getStorageProvider()
if err != nil {
return err
}
version, err := provider.SchemaVersion(ctx)
if err != nil && err.Error() != "unknown schema state" {
return err
}
tables, err := provider.SchemaTables(ctx)
if err != nil {
return err
}
if len(tables) == 0 {
tablesStr = "N/A"
} else {
tablesStr = strings.Join(tables, ", ")
}
latest, err := provider.SchemaLatestVersion()
if err != nil {
return err
}
if latest > version {
upgradeStr = fmt.Sprintf("yes - version %d", latest)
} else {
upgradeStr = "no"
}
fmt.Printf("Schema Version: %s\nSchema Upgrade Available: %s\nSchema Tables: %s\n", storage.SchemaVersionToString(version), upgradeStr, tablesStr)
return nil
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"strings"
"github.com/spf13/pflag"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator"
"github.com/authelia/authelia/v4/internal/utils"
@ -48,3 +50,25 @@ func koanfEnvironmentSecretsCallback(keyMap map[string]string, validator *schema
return k, v
}
}
func koanfCommandLineWithMappingCallback(mapping map[string]string, includeValidKeys, includeUnchangedKeys bool) func(flag *pflag.Flag) (string, interface{}) {
return func(flag *pflag.Flag) (string, interface{}) {
if !includeUnchangedKeys && !flag.Changed {
return "", nil
}
if actualKey, ok := mapping[flag.Name]; ok {
return actualKey, flag.Value.String()
}
if includeValidKeys {
formattedKey := strings.ReplaceAll(flag.Name, "-", "_")
if utils.IsStringInSlice(formattedKey, validator.ValidKeys) {
return formattedKey, flag.Value.String()
}
}
return "", nil
}
}

View File

@ -11,8 +11,17 @@ import (
// Load the configuration given the provided options and sources.
func Load(val *schema.StructValidator, sources ...Source) (keys []string, configuration *schema.Configuration, err error) {
configuration = &schema.Configuration{}
keys, err = LoadAdvanced(val, "", configuration, sources...)
return keys, configuration, err
}
// LoadAdvanced is intended to give more flexibility over loading a particular path to a specific interface.
func LoadAdvanced(val *schema.StructValidator, path string, result interface{}, sources ...Source) (keys []string, err error) {
if val == nil {
return keys, configuration, errNoValidator
return keys, errNoValidator
}
ko := koanf.NewWithConf(koanf.Conf{
@ -22,14 +31,12 @@ func Load(val *schema.StructValidator, sources ...Source) (keys []string, config
err = loadSources(ko, val, sources...)
if err != nil {
return ko.Keys(), configuration, err
return ko.Keys(), err
}
configuration = &schema.Configuration{}
unmarshal(ko, val, path, result)
unmarshal(ko, val, "", configuration)
return ko.Keys(), configuration, nil
return ko.Keys(), nil
}
func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o interface{}) {

View File

@ -1,12 +1,5 @@
package schema
import (
"fmt"
"reflect"
"github.com/Workiva/go-datastructures/queue"
)
// ErrorContainer represents a container where we can add errors and retrieve them.
type ErrorContainer interface {
Push(err error)
@ -17,100 +10,6 @@ type ErrorContainer interface {
Warnings() []error
}
// Validator represents the validator interface.
type Validator struct {
errors map[string][]error
}
// NewValidator create a validator.
func NewValidator() *Validator {
validator := new(Validator)
validator.errors = make(map[string][]error)
return validator
}
// QueueItem an item representing a struct field and its path.
type QueueItem struct {
value reflect.Value
path string
}
func (v *Validator) validateOne(item QueueItem, q *queue.Queue) error { //nolint:unparam
if item.value.Type().Kind() == reflect.Ptr {
if item.value.IsNil() {
return nil
}
elem := item.value.Elem()
q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
value: elem,
path: item.path,
})
} else if item.value.Kind() == reflect.Struct {
numFields := item.value.Type().NumField()
validateFn := item.value.Addr().MethodByName("Validate")
if validateFn.IsValid() {
structValidator := NewStructValidator()
validateFn.Call([]reflect.Value{reflect.ValueOf(structValidator)})
v.errors[item.path] = structValidator.Errors()
}
for i := 0; i < numFields; i++ {
field := item.value.Type().Field(i)
value := item.value.Field(i)
q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
value: value,
path: item.path + "." + field.Name,
})
}
}
return nil
}
// Validate validate a struct.
func (v *Validator) Validate(s interface{}) error {
q := queue.New(40)
q.Put(QueueItem{value: reflect.ValueOf(s), path: "root"}) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
for !q.Empty() {
val, err := q.Get(1)
if err != nil {
return err
}
item, ok := val[0].(QueueItem)
if !ok {
return fmt.Errorf("Cannot convert item into QueueItem")
}
v.validateOne(item, q) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
}
return nil
}
// PrintErrors display the errors thrown during validation.
func (v *Validator) PrintErrors() {
for path, errs := range v.errors {
fmt.Printf("Errors at %s:\n", path)
for _, err := range errs {
fmt.Printf("--> %s\n", err)
}
}
}
// Errors return the errors thrown during validation.
func (v *Validator) Errors() map[string][]error {
return v.errors
}
// StructValidator is a validator for structs.
type StructValidator struct {
errors []error

View File

@ -49,43 +49,6 @@ func (ts *TestStruct) Validate(validator *schema.StructValidator) {
}
}
func TestValidator(t *testing.T) {
validator := schema.NewValidator()
s := TestStruct{
MustBe10: 5,
NotEmpty: "",
NestedPtr: &TestNestedStruct{},
}
err := validator.Validate(&s)
if err != nil {
panic(err)
}
errs := validator.Errors()
assert.Equal(t, 4, len(errs))
assert.Equal(t, 2, len(errs["root"]))
assert.ElementsMatch(t, []error{
fmt.Errorf("MustBe10 must be 10"),
fmt.Errorf("NotEmpty must not be empty")}, errs["root"])
assert.Equal(t, 1, len(errs["root.Nested"]))
assert.ElementsMatch(t, []error{
fmt.Errorf("MustBe5 must be 5")}, errs["root.Nested"])
assert.Equal(t, 1, len(errs["root.Nested2"]))
assert.ElementsMatch(t, []error{
fmt.Errorf("MustBe5 must be 5")}, errs["root.Nested2"])
assert.Equal(t, 1, len(errs["root.NestedPtr"]))
assert.ElementsMatch(t, []error{
fmt.Errorf("MustBe5 must be 5")}, errs["root.NestedPtr"])
assert.Equal(t, "xyz", s.SetDefault)
}
func TestStructValidator(t *testing.T) {
validator := schema.NewStructValidator()
s := TestStruct{

View File

@ -8,6 +8,8 @@ import (
"github.com/knadh/koanf/parsers/yaml"
"github.com/knadh/koanf/providers/env"
"github.com/knadh/koanf/providers/file"
"github.com/knadh/koanf/providers/posflag"
"github.com/spf13/pflag"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator"
@ -112,6 +114,37 @@ func (s *SecretsSource) Load(val *schema.StructValidator) (err error) {
return s.koanf.Load(env.ProviderWithValue(s.prefix, constDelimiter, koanfEnvironmentSecretsCallback(keyMap, val)), nil)
}
// NewCommandLineSourceWithMapping creates a new command line configuration source with a map[string]string which converts
// flag names into other config key names. If includeValidKeys is true we also allow any flag with a name which matches
// the list of valid keys into the koanf.Koanf, otherwise everything not in the map is skipped. Unchanged flags are also
// skipped unless includeUnchangedKeys is set to true.
func NewCommandLineSourceWithMapping(flags *pflag.FlagSet, mapping map[string]string, includeValidKeys, includeUnchangedKeys bool) (source *CommandLineSource) {
return &CommandLineSource{
koanf: koanf.New(constDelimiter),
flags: flags,
callback: koanfCommandLineWithMappingCallback(mapping, includeValidKeys, includeUnchangedKeys),
}
}
// Name of the Source.
func (s CommandLineSource) Name() (name string) {
return "command-line"
}
// Merge the CommandLineSource koanf.Koanf into the provided one.
func (s *CommandLineSource) Merge(ko *koanf.Koanf, val *schema.StructValidator) (err error) {
return ko.Merge(s.koanf)
}
// Load the Source into the YAMLFileSource koanf.Koanf.
func (s *CommandLineSource) Load(_ *schema.StructValidator) (err error) {
if s.callback != nil {
return s.koanf.Load(posflag.ProviderWithFlag(s.flags, ".", s.koanf, s.callback), nil)
}
return s.koanf.Load(posflag.Provider(s.flags, ".", s.koanf), nil)
}
// NewDefaultSources returns a slice of Source configured to load from specified YAML files.
func NewDefaultSources(filePaths []string, prefix, delimiter string) (sources []Source) {
fileSources := NewYAMLFileSources(filePaths)

View File

@ -2,6 +2,7 @@ package configuration
import (
"github.com/knadh/koanf"
"github.com/spf13/pflag"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
@ -32,3 +33,10 @@ type SecretsSource struct {
prefix string
delimiter string
}
// CommandLineSource loads configuration from the command line flags.
type CommandLineSource struct {
koanf *koanf.Koanf
flags *pflag.FlagSet
callback func(flag *pflag.Flag) (string, interface{})
}

View File

@ -73,6 +73,12 @@ const (
pathOpenIDConnectConsent = "/api/oidc/consent"
)
const (
totpAlgoSHA1 = "SHA1"
totpAlgoSHA256 = "SHA256"
totpAlgoSHA512 = "SHA512"
)
const (
accept = "accept"
reject = "reject"

View File

@ -77,7 +77,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
return
}
bannedUntil, err := ctx.Providers.Regulator.Regulate(bodyJSON.Username)
bannedUntil, err := ctx.Providers.Regulator.Regulate(ctx, bodyJSON.Username)
if err != nil {
if err == regulation.ErrUserIsBanned {
@ -95,7 +95,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
if err != nil {
ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username)
if err := ctx.Providers.Regulator.Mark(bodyJSON.Username, false); err != nil {
if err := ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, false); err != nil {
ctx.Logger.Errorf("Unable to mark authentication: %s", err.Error())
}
@ -107,7 +107,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
if !userPasswordOk {
ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username)
if err := ctx.Providers.Regulator.Mark(bodyJSON.Username, false); err != nil {
if err := ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, false); err != nil {
ctx.Logger.Errorf("Unable to mark authentication: %s", err.Error())
}
@ -117,7 +117,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
}
ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username)
err = ctx.Providers.Regulator.Mark(bodyJSON.Username, true)
err = ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, true)
if err != nil {
handleAuthenticationUnauthorized(ctx, fmt.Errorf("unable to mark authentication: %s", err.Error()), messageAuthenticationFailed)

View File

@ -58,7 +58,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Eq(models.AuthenticationAttempt{
AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{
Username: "test",
Successful: false,
Time: s.mock.Clock.Now(),
@ -83,7 +83,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Eq(models.AuthenticationAttempt{
AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{
Username: "test",
Successful: false,
Time: s.mock.Clock.Now(),
@ -106,7 +106,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Any()).
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
Return(nil)
s.mock.UserProviderMock.
@ -133,7 +133,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() {
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Any()).
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
Return(fmt.Errorf("failed"))
s.mock.Ctx.Request.SetBodyString(`{
@ -164,7 +164,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() {
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Any()).
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
Return(nil)
s.mock.Ctx.Request.SetBodyString(`{
@ -204,7 +204,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() {
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Any()).
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
Return(nil)
s.mock.Ctx.Request.SetBodyString(`{
@ -248,7 +248,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Any()).
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
Return(nil)
s.mock.Ctx.Request.SetBodyString(`{
@ -306,7 +306,7 @@ func (s *FirstFactorRedirectionSuite) SetupTest() {
s.mock.StorageProviderMock.
EXPECT().
AppendAuthenticationLog(gomock.Any()).
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
Return(nil)
}

View File

@ -3,9 +3,11 @@ package handlers
import (
"fmt"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/session"
)
@ -37,11 +39,15 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle
})
func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
algorithm := otp.AlgorithmSHA1
key, err := totp.Generate(totp.GenerateOpts{
Issuer: ctx.Configuration.TOTP.Issuer,
AccountName: username,
SecretSize: 32,
Period: uint(ctx.Configuration.TOTP.Period),
SecretSize: 32,
Digits: otp.Digits(6),
Algorithm: algorithm,
})
if err != nil {
@ -49,7 +55,15 @@ func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username strin
return
}
err = ctx.Providers.StorageProvider.SaveTOTPSecret(username, key.Secret())
config := models.TOTPConfiguration{
Username: username,
Algorithm: otpAlgoToString(algorithm),
Digits: 6,
Secret: key.Secret(),
Period: key.Period(),
}
err = ctx.Providers.StorageProvider.SaveTOTPConfiguration(ctx, config)
if err != nil {
ctx.Error(fmt.Errorf("unable to save TOTP secret in DB: %s", err), messageUnableToRegisterOneTimePassword)
return

View File

@ -57,11 +57,11 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
s.mock.StorageProviderMock.EXPECT().
RemoveIdentityVerificationToken(gomock.Eq(token)).
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(nil)
SecondFactorU2FIdentityFinish(s.mock.Ctx)
@ -77,11 +77,11 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
s.mock.StorageProviderMock.EXPECT().
RemoveIdentityVerificationToken(gomock.Eq(token)).
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(nil)
SecondFactorU2FIdentityFinish(s.mock.Ctx)

View File

@ -7,6 +7,7 @@ import (
"github.com/tstranex/u2f"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/models"
)
// SecondFactorU2FRegister handler validating the client has successfully validated the challenge
@ -45,7 +46,12 @@ func SecondFactorU2FRegister(ctx *middlewares.AutheliaCtx) {
ctx.Logger.Debugf("Register U2F device for user %s", userSession.Username)
publicKey := elliptic.Marshal(elliptic.P256(), registration.PubKey.X, registration.PubKey.Y)
err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, registration.KeyHandle, publicKey)
err = ctx.Providers.StorageProvider.SaveU2FDevice(ctx, models.U2FDevice{
Username: userSession.Username,
KeyHandle: registration.KeyHandle,
PublicKey: publicKey},
)
if err != nil {
ctx.Error(fmt.Errorf("unable to register U2F device for user %s: %v", userSession.Username, err), messageUnableToRegisterSecurityKey)

View File

@ -19,13 +19,13 @@ func SecondFactorTOTPPost(totpVerifier TOTPVerifier) middlewares.RequestHandler
userSession := ctx.GetSession()
secret, err := ctx.Providers.StorageProvider.LoadTOTPSecret(userSession.Username)
config, err := ctx.Providers.StorageProvider.LoadTOTPConfiguration(ctx, userSession.Username)
if err != nil {
handleAuthenticationUnauthorized(ctx, fmt.Errorf("unable to load TOTP secret: %s", err), messageMFAValidationFailed)
return
}
isValid, err := totpVerifier.Verify(requestBody.Token, secret)
isValid, err := totpVerifier.Verify(config, requestBody.Token)
if err != nil {
handleAuthenticationUnauthorized(ctx, fmt.Errorf("error occurred during OTP validation for user %s: %s", userSession.Username, err), messageMFAValidationFailed)
return

View File

@ -11,6 +11,7 @@ import (
"github.com/tstranex/u2f"
"github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/session"
)
@ -37,12 +38,14 @@ func (s *HandlerSignTOTPSuite) TearDownTest() {
func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
s.mock.StorageProviderMock.EXPECT().
LoadTOTPSecret(gomock.Any()).
Return("secret", nil)
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
Return(&config, nil)
verifier.EXPECT().
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
Verify(gomock.Eq(&config), gomock.Eq("abc")).
Return(true, nil)
s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL
@ -62,12 +65,14 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
s.mock.StorageProviderMock.EXPECT().
LoadTOTPSecret(gomock.Any()).
Return("secret", nil)
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
Return(&config, nil)
verifier.EXPECT().
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
Verify(gomock.Eq(&config), gomock.Eq("abc")).
Return(true, nil)
bodyBytes, err := json.Marshal(signTOTPRequestBody{
@ -83,12 +88,14 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() {
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
s.mock.StorageProviderMock.EXPECT().
LoadTOTPSecret(gomock.Any()).
Return("secret", nil)
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
Return(&config, nil)
verifier.EXPECT().
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
Verify(gomock.Eq(&config), gomock.Eq("abc")).
Return(true, nil)
bodyBytes, err := json.Marshal(signTOTPRequestBody{
@ -108,11 +115,11 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
s.mock.StorageProviderMock.EXPECT().
LoadTOTPSecret(gomock.Any()).
Return("secret", nil)
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
Return(&models.TOTPConfiguration{Secret: "secret"}, nil)
verifier.EXPECT().
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
Verify(gomock.Eq(&models.TOTPConfiguration{Secret: "secret"}), gomock.Eq("abc")).
Return(true, nil)
bodyBytes, err := json.Marshal(signTOTPRequestBody{
@ -129,12 +136,14 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFixation() {
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
s.mock.StorageProviderMock.EXPECT().
LoadTOTPSecret(gomock.Any()).
Return("secret", nil)
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
Return(&config, nil)
verifier.EXPECT().
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
Verify(gomock.Eq(&config), gomock.Eq("abc")).
Return(true, nil)
bodyBytes, err := json.Marshal(signTOTPRequestBody{

View File

@ -34,7 +34,7 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
}
userSession := ctx.GetSession()
keyHandleBytes, publicKeyBytes, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username)
device, err := ctx.Providers.StorageProvider.LoadU2FDevice(ctx, userSession.Username)
if err != nil {
if err == storage.ErrNoU2FDeviceHandle {
@ -48,16 +48,16 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
}
var registration u2f.Registration
registration.KeyHandle = keyHandleBytes
x, y := elliptic.Unmarshal(elliptic.P256(), publicKeyBytes)
registration.KeyHandle = device.KeyHandle
x, y := elliptic.Unmarshal(elliptic.P256(), device.PublicKey)
registration.PubKey.Curve = elliptic.P256()
registration.PubKey.X = x
registration.PubKey.Y = y
// Save the challenge and registration for use in next request
userSession.U2FRegistration = &session.U2FRegistration{
KeyHandle: keyHandleBytes,
PublicKey: publicKeyBytes,
KeyHandle: device.KeyHandle,
PublicKey: device.PublicKey,
}
userSession.U2FChallenge = challenge
err = ctx.SaveSession(userSession)

View File

@ -3,97 +3,25 @@ package handlers
import (
"fmt"
"strings"
"sync"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/middlewares"
"github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/utils"
)
func loadInfo(username string, storageProvider storage.Provider, userInfo *UserInfo, logger *logrus.Entry) []error {
var wg sync.WaitGroup
wg.Add(3)
errors := make([]error, 0)
go func() {
defer wg.Done()
method, err := storageProvider.LoadPreferred2FAMethod(username)
if err != nil {
errors = append(errors, err)
logger.Error(err)
return
}
if method == "" {
userInfo.Method = authentication.PossibleMethods[0]
} else {
userInfo.Method = method
}
}()
go func() {
defer wg.Done()
_, _, err := storageProvider.LoadU2FDeviceHandle(username)
if err != nil {
if err == storage.ErrNoU2FDeviceHandle {
return
}
errors = append(errors, err)
logger.Error(err)
return
}
userInfo.HasU2F = true
}()
go func() {
defer wg.Done()
_, err := storageProvider.LoadTOTPSecret(username)
if err != nil {
if err == storage.ErrNoTOTPSecret {
return
}
errors = append(errors, err)
logger.Error(err)
return
}
userInfo.HasTOTP = true
}()
wg.Wait()
return errors
}
// UserInfoGet get the info related to the user identified by the session.
func UserInfoGet(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession()
userInfo := UserInfo{}
errors := loadInfo(userSession.Username, ctx.Providers.StorageProvider, &userInfo, ctx.Logger)
if len(errors) > 0 {
ctx.Error(fmt.Errorf("unable to load user information"), messageOperationFailed)
userInfo, err := ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username)
if err != nil {
ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed)
return
}
userInfo.DisplayName = userSession.DisplayName
err := ctx.SetJSONBody(userInfo)
err = ctx.SetJSONBody(userInfo)
if err != nil {
ctx.Logger.Errorf("Unable to set user info response in body: %s", err)
}
@ -121,7 +49,7 @@ func MethodPreferencePost(ctx *middlewares.AutheliaCtx) {
userSession := ctx.GetSession()
ctx.Logger.Debugf("Save new preferred 2FA method of user %s to %s", userSession.Username, bodyJSON.Method)
err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(userSession.Username, bodyJSON.Method)
err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, bodyJSON.Method)
if err != nil {
ctx.Error(fmt.Errorf("unable to save new preferred 2FA method: %s", err), messageOperationFailed)

View File

@ -1,6 +1,8 @@
package handlers
import (
"database/sql"
"errors"
"fmt"
"testing"
@ -11,7 +13,7 @@ import (
"github.com/stretchr/testify/suite"
"github.com/authelia/authelia/v4/internal/mocks"
"github.com/authelia/authelia/v4/internal/storage"
"github.com/authelia/authelia/v4/internal/models"
)
type FetchSuite struct {
@ -33,62 +35,59 @@ func (s *FetchSuite) TearDownTest() {
s.mock.Close()
}
func setPreferencesExpectations(preferences UserInfo, provider *storage.MockProvider) {
provider.
EXPECT().
LoadPreferred2FAMethod(gomock.Eq("john")).
Return(preferences.Method, nil)
if preferences.HasU2F {
u2fData := []byte("abc")
provider.
EXPECT().
LoadU2FDeviceHandle(gomock.Eq("john")).
Return(u2fData, u2fData, nil)
} else {
provider.
EXPECT().
LoadU2FDeviceHandle(gomock.Eq("john")).
Return(nil, nil, storage.ErrNoU2FDeviceHandle)
}
if preferences.HasTOTP {
totpSecret := "secret"
provider.
EXPECT().
LoadTOTPSecret(gomock.Eq("john")).
Return(totpSecret, nil)
} else {
provider.
EXPECT().
LoadTOTPSecret(gomock.Eq("john")).
Return("", storage.ErrNoTOTPSecret)
}
type expectedResponse struct {
db models.UserInfo
api *models.UserInfo
err error
}
func TestMethodSetToU2F(t *testing.T) {
table := []UserInfo{
expectedResponses := []expectedResponse{
{
Method: "totp",
db: models.UserInfo{
Method: "totp",
},
err: nil,
},
{
Method: "u2f",
HasU2F: true,
HasTOTP: true,
db: models.UserInfo{
Method: "u2f",
HasU2F: true,
HasTOTP: true,
},
err: nil,
},
{
Method: "u2f",
HasU2F: true,
HasTOTP: false,
db: models.UserInfo{
Method: "u2f",
HasU2F: true,
HasTOTP: false,
},
err: nil,
},
{
Method: "mobile_push",
HasU2F: false,
HasTOTP: false,
db: models.UserInfo{
Method: "mobile_push",
HasU2F: false,
HasTOTP: false,
},
err: nil,
},
{
db: models.UserInfo{},
err: sql.ErrNoRows,
},
{
db: models.UserInfo{},
err: errors.New("invalid thing"),
},
}
for _, expectedPreferences := range table {
for _, resp := range expectedResponses {
if resp.api == nil {
resp.api = &resp.db
}
mock := mocks.NewMockAutheliaCtx(t)
// Set the initial user session.
userSession := mock.Ctx.GetSession()
@ -97,64 +96,57 @@ func TestMethodSetToU2F(t *testing.T) {
err := mock.Ctx.SaveSession(userSession)
require.NoError(t, err)
setPreferencesExpectations(expectedPreferences, mock.StorageProviderMock)
mock.StorageProviderMock.
EXPECT().
LoadUserInfo(mock.Ctx, gomock.Eq("john")).
Return(resp.db, resp.err)
UserInfoGet(mock.Ctx)
actualPreferences := UserInfo{}
mock.GetResponseData(t, &actualPreferences)
if resp.err == nil {
t.Run("expected status code", func(t *testing.T) {
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
})
t.Run("expected method", func(t *testing.T) {
assert.Equal(t, expectedPreferences.Method, actualPreferences.Method)
})
actualPreferences := models.UserInfo{}
t.Run("registered u2f", func(t *testing.T) {
assert.Equal(t, expectedPreferences.HasU2F, actualPreferences.HasU2F)
})
mock.GetResponseData(t, &actualPreferences)
t.Run("expected method", func(t *testing.T) {
assert.Equal(t, resp.api.Method, actualPreferences.Method)
})
t.Run("registered u2f", func(t *testing.T) {
assert.Equal(t, resp.api.HasU2F, actualPreferences.HasU2F)
})
t.Run("registered totp", func(t *testing.T) {
assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP)
})
} else {
t.Run("expected status code", func(t *testing.T) {
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
})
errResponse := mock.GetResponseError(t)
assert.Equal(t, "KO", errResponse.Status)
assert.Equal(t, "Operation failed.", errResponse.Message)
}
t.Run("registered totp", func(t *testing.T) {
assert.Equal(t, expectedPreferences.HasTOTP, actualPreferences.HasTOTP)
})
mock.Close()
}
}
func (s *FetchSuite) TestShouldGetDefaultPreferenceIfNotInDB() {
s.mock.StorageProviderMock.
EXPECT().
LoadPreferred2FAMethod(gomock.Eq("john")).
Return("", nil)
s.mock.StorageProviderMock.
EXPECT().
LoadU2FDeviceHandle(gomock.Eq("john")).
Return(nil, nil, storage.ErrNoU2FDeviceHandle)
s.mock.StorageProviderMock.
EXPECT().
LoadTOTPSecret(gomock.Eq("john")).
Return("", storage.ErrNoTOTPSecret)
UserInfoGet(s.mock.Ctx)
s.mock.Assert200OK(s.T(), UserInfo{Method: "totp"})
}
func (s *FetchSuite) TestShouldReturnError500WhenStorageFailsToLoad() {
s.mock.StorageProviderMock.EXPECT().
LoadPreferred2FAMethod(gomock.Eq("john")).
Return("", fmt.Errorf("Failure"))
s.mock.StorageProviderMock.
EXPECT().
LoadU2FDeviceHandle(gomock.Eq("john"))
s.mock.StorageProviderMock.
EXPECT().
LoadTOTPSecret(gomock.Eq("john"))
LoadUserInfo(s.mock.Ctx, gomock.Eq("john")).
Return(models.UserInfo{}, fmt.Errorf("failure"))
UserInfoGet(s.mock.Ctx)
s.mock.Assert200KO(s.T(), "Operation failed.")
assert.Equal(s.T(), "unable to load user information", s.mock.Hook.LastEntry().Message)
assert.Equal(s.T(), "unable to load user information: failure", s.mock.Hook.LastEntry().Message)
assert.Equal(s.T(), logrus.ErrorLevel, s.mock.Hook.LastEntry().Level)
}
@ -220,7 +212,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenBadMethodProvided() {
func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() {
s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}"))
s.mock.StorageProviderMock.EXPECT().
SavePreferred2FAMethod(gomock.Eq("john"), gomock.Eq("u2f")).
SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")).
Return(fmt.Errorf("Failure"))
MethodPreferencePost(s.mock.Ctx)
@ -233,7 +225,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() {
func (s *SaveSuite) TestShouldReturn200WhenMethodIsSuccessfullySaved() {
s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}"))
s.mock.StorageProviderMock.EXPECT().
SavePreferred2FAMethod(gomock.Eq("john"), gomock.Eq("u2f")).
SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")).
Return(nil)
MethodPreferencePost(s.mock.Ctx)

View File

@ -1,15 +1,18 @@
package handlers
import (
"errors"
"time"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/authelia/authelia/v4/internal/models"
)
// TOTPVerifier is the interface for verifying TOTPs.
type TOTPVerifier interface {
Verify(token, secret string) (bool, error)
Verify(config *models.TOTPConfiguration, token string) (bool, error)
}
// TOTPVerifierImpl the production implementation for TOTP verification.
@ -19,13 +22,43 @@ type TOTPVerifierImpl struct {
}
// Verify verifies TOTPs.
func (tv *TOTPVerifierImpl) Verify(token, secret string) (bool, error) {
opts := totp.ValidateOpts{
Period: tv.Period,
Skew: tv.Skew,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
func (tv *TOTPVerifierImpl) Verify(config *models.TOTPConfiguration, token string) (bool, error) {
if config == nil {
return false, errors.New("config not provided")
}
return totp.ValidateCustom(token, secret, time.Now().UTC(), opts)
opts := totp.ValidateOpts{
Period: uint(config.Period),
Skew: tv.Skew,
Digits: otp.Digits(config.Digits),
Algorithm: otpStringToAlgo(config.Algorithm),
}
return totp.ValidateCustom(token, config.Secret, time.Now().UTC(), opts)
}
func otpAlgoToString(algorithm otp.Algorithm) (out string) {
switch algorithm {
case otp.AlgorithmSHA1:
return totpAlgoSHA1
case otp.AlgorithmSHA256:
return totpAlgoSHA256
case otp.AlgorithmSHA512:
return totpAlgoSHA512
default:
return ""
}
}
func otpStringToAlgo(in string) (algorithm otp.Algorithm) {
switch in {
case totpAlgoSHA1:
return otp.AlgorithmSHA1
case totpAlgoSHA256:
return otp.AlgorithmSHA256
case totpAlgoSHA512:
return otp.AlgorithmSHA512
default:
return otp.AlgorithmSHA1
}
}

View File

@ -5,45 +5,47 @@
package handlers
import (
reflect "reflect"
"reflect"
gomock "github.com/golang/mock/gomock"
"github.com/golang/mock/gomock"
"github.com/authelia/authelia/v4/internal/models"
)
// MockTOTPVerifier is a mock of TOTPVerifier interface
// MockTOTPVerifier is a mock of TOTPVerifier interface.
type MockTOTPVerifier struct {
ctrl *gomock.Controller
recorder *MockTOTPVerifierMockRecorder
}
// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier
// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier.
type MockTOTPVerifierMockRecorder struct {
mock *MockTOTPVerifier
}
// NewMockTOTPVerifier creates a new mock instance
// NewMockTOTPVerifier creates a new mock instance.
func NewMockTOTPVerifier(ctrl *gomock.Controller) *MockTOTPVerifier {
mock := &MockTOTPVerifier{ctrl: ctrl}
mock.recorder = &MockTOTPVerifierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTOTPVerifier) EXPECT() *MockTOTPVerifierMockRecorder {
return m.recorder
}
// Verify mocks base method
func (m *MockTOTPVerifier) Verify(token, secret string) (bool, error) {
// Verify mocks base method.
func (m *MockTOTPVerifier) Verify(arg0 *models.TOTPConfiguration, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Verify", token, secret)
ret := m.ctrl.Call(m, "Verify", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Verify indicates an expected call of Verify
func (mr *MockTOTPVerifierMockRecorder) Verify(token, secret interface{}) *gomock.Call {
// Verify indicates an expected call of Verify.
func (mr *MockTOTPVerifierMockRecorder) Verify(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), token, secret)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), arg0, arg1)
}

View File

@ -11,21 +11,6 @@ type MethodList = []string
type authorizationMatching int
// UserInfo is the model of user info and second factor preferences.
type UserInfo struct {
// The users display name.
DisplayName string `json:"display_name"`
// The preferred 2FA method.
Method string `json:"method" valid:"required"`
// True if a security key has been registered.
HasU2F bool `json:"has_u2f" valid:"required"`
// True if a TOTP device has been registered.
HasTOTP bool `json:"has_totp" valid:"required"`
}
// signTOTPRequestBody model of the request body received by TOTP authentication endpoint.
type signTOTPRequestBody struct {
Token string `json:"token" valid:"required"`

View File

@ -8,6 +8,7 @@ import (
"github.com/golang-jwt/jwt/v4"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/templates"
)
@ -47,7 +48,9 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
return
}
err = ctx.Providers.StorageProvider.SaveIdentityVerificationToken(ss)
err = ctx.Providers.StorageProvider.SaveIdentityVerification(ctx, models.IdentityVerification{
Token: ss,
})
if err != nil {
ctx.Error(err, messageOperationFailed)
return
@ -128,7 +131,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
return
}
found, err := ctx.Providers.StorageProvider.FindIdentityVerificationToken(finishBody.Token)
found, err := ctx.Providers.StorageProvider.FindIdentityVerification(ctx, finishBody.Token)
if err != nil {
ctx.Error(err, messageOperationFailed)
@ -185,7 +188,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
}
// TODO(c.michaud): find a way to garbage collect unused tokens.
err = ctx.Providers.StorageProvider.RemoveIdentityVerificationToken(finishBody.Token)
err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, finishBody.Token)
if err != nil {
ctx.Error(err, messageOperationFailed)
return

View File

@ -55,7 +55,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) {
mock.Ctx.Configuration.JWTSecret = testJWTSecret
mock.StorageProviderMock.EXPECT().
SaveIdentityVerificationToken(gomock.Any()).
SaveIdentityVerification(mock.Ctx, gomock.Any()).
Return(fmt.Errorf("cannot save"))
args := newArgs(defaultRetriever)
@ -74,7 +74,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
mock.StorageProviderMock.EXPECT().
SaveIdentityVerificationToken(gomock.Any()).
SaveIdentityVerification(mock.Ctx, gomock.Any()).
Return(nil)
mock.NotifierMock.EXPECT().
@ -96,7 +96,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
mock.StorageProviderMock.EXPECT().
SaveIdentityVerificationToken(gomock.Any()).
SaveIdentityVerification(mock.Ctx, gomock.Any()).
Return(nil)
args := newArgs(defaultRetriever)
@ -114,7 +114,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
mock.StorageProviderMock.EXPECT().
SaveIdentityVerificationToken(gomock.Any()).
SaveIdentityVerification(mock.Ctx, gomock.Any()).
Return(nil)
args := newArgs(defaultRetriever)
@ -132,7 +132,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
mock.StorageProviderMock.EXPECT().
SaveIdentityVerificationToken(gomock.Any()).
SaveIdentityVerification(mock.Ctx, gomock.Any()).
Return(nil)
mock.NotifierMock.EXPECT().
@ -209,7 +209,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB(
s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}")
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq("abc")).
FindIdentityVerification(s.mock.Ctx, gomock.Eq("abc")).
Return(false, nil)
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
@ -222,7 +222,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() {
s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}")
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq("abc")).
FindIdentityVerification(s.mock.Ctx, gomock.Eq("abc")).
Return(true, nil)
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
@ -238,7 +238,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
@ -253,7 +253,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
@ -268,7 +268,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
args := newFinishArgs()
@ -285,11 +285,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
s.mock.StorageProviderMock.EXPECT().
RemoveIdentityVerificationToken(gomock.Eq(token)).
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(fmt.Errorf("cannot remove"))
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
@ -304,11 +304,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete(
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
s.mock.StorageProviderMock.EXPECT().
FindIdentityVerificationToken(gomock.Eq(token)).
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(true, nil)
s.mock.StorageProviderMock.EXPECT().
RemoveIdentityVerificationToken(gomock.Eq(token)).
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
Return(nil)
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)

View File

@ -28,11 +28,6 @@ type AutheliaCtx struct {
Clock utils.Clock
}
// ProviderWithStartupCheck represents a provider that has a startup check.
type ProviderWithStartupCheck interface {
StartupCheck(logger *logrus.Logger) (err error)
}
// Providers contain all provider provided to Authelia.
type Providers struct {
Authorizer *authorization.Authorizer

View File

@ -183,3 +183,11 @@ func (m *MockAutheliaCtx) GetResponseData(t *testing.T, data interface{}) {
err := json.Unmarshal(m.Ctx.Response.Body(), &okResponse)
require.NoError(t, err)
}
// GetResponseError retrieves an error response from the service.
func (m *MockAutheliaCtx) GetResponseError(t *testing.T) (errResponse middlewares.ErrorResponse) {
err := json.Unmarshal(m.Ctx.Response.Body(), &errResponse)
require.NoError(t, err)
return errResponse
}

View File

@ -8,7 +8,6 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
)
// MockNotifier is a mock of Notifier interface.
@ -49,15 +48,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go
}
// StartupCheck mocks base method.
func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error {
func (m *MockNotifier) StartupCheck() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck", arg0)
ret := m.ctrl.Call(m, "StartupCheck")
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck))
}

View File

@ -5,12 +5,11 @@
package mocks
import (
"reflect"
reflect "reflect"
"github.com/golang/mock/gomock"
"github.com/sirupsen/logrus"
gomock "github.com/golang/mock/gomock"
"github.com/authelia/authelia/v4/internal/authentication"
authentication "github.com/authelia/authelia/v4/internal/authentication"
)
// MockUserProvider is a mock of UserProvider interface.
@ -66,7 +65,21 @@ func (mr *MockUserProviderMockRecorder) GetDetails(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDetails", reflect.TypeOf((*MockUserProvider)(nil).GetDetails), arg0)
}
// UpdatePassword mocks base method
// StartupCheck mocks base method.
func (m *MockUserProvider) StartupCheck() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck")
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockUserProviderMockRecorder) StartupCheck() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck))
}
// UpdatePassword mocks base method.
func (m *MockUserProvider) UpdatePassword(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
@ -79,17 +92,3 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) *
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1)
}
// StartupCheck mocks base method.
func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0)
}

View File

@ -0,0 +1,17 @@
package models
import (
"time"
)
// AuthenticationAttempt represents an authentication attempt row in the database.
type AuthenticationAttempt struct {
ID int `db:"id"`
Time time.Time `db:"time"`
Successful bool `db:"successful"`
Username string `db:"username"`
Type string `db:"auth_type"`
RemoteIP IPAddress `db:"remote_ip"`
RequestURI string `db:"request_uri"`
RequestMethod string `db:"request_method"`
}

View File

@ -0,0 +1,12 @@
package models
import (
"time"
)
// IdentityVerification represents an identity verification row in the database.
type IdentityVerification struct {
ID int `db:"id"`
Created time.Time `db:"created"`
Token string `db:"token"`
}

View File

@ -0,0 +1,14 @@
package models
import (
"time"
)
// Migration represents a migration row in the database.
type Migration struct {
ID int `db:"id"`
Applied time.Time `db:"applied"`
Before int `db:"version_before"`
After int `db:"version_after"`
Version string `db:"application_version"`
}

View File

@ -0,0 +1,11 @@
package models
// TOTPConfiguration represents a users TOTP configuration row in the database.
type TOTPConfiguration struct {
ID int `db:"id"`
Username string `db:"username"`
Algorithm string `db:"algorithm"`
Digits int `db:"digits"`
Period uint64 `db:"totp_period"`
Secret string `db:"secret"`
}

View File

@ -0,0 +1,10 @@
package models
// U2FDevice represents a users U2F device row in the database.
type U2FDevice struct {
ID int `db:"id"`
Username string `db:"username"`
Description string `db:"description"`
KeyHandle []byte `db:"key_handle"`
PublicKey []byte `db:"public_key"`
}

View File

@ -0,0 +1,16 @@
package models
// UserInfo represents the user information required by the web UI.
type UserInfo struct {
// The users display name.
DisplayName string `db:"-" json:"display_name"`
// The preferred 2FA method.
Method string `db:"second_factor_method" json:"method" valid:"required"`
// True if a security key has been registered.
HasU2F bool `db:"has_u2f" json:"has_u2f" valid:"required"`
// True if a TOTP device has been registered.
HasTOTP bool `db:"has_totp" json:"has_totp" valid:"required"`
}

View File

@ -0,0 +1,42 @@
package models
import (
"database/sql/driver"
"fmt"
"net"
)
// IPAddress is a type specific for storage of a net.IP in the database.
type IPAddress struct {
*net.IP
}
// Value is the IPAddress implementation of the databases/sql driver.Valuer.
func (ip IPAddress) Value() (value driver.Value, err error) {
if ip.IP == nil {
return driver.Value(nil), nil
}
return driver.Value(ip.IP.String()), nil
}
// Scan is the IPAddress implementation of the sql.Scanner.
func (ip *IPAddress) Scan(src interface{}) (err error) {
if src == nil {
ip.IP = nil
return nil
}
var value string
switch v := src.(type) {
case string:
value = v
default:
return fmt.Errorf("invalid type %T for IPAddress %v", src, src)
}
*ip.IP = net.ParseIP(value)
return nil
}

View File

@ -0,0 +1,6 @@
package models
// StartupCheck represents a provider that has a startup check.
type StartupCheck interface {
StartupCheck() (err error)
}

View File

@ -1,13 +0,0 @@
package models
import "time"
// AuthenticationAttempt represent an authentication attempt.
type AuthenticationAttempt struct {
// The user who tried to authenticate.
Username string
// Successful true if the attempt was successful.
Successful bool
// The time of the attempt.
Time time.Time
}

View File

@ -7,8 +7,6 @@ import (
"path/filepath"
"time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
@ -25,7 +23,7 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File
}
// StartupCheck implements the startup check provider interface.
func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) {
func (n *FileNotifier) StartupCheck() (err error) {
dir := filepath.Dir(n.path)
if _, err := os.Stat(dir); err != nil {
if os.IsNotExist(err) {

View File

@ -1,11 +1,12 @@
package notification
import (
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/models"
)
// Notifier interface for sending the identity verification link.
type Notifier interface {
models.StartupCheck
Send(recipient, subject, body, htmlBody string) (err error)
StartupCheck(logger *logrus.Logger) (err error)
}

View File

@ -10,8 +10,6 @@ import (
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
@ -223,7 +221,7 @@ func (n *SMTPNotifier) cleanup() {
}
// StartupCheck implements the startup check provider interface.
func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) {
func (n *SMTPNotifier) StartupCheck() (err error) {
if err := n.dial(); err != nil {
return err
}

View File

@ -6,22 +6,24 @@ import (
"net"
"time"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewProvider instantiate a ntp provider given a configuration.
func NewProvider(config *schema.NTPConfiguration) *Provider {
return &Provider{config}
return &Provider{
config: config,
log: logging.Logger(),
}
}
// StartupCheck implements the startup check provider interface.
func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
func (p *Provider) StartupCheck() (err error) {
conn, err := net.Dial("udp", p.config.Address)
if err != nil {
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
p.log.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
return nil
}
@ -29,7 +31,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
p.log.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
return nil
}
@ -42,7 +44,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)}
if err := binary.Write(conn, binary.BigEndian, req); err != nil {
logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
p.log.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
return nil
}
@ -52,7 +54,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
resp := &ntpPacket{}
if err := binary.Read(conn, binary.BigEndian, resp); err != nil {
logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
p.log.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
return nil
}

View File

@ -7,7 +7,6 @@ import (
"github.com/authelia/authelia/v4/internal/configuration/schema"
"github.com/authelia/authelia/v4/internal/configuration/validator"
"github.com/authelia/authelia/v4/internal/logging"
)
func TestShouldCheckNTP(t *testing.T) {
@ -22,5 +21,5 @@ func TestShouldCheckNTP(t *testing.T) {
ntp := NewProvider(&config)
assert.NoError(t, ntp.StartupCheck(logging.Logger()))
assert.NoError(t, ntp.StartupCheck())
}

View File

@ -1,12 +1,15 @@
package ntp
import (
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// Provider type is the NTP provider.
type Provider struct {
config *schema.NTPConfiguration
log *logrus.Logger
}
type ntpVersion int

View File

@ -20,10 +20,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
return provider, nil
}
provider.Store, err = NewOpenIDConnectStore(configuration)
if err != nil {
return provider, err
}
provider.Store = NewOpenIDConnectStore(configuration)
composeConfiguration := &compose.Config{
AccessTokenLifespan: configuration.AccessTokenLifespan,

View File

@ -14,7 +14,7 @@ import (
)
// NewOpenIDConnectStore returns a new OpenIDConnectStore using the provided schema.OpenIDConnectConfiguration.
func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore, err error) {
func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore) {
logger := logging.Logger()
store = &OpenIDConnectStore{
@ -39,7 +39,7 @@ func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (st
store.clients[client.ID] = NewClient(client)
}
return store, nil
return store
}
// GetClientPolicy retrieves the policy from the client with the matching provided id.

View File

@ -12,7 +12,7 @@ import (
)
func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: exampleIssuerPrivateKey,
Clients: []schema.OpenIDConnectClientConfiguration{
{
@ -32,8 +32,6 @@ func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
},
})
require.NoError(t, err)
policyOne := s.GetClientPolicy("myclient")
assert.Equal(t, authorization.OneFactor, policyOne)
@ -45,7 +43,7 @@ func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
}
func TestOpenIDConnectStore_GetInternalClient(t *testing.T) {
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: exampleIssuerPrivateKey,
Clients: []schema.OpenIDConnectClientConfiguration{
{
@ -58,8 +56,6 @@ func TestOpenIDConnectStore_GetInternalClient(t *testing.T) {
},
})
require.NoError(t, err)
client, err := s.GetClient(context.Background(), "myinvalidclient")
assert.EqualError(t, err, "not_found")
assert.Nil(t, client)
@ -78,13 +74,12 @@ func TestOpenIDConnectStore_GetInternalClient_ValidClient(t *testing.T) {
Scopes: []string{"openid", "profile"},
Secret: "mysecret",
}
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: exampleIssuerPrivateKey,
Clients: []schema.OpenIDConnectClientConfiguration{c1},
})
require.NoError(t, err)
client, err := s.GetInternalClient(c1.ID)
require.NoError(t, err)
require.NotNil(t, client)
@ -106,20 +101,19 @@ func TestOpenIDConnectStore_GetInternalClient_InvalidClient(t *testing.T) {
Scopes: []string{"openid", "profile"},
Secret: "mysecret",
}
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: exampleIssuerPrivateKey,
Clients: []schema.OpenIDConnectClientConfiguration{c1},
})
require.NoError(t, err)
client, err := s.GetInternalClient("another-client")
assert.Nil(t, client)
assert.EqualError(t, err, "not_found")
}
func TestOpenIDConnectStore_IsValidClientID(t *testing.T) {
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
IssuerPrivateKey: exampleIssuerPrivateKey,
Clients: []schema.OpenIDConnectClientConfiguration{
{
@ -132,8 +126,6 @@ func TestOpenIDConnectStore_IsValidClientID(t *testing.T) {
},
})
require.NoError(t, err)
validClient := s.IsValidClientID("myclient")
invalidClient := s.IsValidClientID("myinvalidclient")

View File

@ -1,6 +1,7 @@
package regulation
import (
"context"
"fmt"
"time"
@ -11,7 +12,7 @@ import (
)
// NewRegulator create a regulator instance.
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.Provider, clock utils.Clock) *Regulator {
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
regulator := &Regulator{storageProvider: provider}
regulator.clock = clock
@ -40,30 +41,25 @@ func NewRegulator(configuration *schema.RegulationConfiguration, provider storag
return regulator
}
// Mark mark an authentication attempt.
// Mark an authentication attempt.
// We split Mark and Regulate in order to avoid timing attacks.
func (r *Regulator) Mark(username string, successful bool) error {
return r.storageProvider.AppendAuthenticationLog(models.AuthenticationAttempt{
func (r *Regulator) Mark(ctx context.Context, username string, successful bool) error {
return r.storageProvider.AppendAuthenticationLog(ctx, models.AuthenticationAttempt{
Username: username,
Successful: successful,
Time: r.clock.Now(),
})
}
// Regulate regulate the authentication attempts for a given user.
// This method returns ErrUserIsBanned if the user is banned along with the time until when
// the user is banned.
func (r *Regulator) Regulate(username string) (time.Time, error) {
// Regulate the authentication attempts for a given user.
// This method returns ErrUserIsBanned if the user is banned along with the time until when the user is banned.
func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, error) {
// If there is regulation configuration, no regulation applies.
if !r.enabled {
return time.Time{}, nil
}
now := r.clock.Now()
// TODO(c.michaud): make sure FindTime < BanTime.
attempts, err := r.storageProvider.LoadLatestAuthenticationLogs(username, now.Add(-r.banTime))
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0)
if err != nil {
return time.Time{}, nil
}

View File

@ -1,6 +1,7 @@
package regulation_test
import (
"context"
"testing"
"time"
@ -18,6 +19,7 @@ import (
type RegulatorSuite struct {
suite.Suite
ctx context.Context
ctrl *gomock.Controller
storageMock *storage.MockProvider
configuration schema.RegulationConfiguration
@ -27,6 +29,7 @@ type RegulatorSuite struct {
func (s *RegulatorSuite) SetupTest() {
s.ctrl = gomock.NewController(s.T())
s.storageMock = storage.NewMockProvider(s.ctrl)
s.ctx = context.Background()
s.configuration = schema.RegulationConfiguration{
MaxRetries: 3,
@ -50,12 +53,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
}
@ -81,12 +84,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
}
@ -117,12 +120,12 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
}
@ -150,12 +153,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
}
@ -174,12 +177,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
}
@ -206,12 +209,12 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
}
@ -242,12 +245,12 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
}
@ -277,7 +280,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
}
s.storageMock.EXPECT().
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
Return(attemptsInDB, nil)
// Check Disabled Functionality
@ -288,7 +291,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
}
regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
_, err := regulator.Regulate("john")
_, err := regulator.Regulate(s.ctx, "john")
assert.NoError(s.T(), err)
// Check Enabled Functionality
@ -299,6 +302,6 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
}
regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
_, err = regulator.Regulate("john")
_, err = regulator.Regulate(s.ctx, "john")
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
}

View File

@ -18,7 +18,7 @@ type Regulator struct {
// If a user has been banned, this duration is the timelapse during which the user is banned.
banTime time.Duration
storageProvider storage.Provider
storageProvider storage.RegulatorProvider
clock utils.Clock
}

View File

@ -1,39 +1,52 @@
package storage
import (
"fmt"
"regexp"
)
const storageSchemaCurrentVersion = SchemaVersion(1)
const storageSchemaUpgradeMessage = "Storage schema upgraded to v"
const storageSchemaUpgradeErrorText = "storage schema upgrade failed at v"
const (
tableUserPreferences = "user_preferences"
tableIdentityVerification = "identity_verification_tokens"
tableTOTPConfigurations = "totp_configurations"
tableU2FDevices = "u2f_devices"
tableDUODevices = "duo_devices"
tableAuthenticationLogs = "authentication_logs"
tableMigrations = "migrations"
// Keep table names in lower case because some DB does not support upper case.
const userPreferencesTableName = "user_preferences"
const identityVerificationTokensTableName = "identity_verification_tokens"
const totpSecretsTableName = "totp_secrets"
const u2fDeviceHandlesTableName = "u2f_devices"
const authenticationLogsTableName = "authentication_logs"
const configTableName = "config"
tablePrefixBackup = "_bkp_"
)
// sqlUpgradeCreateTableStatements is a map of the schema version number, plus a map of the table name and the statement used to create it.
// The statement is fmt.Sprintf'd with the table name as the first argument.
var sqlUpgradeCreateTableStatements = map[SchemaVersion]map[string]string{
SchemaVersion(1): {
userPreferencesTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, second_factor_method VARCHAR(11))",
identityVerificationTokensTableName: "CREATE TABLE %s (token VARCHAR(512))",
totpSecretsTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, secret VARCHAR(64))",
u2fDeviceHandlesTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, keyHandle TEXT, publicKey TEXT)",
authenticationLogsTableName: "CREATE TABLE %s (username VARCHAR(100), successful BOOL, time INTEGER)",
configTableName: "CREATE TABLE %s (category VARCHAR(32) NOT NULL, key_name VARCHAR(32) NOT NULL, value TEXT, PRIMARY KEY (category, key_name))",
},
}
// WARNING: Do not change/remove these consts. They are used for Pre1 migrations.
const (
tablePre1TOTPSecrets = "totp_secrets"
tablePre1Config = "config"
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
tableAlphaAuthenticationLogs = "AuthenticationLogs"
tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
tableAlphaPreferences = "Preferences"
tableAlphaPreferencesTableName = "PreferencesTableName"
tableAlphaSecondFactorPreferences = "SecondFactorPreferences"
tableAlphaTOTPSecrets = "TOTPSecrets"
tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
)
// sqlUpgradesCreateTableIndexesStatements is a map of t he schema version number, plus a slice of statements to create all of the indexes.
var sqlUpgradesCreateTableIndexesStatements = map[SchemaVersion][]string{
SchemaVersion(1): {
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s (username, time)", authenticationLogsTableName),
},
}
const (
providerAll = "all"
providerMySQL = "mysql"
providerPostgres = "postgres"
providerSQLite = "sqlite"
)
const unitTestUser = "john"
const (
// This is the latest schema version for the purpose of tests.
testLatestVersion = 1
)
const (
// SchemaLatest represents the value expected for a "migrate to latest" migration. It's the maximum 32bit signed integer.
SchemaLatest = 2147483647
)
var (
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`)
)

View File

@ -1,6 +1,8 @@
package storage
import "errors"
import (
"errors"
)
var (
// ErrNoU2FDeviceHandle error thrown when no U2F device handle has been found in DB.
@ -8,4 +10,35 @@ var (
// ErrNoTOTPSecret error thrown when no TOTP secret has been found in DB.
ErrNoTOTPSecret = errors.New("no TOTP secret registered")
// ErrNoAvailableMigrations is returned when no available migrations can be found.
ErrNoAvailableMigrations = errors.New("no available migrations")
// ErrSchemaAlreadyUpToDate is returned when the schema is already up to date.
ErrSchemaAlreadyUpToDate = errors.New("schema already up to date")
// ErrNoMigrationsFound is returned when no migrations were found.
ErrNoMigrationsFound = errors.New("no schema migrations found")
)
// Error formats for the storage provider.
const (
ErrFmtMigrateUpTargetLessThanCurrent = "schema up migration target version %d is less then the current version %d"
ErrFmtMigrateUpTargetGreaterThanLatest = "schema up migration target version %d is greater then the latest version %d which indicates it doesn't exist"
ErrFmtMigrateDownTargetGreaterThanCurrent = "schema down migration target version %d is greater than the current version %d"
ErrFmtMigrateDownTargetLessThanMinimum = "schema down migration target version %d is less than the minimum version"
ErrFmtMigrateAlreadyOnTargetVersion = "schema migration target version %d is the same current version %d"
)
const (
errFmtFailedMigration = "schema migration %d (%s) failed: %w"
errFmtFailedMigrationPre1 = "schema migration pre1 failed: %w"
errFmtSchemaCurrentGreaterThanLatestKnown = "current schema version is greater than the latest known schema " +
"version, you must downgrade to schema version %d before you can use this version of Authelia"
)
const (
logFmtMigrationFromTo = "Storage schema migration from %s to %s is being attempted"
logFmtMigrationComplete = "Storage schema migration from %s to %s is complete"
logFmtErrClosingConn = "Error occurred closing SQL connection: %v"
)

View File

@ -0,0 +1,204 @@
package storage
import (
"embed"
"errors"
"fmt"
"sort"
"strconv"
"strings"
)
//go:embed migrations/*
var migrationsFS embed.FS
func latestMigrationVersion(providerName string) (version int, err error) {
entries, err := migrationsFS.ReadDir("migrations")
if err != nil {
return -1, err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
m, err := scanMigration(entry.Name())
if err != nil {
return -1, err
}
if m.Provider != providerName {
continue
}
if !m.Up {
continue
}
if m.Version > version {
version = m.Version
}
}
return version, nil
}
func loadMigration(providerName string, version int, up bool) (migration *SchemaMigration, err error) {
entries, err := migrationsFS.ReadDir("migrations")
if err != nil {
return nil, err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
m, err := scanMigration(entry.Name())
if err != nil {
return nil, err
}
migration = &m
if up != migration.Up {
continue
}
if migration.Provider != providerAll && migration.Provider != providerName {
continue
}
if version != migration.Version {
continue
}
return migration, nil
}
return nil, errors.New("migration not found")
}
// loadMigrations scans the migrations fs and loads the appropriate migrations for a given providerName, prior and
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
// this indicates the database zero state.
func loadMigrations(providerName string, prior, target int) (migrations []SchemaMigration, err error) {
if prior == target && (prior != -1 || target != -1) {
return nil, errors.New("cannot migrate to the same version as prior")
}
entries, err := migrationsFS.ReadDir("migrations")
if err != nil {
return nil, err
}
up := prior < target
for _, entry := range entries {
if entry.IsDir() {
continue
}
migration, err := scanMigration(entry.Name())
if err != nil {
return nil, err
}
if skipMigration(providerName, up, target, prior, &migration) {
continue
}
migrations = append(migrations, migration)
}
if up {
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
} else {
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version > migrations[j].Version
})
}
return migrations, nil
}
func skipMigration(providerName string, up bool, target, prior int, migration *SchemaMigration) (skip bool) {
if migration.Provider != providerAll && migration.Provider != providerName {
// Skip if migration.Provider is not a match.
return true
}
if up {
if !migration.Up {
// Skip if we wanted an Up migration but it isn't an Up migration.
return true
}
if target != -1 && (migration.Version > target || migration.Version <= prior) {
// Skip if the migration version is greater than the target or less than or equal to the previous version.
return true
}
} else {
if migration.Up {
// Skip if we didn't want an Up migration but it is an Up migration.
return true
}
if migration.Version == 1 && target == -1 {
// Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data
// preventing a successful migration.
return true
}
if migration.Version <= target || migration.Version > prior {
// Skip the migration if we want to go down and the migration version is less than or equal to the target
// or greater than the previous version.
return true
}
}
return false
}
func scanMigration(m string) (migration SchemaMigration, err error) {
result := reMigration.FindStringSubmatch(m)
if result == nil || len(result) != 5 {
return SchemaMigration{}, errors.New("invalid migration: could not parse the format")
}
migration = SchemaMigration{
Name: strings.ReplaceAll(result[2], "_", " "),
Provider: result[3],
}
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
if err != nil {
return SchemaMigration{}, err
}
migration.Query = string(data)
switch result[4] {
case "up":
migration.Up = true
case "down":
migration.Up = false
default:
return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4])
}
migration.Version, _ = strconv.Atoi(result[1])
switch migration.Provider {
case providerAll, providerSQLite, providerMySQL, providerPostgres:
break
default:
return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3])
}
return migration, nil
}

View File

@ -0,0 +1,6 @@
DROP TABLE IF EXISTS authentication_logs;
DROP TABLE IF EXISTS identity_verification_tokens;
DROP TABLE IF EXISTS totp_configurations;
DROP TABLE IF EXISTS u2f_devices;
DROP TABLE IF EXISTS user_preferences;
DROP TABLE IF EXISTS migrations;

View File

@ -0,0 +1,55 @@
CREATE TABLE IF NOT EXISTS authentication_logs (
id INTEGER AUTO_INCREMENT,
time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
successful BOOL NOT NULL,
username VARCHAR(100) NOT NULL,
PRIMARY KEY (id)
);
CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username);
CREATE TABLE IF NOT EXISTS identity_verification_tokens (
id INTEGER AUTO_INCREMENT,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
token VARCHAR(512),
PRIMARY KEY (id),
UNIQUE KEY (token)
);
CREATE TABLE IF NOT EXISTS totp_configurations (
id INTEGER AUTO_INCREMENT,
username VARCHAR(100) NOT NULL,
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
totp_period INTEGER NOT NULL DEFAULT 30,
secret VARCHAR(64) NOT NULL,
PRIMARY KEY (id),
UNIQUE KEY (username)
);
CREATE TABLE IF NOT EXISTS u2f_devices (
id INTEGER AUTO_INCREMENT,
username VARCHAR(100) NOT NULL,
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
key_handle BLOB NOT NULL,
public_key BLOB NOT NULL,
PRIMARY KEY (id),
UNIQUE KEY (username, description)
);
CREATE TABLE IF NOT EXISTS user_preferences (
id INTEGER AUTO_INCREMENT,
username VARCHAR(100) NOT NULL,
second_factor_method VARCHAR(11) NOT NULL,
PRIMARY KEY (id),
UNIQUE KEY (username)
);
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER AUTO_INCREMENT,
applied TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
version_before INTEGER NULL DEFAULT NULL,
version_after INTEGER NOT NULL,
application_version VARCHAR(128) NOT NULL,
PRIMARY KEY (id)
);

View File

@ -0,0 +1,55 @@
CREATE TABLE IF NOT EXISTS authentication_logs (
id SERIAL,
time TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
successful BOOLEAN NOT NULL,
username VARCHAR(100) NOT NULL,
PRIMARY KEY (id)
);
CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username);
CREATE TABLE IF NOT EXISTS identity_verification_tokens (
id SERIAL,
created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
token VARCHAR(512),
PRIMARY KEY (id),
UNIQUE (token)
);
CREATE TABLE IF NOT EXISTS totp_configurations (
id SERIAL,
username VARCHAR(100) NOT NULL,
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
digits INTEGER NOT NULL DEFAULT 6,
totp_period INTEGER NOT NULL DEFAULT 30,
secret VARCHAR(64) NOT NULL,
PRIMARY KEY (id),
UNIQUE (username)
);
CREATE TABLE IF NOT EXISTS u2f_devices (
id SERIAL,
username VARCHAR(100) NOT NULL,
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
key_handle BYTEA NOT NULL,
public_key BYTEA NOT NULL,
PRIMARY KEY (id),
UNIQUE (username, description)
);
CREATE TABLE IF NOT EXISTS user_preferences (
id SERIAL,
username VARCHAR(100) NOT NULL,
second_factor_method VARCHAR(11) NOT NULL,
PRIMARY KEY (id),
UNIQUE (username)
);
CREATE TABLE IF NOT EXISTS migrations (
id SERIAL,
applied TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
version_before INTEGER NULL DEFAULT NULL,
version_after INTEGER NOT NULL,
application_version VARCHAR(128) NOT NULL,
PRIMARY KEY (id)
);

View File

@ -0,0 +1,54 @@
CREATE TABLE IF NOT EXISTS authentication_logs (
id INTEGER,
time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
successful BOOLEAN NOT NULL,
username VARCHAR(100) NOT NULL,
PRIMARY KEY (id)
);
CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username);
CREATE TABLE IF NOT EXISTS identity_verification_tokens (
id INTEGER,
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
token VARCHAR(512),
PRIMARY KEY (id),
UNIQUE (token)
);
CREATE TABLE IF NOT EXISTS totp_configurations (
id INTEGER,
username VARCHAR(100) NOT NULL,
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
digits INTEGER(1) NOT NULL DEFAULT 6,
totp_period INTEGER NOT NULL DEFAULT 30,
secret VARCHAR(64) NOT NULL,
PRIMARY KEY (id),
UNIQUE (username)
);
CREATE TABLE IF NOT EXISTS u2f_devices (
id INTEGER,
username VARCHAR(100) NOT NULL,
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
key_handle BLOB NOT NULL,
public_key BLOB NOT NULL,
PRIMARY KEY (id),
UNIQUE (username, description)
);
CREATE TABLE IF NOT EXISTS user_preferences (
id INTEGER,
username VARCHAR(100) UNIQUE NOT NULL,
second_factor_method VARCHAR(11) NOT NULL,
PRIMARY KEY (id)
);
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER,
applied TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
version_before INTEGER NULL DEFAULT NULL,
version_after INTEGER NOT NULL,
application_version VARCHAR(128) NOT NULL,
PRIMARY KEY (id)
);

View File

@ -0,0 +1,154 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
ver, err := latestMigrationVersion(providerSQLite)
require.NoError(t, err)
assert.Equal(t, testLatestVersion, ver)
migrations, err := loadMigrations(providerSQLite, 0, ver)
require.NoError(t, err)
assert.Len(t, migrations, ver)
for i := 0; i < len(migrations); i++ {
assert.Equal(t, i+1, migrations[i].Version)
}
}
func TestShouldObtainCorrectDownMigrations(t *testing.T) {
ver, err := latestMigrationVersion(providerSQLite)
require.NoError(t, err)
assert.Equal(t, testLatestVersion, ver)
migrations, err := loadMigrations(providerSQLite, ver, 0)
require.NoError(t, err)
assert.Len(t, migrations, ver)
for i := 0; i < len(migrations); i++ {
assert.Equal(t, ver-i, migrations[i].Version)
}
}
func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) {
migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest)
require.NoError(t, err)
require.NotEqual(t, 0, len(migrations))
previousUp := make([]int, len(migrations))
for i, migration := range migrations {
assert.True(t, migration.Up)
if i != 0 {
for _, v := range previousUp {
assert.NotEqual(t, v, migration.Version)
}
}
previousUp = append(previousUp, migration.Version)
}
migrations, err = loadMigrations(providerPostgres, SchemaLatest, 0)
require.NoError(t, err)
require.NotEqual(t, 0, len(migrations))
previousDown := make([]int, len(migrations))
for i, migration := range migrations {
assert.False(t, migration.Up)
if i != 0 {
for _, v := range previousDown {
assert.NotEqual(t, v, migration.Version)
}
}
previousDown = append(previousDown, migration.Version)
}
}
func TestMigrationsShouldNotBeDuplicatedMySQL(t *testing.T) {
migrations, err := loadMigrations(providerMySQL, 0, SchemaLatest)
require.NoError(t, err)
require.NotEqual(t, 0, len(migrations))
previousUp := make([]int, len(migrations))
for i, migration := range migrations {
assert.True(t, migration.Up)
if i != 0 {
for _, v := range previousUp {
assert.NotEqual(t, v, migration.Version)
}
}
previousUp = append(previousUp, migration.Version)
}
migrations, err = loadMigrations(providerMySQL, SchemaLatest, 0)
require.NoError(t, err)
require.NotEqual(t, 0, len(migrations))
previousDown := make([]int, len(migrations))
for i, migration := range migrations {
assert.False(t, migration.Up)
if i != 0 {
for _, v := range previousDown {
assert.NotEqual(t, v, migration.Version)
}
}
previousDown = append(previousDown, migration.Version)
}
}
func TestMigrationsShouldNotBeDuplicatedSQLite(t *testing.T) {
migrations, err := loadMigrations(providerSQLite, 0, SchemaLatest)
require.NoError(t, err)
require.NotEqual(t, 0, len(migrations))
previousUp := make([]int, len(migrations))
for i, migration := range migrations {
assert.True(t, migration.Up)
if i != 0 {
for _, v := range previousUp {
assert.NotEqual(t, v, migration.Version)
}
}
previousUp = append(previousUp, migration.Version)
}
migrations, err = loadMigrations(providerSQLite, SchemaLatest, 0)
require.NoError(t, err)
require.NotEqual(t, 0, len(migrations))
previousDown := make([]int, len(migrations))
for i, migration := range migrations {
assert.False(t, migration.Up)
if i != 0 {
for _, v := range previousDown {
assert.NotEqual(t, v, migration.Version)
}
}
previousDown = append(previousDown, migration.Version)
}
}

View File

@ -1,85 +0,0 @@
package storage
import (
"database/sql"
"fmt"
"time"
_ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// MySQLProvider is a MySQL provider.
type MySQLProvider struct {
SQLProvider
}
// NewMySQLProvider a MySQL provider.
func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProvider {
provider := MySQLProvider{
SQLProvider{
name: "mysql",
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName),
sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName),
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName),
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName),
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName),
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),
sqlGetExistingTables: "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema=database()",
sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName),
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName),
},
}
provider.sqlUpgradesCreateTableStatements[SchemaVersion(1)][authenticationLogsTableName] = "CREATE TABLE %s (username VARCHAR(100), successful BOOL, time INTEGER, INDEX usr_time_idx (username, time))"
connectionString := configuration.Username
if configuration.Password != "" {
connectionString += fmt.Sprintf(":%s", configuration.Password)
}
if connectionString != "" {
connectionString += "@"
}
address := configuration.Host
if configuration.Port > 0 {
address += fmt.Sprintf(":%d", configuration.Port)
}
connectionString += fmt.Sprintf("tcp(%s)", address)
if configuration.Database != "" {
connectionString += fmt.Sprintf("/%s", configuration.Database)
}
connectionString += "?"
connectionString += fmt.Sprintf("timeout=%ds", int32(configuration.Timeout/time.Second))
db, err := sql.Open("mysql", connectionString)
if err != nil {
provider.log.Fatalf("Unable to connect to SQL database: %v", err)
}
if err := provider.initialize(db); err != nil {
provider.log.Fatalf("Unable to initialize SQL database: %v", err)
}
return &provider
}

View File

@ -1,90 +0,0 @@
package storage
import (
"database/sql"
"fmt"
"strings"
"time"
_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// PostgreSQLProvider is a PostgreSQL provider.
type PostgreSQLProvider struct {
SQLProvider
}
// NewPostgreSQLProvider a PostgreSQL provider.
func NewPostgreSQLProvider(configuration schema.PostgreSQLStorageConfiguration) *PostgreSQLProvider {
provider := PostgreSQLProvider{
SQLProvider{
name: "postgres",
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements,
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=$1", userPreferencesTableName),
sqlUpsertSecondFactorPreference: fmt.Sprintf("INSERT INTO %s (username, second_factor_method) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET second_factor_method=$2", userPreferencesTableName),
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=$1)", identityVerificationTokensTableName),
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES ($1)", identityVerificationTokensTableName),
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=$1", identityVerificationTokensTableName),
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=$1", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("INSERT INTO %s (username, secret) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET secret=$2", totpSecretsTableName),
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=$1", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=$1", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, keyHandle, publicKey) VALUES ($1, $2, $3) ON CONFLICT (username) DO UPDATE SET keyHandle=$2, publicKey=$3", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES ($1, $2, $3)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>$1 AND username=$2 ORDER BY time DESC", authenticationLogsTableName),
sqlGetExistingTables: "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema='public'",
sqlConfigSetValue: fmt.Sprintf("INSERT INTO %s (category, key_name, value) VALUES ($1, $2, $3) ON CONFLICT (category, key_name) DO UPDATE SET value=$3", configTableName),
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=$1 AND key_name=$2", configTableName),
},
}
args := make([]string, 0)
if configuration.Username != "" {
args = append(args, fmt.Sprintf("user='%s'", configuration.Username))
}
if configuration.Password != "" {
args = append(args, fmt.Sprintf("password='%s'", configuration.Password))
}
if configuration.Host != "" {
args = append(args, fmt.Sprintf("host=%s", configuration.Host))
}
if configuration.Port > 0 {
args = append(args, fmt.Sprintf("port=%d", configuration.Port))
}
if configuration.Database != "" {
args = append(args, fmt.Sprintf("dbname=%s", configuration.Database))
}
if configuration.SSLMode != "" {
args = append(args, fmt.Sprintf("sslmode=%s", configuration.SSLMode))
}
args = append(args, fmt.Sprintf("connect_timeout=%d", int32(configuration.Timeout/time.Second)))
connectionString := strings.Join(args, " ")
db, err := sql.Open("pgx", connectionString)
if err != nil {
provider.log.Fatalf("Unable to connect to SQL database: %v", err)
}
if err := provider.initialize(db); err != nil {
provider.log.Fatalf("Unable to initialize SQL database: %v", err)
}
return &provider
}

View File

@ -1,28 +1,45 @@
package storage
import (
"context"
"time"
"github.com/authelia/authelia/v4/internal/models"
)
// Provider is an interface providing storage capabilities for
// persisting any kind of data related to Authelia.
// Provider is an interface providing storage capabilities for persisting any kind of data related to Authelia.
type Provider interface {
LoadPreferred2FAMethod(username string) (string, error)
SavePreferred2FAMethod(username string, method string) error
models.StartupCheck
FindIdentityVerificationToken(token string) (bool, error)
SaveIdentityVerificationToken(token string) error
RemoveIdentityVerificationToken(token string) error
RegulatorProvider
SaveTOTPSecret(username string, secret string) error
LoadTOTPSecret(username string) (string, error)
DeleteTOTPSecret(username string) error
SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error)
LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error)
LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error)
SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error
LoadU2FDeviceHandle(username string) (keyHandle []byte, publicKey []byte, err error)
SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error)
RemoveIdentityVerification(ctx context.Context, jti string) (err error)
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
AppendAuthenticationLog(attempt models.AuthenticationAttempt) error
LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error)
SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error)
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error)
SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error)
LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error)
SchemaTables(ctx context.Context) (tables []string, err error)
SchemaVersion(ctx context.Context) (version int, err error)
SchemaMigrate(ctx context.Context, up bool, version int) (err error)
SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error)
SchemaLatestVersion() (version int, err error)
SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error)
SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error)
}
// RegulatorProvider is an interface providing storage capabilities for persisting any kind of data related to the regulator.
type RegulatorProvider interface {
AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error)
LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error)
}

View File

@ -1,10 +1,10 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: internal/storage/provider.go
// Source: ./internal/storage/provider.go
// Package storage is a generated GoMock package.
package storage
import (
context "context"
reflect "reflect"
time "time"
@ -13,199 +13,331 @@ import (
models "github.com/authelia/authelia/v4/internal/models"
)
// MockProvider is a mock of Provider interface
// MockProvider is a mock of Provider interface.
type MockProvider struct {
ctrl *gomock.Controller
recorder *MockProviderMockRecorder
}
// MockProviderMockRecorder is the mock recorder for MockProvider
// MockProviderMockRecorder is the mock recorder for MockProvider.
type MockProviderMockRecorder struct {
mock *MockProvider
}
// NewMockProvider creates a new mock instance
// NewMockProvider creates a new mock instance.
func NewMockProvider(ctrl *gomock.Controller) *MockProvider {
mock := &MockProvider{ctrl: ctrl}
mock.recorder = &MockProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProvider) EXPECT() *MockProviderMockRecorder {
return m.recorder
}
// LoadPreferred2FAMethod mocks base method
func (m *MockProvider) LoadPreferred2FAMethod(username string) (string, error) {
// AppendAuthenticationLog mocks base method.
func (m *MockProvider) AppendAuthenticationLog(arg0 context.Context, arg1 models.AuthenticationAttempt) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", username)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod
func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), username)
}
// SavePreferred2FAMethod mocks base method
func (m *MockProvider) SavePreferred2FAMethod(username, method string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePreferred2FAMethod", username, method)
ret := m.ctrl.Call(m, "AppendAuthenticationLog", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod
func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(username, method interface{}) *gomock.Call {
// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog.
func (mr *MockProviderMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), username, method)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), arg0, arg1)
}
// FindIdentityVerificationToken mocks base method
func (m *MockProvider) FindIdentityVerificationToken(token string) (bool, error) {
// DeleteTOTPConfiguration mocks base method.
func (m *MockProvider) DeleteTOTPConfiguration(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindIdentityVerificationToken", token)
ret := m.ctrl.Call(m, "DeleteTOTPConfiguration", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTOTPConfiguration indicates an expected call of DeleteTOTPConfiguration.
func (mr *MockProviderMockRecorder) DeleteTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPConfiguration), arg0, arg1)
}
// FindIdentityVerification mocks base method.
func (m *MockProvider) FindIdentityVerification(arg0 context.Context, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FindIdentityVerification", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken
func (mr *MockProviderMockRecorder) FindIdentityVerificationToken(token interface{}) *gomock.Call {
// FindIdentityVerification indicates an expected call of FindIdentityVerification.
func (mr *MockProviderMockRecorder) FindIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerificationToken), token)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerification", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerification), arg0, arg1)
}
// SaveIdentityVerificationToken mocks base method
func (m *MockProvider) SaveIdentityVerificationToken(token string) error {
// LoadAuthenticationLogs mocks base method.
func (m *MockProvider) LoadAuthenticationLogs(arg0 context.Context, arg1 string, arg2 time.Time, arg3, arg4 int) ([]models.AuthenticationAttempt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", token)
ret0, _ := ret[0].(error)
return ret0
}
// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken
func (mr *MockProviderMockRecorder) SaveIdentityVerificationToken(token interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerificationToken), token)
}
// RemoveIdentityVerificationToken mocks base method
func (m *MockProvider) RemoveIdentityVerificationToken(token string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", token)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken
func (mr *MockProviderMockRecorder) RemoveIdentityVerificationToken(token interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerificationToken), token)
}
// SaveTOTPSecret mocks base method
func (m *MockProvider) SaveTOTPSecret(username, secret string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveTOTPSecret", username, secret)
ret0, _ := ret[0].(error)
return ret0
}
// SaveTOTPSecret indicates an expected call of SaveTOTPSecret
func (mr *MockProviderMockRecorder) SaveTOTPSecret(username, secret interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockProvider)(nil).SaveTOTPSecret), username, secret)
}
// LoadTOTPSecret mocks base method
func (m *MockProvider) LoadTOTPSecret(username string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadTOTPSecret", username)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadTOTPSecret indicates an expected call of LoadTOTPSecret
func (mr *MockProviderMockRecorder) LoadTOTPSecret(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockProvider)(nil).LoadTOTPSecret), username)
}
// DeleteTOTPSecret mocks base method
func (m *MockProvider) DeleteTOTPSecret(username string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTOTPSecret", username)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTOTPSecret indicates an expected call of DeleteTOTPSecret
func (mr *MockProviderMockRecorder) DeleteTOTPSecret(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPSecret", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPSecret), username)
}
// SaveU2FDeviceHandle mocks base method
func (m *MockProvider) SaveU2FDeviceHandle(username string, keyHandle, publicKey []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", username, keyHandle, publicKey)
ret0, _ := ret[0].(error)
return ret0
}
// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle
func (mr *MockProviderMockRecorder) SaveU2FDeviceHandle(username, keyHandle, publicKey interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).SaveU2FDeviceHandle), username, keyHandle, publicKey)
}
// LoadU2FDeviceHandle mocks base method
func (m *MockProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", username)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].([]byte)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle
func (mr *MockProviderMockRecorder) LoadU2FDeviceHandle(username interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).LoadU2FDeviceHandle), username)
}
// AppendAuthenticationLog mocks base method
func (m *MockProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AppendAuthenticationLog", attempt)
ret0, _ := ret[0].(error)
return ret0
}
// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog
func (mr *MockProviderMockRecorder) AppendAuthenticationLog(attempt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), attempt)
}
// LoadLatestAuthenticationLogs mocks base method
func (m *MockProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", username, fromDate)
ret := m.ctrl.Call(m, "LoadAuthenticationLogs", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].([]models.AuthenticationAttempt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs
func (mr *MockProviderMockRecorder) LoadLatestAuthenticationLogs(username, fromDate interface{}) *gomock.Call {
// LoadAuthenticationLogs indicates an expected call of LoadAuthenticationLogs.
func (mr *MockProviderMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadLatestAuthenticationLogs), username, fromDate)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4)
}
// LoadPreferred2FAMethod mocks base method.
func (m *MockProvider) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod.
func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), arg0, arg1)
}
// LoadTOTPConfiguration mocks base method.
func (m *MockProvider) LoadTOTPConfiguration(arg0 context.Context, arg1 string) (*models.TOTPConfiguration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadTOTPConfiguration", arg0, arg1)
ret0, _ := ret[0].(*models.TOTPConfiguration)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadTOTPConfiguration indicates an expected call of LoadTOTPConfiguration.
func (mr *MockProviderMockRecorder) LoadTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).LoadTOTPConfiguration), arg0, arg1)
}
// LoadU2FDevice mocks base method.
func (m *MockProvider) LoadU2FDevice(arg0 context.Context, arg1 string) (*models.U2FDevice, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadU2FDevice", arg0, arg1)
ret0, _ := ret[0].(*models.U2FDevice)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadU2FDevice indicates an expected call of LoadU2FDevice.
func (mr *MockProviderMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockProvider)(nil).LoadU2FDevice), arg0, arg1)
}
// LoadUserInfo mocks base method.
func (m *MockProvider) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadUserInfo", arg0, arg1)
ret0, _ := ret[0].(models.UserInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadUserInfo indicates an expected call of LoadUserInfo.
func (mr *MockProviderMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockProvider)(nil).LoadUserInfo), arg0, arg1)
}
// RemoveIdentityVerification mocks base method.
func (m *MockProvider) RemoveIdentityVerification(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification.
func (mr *MockProviderMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerification), arg0, arg1)
}
// SaveIdentityVerification mocks base method.
func (m *MockProvider) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveIdentityVerification", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveIdentityVerification indicates an expected call of SaveIdentityVerification.
func (mr *MockProviderMockRecorder) SaveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerification), arg0, arg1)
}
// SavePreferred2FAMethod mocks base method.
func (m *MockProvider) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePreferred2FAMethod", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod.
func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), arg0, arg1, arg2)
}
// SaveTOTPConfiguration mocks base method.
func (m *MockProvider) SaveTOTPConfiguration(arg0 context.Context, arg1 models.TOTPConfiguration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveTOTPConfiguration", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveTOTPConfiguration indicates an expected call of SaveTOTPConfiguration.
func (mr *MockProviderMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).SaveTOTPConfiguration), arg0, arg1)
}
// SaveU2FDevice mocks base method.
func (m *MockProvider) SaveU2FDevice(arg0 context.Context, arg1 models.U2FDevice) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveU2FDevice", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SaveU2FDevice indicates an expected call of SaveU2FDevice.
func (mr *MockProviderMockRecorder) SaveU2FDevice(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDevice", reflect.TypeOf((*MockProvider)(nil).SaveU2FDevice), arg0, arg1)
}
// SchemaLatestVersion mocks base method.
func (m *MockProvider) SchemaLatestVersion() (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaLatestVersion")
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SchemaLatestVersion indicates an expected call of SchemaLatestVersion.
func (mr *MockProviderMockRecorder) SchemaLatestVersion() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaLatestVersion", reflect.TypeOf((*MockProvider)(nil).SchemaLatestVersion))
}
// SchemaMigrate mocks base method.
func (m *MockProvider) SchemaMigrate(arg0 context.Context, arg1 bool, arg2 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaMigrate", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SchemaMigrate indicates an expected call of SchemaMigrate.
func (mr *MockProviderMockRecorder) SchemaMigrate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrate", reflect.TypeOf((*MockProvider)(nil).SchemaMigrate), arg0, arg1, arg2)
}
// SchemaMigrationHistory mocks base method.
func (m *MockProvider) SchemaMigrationHistory(arg0 context.Context) ([]models.Migration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaMigrationHistory", arg0)
ret0, _ := ret[0].([]models.Migration)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SchemaMigrationHistory indicates an expected call of SchemaMigrationHistory.
func (mr *MockProviderMockRecorder) SchemaMigrationHistory(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationHistory", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationHistory), arg0)
}
// SchemaMigrationsDown mocks base method.
func (m *MockProvider) SchemaMigrationsDown(arg0 context.Context, arg1 int) ([]SchemaMigration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaMigrationsDown", arg0, arg1)
ret0, _ := ret[0].([]SchemaMigration)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SchemaMigrationsDown indicates an expected call of SchemaMigrationsDown.
func (mr *MockProviderMockRecorder) SchemaMigrationsDown(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsDown", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsDown), arg0, arg1)
}
// SchemaMigrationsUp mocks base method.
func (m *MockProvider) SchemaMigrationsUp(arg0 context.Context, arg1 int) ([]SchemaMigration, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaMigrationsUp", arg0, arg1)
ret0, _ := ret[0].([]SchemaMigration)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SchemaMigrationsUp indicates an expected call of SchemaMigrationsUp.
func (mr *MockProviderMockRecorder) SchemaMigrationsUp(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsUp", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsUp), arg0, arg1)
}
// SchemaTables mocks base method.
func (m *MockProvider) SchemaTables(arg0 context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaTables", arg0)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SchemaTables indicates an expected call of SchemaTables.
func (mr *MockProviderMockRecorder) SchemaTables(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaTables", reflect.TypeOf((*MockProvider)(nil).SchemaTables), arg0)
}
// SchemaVersion mocks base method.
func (m *MockProvider) SchemaVersion(arg0 context.Context) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SchemaVersion", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SchemaVersion indicates an expected call of SchemaVersion.
func (mr *MockProviderMockRecorder) SchemaVersion(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaVersion", reflect.TypeOf((*MockProvider)(nil).SchemaVersion), arg0)
}
// StartupCheck mocks base method.
func (m *MockProvider) StartupCheck() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StartupCheck")
ret0, _ := ret[0].(error)
return ret0
}
// StartupCheck indicates an expected call of StartupCheck.
func (mr *MockProviderMockRecorder) StartupCheck() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockProvider)(nil).StartupCheck))
}

View File

@ -1,173 +1,199 @@
package storage
import (
"context"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/logging"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/utils"
)
// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's.
func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) {
db, err := sqlx.Open(driverName, dataSourceName)
provider = SQLProvider{
name: name,
driverName: driverName,
db: db,
log: logging.Logger(),
errOpen: err,
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences),
sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableUserPreferences),
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
sqlFmtRenameTable: queryFmtRenameTable,
}
return provider
}
// SQLProvider is a storage provider persisting data in a SQL database.
type SQLProvider struct {
db *sql.DB
log *logrus.Logger
name string
db *sqlx.DB
log *logrus.Logger
name string
driverName string
errOpen error
sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string
sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string
// Table: authentication_logs.
sqlInsertAuthenticationAttempt string
sqlSelectAuthenticationAttemptsByUsername string
sqlGetPreferencesByUsername string
sqlUpsertSecondFactorPreference string
// Table: identity_verification_tokens.
sqlInsertIdentityVerification string
sqlDeleteIdentityVerification string
sqlSelectExistsIdentityVerification string
sqlTestIdentityVerificationTokenExistence string
sqlInsertIdentityVerificationToken string
sqlDeleteIdentityVerificationToken string
// Table: totp_configurations.
sqlUpsertTOTPConfig string
sqlDeleteTOTPConfig string
sqlSelectTOTPConfig string
sqlGetTOTPSecretByUsername string
sqlUpsertTOTPSecret string
sqlDeleteTOTPSecret string
// Table: u2f_devices.
sqlUpsertU2FDevice string
sqlSelectU2FDevice string
sqlGetU2FDeviceHandleByUsername string
sqlUpsertU2FDeviceHandle string
// Table: user_preferences.
sqlUpsertPreferred2FAMethod string
sqlSelectPreferred2FAMethod string
sqlSelectUserInfo string
sqlInsertAuthenticationLog string
sqlGetLatestAuthenticationLogs string
// Table: migrations.
sqlInsertMigration string
sqlSelectMigrations string
sqlSelectLatestMigration string
sqlGetExistingTables string
sqlConfigSetValue string
sqlConfigGetValue string
// Utility.
sqlSelectExistingTables string
sqlFmtRenameTable string
}
func (p *SQLProvider) initialize(db *sql.DB) error {
p.db = db
p.log = logging.Logger()
return p.upgrade()
}
func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) {
rows, err := p.db.Query(p.sqlGetExistingTables)
if err != nil {
return version, tables, err
// StartupCheck implements the provider startup check interface.
func (p *SQLProvider) StartupCheck() (err error) {
if p.errOpen != nil {
return p.errOpen
}
defer rows.Close()
var table string
for rows.Next() {
err := rows.Scan(&table)
if err != nil {
return version, tables, err
// TODO: Decide if this is needed, or if it should be configurable.
for i := 0; i < 19; i++ {
err = p.db.Ping()
if err == nil {
break
}
tables = append(tables, table)
time.Sleep(time.Millisecond * 500)
}
if utils.IsStringInSlice(configTableName, tables) {
rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version")
if err != nil {
return version, tables, err
}
for rows.Next() {
err := rows.Scan(&version)
if err != nil {
return version, tables, err
}
}
}
return version, tables, nil
}
func (p *SQLProvider) upgrade() error {
p.log.Debug("Storage schema is being checked to verify it is up to date")
version, tables, err := p.getSchemaBasicDetails()
if err != nil {
return err
}
if version < storageSchemaCurrentVersion {
p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion)
p.log.Infof("Storage schema is being checked for updates")
tx, err := p.db.Begin()
if err != nil {
return err
}
ctx := context.Background()
switch version {
case 0:
err := p.upgradeSchemaToVersion001(tx, tables)
if err != nil {
return p.handleUpgradeFailure(tx, 1, err)
}
err = p.SchemaMigrate(ctx, true, SchemaLatest)
fallthrough
default:
err := tx.Commit()
if err != nil {
return err
}
p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion)
}
} else {
p.log.Debug("Storage schema is up to date")
switch err {
case ErrSchemaAlreadyUpToDate:
p.log.Infof("Storage schema is already up to date")
return nil
case nil:
return nil
default:
return err
}
return nil
}
func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error {
rollbackErr := tx.Rollback()
formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err)
if rollbackErr != nil {
return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr)
}
return formattedErr
}
// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) {
var method string
rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username)
if err != nil {
return "", err
}
defer rows.Close()
if !rows.Next() {
return "", nil
}
err = rows.Scan(&method)
return method, err
}
// SavePreferred2FAMethod save the preferred method for 2FA to the database.
func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error {
_, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method)
func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method)
return err
}
// FindIdentityVerificationToken look for an identity verification token in the database.
func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) {
var found bool
// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found)
switch err {
case sql.ErrNoRows:
return "", nil
case nil:
return method, err
default:
return "", err
}
}
// LoadUserInfo loads the models.UserInfo from the database.
func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) {
err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
switch {
case err == nil:
return info, nil
case errors.Is(err, sql.ErrNoRows):
_, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0])
if err != nil {
return models.UserInfo{}, err
}
err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
if err != nil {
return models.UserInfo{}, err
}
return info, nil
default:
return models.UserInfo{}, err
}
}
// SaveIdentityVerification save an identity verification record to the database.
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token)
return err
}
// RemoveIdentityVerification remove an identity verification record from the database.
func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token)
return err
}
// FindIdentityVerification checks if an identity verification record is in the database and active.
func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti)
if err != nil {
return false, err
}
@ -175,105 +201,94 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error)
return found, nil
}
// SaveIdentityVerificationToken save an identity verification token in the database.
func (p *SQLProvider) SaveIdentityVerificationToken(token string) error {
_, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token)
// SaveTOTPConfiguration save a TOTP config of a given user in the database.
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) {
// TODO: Encrypt config.Secret here.
_, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
config.Username,
config.Algorithm,
config.Digits,
config.Period,
config.Secret,
)
return err
}
// RemoveIdentityVerificationToken remove an identity verification token from the database.
func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error {
_, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token)
// DeleteTOTPConfiguration delete a TOTP secret from the database given a username.
func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username)
return err
}
// SaveTOTPSecret save a TOTP secret of a given user in the database.
func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error {
_, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret)
return err
}
// LoadTOTPConfiguration load a TOTP secret given a username from the database.
func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) {
config = &models.TOTPConfiguration{}
// LoadTOTPSecret load a TOTP secret given a username from the database.
func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
var secret string
if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil {
err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config)
if err != nil {
if err == sql.ErrNoRows {
return "", ErrNoTOTPSecret
return nil, ErrNoTOTPSecret
}
return "", err
return nil, err
}
return secret, nil
// TODO: Decrypt config.Secret here.
return config, nil
}
// DeleteTOTPSecret delete a TOTP secret from the database given a username.
func (p *SQLProvider) DeleteTOTPSecret(username string) error {
_, err := p.db.Exec(p.sqlDeleteTOTPSecret, username)
return err
}
// SaveU2FDeviceHandle save a registered U2F device registration blob.
func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error {
_, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle,
username,
base64.StdEncoding.EncodeToString(keyHandle),
base64.StdEncoding.EncodeToString(publicKey))
// SaveU2FDevice saves a registered U2F device.
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey)
return err
}
// LoadU2FDeviceHandle load a U2F device registration blob for a given username.
func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) {
var keyHandleBase64, publicKeyBase64 string
if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil {
// LoadU2FDevice loads a U2F device registration for a given username.
func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
device = &models.U2FDevice{
Username: username,
}
err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, ErrNoU2FDeviceHandle
return nil, ErrNoU2FDeviceHandle
}
return nil, nil, err
return nil, err
}
keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
if err != nil {
return nil, nil, err
}
publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
if err != nil {
return nil, nil, err
}
return keyHandle, publicKey, nil
return device, nil
}
// AppendAuthenticationLog append a mark to the authentication log.
func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error {
_, err := p.db.Exec(p.sqlInsertAuthenticationLog, attempt.Username, attempt.Successful, attempt.Time.Unix())
func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username)
return err
}
// LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log.
func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
var t int64
rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username)
// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log.
func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) {
rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page)
if err != nil {
return nil, err
}
attempts := make([]models.AuthenticationAttempt, 0, 10)
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
attempts = make([]models.AuthenticationAttempt, 0, limit)
for rows.Next() {
attempt := models.AuthenticationAttempt{
Username: username,
}
err = rows.Scan(&attempt.Successful, &t)
attempt.Time = time.Unix(t, 0)
var attempt models.AuthenticationAttempt
err = rows.StructScan(&attempt)
if err != nil {
return nil, err
}

View File

@ -0,0 +1,53 @@
package storage
import (
"fmt"
"time"
_ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// MySQLProvider is a MySQL provider.
type MySQLProvider struct {
SQLProvider
}
// NewMySQLProvider a MySQL provider.
func NewMySQLProvider(config schema.MySQLStorageConfiguration) (provider *MySQLProvider) {
provider = &MySQLProvider{
SQLProvider: NewSQLProvider(providerMySQL, providerMySQL, dataSourceNameMySQL(config)),
}
// All providers have differing SELECT existing table statements.
provider.sqlSelectExistingTables = queryMySQLSelectExistingTables
// Specific alterations to this provider.
provider.sqlFmtRenameTable = queryFmtMySQLRenameTable
return provider
}
func dataSourceNameMySQL(config schema.MySQLStorageConfiguration) (dataSourceName string) {
dataSourceName = fmt.Sprintf("%s:%s", config.Username, config.Password)
if dataSourceName != "" {
dataSourceName += "@"
}
address := config.Host
if config.Port > 0 {
address += fmt.Sprintf(":%d", config.Port)
}
dataSourceName += fmt.Sprintf("tcp(%s)", address)
if config.Database != "" {
dataSourceName += fmt.Sprintf("/%s", config.Database)
}
dataSourceName += "?"
dataSourceName += fmt.Sprintf("timeout=%ds&multiStatements=true&parseTime=true", int32(config.Timeout/time.Second))
return dataSourceName
}

View File

@ -0,0 +1,72 @@
package storage
import (
"fmt"
"strings"
"time"
_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
"github.com/authelia/authelia/v4/internal/configuration/schema"
)
// PostgreSQLProvider is a PostgreSQL provider.
type PostgreSQLProvider struct {
SQLProvider
}
// NewPostgreSQLProvider a PostgreSQL provider.
func NewPostgreSQLProvider(config schema.PostgreSQLStorageConfiguration) (provider *PostgreSQLProvider) {
provider = &PostgreSQLProvider{
SQLProvider: NewSQLProvider(providerPostgres, "pgx", dataSourceNamePostgreSQL(config)),
}
// All providers have differing SELECT existing table statements.
provider.sqlSelectExistingTables = queryPostgreSelectExistingTables
// Specific alterations to this provider.
// PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead.
provider.sqlUpsertU2FDevice = fmt.Sprintf(queryFmtPostgresUpsertU2FDevice, tableU2FDevices)
provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations)
provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences)
// PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders.
provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable)
provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod)
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
provider.sqlUpsertTOTPConfig = provider.db.Rebind(provider.sqlUpsertTOTPConfig)
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername)
provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration)
return provider
}
func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) {
args := []string{
fmt.Sprintf("user='%s'", config.Username),
fmt.Sprintf("password='%s'", config.Password),
}
if config.Host != "" {
args = append(args, fmt.Sprintf("host=%s", config.Host))
}
if config.Port > 0 {
args = append(args, fmt.Sprintf("port=%d", config.Port))
}
if config.Database != "" {
args = append(args, fmt.Sprintf("dbname=%s", config.Database))
}
args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second)))
return strings.Join(args, " ")
}

View File

@ -0,0 +1,22 @@
package storage
import (
_ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string.
)
// SQLiteProvider is a SQLite3 provider.
type SQLiteProvider struct {
SQLProvider
}
// NewSQLiteProvider constructs a SQLite provider.
func NewSQLiteProvider(path string) (provider *SQLiteProvider) {
provider = &SQLiteProvider{
SQLProvider: NewSQLProvider(providerSQLite, "sqlite3", path),
}
// All providers have differing SELECT existing table statements.
provider.sqlSelectExistingTables = querySQLiteSelectExistingTables
return provider
}

View File

@ -0,0 +1,125 @@
package storage
const (
queryFmtSelectMigrations = `
SELECT id, applied, version_before, version_after, application_version
FROM %s;`
queryFmtSelectLatestMigration = `
SELECT id, applied, version_before, version_after, application_version
FROM %s
ORDER BY id DESC
LIMIT 1;`
queryFmtInsertMigration = `
INSERT INTO %s (applied, version_before, version_after, application_version)
VALUES (?, ?, ?, ?);`
)
const (
queryMySQLSelectExistingTables = `
SELECT table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE' AND table_schema = database();`
queryPostgreSelectExistingTables = `
SELECT table_name
FROM information_schema.tables
WHERE table_type = 'BASE TABLE' AND table_schema = 'public';`
querySQLiteSelectExistingTables = `
SELECT name
FROM sqlite_master
WHERE type = 'table';`
)
const (
queryFmtSelectUserInfo = `
SELECT second_factor_method, (SELECT EXISTS (SELECT id FROM %s WHERE username = ?)) AS has_totp, (SELECT EXISTS (SELECT id FROM %s WHERE username = ?)) AS has_u2f
FROM %s
WHERE username = ?;`
queryFmtSelectPreferred2FAMethod = `
SELECT second_factor_method
FROM %s
WHERE username = ?;`
queryFmtUpsertPreferred2FAMethod = `
REPLACE INTO %s (username, second_factor_method)
VALUES (?, ?);`
queryFmtPostgresUpsertPreferred2FAMethod = `
INSERT INTO %s (username, second_factor_method)
VALUES ($1, $2)
ON CONFLICT (username)
DO UPDATE SET second_factor_method = $2;`
)
const (
queryFmtSelectExistsIdentityVerification = `
SELECT EXISTS (
SELECT id
FROM %s
WHERE token = ?
);`
queryFmtInsertIdentityVerification = `
INSERT INTO %s (token)
VALUES (?);`
queryFmtDeleteIdentityVerification = `
DELETE FROM %s
WHERE token = ?;`
)
const (
queryFmtSelectTOTPConfiguration = `
SELECT id, username, algorithm, digits, totp_period, secret
FROM %s
WHERE username = ?;`
queryFmtUpsertTOTPConfiguration = `
REPLACE INTO %s (username, algorithm, digits, totp_period, secret)
VALUES (?, ?, ?, ?, ?);`
queryFmtPostgresUpsertTOTPConfiguration = `
INSERT INTO %s (username, algorithm, digits, totp_period, secret)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (username)
DO UPDATE SET algorithm = $2, digits = $3, totp_period = $4, secret = $5;`
queryFmtDeleteTOTPConfiguration = `
DELETE FROM %s
WHERE username = ?;`
)
const (
queryFmtSelectU2FDevice = `
SELECT key_handle, public_key
FROM %s
WHERE username = ?;`
queryFmtUpsertU2FDevice = `
REPLACE INTO %s (username, key_handle, public_key)
VALUES (?, ?, ?);`
queryFmtPostgresUpsertU2FDevice = `
INSERT INTO %s (username, key_handle, public_key)
VALUES ($1, $2, $3)
ON CONFLICT (username)
DO UPDATE SET key_handle=$2, public_key=$3;`
)
const (
queryFmtInsertAuthenticationLogEntry = `
INSERT INTO %s (time, successful, username)
VALUES (?, ?, ?);`
queryFmtSelect1FAAuthenticationLogEntryByUsername = `
SELECT time, successful, username
FROM %s
WHERE time > ? AND username = ?
ORDER BY time DESC
LIMIT ?
OFFSET ?;`
)

View File

@ -0,0 +1,109 @@
package storage
const (
queryFmtDropTableIfExists = `DROP TABLE IF EXISTS %s;`
queryFmtRenameTable = `
ALTER TABLE %s
RENAME TO %s;`
queryFmtMySQLRenameTable = `
ALTER TABLE %s
RENAME %s;`
)
// Pre1 migration constants.
const (
queryFmtPre1To1SelectAuthenticationLogs = `
SELECT username, successful, time
FROM %s
ORDER BY time ASC
LIMIT 100 OFFSET ?;`
queryFmtPre1To1InsertAuthenticationLogs = `
INSERT INTO %s (username, successful, time)
VALUES (?, ?, ?);`
queryFmtPre1InsertUserPreferencesFromSelect = `
INSERT INTO %s (username, second_factor_method)
SELECT username, second_factor_method
FROM %s
ORDER BY username ASC;`
queryFmtPre1SelectTOTPConfigurations = `
SELECT username, secret
FROM %s
ORDER BY username ASC;`
queryFmtPre1InsertTOTPConfiguration = `
INSERT INTO %s (username, secret)
VALUES (?, ?);`
queryFmtPre1To1SelectU2FDevices = `
SELECT username, keyHandle, publicKey
FROM %s
ORDER BY username ASC;`
queryFmtPre1To1InsertU2FDevice = `
INSERT INTO %s (username, key_handle, public_key)
VALUES (?, ?, ?);`
queryFmt1ToPre1InsertAuthenticationLogs = `
INSERT INTO %s (username, successful, time)
VALUES (?, ?, ?);`
queryFmt1ToPre1SelectAuthenticationLogs = `
SELECT username, successful, time
FROM %s
ORDER BY id ASC
LIMIT 100 OFFSET ?;`
queryFmt1ToPre1SelectU2FDevices = `
SELECT username, key_handle, public_key
FROM %s
ORDER BY username ASC;`
queryFmt1ToPre1InsertU2FDevice = `
INSERT INTO %s (username, keyHandle, publicKey)
VALUES (?, ?, ?);`
queryCreatePre1 = `
CREATE TABLE user_preferences (
username VARCHAR(100),
second_factor_method VARCHAR(11),
PRIMARY KEY (username)
);
CREATE TABLE identity_verification_tokens (
token VARCHAR(512)
);
CREATE TABLE totp_secrets (
username VARCHAR(100),
secret VARCHAR(64),
PRIMARY KEY (username)
);
CREATE TABLE u2f_devices (
username VARCHAR(100),
keyHandle TEXT,
publicKey TEXT,
PRIMARY KEY (username)
);
CREATE TABLE authentication_logs (
username VARCHAR(100),
successful BOOL,
time INTEGER
);
CREATE TABLE config (
category VARCHAR(32) NOT NULL,
key_name VARCHAR(32) NOT NULL,
value TEXT,
PRIMARY KEY (category, key_name)
);
INSERT INTO config (category, key_name, value)
VALUES ('schema', 'version', '1');`
)

View File

@ -0,0 +1,327 @@
package storage
import (
"context"
"fmt"
"strconv"
"time"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/utils"
)
// SchemaTables returns a list of tables.
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
if err != nil {
return tables, err
}
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
var table string
for rows.Next() {
err = rows.Scan(&table)
if err != nil {
return []string{}, err
}
tables = append(tables, table)
}
return tables, nil
}
// SchemaVersion returns the version of the schema.
func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error) {
tables, err := p.SchemaTables(ctx)
if err != nil {
return -2, err
}
if len(tables) == 0 {
return 0, nil
}
if utils.IsStringInSlice(tableMigrations, tables) {
migration, err := p.schemaLatestMigration(ctx)
if err != nil {
return -2, err
}
return migration.After, nil
}
if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) &&
utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) &&
utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) {
return -1, nil
}
// TODO: Decide if we want to support external tables.
// return -2, ErrUnknownSchemaState
return 0, nil
}
func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *models.Migration, err error) {
migration = &models.Migration{}
err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration)
if err != nil {
return nil, err
}
return migration, nil
}
// SchemaMigrationHistory returns migration history rows.
func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error) {
rows, err := p.db.QueryxContext(ctx, p.sqlSelectMigrations)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
var migration models.Migration
for rows.Next() {
err = rows.StructScan(&migration)
if err != nil {
return nil, err
}
migrations = append(migrations, migration)
}
return migrations, nil
}
// SchemaMigrate migrates from the current version to the provided version.
func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) {
currentVersion, err := p.SchemaVersion(ctx)
if err != nil {
return err
}
if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
return err
}
return p.schemaMigrate(ctx, currentVersion, version)
}
func (p *SQLProvider) schemaMigrate(ctx context.Context, prior, target int) (err error) {
migrations, err := loadMigrations(p.name, prior, target)
if err != nil {
return err
}
if len(migrations) == 0 {
return ErrNoMigrationsFound
}
switch {
case prior == -1:
p.log.Infof(logFmtMigrationFromTo, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
err = p.schemaMigratePre1To1(ctx)
if err != nil {
if errRollback := p.schemaMigratePre1To1Rollback(ctx, true); errRollback != nil {
return fmt.Errorf(errFmtFailedMigrationPre1, err)
}
return fmt.Errorf(errFmtFailedMigrationPre1, err)
}
case target == -1:
p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), "pre1")
default:
p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
}
for _, migration := range migrations {
if prior == -1 && migration.Version == 1 {
// Skip migration version 1 when upgrading from pre1 as it's applied as part of the pre1 upgrade.
continue
}
err = p.schemaMigrateApply(ctx, migration)
if err != nil {
return p.schemaMigrateRollback(ctx, prior, migration.After(), err)
}
}
switch {
case prior == -1:
p.log.Infof(logFmtMigrationComplete, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
case target == -1:
err = p.schemaMigrate1ToPre1(ctx)
if err != nil {
if errRollback := p.schemaMigratePre1To1Rollback(ctx, false); errRollback != nil {
return fmt.Errorf(errFmtFailedMigrationPre1, err)
}
return fmt.Errorf(errFmtFailedMigrationPre1, err)
}
p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), "pre1")
default:
p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
}
return nil
}
func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after int, migrateErr error) (err error) {
migrations, err := loadMigrations(p.name, after, prior)
if err != nil {
return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %+v", prior, after, err, migrateErr)
}
for _, migration := range migrations {
if prior == -1 && !migration.Up && migration.Version == 1 {
continue
}
err = p.schemaMigrateApply(ctx, migration)
if err != nil {
return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %+v", migration.Before(), migration.After(), err, migrateErr)
}
}
if prior == -1 {
if err = p.schemaMigrate1ToPre1(ctx); err != nil {
return fmt.Errorf("error applying migration version 1 to version pre1 for rollback: %+v. rollback caused by: %+v", err, migrateErr)
}
}
return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr)
}
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration SchemaMigration) (err error) {
_, err = p.db.ExecContext(ctx, migration.Query)
if err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
// Skip the migration history insertion in a migration to v0.
if migration.Version == 1 && !migration.Up {
return nil
}
return p.schemaMigrateFinalize(ctx, migration)
}
func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration SchemaMigration) (err error) {
return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
}
func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, after int) (err error) {
_, err = p.db.ExecContext(ctx, p.sqlInsertMigration, time.Now(), before, after, utils.Version())
if err != nil {
return err
}
p.log.Debugf("Storage schema migrated from version %d to %d", before, after)
return nil
}
// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error) {
current, err := p.SchemaVersion(ctx)
if err != nil {
return migrations, err
}
if version == 0 {
version = SchemaLatest
}
if current >= version {
return migrations, ErrNoAvailableMigrations
}
return loadMigrations(p.name, current, version)
}
// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error) {
current, err := p.SchemaVersion(ctx)
if err != nil {
return migrations, err
}
if current <= version {
return migrations, ErrNoAvailableMigrations
}
return loadMigrations(p.name, current, version)
}
// SchemaLatestVersion returns the latest version available for migration..
func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
return latestMigrationVersion(p.name)
}
func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) {
if targetVersion == currentVersion {
return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion)
}
latest, err := latestMigrationVersion(providerName)
if err != nil {
return err
}
if currentVersion > latest {
return fmt.Errorf(errFmtSchemaCurrentGreaterThanLatestKnown, latest)
}
if up {
if targetVersion < currentVersion {
return fmt.Errorf(ErrFmtMigrateUpTargetLessThanCurrent, targetVersion, currentVersion)
}
if targetVersion == SchemaLatest && latest == currentVersion {
return ErrSchemaAlreadyUpToDate
}
if targetVersion != SchemaLatest && latest < targetVersion {
return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest)
}
} else {
if targetVersion < -1 {
return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion)
}
if targetVersion > currentVersion {
return fmt.Errorf(ErrFmtMigrateDownTargetGreaterThanCurrent, targetVersion, currentVersion)
}
}
return nil
}
// SchemaVersionToString returns a version string given a version number.
func SchemaVersionToString(version int) (versionStr string) {
switch version {
case -2:
return "unknown"
case -1:
return "pre1"
case 0:
return "N/A"
default:
return strconv.Itoa(version)
}
}

View File

@ -0,0 +1,449 @@
package storage
import (
"context"
"database/sql"
"encoding/base64"
"fmt"
"strings"
"time"
"github.com/authelia/authelia/v4/internal/models"
"github.com/authelia/authelia/v4/internal/utils"
)
// schemaMigratePre1To1 takes the v1 migration and migrates to this version.
func (p *SQLProvider) schemaMigratePre1To1(ctx context.Context) (err error) {
migration, err := loadMigration(p.name, 1, true)
if err != nil {
return err
}
// Get Tables list.
tables, err := p.SchemaTables(ctx)
if err != nil {
return err
}
tablesRename := []string{
tablePre1Config,
tablePre1TOTPSecrets,
tablePre1IdentityVerificationTokens,
tableU2FDevices,
tableUserPreferences,
tableAuthenticationLogs,
tableAlphaPreferences,
tableAlphaIdentityVerificationTokens,
tableAlphaAuthenticationLogs,
tableAlphaPreferencesTableName,
tableAlphaSecondFactorPreferences,
tableAlphaTOTPSecrets,
tableAlphaU2FDeviceHandles,
}
if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil {
return err
}
if _, err = p.db.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect),
tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil {
return err
}
if err = p.schemaMigratePre1To1AuthenticationLogs(ctx); err != nil {
return err
}
if err = p.schemaMigratePre1To1U2F(ctx); err != nil {
return err
}
if err = p.schemaMigratePre1To1TOTP(ctx); err != nil {
return err
}
for _, table := range tablesRename {
if _, err = p.db.Exec(fmt.Sprintf(p.db.Rebind(queryFmtDropTableIfExists), tablePrefixBackup+table)); err != nil {
return err
}
}
return p.schemaMigrateFinalizeAdvanced(ctx, -1, 1)
}
func (p *SQLProvider) schemaMigratePre1Rename(ctx context.Context, tables, tablesRename []string) (err error) {
// Rename Tables and Indexes.
for _, table := range tables {
if !utils.IsStringInSlice(table, tablesRename) {
continue
}
tableNew := tablePrefixBackup + table
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil {
return err
}
if p.name == providerPostgres {
if table == tableU2FDevices || table == tableUserPreferences {
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`,
tableNew, table, tableNew)); err != nil {
continue
}
}
}
}
return nil
}
func (p *SQLProvider) schemaMigratePre1To1Rollback(ctx context.Context, up bool) (err error) {
if up {
migration, err := loadMigration(p.name, 1, false)
if err != nil {
return err
}
if _, err = p.db.ExecContext(ctx, migration.Query); err != nil {
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
}
}
tables, err := p.SchemaTables(ctx)
if err != nil {
return err
}
for _, table := range tables {
if !strings.HasPrefix(table, tablePrefixBackup) {
continue
}
tableNew := strings.Replace(table, tablePrefixBackup, "", 1)
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil {
return err
}
if p.name == providerPostgres && (tableNew == tableU2FDevices || tableNew == tableUserPreferences) {
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`,
tableNew, table, tableNew)); err != nil {
continue
}
}
}
return nil
}
func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogs(ctx context.Context) (err error) {
for page := 0; true; page++ {
attempts, err := p.schemaMigratePre1To1AuthenticationLogsGetRows(ctx, page)
if err != nil {
if err == sql.ErrNoRows {
break
}
return err
}
for _, attempt := range attempts {
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time)
if err != nil {
return err
}
}
if len(attempts) != 100 {
break
}
}
return nil
}
func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []models.AuthenticationAttempt, err error) {
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100)
if err != nil {
return nil, err
}
attempts = make([]models.AuthenticationAttempt, 0, 100)
for rows.Next() {
var (
username string
successful bool
timestamp int64
)
err = rows.Scan(&username, &successful, &timestamp)
if err != nil {
return nil, err
}
attempts = append(attempts, models.AuthenticationAttempt{Username: username, Successful: successful, Time: time.Unix(timestamp, 0)})
}
return attempts, nil
}
func (p *SQLProvider) schemaMigratePre1To1TOTP(ctx context.Context) (err error) {
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tablePre1TOTPSecrets))
if err != nil {
return err
}
var totpConfigs []models.TOTPConfiguration
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
for rows.Next() {
var username, secret string
err = rows.Scan(&username, &secret)
if err != nil {
return err
}
// TODO: Add encryption migration here.
encryptedSecret := "encrypted:" + secret
totpConfigs = append(totpConfigs, models.TOTPConfiguration{Username: username, Secret: encryptedSecret})
}
for _, config := range totpConfigs {
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tableTOTPConfigurations), config.Username, config.Secret)
if err != nil {
return err
}
}
return nil
}
func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) {
rows, err := p.db.Queryx(fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectU2FDevices), tablePrefixBackup+tableU2FDevices))
if err != nil {
return err
}
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
var devices []models.U2FDevice
for rows.Next() {
var username, keyHandleBase64, publicKeyBase64 string
err = rows.Scan(&username, &keyHandleBase64, &publicKeyBase64)
if err != nil {
return err
}
keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
if err != nil {
return err
}
publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
if err != nil {
return err
}
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey})
}
for _, device := range devices {
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertU2FDevice), tableU2FDevices), device.Username, device.KeyHandle, device.PublicKey)
if err != nil {
return err
}
}
return nil
}
func (p *SQLProvider) schemaMigrate1ToPre1(ctx context.Context) (err error) {
tables, err := p.SchemaTables(ctx)
if err != nil {
return err
}
tablesRename := []string{
tableMigrations,
tableTOTPConfigurations,
tableIdentityVerification,
tableU2FDevices,
tableDUODevices,
tableUserPreferences,
tableAuthenticationLogs,
}
if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil {
return err
}
if _, err := p.db.ExecContext(ctx, queryCreatePre1); err != nil {
return err
}
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect),
tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil {
return err
}
if err = p.schemaMigrate1ToPre1AuthenticationLogs(ctx); err != nil {
return err
}
if err = p.schemaMigrate1ToPre1U2F(ctx); err != nil {
return err
}
if err = p.schemaMigrate1ToPre1TOTP(ctx); err != nil {
return err
}
queryFmtDropTableRebound := p.db.Rebind(queryFmtDropTableIfExists)
for _, table := range tablesRename {
if _, err = p.db.Exec(fmt.Sprintf(queryFmtDropTableRebound, tablePrefixBackup+table)); err != nil {
return err
}
}
return nil
}
func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogs(ctx context.Context) (err error) {
for page := 0; true; page++ {
attempts, err := p.schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx, page)
if err != nil {
if err == sql.ErrNoRows {
break
}
return err
}
for _, attempt := range attempts {
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time.Unix())
if err != nil {
return err
}
}
if len(attempts) != 100 {
break
}
}
return nil
}
func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []models.AuthenticationAttempt, err error) {
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100)
if err != nil {
return nil, err
}
attempts = make([]models.AuthenticationAttempt, 0, 100)
var attempt models.AuthenticationAttempt
for rows.Next() {
err = rows.StructScan(&attempt)
if err != nil {
return nil, err
}
attempts = append(attempts, attempt)
}
return attempts, nil
}
func (p *SQLProvider) schemaMigrate1ToPre1TOTP(ctx context.Context) (err error) {
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tableTOTPConfigurations))
if err != nil {
return err
}
var totpConfigs []models.TOTPConfiguration
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
for rows.Next() {
var username, encryptedSecret string
err = rows.Scan(&username, &encryptedSecret)
if err != nil {
return err
}
// TODO: Fix.
// TODO: Add DECRYPTION migration here.
decryptedSecret := strings.Replace(encryptedSecret, "encrypted:", "", 1)
totpConfigs = append(totpConfigs, models.TOTPConfiguration{Username: username, Secret: decryptedSecret})
}
for _, config := range totpConfigs {
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tablePre1TOTPSecrets), config.Username, config.Secret)
if err != nil {
return err
}
}
return nil
}
func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) {
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectU2FDevices), tablePrefixBackup+tableU2FDevices))
if err != nil {
return err
}
defer func() {
if err := rows.Close(); err != nil {
p.log.Errorf(logFmtErrClosingConn, err)
}
}()
var (
devices []models.U2FDevice
device models.U2FDevice
)
for rows.Next() {
err = rows.StructScan(&device)
if err != nil {
return err
}
devices = append(devices, device)
}
for _, device := range devices {
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertU2FDevice), tableU2FDevices), device.Username, base64.StdEncoding.EncodeToString(device.KeyHandle), base64.StdEncoding.EncodeToString(device.PublicKey))
if err != nil {
return err
}
}
return nil
}

View File

@ -0,0 +1,134 @@
package storage
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestShouldReturnErrOnTargetSameAsCurrent(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, true, 1, 1),
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, false, 1, 1),
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, false, 2, 2),
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 2, 2))
assert.EqualError(t,
schemaMigrateChecks(providerMySQL, false, 1, 1),
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, false, 1, 1),
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
}
func TestShouldReturnErrOnUpMigrationTargetVersionLessTHanCurrent(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, true, 0, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion))
assert.NoError(t,
schemaMigrateChecks(providerPostgres, true, testLatestVersion, 0))
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, true, 0, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion))
assert.NoError(t,
schemaMigrateChecks(providerSQLite, true, testLatestVersion, 0))
assert.EqualError(t,
schemaMigrateChecks(providerMySQL, true, 0, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion))
assert.NoError(t,
schemaMigrateChecks(providerMySQL, true, testLatestVersion, 0))
}
func TestMigrationUpShouldReturnErrOnAlreadyLatest(t *testing.T) {
assert.Equal(t,
ErrSchemaAlreadyUpToDate,
schemaMigrateChecks(providerPostgres, true, SchemaLatest, testLatestVersion))
assert.Equal(t,
ErrSchemaAlreadyUpToDate,
schemaMigrateChecks(providerMySQL, true, SchemaLatest, testLatestVersion))
assert.Equal(t,
ErrSchemaAlreadyUpToDate,
schemaMigrateChecks(providerSQLite, true, SchemaLatest, testLatestVersion))
}
func TestShouldReturnErrOnVersionDoesntExits(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, true, SchemaLatest-1, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion))
assert.EqualError(t,
schemaMigrateChecks(providerMySQL, true, SchemaLatest-1, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion))
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, true, SchemaLatest-1, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion))
}
func TestMigrationDownShouldReturnErrOnTargetLessThanPre1(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, false, -4, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -4))
assert.EqualError(t,
schemaMigrateChecks(providerMySQL, false, -2, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2))
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, false, -2, testLatestVersion),
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2))
assert.NoError(t,
schemaMigrateChecks(providerPostgres, false, -1, testLatestVersion))
}
func TestMigrationDownShouldReturnErrOnTargetVersionGreaterThanCurrent(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, false, testLatestVersion, 0),
fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0))
assert.EqualError(t,
schemaMigrateChecks(providerMySQL, false, testLatestVersion, 0),
fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0))
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, false, testLatestVersion, 0),
fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0))
}
func TestShouldReturnErrWhenCurrentIsGreaterThanLatest(t *testing.T) {
assert.EqualError(t,
schemaMigrateChecks(providerPostgres, true, SchemaLatest-4, SchemaLatest-5),
fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion))
assert.EqualError(t,
schemaMigrateChecks(providerMySQL, true, SchemaLatest-4, SchemaLatest-5),
fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion))
assert.EqualError(t,
schemaMigrateChecks(providerSQLite, true, SchemaLatest-4, SchemaLatest-5),
fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion))
}
func TestSchemaVersionToString(t *testing.T) {
assert.Equal(t, "unknown", SchemaVersionToString(-2))
assert.Equal(t, "pre1", SchemaVersionToString(-1))
assert.Equal(t, "N/A", SchemaVersionToString(0))
assert.Equal(t, "1", SchemaVersionToString(1))
assert.Equal(t, "2", SchemaVersionToString(2))
}

View File

@ -1,400 +0,0 @@
package storage
import (
"database/sql/driver"
"encoding/base64"
"fmt"
"sort"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/authelia/authelia/v4/internal/authentication"
"github.com/authelia/authelia/v4/internal/models"
)
const currentSchemaMockSchemaVersion = "1"
func TestSQLInitializeDatabase(t *testing.T) {
provider, mock := NewSQLMockProvider()
rows := sqlmock.NewRows([]string{"name"})
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(rows)
mock.ExpectBegin()
keys := make([]string, 0, len(sqlUpgradeCreateTableStatements[1]))
for k := range sqlUpgradeCreateTableStatements[1] {
keys = append(keys, k)
}
sort.Strings(keys)
for _, table := range keys {
mock.ExpectExec(
fmt.Sprintf("CREATE TABLE %s .*", table)).
WillReturnResult(sqlmock.NewResult(0, 0))
}
mock.ExpectExec(
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
WithArgs("schema", "version", "1").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := provider.initialize(provider.db)
assert.NoError(t, err)
}
func TestSQLUpgradeDatabase(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName))
mock.ExpectBegin()
mock.ExpectExec(
fmt.Sprintf("CREATE TABLE %s .*", configTableName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
WithArgs("schema", "version", "1").
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := provider.initialize(provider.db)
assert.NoError(t, err)
}
func TestSQLProviderMethodsAuthenticationLogs(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow("1"))
err := provider.initialize(provider.db)
assert.NoError(t, err)
attempts := []models.AuthenticationAttempt{
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880001, 0)},
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880002, 0)},
{Username: unitTestUser, Successful: false, Time: time.Unix(1577880003, 0)},
}
rows := sqlmock.NewRows([]string{"successful", "time"})
for id, attempt := range attempts {
args = []driver.Value{attempt.Username, attempt.Successful, attempt.Time.Unix()}
mock.ExpectExec(
fmt.Sprintf("INSERT INTO %s \\(username, successful, time\\) VALUES \\(\\?, \\?, \\?\\)", authenticationLogsTableName)).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(int64(id), 1))
err := provider.AppendAuthenticationLog(attempt)
assert.NoError(t, err)
rows.AddRow(attempt.Successful, attempt.Time.Unix())
}
args = []driver.Value{1577880000, unitTestUser}
mock.ExpectQuery(
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
WithArgs(args...).
WillReturnRows(rows)
after := time.Unix(1577880000, 0)
results, err := provider.LoadLatestAuthenticationLogs(unitTestUser, after)
assert.NoError(t, err)
require.Len(t, results, 3)
assert.Equal(t, unitTestUser, results[0].Username)
assert.Equal(t, true, results[0].Successful)
assert.Equal(t, time.Unix(1577880001, 0), results[0].Time)
assert.Equal(t, unitTestUser, results[1].Username)
assert.Equal(t, true, results[1].Successful)
assert.Equal(t, time.Unix(1577880002, 0), results[1].Time)
assert.Equal(t, unitTestUser, results[2].Username)
assert.Equal(t, false, results[2].Successful)
assert.Equal(t, time.Unix(1577880003, 0), results[2].Time)
// Test Blank Rows.
mock.ExpectQuery(
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"successful", "time"}))
results, err = provider.LoadLatestAuthenticationLogs(unitTestUser, after)
assert.NoError(t, err)
assert.Len(t, results, 0)
}
func TestSQLProviderMethodsPreferred(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(username, second_factor_method\\) VALUES \\(\\?, \\?\\)", userPreferencesTableName)).
WithArgs(unitTestUser, authentication.TOTP).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.SavePreferred2FAMethod(unitTestUser, authentication.TOTP)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
WithArgs(unitTestUser).
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}).AddRow(authentication.TOTP))
method, err := provider.LoadPreferred2FAMethod(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, authentication.TOTP, method)
// Test Blank Rows.
mock.ExpectQuery(
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
WithArgs(unitTestUser).
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}))
method, err = provider.LoadPreferred2FAMethod(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, "", method)
}
func TestSQLProviderMethodsTOTP(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
pretendSecret := "abc123"
args = []driver.Value{unitTestUser, pretendSecret}
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(username, secret\\) VALUES \\(\\?, \\?\\)", totpSecretsTableName)).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.SaveTOTPSecret(unitTestUser, pretendSecret)
assert.NoError(t, err)
args = []driver.Value{unitTestUser}
mock.ExpectQuery(
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"secret"}).AddRow(pretendSecret))
secret, err := provider.LoadTOTPSecret(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, pretendSecret, secret)
mock.ExpectExec(
fmt.Sprintf("DELETE FROM %s WHERE username=\\?", totpSecretsTableName)).
WithArgs(unitTestUser).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.DeleteTOTPSecret(unitTestUser)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"secret"}))
// Test Blank Rows
secret, err = provider.LoadTOTPSecret(unitTestUser)
assert.EqualError(t, err, "no TOTP secret registered")
assert.Equal(t, "", secret)
}
func TestSQLProviderMethodsU2F(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
pretendKeyHandle := []byte("abc")
pretendPublicKey := []byte("123")
pretendKeyHandleB64 := base64.StdEncoding.EncodeToString(pretendKeyHandle)
pretendPublicKeyB64 := base64.StdEncoding.EncodeToString(pretendPublicKey)
args = []driver.Value{unitTestUser, pretendKeyHandleB64, pretendPublicKeyB64}
mock.ExpectExec(
fmt.Sprintf("REPLACE INTO %s \\(username, keyHandle, publicKey\\) VALUES \\(\\?, \\?, \\?\\)", u2fDeviceHandlesTableName)).
WithArgs(args...).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.SaveU2FDeviceHandle(unitTestUser, pretendKeyHandle, pretendPublicKey)
assert.NoError(t, err)
args = []driver.Value{unitTestUser}
mock.ExpectQuery(
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}).
AddRow(pretendKeyHandleB64, pretendPublicKeyB64))
keyHandle, publicKey, err := provider.LoadU2FDeviceHandle(unitTestUser)
assert.NoError(t, err)
assert.Equal(t, pretendKeyHandle, keyHandle)
assert.Equal(t, pretendPublicKey, publicKey)
// Test Blank Rows.
mock.ExpectQuery(
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}))
keyHandle, publicKey, err = provider.LoadU2FDeviceHandle(unitTestUser)
assert.EqualError(t, err, "no U2F device handle found")
assert.Equal(t, []byte(nil), keyHandle)
assert.Equal(t, []byte(nil), publicKey)
}
func TestSQLProviderMethodsIdentityVerificationTokens(t *testing.T) {
provider, mock := NewSQLMockProvider()
mock.ExpectQuery(
"SELECT name FROM sqlite_master WHERE type='table'").
WillReturnRows(sqlmock.NewRows([]string{"name"}).
AddRow(userPreferencesTableName).
AddRow(identityVerificationTokensTableName).
AddRow(totpSecretsTableName).
AddRow(u2fDeviceHandlesTableName).
AddRow(authenticationLogsTableName).
AddRow(configTableName))
args := []driver.Value{"schema", "version"}
mock.ExpectQuery(
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
WithArgs(args...).
WillReturnRows(sqlmock.NewRows([]string{"value"}).
AddRow(currentSchemaMockSchemaVersion))
err := provider.initialize(provider.db)
assert.NoError(t, err)
fakeIdentityVerificationToken := "abc"
mock.ExpectExec(
fmt.Sprintf("INSERT INTO %s \\(token\\) VALUES \\(\\?\\)", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnResult(sqlmock.NewResult(1, 1))
err = provider.SaveIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
AddRow(true))
valid, err := provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
assert.True(t, valid)
mock.ExpectExec(
fmt.Sprintf("DELETE FROM %s WHERE token=\\?", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnResult(sqlmock.NewResult(0, 1))
err = provider.RemoveIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
mock.ExpectQuery(
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
WithArgs(fakeIdentityVerificationToken).
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
AddRow(false))
valid, err = provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
assert.NoError(t, err)
assert.False(t, valid)
}

View File

@ -1,58 +0,0 @@
package storage
import (
"database/sql"
"fmt"
_ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string.
)
// SQLiteProvider is a SQLite3 provider.
type SQLiteProvider struct {
SQLProvider
}
// NewSQLiteProvider constructs a SQLite provider.
func NewSQLiteProvider(path string) *SQLiteProvider {
provider := SQLiteProvider{
SQLProvider{
name: "sqlite",
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements,
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName),
sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName),
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName),
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName),
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName),
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),
sqlGetExistingTables: "SELECT name FROM sqlite_master WHERE type='table'",
sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName),
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName),
},
}
db, err := sql.Open("sqlite3", path)
if err != nil {
provider.log.Fatalf("Unable to create SQL database %s: %s", path, err)
}
if err := provider.initialize(db); err != nil {
provider.log.Fatalf("Unable to initialize SQL database %s: %s", path, err)
}
return &provider
}

View File

@ -1,60 +0,0 @@
package storage
import (
"fmt"
"github.com/DATA-DOG/go-sqlmock"
)
// SQLMockProvider is a SQLMock provider.
type SQLMockProvider struct {
SQLProvider
}
// NewSQLMockProvider constructs a SQLMock provider.
func NewSQLMockProvider() (*SQLMockProvider, sqlmock.Sqlmock) {
provider := SQLMockProvider{
SQLProvider{
name: "sqlmock",
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements,
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName),
sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName),
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName),
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName),
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName),
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName),
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),
sqlGetExistingTables: "SELECT name FROM sqlite_master WHERE type='table'",
sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName),
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName),
},
}
db, mock, err := sqlmock.New()
if err != nil {
provider.log.Fatalf("Unable to create SQL database: %s", err)
}
provider.db = db
/*
We do initialize in the tests rather than in the new up.
*/
return &provider, mock
}

View File

@ -1,18 +1,28 @@
package storage
import (
"database/sql"
"strconv"
)
// SchemaVersion is a simple int representation of the schema version.
type SchemaVersion int
// ToString converts the schema version into a string and returns that converted value.
func (s SchemaVersion) ToString() string {
return strconv.Itoa(int(s))
// SchemaMigration represents an intended migration.
type SchemaMigration struct {
Version int
Name string
Provider string
Up bool
Query string
}
type transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
// Before returns the version the schema should be at Before the migration is applied.
func (m SchemaMigration) Before() (before int) {
if m.Up {
return m.Version - 1
}
return m.Version
}
// After returns the version the schema will be at After the migration is applied.
func (m SchemaMigration) After() (after int) {
if m.Up {
return m.Version
}
return m.Version - 1
}

View File

@ -1,76 +0,0 @@
package storage
import (
"fmt"
"sort"
"github.com/authelia/authelia/v4/internal/utils"
)
func (p *SQLProvider) upgradeCreateTableStatements(tx transaction, statements map[string]string, existingTables []string) error {
keys := make([]string, 0, len(statements))
for k := range statements {
keys = append(keys, k)
}
sort.Strings(keys)
for _, table := range keys {
if !utils.IsStringInSlice(table, existingTables) {
_, err := tx.Exec(fmt.Sprintf(statements[table], table))
if err != nil {
return fmt.Errorf("unable to create table %s: %v", table, err)
}
}
}
return nil
}
func (p *SQLProvider) upgradeRunMultipleStatements(tx transaction, statements []string) error {
for _, statement := range statements {
_, err := tx.Exec(statement)
if err != nil {
return err
}
}
return nil
}
// upgradeFinalize sets the schema version and logs a message, as well as any other future finalization tasks.
func (p *SQLProvider) upgradeFinalize(tx transaction, version SchemaVersion) error {
_, err := tx.Exec(p.sqlConfigSetValue, "schema", "version", version.ToString())
if err != nil {
return err
}
p.log.Debugf("%s%d", storageSchemaUpgradeMessage, version)
return nil
}
// upgradeSchemaToVersion001 upgrades the schema to version 1.
func (p *SQLProvider) upgradeSchemaToVersion001(tx transaction, tables []string) error {
version := SchemaVersion(1)
err := p.upgradeCreateTableStatements(tx, p.sqlUpgradesCreateTableStatements[version], tables)
if err != nil {
return err
}
// Skip mysql create index statements. It doesn't support CREATE INDEX IF NOT EXIST. May be able to work around this with an Index struct.
if p.name != "mysql" {
err = p.upgradeRunMultipleStatements(tx, p.sqlUpgradesCreateTableIndexesStatements[1])
if err != nil {
return fmt.Errorf("unable to create index: %v", err)
}
}
err = p.upgradeFinalize(tx, version)
if err != nil {
return err
}
return nil
}

View File

@ -122,7 +122,9 @@ func (s *StandaloneWebDriverSuite) TestShouldCheckUserIsAskedToRegisterDevice()
// Clean up any TOTP secret already in DB.
provider := storage.NewSQLiteProvider("/tmp/db.sqlite3")
require.NoError(s.T(), provider.DeleteTOTPSecret(username))
require.NoError(s.T(), provider.StartupCheck())
require.NoError(s.T(), provider.DeleteTOTPConfiguration(ctx, username))
// Login one factor.
s.doLoginOneFactor(s.T(), s.Context(ctx), username, password, false, "")